In [3]:
import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import models, transforms
from sklearn.metrics import classification_report, roc_auc_score
from sklearn.model_selection import train_test_split
from PIL import Image
import matplotlib.pyplot as plt

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Path to dataset
data_dir = '/kaggle/input/diabetic-retinopathy-224x224-gaussian-filtered/gaussian_filtered_images/gaussian_filtered_images'
classes = ['Mild', 'Moderate', 'No_DR', 'Proliferate_DR', 'Severe']

# Custom Dataset class
class DiabeticRetinopathyDataset(Dataset):
    def __init__(self, file_paths, labels, transform=None):
        self.file_paths = file_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        img_path = self.file_paths[idx]
        label = self.labels[idx]
        image = Image.open(img_path).convert("RGB")

        if self.transform:
            image = self.transform(image)

        return image, label

# Data preprocessing and augmentation
transform = {
    'train': transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(10),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]),
    'test': transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
}

# Load dataset
file_paths = []
labels = []

for label, cls in enumerate(classes):
    cls_folder = os.path.join(data_dir, cls)
    for img_name in os.listdir(cls_folder):
        file_paths.append(os.path.join(cls_folder, img_name))
        labels.append(label)

# Train-test split
train_paths, val_paths, train_labels, val_labels = train_test_split(
    file_paths, labels, test_size=0.2, stratify=labels, random_state=42
)

# Create datasets and dataloaders
train_dataset = DiabeticRetinopathyDataset(train_paths, train_labels, transform=transform['train'])
val_dataset = DiabeticRetinopathyDataset(val_paths, val_labels, transform=transform['test'])

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)

# Load pretrained model (ResNet50)
model = models.resnet50(pretrained=True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, len(classes))
model = model.to(device)

# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# Training and validation loop
def train_model(model, criterion, optimizer, num_epochs=20):
    best_model_wts = model.state_dict()
    best_acc = 0.0

    for epoch in range(num_epochs):
        print(f'Epoch {epoch + 1}/{num_epochs}')
        print('-' * 10)

        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
                dataloader = train_loader
            else:
                model.eval()
                dataloader = val_loader

            running_loss = 0.0
            running_corrects = 0

            for inputs, labels in dataloader:
                inputs = inputs.to(device)
                labels = labels.to(device)

                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / len(dataloader.dataset)
            epoch_acc = running_corrects.double() / len(dataloader.dataset)

            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = model.state_dict()

    print(f'Best val Acc: {best_acc:.4f}')
    model.load_state_dict(best_model_wts)
    return model

# Train the model
model = train_model(model, criterion, optimizer, num_epochs=20)

# Save the model
torch.save(model.state_dict(), "diabetic_retinopathy_model.pth")

# Evaluate the model
model.eval()
all_preds = []
all_labels = []

with torch.no_grad():
    for inputs, labels in val_loader:
        inputs = inputs.to(device)
        labels = labels.to(device)
        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

# Classification report
print("Classification Report:")
print(classification_report(all_labels, all_preds, target_names=classes))

# Compute ROC AUC
roc_auc = roc_auc_score(
    np.eye(len(classes))[all_labels],
    np.eye(len(classes))[all_preds],
    multi_class='ovr'
)
print(f"ROC AUC Score: {roc_auc:.4f}")


Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 185MB/s]


Epoch 1/20
----------
train Loss: 0.7178 Acc: 0.7453
val Loss: 0.6562 Acc: 0.7503
Epoch 2/20
----------
train Loss: 0.5723 Acc: 0.7876
val Loss: 0.5745 Acc: 0.7885
Epoch 3/20
----------
train Loss: 0.5139 Acc: 0.7982
val Loss: 0.5261 Acc: 0.7899
Epoch 4/20
----------
train Loss: 0.4859 Acc: 0.8197
val Loss: 0.4833 Acc: 0.8049
Epoch 5/20
----------
train Loss: 0.4276 Acc: 0.8371
val Loss: 0.5310 Acc: 0.7804
Epoch 6/20
----------
train Loss: 0.3957 Acc: 0.8491
val Loss: 0.5270 Acc: 0.8022
Epoch 7/20
----------
train Loss: 0.3563 Acc: 0.8641
val Loss: 0.5594 Acc: 0.7858
Epoch 8/20
----------
train Loss: 0.3372 Acc: 0.8767
val Loss: 0.5802 Acc: 0.8035
Epoch 9/20
----------
train Loss: 0.3048 Acc: 0.8832
val Loss: 0.7548 Acc: 0.8008
Epoch 10/20
----------
train Loss: 0.2687 Acc: 0.9013
val Loss: 0.5863 Acc: 0.8117
Epoch 11/20
----------
train Loss: 0.2586 Acc: 0.8996
val Loss: 0.6786 Acc: 0.7831
Epoch 12/20
----------
train Loss: 0.2312 Acc: 0.9184
val Loss: 0.7573 Acc: 0.7967
Epoch 13/20
-