In [None]:
%pip install tensorflow 

In [2]:
from torchvision.models import resnet18# type: ignore
import medmnist
from medmnist import DermaMNIST
from medmnist import INFO, Evaluator
from matplotlib import transforms
from torchvision import transforms
import torch.utils.data as data
from tqdm import tqdm
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
from collections import Counter

In [18]:
# Constants 
data_flag = 'dermamnist'
download = False

NUM_EPOCHS = 12
BATCH_SIZE = 128
lr = 0.001

info = INFO[data_flag]
task = info['task']
n_channels = info['n_channels']
n_classes = len(info['label'])

DataClass = getattr(medmnist, info['python_class'])

In [21]:
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])

train_transform = transforms.Compose([
    transforms.Lambda(lambda x: x.convert('RGB') if x.mode == 'L' else x),
    # Small zoom + crop around original size
    transforms.RandomResizedCrop(size=28, scale=(0.9, 1.1)),  # small zoom in/out
    transforms.RandomRotation(degrees=10, fill=0),            # random rotation
    transforms.ColorJitter(brightness=0.2, contrast=0.2),     # random brightness/contrast
    transforms.ToTensor(),
    normalize,
])

val_transform = transforms.Compose([
    transforms.Lambda(lambda x: x.convert('RGB') if x.mode == 'L' else x),
    transforms.CenterCrop(28),
    transforms.ToTensor(),
    normalize,
])

test_transform = val_transform

pil_dataset = DermaMNIST(split='train', download=download)
train_dataset = DermaMNIST(split='train', transform=train_transform, download=False)
train_dataset_at_eval = DermaMNIST(split='train', transform=val_transform, download=False)
val_dataset   = DermaMNIST(split='val',   transform=val_transform, download=False)
test_dataset  = DermaMNIST(split='test',  transform=test_transform, download=False)

train_loader = data.DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True) # type: ignore
train_loader_at_eval = data.DataLoader(dataset=train_dataset_at_eval, batch_size=2*BATCH_SIZE, shuffle=False) # type: ignore
val_loader = data.DataLoader(dataset=val_dataset, batch_size=2*BATCH_SIZE,   shuffle = False) # type: ignore
test_loader = data.DataLoader(dataset=test_dataset, batch_size=2*BATCH_SIZE, shuffle=False) # type: ignore

In [22]:
label_counts = Counter()
for _, label in train_dataset:
    label_counts[label.item()] += 1

num_classes = len(label_counts)
counts = np.array([label_counts[c] for c in range(num_classes)], dtype=np.float32)

class_weights = 1.0 / (counts + 1e-6)
class_weights = class_weights / class_weights.sum() * num_classes
class_weights = torch.tensor(class_weights, dtype=torch.float32)

In [26]:
pretrained_model = resnet18(weights='IMAGENET1K_V1')

# Freeze the backbone
for param in pretrained_model.parameters():
    param.requires_grad = False

for param in pretrained_model.layer4.parameters():
    param.requires_grad = True

for param in pretrained_model.layer3.parameters():
    param.requires_grad = True

# Replace fc head (ResNet uses .fc, not .classifier)
pretrained_model.fc = nn.Linear(pretrained_model.fc.in_features, n_classes)

# Unfreeze the new head
for param in pretrained_model.fc.parameters():
    param.requires_grad = True

model = pretrained_model

# Loss and optimizer
if task == "multi-label, binary-class":
    criterion = nn.BCEWithLogitsLoss(weight=class_weights)
else:
    criterion = nn.CrossEntropyLoss(weight=class_weights)

optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=2
)


In [27]:
# Training loop      
store_train = []
store_val = []

patience = 3
best_val = float('inf')
patience_ctr = 0
best_path = "best_resnet18.pth"

for epoch in range(NUM_EPOCHS):
    running_train_loss = 0.0
    running_train_count = 0
    running_val_loss = 0.0
    running_val_count = 0
    
    model.train()
    for inputs, targets in tqdm(train_loader):
        optimizer.zero_grad()
        outputs = model(inputs)
        
        if task == 'multi-label, binary-class':
            targets = targets.to(torch.float32)
            loss = criterion(outputs, targets)
        else:
            targets = targets.squeeze().long()
            loss = criterion(outputs, targets)
        
        loss.backward()
        optimizer.step()
        bs = inputs.size(0)
        running_train_loss += loss.item() * bs
        running_train_count += bs
    
    model.eval()
    with torch.no_grad():
        for inputs, targets in tqdm(val_loader):
            outputs = model(inputs)
            
            if task == 'multi-label, binary-class':
                targets = targets.to(torch.float32)
                loss_val = criterion(outputs, targets)
            else:
                targets = targets.squeeze().long()
                loss_val = criterion(outputs, targets)
            bs = inputs.size(0)
            running_val_loss += loss_val.item() * bs
            running_val_count += bs

    epoch_train_loss = running_train_loss / running_train_count
    epoch_val_loss = running_val_loss / running_val_count

    scheduler.step(epoch_val_loss)

    store_train.append(epoch_train_loss)
    store_val.append(epoch_val_loss)

    if epoch_val_loss < best_val:
        best_val = epoch_val_loss
        patience_ctr = 0
        torch.save(model.state_dict(), best_path)  # keep best model
    else:
        patience_ctr += 1
        if patience_ctr >= patience:
            print(f"Early stopping at epoch {epoch+1}")
            break

    print('Epoch [{}/{}], Loss: {:.4f}, Val Loss: {:.4f}'
          .format(epoch+1, NUM_EPOCHS, epoch_train_loss, epoch_val_loss))


