## Importing all necessary libraries

In [15]:
import pandas as pd
import torch
import torch.nn as nn
import torchvision.models as models
from torch.nn import functional as F
from sklearn.model_selection import train_test_split
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image

## Loading data and making a dataset

In [None]:
df = pd.read_csv("../data/aptos2019/train.csv")
train_df, val_df = train_test_split(df, test_size=0.2, random_state=42)

# Dataset class
class RetinopathyDataset(Dataset):
    def __init__(self, df, img_dir, transform=None):
        self.df = df
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        img_path = f"{self.img_dir}/{self.df.iloc[idx, 0]}.png"
        img = Image.open(img_path)
        label = self.df.iloc[idx, 1]
        
        if self.transform:
            img = self.transform(img)
        return img, label

# Transforms
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [None]:
# DataLoaders
train_dataset = RetinopathyDataset(train_df, "data/aptos2019/train_images", transform)
val_dataset = RetinopathyDataset(val_df, "data/aptos2019/train_images", transform)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

## Model setup and training

In [None]:
class DRModel(nn.Module):
    def __init__(self, num_classes=5):
        super(DRModel, self).__init__()
        
        # Loading pre-trained VGG19
        self.vgg = models.vgg19(weights='IMAGENET1K_V1')
        
        # Removing the classifier
        self.features = self.vgg.features
        
        # Freezing VGG19 weights
        for param in self.features.parameters():
            param.requires_grad = False
            
        # Additional layers
        self.global_avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.flatten = nn.Flatten()

        # VGG19 last conv has 512 channels
        self.fc1 = nn.Linear(512, 256)
        self.dropout = nn.Dropout(0.3)
        self.fc2 = nn.Linear(256, num_classes)
        
        # Activation
        self.elu = nn.ELU()
        
    def forward(self, x):
        # Feature extraction
        x = self.features(x)
        x = self.global_avg_pool(x)
        x = self.flatten(x)
        
        # Classification
        x = self.fc1(x)
        x = self.elu(x)
        x = self.dropout(x)
        x = self.fc2(x)
        
        return x

# Model setup
model = DRModel(num_classes=5)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

In [None]:
# Training setup
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
scheduler = ReduceLROnPlateau(optimizer, mode="min", patience=2, verbose=True)
best_val_loss = float("inf")
early_stop_patience = 5
epochs_no_improve = 0
num_epochs = 20



In [None]:
# Training loop
for epoch in range(num_epochs):
    # Training
    model.train()
    train_loss = 0.0
    train_correct = 0
    
    for images, labels in tqdm(train_loader):
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        train_correct += (predicted == labels).sum().item()

    # Validation
    model.eval()
    val_loss = 0.0
    val_correct = 0
    
    with torch.no_grad():
        for images, labels in tqdm(val_loader):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            val_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            val_correct += (predicted == labels).sum().item()
            
    # LR scheduling & early stopping
    scheduler.step(val_loss)
    
	# Saving best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), "best_model.pth")
        epochs_no_improve = 0
    else:
        epochs_no_improve += 1
        if epochs_no_improve == early_stop_patience:
            print(f"Epoch {epoch+1}/{num_epochs}")
            print(f"Train Loss: {train_loss/len(train_loader):.4f} | Train Acc: {train_correct/len(train_loader):.4f}%")
            print(f"Val Loss: {val_loss/len(val_loader):.4f} | Val Acc: {val_correct/len(val_loader):.4f}%")
            print("Early stopping triggered!")
            break

    print(f"Epoch {epoch+1}/{num_epochs}")
    print(f"Train Loss: {train_loss/len(train_loader):.4f} | Train Acc: {train_correct/len(train_loader):.4f}%")
    print(f"Val Loss: {val_loss/len(val_loader):.4f} | Val Acc: {val_correct/len(val_loader):.4f}%")

100%|██████████| 92/92 [06:19<00:00,  4.12s/it]


Epoch 1/20
Train Loss: 1.2371 | Train Acc: 16.8696%
Val Loss: 1.1014 | Val Acc: 17.7826%


100%|██████████| 92/92 [05:16<00:00,  3.44s/it]


Epoch 2/20
Train Loss: 0.9925 | Train Acc: 21.4022%
Val Loss: 0.9211 | Val Acc: 22.5217%


100%|██████████| 92/92 [05:16<00:00,  3.43s/it]


