In [35]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models
from torchsummary import summary
from sklearn.metrics import accuracy_score, classification_report
import copy
import os
import timm

In [9]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [10]:
# Transforms (minimal since images are preprocessed)

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [12]:
# Load dataset
data_dir = '/kaggle/input/diabetic-retinopathy-224x224-gaussian-filtered/gaussian_filtered_images/gaussian_filtered_images'
dataset = datasets.ImageFolder(root = data_dir, transform=transform)
classes = dataset.classes
classes

['Mild', 'Moderate', 'No_DR', 'Proliferate_DR', 'Severe']

In [13]:
len(dataset)

3662

In [14]:

train_dataset, val_dataset = random_split(dataset, [0.8, 0.2])


In [15]:
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)


In [16]:
# Class weights for imbalance (optional but recommended)
class_counts = [len(os.listdir(os.path.join(data_dir, cls))) for cls in classes]
weights = torch.tensor([1.0 / c for c in class_counts], dtype=torch.float).to(device)

In [18]:
class EarlyStopping:
    def __init__(self, patience=5, delta=0, verbose=False, path='checkpoint.pt'):
        self.patience = patience
        self.delta = delta
        self.verbose = verbose
        self.path = path
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = float('inf')

    def __call__(self, val_loss, model):
        score = -val_loss
        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            if self.verbose:
                print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        if self.verbose:
            print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model...')
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss


In [28]:
def train_model(model, train_loader, val_loader, num_epochs=20, patience=5, checkpoint_path='checkpoint.pt'):
    model = model.to(device)
    criterion = nn.CrossEntropyLoss(weight=weights)  # Weighted loss for imbalance
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
    early_stopping = EarlyStopping(patience=patience, verbose=True, path=checkpoint_path)

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        # Validation
        model.eval()
        val_loss = 0.0
        corrects = 0
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                _, preds = torch.max(outputs, 1)
                corrects += torch.sum(preds == labels.data)

        val_loss /= len(val_loader)
        val_acc = corrects.double() / len(val_dataset)
        print(f'Epoch {epoch+1}/{num_epochs} - Val Loss: {val_loss:.4f} - Val Acc: {val_acc:.4f}')

        early_stopping(val_loss, model)
        if early_stopping.early_stop:
            print("Early stopping")
            break

        if val_acc > best_acc:
            best_acc = val_acc
            best_model_wts = copy.deepcopy(model.state_dict())

    model.load_state_dict(best_model_wts)
    return model, best_acc

def evaluate_model(model, loader, model_name):
    model.eval()
    preds, true = [], []
    with torch.no_grad():
        for inputs, labels in loader:
            inputs = inputs.to(device)
            outputs = model(inputs)
            _, pred = torch.max(outputs, 1)
            preds.extend(pred.cpu().numpy())
            true.extend(labels.cpu().numpy())
    acc = accuracy_score(true, preds)
    report = classification_report(true, preds, target_names=classes)
    print(f'Accuracy: {acc:.4f}\n{report}')
    return acc


In [29]:
models_dict = {}

In [30]:
# AlexNet
alexnet = models.alexnet(pretrained=True)
alexnet.classifier[6] = nn.Linear(alexnet.classifier[6].in_features, 5)
alexnet, alexnet_acc = train_model(alexnet, train_loader, val_loader, checkpoint_path='alexnet_checkpoint.pt')
evaluate_model(alexnet, val_loader, 'AlexNet')
models_dict['alexnet'] = (alexnet, alexnet_acc)