100%|██████████| 55/55 [00:30<00:00,  1.79it/s]
100%|██████████| 4/4 [00:00<00:00,  6.51it/s]


Epoch [1/12], Loss: 1.7642, Val Loss: 1.7513


100%|██████████| 55/55 [00:31<00:00,  1.73it/s]
100%|██████████| 4/4 [00:00<00:00,  6.12it/s]


Epoch [2/12], Loss: 1.4522, Val Loss: 1.3216


100%|██████████| 55/55 [00:31<00:00,  1.75it/s]
100%|██████████| 4/4 [00:00<00:00,  4.91it/s]


Epoch [3/12], Loss: 1.3326, Val Loss: 1.4204


100%|██████████| 55/55 [00:32<00:00,  1.69it/s]
100%|██████████| 4/4 [00:00<00:00,  5.33it/s]


Epoch [4/12], Loss: 1.2426, Val Loss: 1.2969


100%|██████████| 55/55 [00:32<00:00,  1.69it/s]
100%|██████████| 4/4 [00:00<00:00,  5.12it/s]


Epoch [5/12], Loss: 1.1536, Val Loss: 1.2632


100%|██████████| 55/55 [00:31<00:00,  1.73it/s]
100%|██████████| 4/4 [00:00<00:00,  5.03it/s]


Epoch [6/12], Loss: 1.1215, Val Loss: 1.2660


100%|██████████| 55/55 [00:32<00:00,  1.67it/s]
100%|██████████| 4/4 [00:00<00:00,  6.86it/s]


Epoch [7/12], Loss: 1.0839, Val Loss: 1.2076


100%|██████████| 55/55 [00:32<00:00,  1.69it/s]
100%|██████████| 4/4 [00:00<00:00,  4.89it/s]


Epoch [8/12], Loss: 1.0208, Val Loss: 1.2603


100%|██████████| 55/55 [00:31<00:00,  1.73it/s]
100%|██████████| 4/4 [00:00<00:00,  6.76it/s]


Epoch [9/12], Loss: 1.0094, Val Loss: 1.2176


100%|██████████| 55/55 [00:32<00:00,  1.70it/s]
100%|██████████| 4/4 [00:00<00:00,  4.83it/s]


Epoch [10/12], Loss: 1.0074, Val Loss: 1.2018


100%|██████████| 55/55 [00:33<00:00,  1.66it/s]
100%|██████████| 4/4 [00:00<00:00,  5.41it/s]


Epoch [11/12], Loss: 0.9638, Val Loss: 1.1879


100%|██████████| 55/55 [00:33<00:00,  1.65it/s]
100%|██████████| 4/4 [00:00<00:00,  4.43it/s]

Epoch [12/12], Loss: 0.9453, Val Loss: 1.1354





In [None]:
def test(split):
    model.eval()
    y_true = torch.tensor([])
    y_score = torch.tensor([])
    
    data_loader = test_loader if split == 'test' else train_loader_at_eval

    with torch.no_grad():
        for inputs, targets in data_loader:
            outputs = model(inputs)

            if task == 'multi-label, binary-class':
                targets = targets.to(torch.float32)
                outputs = outputs.softmax(dim=-1)
            else:
                targets = targets.squeeze().long()
                outputs = outputs.softmax(dim=-1)
                targets = targets.float().resize_(len(targets), 1)

            y_true = torch.cat((y_true, targets), 0)
            y_score = torch.cat((y_score, outputs), 0)

        y_true = y_true.numpy()
        y_score = y_score.detach().numpy()
        
        evaluator = Evaluator(data_flag, split)
        metrics = evaluator.evaluate(y_score)
    
        print('%s  auc: %.3f  acc:%.3f' % (split, *metrics))

print('==> Evaluating ...')
test('train')
test('test')

==> Evaluating ...
train  auc: 0.916  acc:0.594
test  auc: 0.865  acc:0.549


In [None]:
#Fine Tuning                      
#learning rate scheduling - Helps to improve accuracy            
#increase epoch - I don't have GPU sad
#decrease batch size - same
#use better backbone - same
#data transforms - Done
#better optimizer - Need to try with diffrent SGD vs AdamW
#early stopping - Extermly useful we don't know the right number of epoch we should run on 
# Apply class weights - Used for class Imbalance which is the case in our dataset
# better comparisons
# more layers to be frozen - Freezing one block is enough