In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision.transforms as T
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, roc_auc_score, roc_curve, auc
from sklearn.preprocessing import label_binarize
import medmnist
from medmnist import INFO
from tqdm import tqdm
from cnn import CNN
from resnet import ResNet18
from vit import VisionTransformer
from training_evalutation_utils import * 
from plotting import * 

In [7]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [11]:
# Dataset configuration
data_flag = 'tissuemnist'
info = INFO[data_flag]
task = info['task']
n_channels = info['n_channels']
n_classes = len(info['label'])
class_names = list(info['label'].values())

print(f"\nDataset: {data_flag}")
print(f"Task: {task}")
print(f"Input channels: {n_channels}")
print(f"Number of classes: {n_classes}")
print(f"Classes: {info['label']}")

DataClass = getattr(medmnist, info['python_class'])

# Data transformations for models trained from scratch
transform_scratch = T.Compose([
    T.ToTensor(),
    T.Normalize(mean=[0.5], std=[0.5])
])

# Data transformations for transfer learning (ViT expects 3 channels and specific normalization)
transform_transfer = T.Compose([
    T.Grayscale(num_output_channels=3),
    T.Resize(224),  # ViT expects 224x224 images
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # ImageNet stats
])

# Load datasets for scratch models
train_dataset_scratch = DataClass(split='train', transform=transform_scratch, download=True, size=64)
val_dataset_scratch = DataClass(split='val', transform=transform_scratch, download=True, size=64)
test_dataset_scratch = DataClass(split='test', transform=transform_scratch, download=True, size=64)

# Load datasets for transfer learning
train_dataset_transfer = DataClass(split='train', transform=transform_transfer, download=True, size=64)
val_dataset_transfer = DataClass(split='val', transform=transform_transfer, download=True, size=64)
test_dataset_transfer = DataClass(split='test', transform=transform_transfer, download=True, size=64)


Dataset: tissuemnist
Task: multi-class
Input channels: 1
Number of classes: 8
Classes: {'0': 'Collecting Duct, Connecting Tubule', '1': 'Distal Convoluted Tubule', '2': 'Glomerular endothelial cells', '3': 'Interstitial endothelial cells', '4': 'Leukocytes', '5': 'Podocytes', '6': 'Proximal Tubule Segments', '7': 'Thick Ascending Limb'}


In [9]:
# Create data loaders
batch_size = 32

train_loader_scratch = DataLoader(train_dataset_scratch, batch_size=batch_size, shuffle=True)
val_loader_scratch = DataLoader(val_dataset_scratch, batch_size=batch_size, shuffle=False)
test_loader_scratch = DataLoader(test_dataset_scratch, batch_size=batch_size, shuffle=False)

train_loader_transfer = DataLoader(train_dataset_transfer, batch_size=batch_size, shuffle=True)
val_loader_transfer = DataLoader(val_dataset_transfer, batch_size=batch_size, shuffle=False)
test_loader_transfer = DataLoader(test_dataset_transfer, batch_size=batch_size, shuffle=False)

print(f"\nTraining samples: {len(train_dataset_scratch)}")
print(f"Validation samples: {len(val_dataset_scratch)}")
print(f"Test samples: {len(test_dataset_scratch)}")


Training samples: 165466
Validation samples: 23640
Test samples: 47280


In [None]:
# Training hyperparameters
epochs = 50
learning_rate = 0.0001
# Loss function
criterion = nn.CrossEntropyLoss()

In [None]:
print("\n" + "=" * 70)
print("PART 1: TRAINING FROM SCRATCH")
print("=" * 70)

# ---------------------- Custom CNN ----------------------
print("\n### Model 1: Custom CNN ###")

cnn_model = CNN(in_channels=1, hidden_units1=128, hidden_units2=256, output_shape=n_classes).to(device)
cnn_optimizer = torch.optim.AdamW(cnn_model.parameters(), lr=learning_rate, weight_decay=1e-4)

train_losses_cnn, val_losses_cnn, val_accs_cnn = fit(
    cnn_model, train_loader_scratch, val_loader_scratch,
    criterion, cnn_optimizer, epochs, device,
    early_stopping=True, patience=5, model_name="cnn",
    save_best=True, checkpoint_path="cnn_best.pth",
    use_mixed_precision=True
)

# Evaluate CNN
cnn_acc, cnn_preds, cnn_labels, cnn_probs = evaluate(cnn_model, test_loader_scratch, device)
print(f"\n✓ Custom CNN Test Accuracy: {cnn_acc:.4f}")
plot_training_history(train_losses_cnn, val_losses_cnn, val_accs_cnn, "Custom CNN")
plot_confusion_matrix(cnn_labels, cnn_preds, class_names, "Custom CNN")
cnn_auc = plot_roc_curves(cnn_labels, cnn_probs, n_classes, class_names, "Custom CNN")

In [None]:
# ---------------------- ResNet-18 ----------------------
print("\n### Model 2: ResNet-18 ###")
resnet_model = ResNet18(in_channels=1, num_classes=n_classes).to(device)
resnet_optimizer = torch.optim.AdamW(resnet_model.parameters(), lr=learning_rate, 
                                     weight_decay=1e-3, amsgrad=True)

train_losses_resnet, val_losses_resnet, val_accs_resnet = fit(
    resnet_model, train_loader_scratch, val_loader_scratch,
    criterion, resnet_optimizer, epochs, device,
    early_stopping=True, patience=5, model_name="ResNet18",
    save_best=True, checkpoint_path="resnet18_best.pth",
    use_mixed_precision=True
)

# Evaluate ResNet
resnet_acc, resnet_preds, resnet_labels, resnet_probs = evaluate(resnet_model, test_loader_scratch, device)
print(f"\n✓ ResNet-18 Test Accuracy: {resnet_acc:.4f}")

plot_training_history(train_losses_resnet, val_losses_resnet, val_accs_resnet, "ResNet-18")
plot_confusion_matrix(resnet_labels, resnet_preds, class_names, "ResNet-18")
resnet_auc = plot_roc_curves(resnet_labels, resnet_probs, n_classes, class_names, "ResNet-18")

In [None]:
# ---------------------- ViT ----------------------
print("\n### Model 3: ViT ###")
vit = VisionTransformer(img_size=64,
                        patch_size=8,
                        in_ch=1,
                        num_classes=8,
                        embed_dim=256,
                        depth=8,
                        num_heads=8,
                        mlp_ratio=4.0,
                        dropout=0.25).to(device)
vit_optimizer = torch.optim.AdamW(vit.parameters(), lr=learning_rate, 
                                     weight_decay=1e-2, amsgrad=True)


train_losses_vit, val_losses_vit, val_accs_vit = fit(
    vit, train_loader_scratch, val_loader_scratch,
    criterion, vit_optimizer, epochs, device,
    early_stopping=True, patience=5, model_name="ViT",
    save_best=True, checkpoint_path="ViT_best.pth",
    use_mixed_precision=True
)
