In [7]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from data_processing import get_loaders, class_cols
from utils import get_device
from resnet import Bottleneck, ResNet, ResNet50
from tqdm import tqdm

In [2]:
transform_train = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])

In [3]:
train_loader, val_loader, test_loader, train_df, val_df, test_df = get_loaders(
    image_size=(384, 384),
    num_workers=0,
)
class_cols_order = ["MEL", "NV", "BCC", "AKIEC", "BKL", "DF", "VASC"]
num_classes = len(class_cols_order)

In [4]:
# Check device
get_device()

device(type='mps')

In [5]:
net = ResNet50(num_classes).to('mps')

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9, weight_decay=0.0001)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor = 0.1, patience=5)

In [9]:
EPOCHS = 10
for epoch in range(EPOCHS):
    print(f'Starting epoch {epoch+1}/{EPOCHS}')
    losses = []
    running_loss = 0
    for i, inp in tqdm(enumerate(train_loader)):

        inputs, labels = inp
        inputs, labels = inputs.to('mps'), labels.to('mps')
        optimizer.zero_grad()
    
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        losses.append(loss.item())

        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        
        if i%100 == 0 and i > 0:
            print(f'Loss [{epoch+1}, {i}](epoch, minibatch): ', running_loss / 100)
            running_loss = 0.0

    avg_loss = sum(losses)/len(losses)
    scheduler.step(avg_loss)
            
print('Training Done')

Starting epoch 1/10


101it [00:52,  1.94it/s]

Loss [1, 100](epoch, minibatch):  2.3411827969551084


201it [01:43,  1.94it/s]

Loss [1, 200](epoch, minibatch):  1.1834919220209121


301it [02:36,  1.89it/s]

Loss [1, 300](epoch, minibatch):  1.1410216993093492


401it [03:28,  1.84it/s]

Loss [1, 400](epoch, minibatch):  1.1110780653357506


501it [04:21,  1.93it/s]

Loss [1, 500](epoch, minibatch):  1.073774127960205


601it [05:13,  1.92it/s]

Loss [1, 600](epoch, minibatch):  1.103025357723236


626it [05:27,  1.91it/s]


Starting epoch 2/10


101it [00:53,  1.91it/s]

Loss [2, 100](epoch, minibatch):  1.103046836256981


201it [01:45,  1.84it/s]

Loss [2, 200](epoch, minibatch):  1.0907086592912674


301it [02:38,  1.90it/s]

Loss [2, 300](epoch, minibatch):  1.0482691663503647


401it [03:31,  1.88it/s]

Loss [2, 400](epoch, minibatch):  1.0657124829292297


501it [04:24,  1.93it/s]

Loss [2, 500](epoch, minibatch):  0.9685556423664093


601it [05:16,  1.92it/s]

Loss [2, 600](epoch, minibatch):  1.0343643087148666


626it [05:29,  1.90it/s]


Starting epoch 3/10


101it [00:52,  1.93it/s]

Loss [3, 100](epoch, minibatch):  1.0378424137830735


201it [01:44,  1.92it/s]

Loss [3, 200](epoch, minibatch):  1.0548190522193908


301it [02:36,  1.89it/s]

Loss [3, 300](epoch, minibatch):  1.0375571748614312


401it [03:28,  1.92it/s]

Loss [3, 400](epoch, minibatch):  1.014284301996231


501it [04:20,  1.90it/s]

Loss [3, 500](epoch, minibatch):  1.019305123090744


601it [05:13,  1.91it/s]

Loss [3, 600](epoch, minibatch):  1.011924531161785


626it [05:26,  1.91it/s]


Starting epoch 4/10


101it [00:52,  1.89it/s]

Loss [4, 100](epoch, minibatch):  1.0487035429477691


201it [01:45,  1.91it/s]

Loss [4, 200](epoch, minibatch):  1.024999424815178


301it [02:38,  1.87it/s]

Loss [4, 300](epoch, minibatch):  0.9726478189229966


401it [03:31,  1.93it/s]

Loss [4, 400](epoch, minibatch):  1.0098513227701187


501it [04:24,  1.90it/s]

Loss [4, 500](epoch, minibatch):  0.9952134004235268


601it [05:16,  1.93it/s]

Loss [4, 600](epoch, minibatch):  1.0443688857555389


626it [05:29,  1.90it/s]


Starting epoch 5/10


101it [00:52,  1.92it/s]

Loss [5, 100](epoch, minibatch):  1.041561549603939


201it [01:44,  1.92it/s]

Loss [5, 200](epoch, minibatch):  0.9966088211536408


301it [02:37,  1.92it/s]

Loss [5, 300](epoch, minibatch):  0.9914626893401146


401it [03:29,  1.90it/s]

Loss [5, 400](epoch, minibatch):  0.9882300728559494


501it [04:21,  1.93it/s]

Loss [5, 500](epoch, minibatch):  0.9418884354829788


601it [05:13,  1.93it/s]

Loss [5, 600](epoch, minibatch):  0.9887402927875519


626it [05:25,  1.92it/s]


Starting epoch 6/10


101it [00:52,  1.92it/s]

Loss [6, 100](epoch, minibatch):  1.0561358019709588


201it [01:44,  1.92it/s]

Loss [6, 200](epoch, minibatch):  0.9713498073816299


301it [02:36,  1.93it/s]

Loss [6, 300](epoch, minibatch):  0.9454820850491523


401it [03:28,  1.93it/s]

Loss [6, 400](epoch, minibatch):  0.9599829000234604