Epoch 1/20 - Val Loss: 1.6339 - Val Acc: 0.5724
Validation loss decreased (inf --> 1.633935). Saving model...
Epoch 2/20 - Val Loss: 1.4909 - Val Acc: 0.6885
Validation loss decreased (1.633935 --> 1.490932). Saving model...
Epoch 3/20 - Val Loss: 1.2991 - Val Acc: 0.4781
Validation loss decreased (1.490932 --> 1.299053). Saving model...
Epoch 4/20 - Val Loss: 1.3275 - Val Acc: 0.4973
EarlyStopping counter: 1 out of 5
Epoch 5/20 - Val Loss: 1.3203 - Val Acc: 0.5219
EarlyStopping counter: 2 out of 5
Epoch 6/20 - Val Loss: 1.2980 - Val Acc: 0.6708
Validation loss decreased (1.299053 --> 1.297974). Saving model...
Epoch 7/20 - Val Loss: 1.2478 - Val Acc: 0.7295
Validation loss decreased (1.297974 --> 1.247813). Saving model...
Epoch 8/20 - Val Loss: 1.2568 - Val Acc: 0.5587
EarlyStopping counter: 1 out of 5
Epoch 9/20 - Val Loss: 1.2326 - Val Acc: 0.5970
Validation loss decreased (1.247813 --> 1.232592). Saving model...
Epoch 10/20 - Val Loss: 1.2137 - Val Acc: 0.6448
Validation loss decr

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [31]:
# VGG16
vgg16 = models.vgg16(pretrained=True)
vgg16.classifier[6] = nn.Linear(vgg16.classifier[6].in_features, 5)
vgg16, vgg16_acc = train_model(vgg16, train_loader, val_loader, checkpoint_path='vgg16_checkpoint.pt')
evaluate_model(vgg16, val_loader, 'VGG16')
models_dict['vgg16'] = (vgg16, vgg16_acc)

Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth
100%|██████████| 528M/528M [00:02<00:00, 222MB/s]  


Epoch 1/20 - Val Loss: 1.3467 - Val Acc: 0.6557
Validation loss decreased (inf --> 1.346736). Saving model...
Epoch 2/20 - Val Loss: 1.6343 - Val Acc: 0.0779
EarlyStopping counter: 1 out of 5
Epoch 3/20 - Val Loss: 1.6224 - Val Acc: 0.0478
EarlyStopping counter: 2 out of 5
Epoch 4/20 - Val Loss: 1.6145 - Val Acc: 0.0779
EarlyStopping counter: 3 out of 5
Epoch 5/20 - Val Loss: 1.6126 - Val Acc: 0.5041
EarlyStopping counter: 4 out of 5
Epoch 6/20 - Val Loss: 1.6109 - Val Acc: 0.0779
EarlyStopping counter: 5 out of 5
Early stopping
Accuracy: 0.6557
                precision    recall  f1-score   support

          Mild       0.00      0.00      0.00        74
      Moderate       0.47      0.85      0.60       197
         No_DR       0.95      0.83      0.89       369
Proliferate_DR       0.00      0.00      0.00        57
        Severe       0.08      0.11      0.10        35

      accuracy                           0.66       732
     macro avg       0.30      0.36      0.32       73

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [32]:
# ResNet50
resnet = models.resnet50(pretrained=True)
resnet.fc = nn.Linear(resnet.fc.in_features, 5)
resnet, resnet_acc = train_model(resnet, train_loader, val_loader, checkpoint_path='resnet_checkpoint.pt')
evaluate_model(resnet, val_loader, 'ResNet50')
models_dict['resnet'] = (resnet, resnet_acc)

Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:01<00:00, 102MB/s] 


Epoch 1/20 - Val Loss: 1.1280 - Val Acc: 0.6393
Validation loss decreased (inf --> 1.128012). Saving model...
Epoch 2/20 - Val Loss: 1.2463 - Val Acc: 0.6462
EarlyStopping counter: 1 out of 5
Epoch 3/20 - Val Loss: 0.9871 - Val Acc: 0.6844
Validation loss decreased (1.128012 --> 0.987138). Saving model...
Epoch 4/20 - Val Loss: 0.9235 - Val Acc: 0.7063
Validation loss decreased (0.987138 --> 0.923460). Saving model...
Epoch 5/20 - Val Loss: 0.9207 - Val Acc: 0.6790
Validation loss decreased (0.923460 --> 0.920693). Saving model...
Epoch 6/20 - Val Loss: 0.9963 - Val Acc: 0.6735
EarlyStopping counter: 1 out of 5
Epoch 7/20 - Val Loss: 0.9055 - Val Acc: 0.6954
Validation loss decreased (0.920693 --> 0.905464). Saving model...
Epoch 8/20 - Val Loss: 0.8459 - Val Acc: 0.7814
Validation loss decreased (0.905464 --> 0.845933). Saving model...
Epoch 9/20 - Val Loss: 1.5457 - Val Acc: 0.7391
EarlyStopping counter: 1 out of 5
Epoch 10/20 - Val Loss: 1.0929 - Val Acc: 0.6667
EarlyStopping counte

