In [None]:
"""
Colab Notebook Set Up

Use this cell to upload your kaggle.json file as well as the `download_data.sh`
and `preprocess.py` scripts.
"""

from google.colab import files, drive
import os

# Mount google drive
drive.mount('/content/drive')

# Upload files
kaggle = files.upload()
data_script = files.upload()
preprocess_script = files.upload()
config = files.upload()

# Verify uploads
for file in ["kaggle.json", "download_data.sh", "preprocess.py", "oct.yaml"]:
    assert file in os.listdir(), f"Make sure you upload the {file} file"

# Shell commands (files)
!mkdir -p ~/.kaggle/ data/ models/ config/ scripts/ net/
!mv kaggle.json ~/.kaggle/
!mv download_data.sh preprocess.py scripts/
!mv oct.yaml config/
!chmod 600 ~/.kaggle/kaggle.json
!chmod +x scripts/download_data.sh scripts/preprocess.py
!sed -i -e 's/\r$//' scripts/download_data.sh
!pip install -q kaggle pretrainedmodels rich
!touch net/__init__.py net/train.py net/utils.py

# Run shell commands
!scripts/download_data.sh
!python scripts/preprocess.py --config config/oct.yaml --kw batch-size=64

In [None]:
%load_ext tensorboard
%load_ext autoreload
%autoreload 2

%tensorboard --logdir logs

## Finetuning InceptionV3 for Retinal OCT Images

### Context

- Retinal Optical Coherence Tomography (OCT) is an imaging technique used to
  capture high-res cross sections of the retina
- ~84, 495 OCT Images in total

### Content

- Images in JPEG format with 3 channels, i.e., RGB
- 4 categories: CNV, DME, DRUSEN, NORMAL

### This Notebook

- Fine-tune InceptionV3 by training the last, linear layer on the new data
- The images are pre-processed by running the forward pass through the
  InceptionV3 network and saving the output of the last pooling layer
  (2048-dimensional vector) to disk.
  - These feature vectors are then used to train a single-layer linear
    classifier on the new data

In [None]:
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
from torch import nn, optim
import torch

from datetime import datetime
from tqdm.auto import tqdm
from rich import print
from glob import glob

import sys; sys.path.append(".")
from net import train, utils

import numpy as np
import copy
import yaml
import os
import re


config = "config/oct.yaml"
with open(config, 'r') as f:
    config = yaml.safe_load(f)

In [None]:
class FeatureDataset(Dataset):
    """
    Custom dataset to load InceptionV3 features for fine-tuning
    """

    def __init__(self, feature_dir):
        self.files = sorted(glob(f"{feature_dir}/FL*.npy"), key=self._extract_idx)

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        label, feature = np.load(self.files[idx], allow_pickle=True)
        return torch.from_numpy(feature), torch.from_numpy(label)

    def _extract_idx(self, filename):
        """Extract batch index from filename"""
        match = re.search(r"(\d+)\.npy$", filename)
        match = int(match.group(1)) if match else -1
        if match == -1:
            raise ValueError(f"Invalid filename {filename}")
        return match


# Create datasets
data = {
    "train": FeatureDataset(config["features"]["train"]),
    "test": FeatureDataset(config["features"]["test"]),
    "val": FeatureDataset(config["features"]["val"])
}

# Create dataloaders
batch_size = 1024
dataloaders = {
    "train": DataLoader(data["train"], batch_size=batch_size, shuffle=True),
    "test": DataLoader(data["test"], batch_size=batch_size, shuffle=False),
    "val": DataLoader(data["val"], batch_size=batch_size, shuffle=False)
}

In [None]:
class FineTuned(nn.Module):
    """Fine-tuned output layer for InceptionV3"""

    def __init__(self, config):
        super(FineTuned, self).__init__()
        self.fc = nn.Linear(2048, config["num_classes"])

    def forward(self, x):
        x = self.fc(x)
        return x
    
model = FineTuned(config)
model.to(train.device())

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=config["lr"])
# optimizer = optim.SGD(model.parameters(), lr=config["lr"], momentum=0.9)

In [None]:
# Initialize training and logging
timestamp = datetime.now().strftime("%m%d_%H%M%S")
writer = SummaryWriter(f"logs/inception/oct-{timestamp}")
EPOCH = 0

In [None]:
num_epochs = 10
EPOCH += num_epochs

model, history = train.train_model(
    model, dataloaders, criterion, optimizer, num_epochs, EPOCH, writer,
    is_inception=False
)

torch.save(model.state_dict(), f"models/oct-preprocess-{timestamp}.pth")