In [None]:
import torch
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision import transforms
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

### Loading the Augmented Dataset

In [1]:
from torchvision.transforms import v2

In [5]:
transform = v2.Compose([
    v2.Resize((224, 224)),
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=(0.4203, 0.2800, 0.1714),
                std=(0.2932, 0.2165, 0.1632))
    ])

In [9]:
transform_augment = v2.Compose([
    v2.Resize((224,224)),
    v2.RandomRotation(degrees=15, expand=False, fill=0),
    v2.ColorJitter(brightness=0.1, contrast=0.1),
    v2.RandomHorizontalFlip(),
    v2.ToImage(),                                 # convert PIL â†’ tensor
    v2.ToDtype(torch.float32, scale=True),        # scale to [0,1]
    v2.Normalize(mean=(0.4203, 0.2800, 0.1714),
                std=(0.2932, 0.2165, 0.1632)),
    v2.GaussianNoise(mean=0.0, sigma=0.01, clip=True)])

In [12]:
path_to_dataset = "/kaggle/input/eye-diseases-classification/dataset"
dataset= ImageFolder(root=path_to_dataset, transform=None)
dataloader = DataLoader(dataset=dataset, batch_size=64, shuffle=True)

FileNotFoundError: [Errno 2] No such file or directory: '/kaggle/input/eye-diseases-classification/dataset'

### Separating train-validation-test Dataset

In [13]:
from sklearn.model_selection import train_test_split

In [None]:
print(f"Total number of data items: {len(dataset)}")

In [None]:
indices = np.arange(len(dataset)) # represent each dataset item with an index for easier splitting
labels = dataset.targets # labels of corresponding indexed images

In [None]:
train_idx, test_idx, y_train, y_test = train_test_split(indices, labels, train_size=0.8, stratify=labels, random_state=42)

In [None]:
print(f"Train size = {train_idx.shape[0]}")
print(f"Test size = {test_idx.shape[0]}")

- Below: Using `torch.utils.data.Subset()` for getting the final splitted dataset

In [None]:
train_dataset_original = torch.utils.data.Subset(dataset=dataset, indices=train_idx)
train_dataset_augment = torch.utils.data.Subset(dataset=dataset, indices=train_idx)
# validation_dataset = torch.utils.data.Subset(dataset=dataset, indices=val_idx)
test_dataset = torch.utils.data.Subset(dataset=dataset, indices=test_idx)

In [None]:
from copy import copy
# all subsets now point to different parent dataset
train_dataset_original.dataset = copy(dataset)
train_dataset_augment.dataset = copy(dataset)
test_dataset.dataset = copy(dataset)
# now individual transform can be applied
train_dataset_augment.dataset.transform = transform_augment
train_dataset_original.dataset.transform = transform
test_dataset.dataset.transform = transform
# validation_dataset.dataset.transform = transform

In [None]:
train_dataset_complete = torch.utils.data.ConcatDataset([train_dataset_augment, train_dataset_original])

In [None]:
len(train_dataset_complete)

### Training the ResNet model

In [None]:
class ResBlock(nn.Module):
    '''A resnet block with skip connection'''
    def __init__(self, in_channels:int, out_channels:int, stride:int):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=stride, bias=False)
        self.batchnorm1 = nn.BatchNorm2d(num_features=out_channels)
        # self.relu1 --> non-learnable torch.functional kept in forward() method
        self.conv2 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, padding=1, stride=1, bias=False)
        self.batchnorm2 = nn.BatchNorm2d(num_features=out_channels)

        self.shortcut = nn.Sequential() # in case of residual-output dimension mismatch
        if (in_channels!=out_channels or stride!=1):
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(num_features=out_channels)
            )

    def forward(self, x:torch.tensor)->torch.tensor:
        out = torch.relu(self.batchnorm1(self.conv1(x)))
        out = self.batchnorm2(self.conv2(out))
        out += self.shortcut(x)
        out = torch.relu(out)
        return out


In [None]:
class ResNet18(nn.Module):
    '''A ResNet18 model'''
    def __init__(self, num_classes:int=4):
        super().__init__()
        self.in_channels=64
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=7, stride=2, padding=3, bias=False)
        self.batchnorm1 = nn.BatchNorm2d(num_features=64)
        # relu, maxpool, in forward() method

        self.layer1 = self.make_blocks(ResBlock, 64, 2, 1)
        self.layer2 = self.make_blocks(ResBlock, 128, 2, 2)
        self.layer3 = self.make_blocks(ResBlock, 256, 2, 2)
        self.layer4 = self.make_blocks(ResBlock, 512, 2, 2)

        self.avg_pool = nn.AdaptiveAvgPool2d((1,1))
        self.fc = nn.Linear(512, num_classes)

    def make_blocks(self, block:ResBlock, out_channels:int, num_blocks:int, stride:int):
        '''make a residual block'''
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for i in strides:
            layers.append(block(self.in_channels, out_channels, stride=i))
            self.in_channels=out_channels
        return nn.Sequential(*layers)

    def forward(self, x:torch.tensor)->torch.tensor:
        out = self.batchnorm1(self.conv1(x))
        out = F.max_pool2d(torch.relu(out), 2)
        out = F.dropout(out, p=0.1, training=self.training)

        out = self.layer1(out)
        out = F.dropout(out, p=0.1, training=self.training)
        out = self.layer2(out)
        out = F.dropout(out, p=0.2, training=self.training)
        out = self.layer3(out)
        out = F.dropout(out, p=0.3, training=self.training)
        out = self.layer4(out)

        out = self.avg_pool(out) # returns tensor of shape (B,512,1,1)
        out = out.view(out.shape[0], -1) # turn into shape (no. of images in a batch, 512)
        out = F.dropout(out, p=0.5, training=self.training)
        out = self.fc(out)
        return out