501it [04:21,  1.85it/s]

Loss [6, 500](epoch, minibatch):  0.9215902832150459


601it [05:13,  1.93it/s]

Loss [6, 600](epoch, minibatch):  0.9634536477923393


626it [05:26,  1.92it/s]


Starting epoch 7/10


101it [00:52,  1.92it/s]

Loss [7, 100](epoch, minibatch):  0.9240616008639335


201it [01:44,  1.94it/s]

Loss [7, 200](epoch, minibatch):  0.9582858249545098


301it [02:36,  1.93it/s]

Loss [7, 300](epoch, minibatch):  0.9539153555035591


401it [03:28,  1.91it/s]

Loss [7, 400](epoch, minibatch):  0.8753782993555069


501it [04:21,  1.91it/s]

Loss [7, 500](epoch, minibatch):  0.9140063500404358


601it [05:14,  1.89it/s]

Loss [7, 600](epoch, minibatch):  0.9158892267942429


626it [05:27,  1.91it/s]


Starting epoch 8/10


101it [00:52,  1.92it/s]

Loss [8, 100](epoch, minibatch):  0.922256760597229


201it [01:44,  1.92it/s]

Loss [8, 200](epoch, minibatch):  0.9079651689529419


301it [02:37,  1.91it/s]

Loss [8, 300](epoch, minibatch):  0.9050401133298874


401it [03:30,  1.90it/s]

Loss [8, 400](epoch, minibatch):  0.9495175546407699


501it [04:22,  1.93it/s]

Loss [8, 500](epoch, minibatch):  0.8896061384677887


601it [05:14,  1.91it/s]

Loss [8, 600](epoch, minibatch):  0.9129515400528908


626it [05:27,  1.91it/s]


Starting epoch 9/10


101it [00:53,  1.90it/s]

Loss [9, 100](epoch, minibatch):  0.9322891533374786


201it [01:46,  1.90it/s]

Loss [9, 200](epoch, minibatch):  0.8697421339154243


301it [02:39,  1.90it/s]

Loss [9, 300](epoch, minibatch):  0.8680654579401016


401it [03:32,  1.90it/s]

Loss [9, 400](epoch, minibatch):  0.8969259896874427


501it [04:25,  1.91it/s]

Loss [9, 500](epoch, minibatch):  0.8884210386872291


601it [05:17,  1.89it/s]

Loss [9, 600](epoch, minibatch):  0.9061165529489518


626it [05:31,  1.89it/s]


Starting epoch 10/10


101it [00:53,  1.90it/s]

Loss [10, 100](epoch, minibatch):  0.8836296138167381


201it [01:46,  1.89it/s]

Loss [10, 200](epoch, minibatch):  0.9070937097072601


301it [02:39,  1.88it/s]

Loss [10, 300](epoch, minibatch):  0.8661537957191467


401it [03:32,  1.82it/s]

Loss [10, 400](epoch, minibatch):  0.891879109442234


501it [04:26,  1.91it/s]

Loss [10, 500](epoch, minibatch):  0.9014206299185753


601it [05:18,  1.84it/s]

Loss [10, 600](epoch, minibatch):  0.9160161578655243


626it [05:32,  1.88it/s]

Training Done





In [11]:
# Compute balanced accuracy (mean per-class recall)
import torch
from tqdm import tqdm
import numpy as np

# choose device robustly
device = 'mps' if torch.backends.mps.is_available() else ('cuda' if torch.cuda.is_available() else 'cpu')
print('Evaluation device:', device)

# Make sure model is on device and in eval mode
net.to(device)
net.eval()

# Prepare per-class counters on CPU to avoid backend-specific tensor issues
num_classes = int(num_classes)  # ensure it's a plain int
tp = torch.zeros(num_classes, dtype=torch.long)
support = torch.zeros(num_classes, dtype=torch.long)

with torch.no_grad():
    for batch in tqdm(test_loader, desc='Testing'):
        inputs, labels = batch
        inputs = inputs.to(device)
        labels = labels.to(device)

        outputs = net(inputs)
        preds = outputs.argmax(dim=1)

        # Move to CPU for reliable counting
        preds_cpu = preds.cpu()
        labels_cpu = labels.cpu()

        # accumulate true positives and supports
        for i in range(labels_cpu.size(0)):
            lbl = int(labels_cpu[i].item())
            pr = int(preds_cpu[i].item())
            support[lbl] += 1
            if pr == lbl:
                tp[lbl] += 1

# compute per-class recall and balanced accuracy
recalls = torch.zeros(num_classes, dtype=torch.float)
mask = support > 0
if mask.any():
    recalls[mask] = tp[mask].float() / support[mask].float()
    balanced_acc = recalls[mask].mean().item()
else:
    balanced_acc = 0.0

# report
for c in range(num_classes):
    print(f'Class {c}: support={int(support[c].item()):4d}  recall={recalls[c].item():.4f}')

balanced_error = 1.0 - balanced_acc
print(f'Balanced accuracy: {balanced_acc:.4f} | Balanced error: {balanced_error:.4f}')


Evaluation device: mps


Testing: 100%|██████████| 95/95 [00:19<00:00,  4.94it/s]

Class 0: support= 171  recall=0.1345
Class 1: support= 909  recall=0.9835
Class 2: support=  93  recall=0.0430
Class 3: support=  43  recall=0.2326
Class 4: support= 217  recall=0.1429
Class 5: support=  44  recall=0.0000
Class 6: support=  35  recall=0.0286
Balanced accuracy: 0.2236 | Balanced error: 0.7764