Epoch 3/20
Train Loss: 0.8546 | Train Acc: 22.6304%
Val Loss: 0.8258 | Val Acc: 23.0870%


100%|██████████| 92/92 [05:15<00:00,  3.43s/it]


Epoch 4/20
Train Loss: 0.7813 | Train Acc: 22.9239%
Val Loss: 0.7751 | Val Acc: 23.0870%


100%|██████████| 92/92 [05:17<00:00,  3.45s/it]


Epoch 5/20
Train Loss: 0.7366 | Train Acc: 23.0978%
Val Loss: 0.7417 | Val Acc: 23.3043%


100%|██████████| 92/92 [05:17<00:00,  3.45s/it]


Epoch 6/20
Train Loss: 0.7038 | Train Acc: 23.2935%
Val Loss: 0.7156 | Val Acc: 23.4348%


100%|██████████| 92/92 [05:19<00:00,  3.47s/it]


Epoch 7/20
Train Loss: 0.6800 | Train Acc: 23.5217%
Val Loss: 0.7006 | Val Acc: 23.5652%


100%|██████████| 92/92 [05:15<00:00,  3.43s/it]


Epoch 8/20
Train Loss: 0.6548 | Train Acc: 23.7826%
Val Loss: 0.6841 | Val Acc: 24.0870%


100%|██████████| 92/92 [05:15<00:00,  3.43s/it]


Epoch 9/20
Train Loss: 0.6396 | Train Acc: 23.9457%
Val Loss: 0.6669 | Val Acc: 23.9565%


100%|██████████| 92/92 [05:13<00:00,  3.41s/it]


Epoch 10/20
Train Loss: 0.6229 | Train Acc: 24.1957%
Val Loss: 0.6556 | Val Acc: 24.2174%


100%|██████████| 92/92 [05:13<00:00,  3.41s/it]


Epoch 11/20
Train Loss: 0.6109 | Train Acc: 24.3587%
Val Loss: 0.6469 | Val Acc: 24.3478%


100%|██████████| 92/92 [05:16<00:00,  3.43s/it]


Epoch 12/20
Train Loss: 0.6007 | Train Acc: 24.4130%
Val Loss: 0.6404 | Val Acc: 24.4348%


100%|██████████| 92/92 [05:13<00:00,  3.41s/it]


Epoch 13/20
Train Loss: 0.5919 | Train Acc: 24.6522%
Val Loss: 0.6320 | Val Acc: 24.3913%


100%|██████████| 92/92 [05:13<00:00,  3.41s/it]


Epoch 14/20
Train Loss: 0.5845 | Train Acc: 24.9130%
Val Loss: 0.6243 | Val Acc: 24.3913%


100%|██████████| 92/92 [05:13<00:00,  3.41s/it]


Epoch 15/20
Train Loss: 0.5724 | Train Acc: 25.0435%
Val Loss: 0.6246 | Val Acc: 24.5652%


100%|██████████| 92/92 [05:14<00:00,  3.41s/it]


Epoch 16/20
Train Loss: 0.5674 | Train Acc: 25.0870%
Val Loss: 0.6152 | Val Acc: 24.2609%


100%|██████████| 92/92 [05:13<00:00,  3.41s/it]


Epoch 17/20
Train Loss: 0.5566 | Train Acc: 25.1413%
Val Loss: 0.6086 | Val Acc: 24.3913%


100%|██████████| 92/92 [05:13<00:00,  3.41s/it]


Epoch 18/20
Train Loss: 0.5528 | Train Acc: 25.0109%
Val Loss: 0.6093 | Val Acc: 24.3043%


100%|██████████| 92/92 [05:14<00:00,  3.41s/it]


Epoch 19/20
Train Loss: 0.5518 | Train Acc: 25.1087%
Val Loss: 0.6015 | Val Acc: 24.5652%


100%|██████████| 92/92 [05:13<00:00,  3.41s/it]


Epoch 20/20
Train Loss: 0.5402 | Train Acc: 25.3696%
Val Loss: 0.6045 | Val Acc: 24.6087%


I made an error while calculating accuracy, so the numbers after "Train Acc:" and "Val Acc:" are mean number of correct predintions in one batch (with size 32), so it should be `number / 32 * 100` %

In [27]:
print(val_correct/len(val_loader.dataset))

0.772169167803547
