In [None]:
%pip install tensorflow 

In [None]:
from torchvision.models import efficientnet_b0# type: ignore
import medmnist
from medmnist import DermaMNIST
from medmnist import INFO, Evaluator
from matplotlib import transforms
from torchvision import transforms
import torch.utils.data as data
from tqdm import tqdm
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data

In [None]:
# Constants 
data_flag = 'dermamnist'
download = False

NUM_EPOCHS = 3
BATCH_SIZE = 128
lr = 0.001

info = INFO[data_flag]
task = info['task']
n_channels = info['n_channels']
n_classes = len(info['label'])

DataClass = getattr(medmnist, info['python_class'])

In [None]:
data_transform = transforms.Compose([
    transforms.Lambda(lambda x: x.convert('RGB') if x.mode == 'L' else x),  # Convert grayscale to RGB
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                         std=[0.229, 0.224, 0.225])
])

pil_dataset = DermaMNIST(split='train', download=download)
train_dataset = DermaMNIST(split='train', transform=data_transform, download=False)
val_dataset   = DermaMNIST(split='val',   transform=data_transform, download=False)
test_dataset  = DermaMNIST(split='test',  transform=data_transform, download=False)

train_loader = data.DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True) # type: ignore
train_loader_at_eval = data.DataLoader(dataset=train_dataset, batch_size=2*BATCH_SIZE, shuffle=False) # type: ignore
val_loader = data.DataLoader(dataset=val_dataset, batch_size=2*BATCH_SIZE,   shuffle = False) # type: ignore
test_loader = data.DataLoader(dataset=test_dataset, batch_size=2*BATCH_SIZE, shuffle=False) # type: ignore

In [None]:
pretrained_model = efficientnet_b0(pretrained=True)
pretrained_model.eval()  # Set to eval mode initially

# Freeze the base model
for param in pretrained_model.parameters():
    param.requires_grad = False

# Replace the classifier head
num_features = pretrained_model.classifier[1].in_features
pretrained_model.classifier[1] = nn.Linear(num_features, n_classes)

# Unfreeze the new classification layer for training
for param in pretrained_model.classifier[1].parameters():
    param.requires_grad = True

model = pretrained_model

In [None]:
num_features = pretrained_model.classifier[1].in_features
pretrained_model.classifier[1] = nn.Linear(num_features, n_classes)

model = pretrained_model

# Loss and optimizer
if task == "multi-label, binary-class":
    criterion = nn.BCEWithLogitsLoss()
else:
    criterion = nn.CrossEntropyLoss()

optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)

In [None]:
# Training loop (same as CNN.ipynb)
store_train = []
store_val = []

for epoch in range(NUM_EPOCHS):
    running_train_loss = 0.0
    running_train_count = 0
    running_val_loss = 0.0
    running_val_count = 0
    
    model.train()
    for inputs, targets in tqdm(train_loader):
        optimizer.zero_grad()
        outputs = model(inputs)
        
        if task == 'multi-label, binary-class':
            targets = targets.to(torch.float32)
            loss = criterion(outputs, targets)
        else:
            targets = targets.squeeze().long()
            loss = criterion(outputs, targets)
        
        loss.backward()
        optimizer.step()
        bs = inputs.size(0)
        running_train_loss += loss.item() * bs
        running_train_count += bs
    
    model.eval()
    with torch.no_grad():
        for inputs, targets in tqdm(val_loader):
            outputs = model(inputs)
            
            if task == 'multi-label, binary-class':
                targets = targets.to(torch.float32)
                loss_val = criterion(outputs, targets)
            else:
                targets = targets.squeeze().long()
                loss_val = criterion(outputs, targets)
            bs = inputs.size(0)
            running_val_loss += loss_val.item() * bs
            running_val_count += bs

    epoch_train_loss = running_train_loss / running_train_count
    epoch_val_loss = running_val_loss / running_val_count

    store_train.append(epoch_train_loss)
    store_val.append(epoch_val_loss)

    print('Epoch [{}/{}], Loss: {:.4f}, Val Loss: {:.4f}'
          .format(epoch+1, NUM_EPOCHS, epoch_train_loss, epoch_val_loss))


In [None]:
def test(split):
    model.eval()
    y_true = torch.tensor([])
    y_score = torch.tensor([])
    
    data_loader = test_loader if split == 'test' else train_loader

    with torch.no_grad():
        for inputs, targets in data_loader:
            outputs = model(inputs)

            if task == 'multi-label, binary-class':
                targets = targets.to(torch.float32)
                outputs = outputs.softmax(dim=-1)
            else:
                targets = targets.squeeze().long()
                outputs = outputs.softmax(dim=-1)
                targets = targets.float().resize_(len(targets), 1)

            y_true = torch.cat((y_true, targets), 0)
            y_score = torch.cat((y_score, outputs), 0)

        y_true = y_true.numpy()
        y_score = y_score.detach().numpy()
        
        evaluator = Evaluator(data_flag, split)
        metrics = evaluator.evaluate(y_score)
    
        print('%s  auc: %.3f  acc:%.3f' % (split, *metrics))

print('==> Evaluating ...')
test('train')
test('test')