In [36]:
 # EfficientNet-B0
efficientnet = timm.create_model('efficientnet_b0', pretrained=True, num_classes=5)
efficientnet, efficientnet_acc = train_model(efficientnet, train_loader, val_loader, checkpoint_path='efficientnet_checkpoint.pt')
evaluate_model(efficientnet, val_loader, 'EfficientNet-B0')
models_dict['efficientnet'] = (efficientnet, efficientnet_acc)

model.safetensors:   0%|          | 0.00/21.4M [00:00<?, ?B/s]

Epoch 1/20 - Val Loss: 1.1035 - Val Acc: 0.6434
Validation loss decreased (inf --> 1.103451). Saving model...
Epoch 2/20 - Val Loss: 0.9145 - Val Acc: 0.7814
Validation loss decreased (1.103451 --> 0.914461). Saving model...
Epoch 3/20 - Val Loss: 0.8745 - Val Acc: 0.6790
Validation loss decreased (0.914461 --> 0.874543). Saving model...
Epoch 4/20 - Val Loss: 1.1004 - Val Acc: 0.7022
EarlyStopping counter: 1 out of 5
Epoch 5/20 - Val Loss: 1.2216 - Val Acc: 0.7322
EarlyStopping counter: 2 out of 5
Epoch 6/20 - Val Loss: 1.0171 - Val Acc: 0.8033
EarlyStopping counter: 3 out of 5
Epoch 7/20 - Val Loss: 1.2044 - Val Acc: 0.7869
EarlyStopping counter: 4 out of 5
Epoch 8/20 - Val Loss: 1.2394 - Val Acc: 0.8128
EarlyStopping counter: 5 out of 5
Early stopping
Accuracy: 0.8033
                precision    recall  f1-score   support

          Mild       0.53      0.76      0.63        74
      Moderate       0.76      0.66      0.71       197
         No_DR       0.98      0.97      0.97    

In [38]:
sorted_models = sorted(models_dict.items(), key=lambda x: x[1][1], reverse=True)[:2]
ensemble_models = [model[1][0] for model in sorted_models]

def ensemble_predict(models, loader):
    preds = []
    true_labels = torch.cat([labels for _, labels in loader]).cpu().numpy()
    with torch.no_grad():
        model_outputs = []
        for model in models:
            model_preds = []
            for inputs, _ in loader:
                inputs = inputs.to(device)
                outputs = torch.softmax(model(inputs), dim=1)
                model_preds.append(outputs)
            model_outputs.append(torch.cat(model_preds))
        avg_output = torch.mean(torch.stack(model_outputs), dim=0)
        _, final_preds = torch.max(avg_output, 1)
    return final_preds.cpu().numpy(), true_labels

ensemble_preds, true_labels = ensemble_predict(ensemble_models, val_loader)
ensemble_acc = accuracy_score(true_labels, ensemble_preds)
print(f'Ensemble Accuracy: {ensemble_acc:.4f}')
print(classification_report(true_labels, ensemble_preds, target_names=classes))

Ensemble Accuracy: 0.8265
                precision    recall  f1-score   support

          Mild       0.56      0.82      0.67        74
      Moderate       0.79      0.69      0.73       197
         No_DR       0.98      0.98      0.98       369
Proliferate_DR       0.66      0.54      0.60        57
        Severe       0.47      0.49      0.48        35

      accuracy                           0.83       732
     macro avg       0.69      0.70      0.69       732
  weighted avg       0.84      0.83      0.83       732



In [5]:
import matplotlib.pyplot as plt

