In [84]:
import os
import numpy as np
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader, Subset
from sklearn.model_selection import train_test_split
from torch.utils.data import WeightedRandomSampler
import torch
import torch.nn as nn
import timm
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler

In [2]:
data_dir = "HyperKvasirDataset"

In [61]:
all_data = ImageFolder(root=data_dir)

In [63]:
# Stratified split of dataset due to large class imbalance
labels = np.array([all_data.targets[i] for i in range(len(all_data))])

# 70% train, 20% val, 10% test
train_idx, temp_idx, train_labels, temp_labels = train_test_split(
    np.arange(len(all_data)), labels, test_size=0.3, stratify=labels, random_state=42
)

val_idx, test_idx, val_labels, test_labels = train_test_split(
    temp_idx, temp_labels, test_size=0.33, stratify=temp_labels, random_state=42
)

In [78]:
# Create the data transforms
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(degrees=5),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.02),
    transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 1.0)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 
])

val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

test_transform = val_transform  # No augmentation for test data

In [79]:
# Create Subsets
train_dataset = Subset(all_data, train_idx)
val_dataset = Subset(all_data, val_idx)
test_dataset = Subset(all_data, test_idx)

# Set the augmentations
train_dataset.dataset.transform = train_transform
val_dataset.dataset.transform = val_transform
test_dataset.dataset.transform = test_transform

In [67]:
# Compute class weights for use with weighted loss function
class_counts = np.bincount(train_labels)
class_weights = 1.0 / torch.tensor(class_counts, dtype=torch.float32)
class_weights_normalized = class_weights / class_weights.sum()

In [68]:
# Create WeightedRandomSampler for oversampling minority classes in training batches
sampler = WeightedRandomSampler(weights=class_weights_normalized, num_samples=len(class_weights_normalized), replacement=True)

In [83]:
# Get the pretrained Swin Transformer model
num_classes = len(np.unique(train_labels))
model = timm.create_model("swin_base_patch4_window7_224", pretrained=True, num_classes=num_classes, drop_rate=0.3)

In [76]:
# Freeze early layers
for param in model.patch_embed.parameters():
    param.requires_grad = False

In [86]:
# Learning rate warmup function
def warmup_lr(epoch, warmup_epochs=3):
    if epoch < warmup_epochs:
        return (epoch+1) / warmup_epochs # Gradually increase LR

    return 1 # Keep LR at base value after warmup

In [87]:
# Define hyperparameters
batch = 32
criterion = torch.nn.CrossEntropyLoss(weight=class_weights_normalized)
learning_rate = 0.0003
decay = 0.01
lr_warmup_epochs = 3
optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=decay)

# Add a learning rate scheduler
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3)
warmup_scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda epoch: warmup_lr(epoch, lr_warmup_epochs))

num_epochs = 15
best_val_loss = float("inf")
best_val_path = "best_model.pth"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [81]:
# Create the DataLoaders
train_loader = DataLoader(train_dataset, batch_size=batch, sampler=sampler)
val_loader = DataLoader(val_dataset, batch_size=batch, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch, shuffle=False)

In [82]:
for epoch in range(num_epochs):
    model.train()
    train_loss = 0

    # Training loop
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        # Forward pass
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)

        # Backprop
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        
    train_loss /= len(train_loader)
    # Adjust learning rate initially
    if epoch < lr_warmup_epochs:
        warmup_scheduler.step()
    
    # Validation loop
    model.eval()
    val_loss = 0

    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            
    val_loss /= len(val_loader)

    # Adjust learning rate based on validation loss
    scheduler.step(val_loss)

    print(f"Epoch {epoch+1}: Learning Rate: {scheduler.optimizer.param_groups[0]['lr']}")
    print(f"Train Loss: {train_loss:.4f}")
    print(f"Validation Loss: {val_loss:.4f}")

    # Save the best model based on validation loss
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        print("New best model saved.")
        torch.save(model.state_dict(), best_val_path)

Epoch 1: Learning Rate: 0.0005
Train Loss: 3.3600
Validation Loss: 3.6735
New best model saved.
Epoch 2: Learning Rate: 0.0005
Train Loss: 1.1104
Validation Loss: 6.4124
Epoch 3: Learning Rate: 0.0005
Train Loss: 1.8704
Validation Loss: 3.2980
New best model saved.
Epoch 4: Learning Rate: 0.0005
Train Loss: 1.9925
Validation Loss: 5.0247
Epoch 5: Learning Rate: 0.0005
Train Loss: 1.6285
Validation Loss: 4.3483
Epoch 6: Learning Rate: 0.0005
Train Loss: 0.7345
Validation Loss: 5.1635
Epoch 7: Learning Rate: 0.00025
Train Loss: 0.4282
Validation Loss: 4.5313
Epoch 8: Learning Rate: 0.00025
Train Loss: 0.6540
Validation Loss: 3.9253
Epoch 9: Learning Rate: 0.00025
Train Loss: 0.3104
Validation Loss: 3.4770
Epoch 10: Learning Rate: 0.00025
Train Loss: 0.2321
Validation Loss: 3.4565
Epoch 11: Learning Rate: 0.000125
Train Loss: 2.3520
Validation Loss: 3.3919
Epoch 12: Learning Rate: 0.000125
Train Loss: 0.8579
Validation Loss: 3.4770
Epoch 13: Learning Rate: 0.000125
Train Loss: 0.5129
Vali