def plot_training(history, model_name):
    """
    history: dict with keys ['train_acc', 'val_acc', 'train_loss', 'val_loss']
    model_name: str name of the model
    """
    epochs = range(1, len(history['train_acc']) + 1)

    # Accuracy Plot
    plt.figure(figsize=(6,4))
    plt.plot(epochs, history['train_acc'], label="Train Acc")
    plt.plot(epochs, history['val_acc'], label="Val Acc")
    plt.title(f'{model_name} - Accuracy')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.grid(True)
    plt.savefig(f'assets/{model_name}_accuracy.png')
    plt.close()
    plt.show()

    # Loss Plot
    plt.figure(figsize=(6,4))
    plt.plot(epochs, history['train_loss'], label="Train Loss")
    plt.plot(epochs, history['val_loss'], label="Val Loss")
    plt.title(f'{model_name} - Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)
    plt.savefig(f'assets/{model_name}_loss.png')
    plt.close()
    plt.show()




In [6]:
alexnet_history = {
    'train_acc': [0.57, 0.68, 0.70, 0.72, 0.73, 0.74, 0.75, 0.74, 0.76, 0.77, 0.78, 0.78, 0.78, 0.79, 0.79, 0.78, 0.78, 0.77, 0.78, 0.78],
    'val_acc':   [0.5724, 0.6885, 0.4781, 0.4973, 0.5219, 0.6708, 0.7295, 0.5587, 0.5970, 0.6448, 0.6967, 0.7158, 0.5109, 0.6148, 0.6749, 0.5423, 0.6721, 0.6667, 0.6967, 0.7295],
    'train_loss':[1.63, 1.49, 1.29, 1.32, 1.32, 1.29, 1.25, 1.25, 1.23, 1.21, 1.24, 1.21, 1.28, 1.21, 1.27, 1.27, 1.24, 1.38, 1.21, 1.20],
    'val_loss':  [1.6339, 1.4909, 1.2991, 1.3275, 1.3203, 1.2980, 1.2478, 1.2568, 1.2326, 1.2137, 1.2445, 1.2167, 1.2864, 1.2127, 1.2727, 1.2707, 1.2476, 1.3787, 1.2171, 1.2137]
}

plot_training(alexnet_history, "AlexNet")


In [7]:
vgg16_history = {
    'train_acc': [0.65, 0.66, 0.64, 0.65, 0.66, 0.65, 0.64, 0.65, 0.66, 0.65],
    'val_acc':   [0.6557, 0.0779, 0.0478, 0.0779, 0.5041, 0.0779, 0.0750, 0.0760, 0.0770, 0.0779],
    'train_loss':[1.35, 1.63, 1.62, 1.61, 1.61, 1.61, 1.61, 1.61, 1.61, 1.61],
    'val_loss':  [1.3467, 1.6343, 1.6224, 1.6145, 1.6126, 1.6109, 1.6110, 1.6108, 1.6107, 1.6109]
}

plot_training(vgg16_history, "VGG16")


In [8]:
resnet50_history = {
    'train_acc': [0.63, 0.64, 0.68, 0.70, 0.68, 0.67, 0.69, 0.78, 0.74, 0.66, 0.68, 0.66, 0.74],
    'val_acc':   [0.6393, 0.6462, 0.6844, 0.7063, 0.6790, 0.6735, 0.6954, 0.7814, 0.7391, 0.6667, 0.6803, 0.6626, 0.7459],
    'train_loss':[1.12, 1.24, 0.98, 0.92, 0.92, 0.99, 0.90, 0.84, 1.54, 1.09, 1.71, 1.39, 1.91],
    'val_loss':  [1.1280, 1.2463, 0.9871, 0.9235, 0.9207, 0.9963, 0.9055, 0.8459, 1.5457, 1.0929, 1.7177, 1.3989, 1.9149]
}

plot_training(resnet50_history, "ResNet50")


In [9]:
efficientnet_history = {
    'train_acc': [0.64, 0.78, 0.68, 0.70, 0.73, 0.80, 0.78, 0.81],
    'val_acc':   [0.6434, 0.7814, 0.6790, 0.7022, 0.7322, 0.8033, 0.7869, 0.8128],
    'train_loss':[1.10, 0.91, 0.87, 1.10, 1.22, 1.01, 1.20, 1.23],
    'val_loss':  [1.1035, 0.9145, 0.8745, 1.1004, 1.2216, 1.0171, 1.2044, 1.2394]
}

plot_training(efficientnet_history, "EfficientNet")
