# EC2 Training Tutorial: Medical Image Classification

Train DenseNet121 or ViT models on medical images using MONAI.

## Step 1: Setup

Import libraries and configure logging. Check GPU availability.

In [None]:
import numpy as np
import random
import torch
import torch.nn as nn
import torch.optim as optim
import monai
from monai.data import DataLoader
from monai.networks.nets import DenseNet121, ViT
from monai.transforms import Compose, LoadImage, EnsureChannelFirst, Resize, ScaleIntensity
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
import json
import logging
import os
import sys
import datetime
from PIL import ImageFile

ImageFile.LOAD_TRUNCATED_IMAGES = True

# Logging
timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
log_dir = os.path.join(os.getcwd(), 'logs', 'log_' + timestamp)
os.makedirs(log_dir, exist_ok=True)
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s',
                   handlers=[logging.FileHandler(os.path.join(log_dir, 'app.log')),
                            logging.StreamHandler(sys.stdout)])
logger = logging.getLogger(__name__)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## Step 2: Configure Training

Set data paths and hyperparameters. Choose DenseNet121 or ViT architecture.

In [None]:
# Paths Change these paths as needed
data_dir = '/home/ubuntu/data/vindr-spinexr-subset'
model_dir = '/home/ubuntu/data/spine-model'
output_dir = '/home/ubuntu/data/spine-output'

# Hyperparameters
model_name = 'DenseNet121'  # or 'ViT'
learning_rate = 0.001
batch_size = 32
num_epochs = 3
val_interval = 10
early_stopping_rounds = 10
img_size = (256, 256, 1)

os.makedirs(os.path.join(model_dir, model_name), exist_ok=True)
os.makedirs(output_dir, exist_ok=True)

## Step 3: Define Model

Create ModelDef class to instantiate DenseNet121 or ViT from MONAI.

In [None]:
class ModelDef:
    def __init__(self, num_classes, model_name):
        self.num_classes = num_classes
        self.model_name = model_name

    def get_model(self):
        if self.model_name == "DenseNet121":
            return DenseNet121(spatial_dims=2, in_channels=1, out_channels=self.num_classes)
        elif self.model_name == "ViT":
            return ViT(in_channels=1, img_size=(256, 256, 1), patch_size=(16, 16, 1),
                      hidden_size=768, mlp_dim=3072, num_layers=12, num_heads=12,
                      classification=True, num_classes=self.num_classes)
        else:
            raise ValueError(f"Model {self.model_name} not supported.")

## Step 4: Create Dataset & Loaders

Build custom dataset from label_dict.json. Apply MONAI transforms and create data loaders.

In [None]:
class MedNISTDataset(torch.utils.data.Dataset):
    def __init__(self, image_files, labels, transforms):
        self.image_files = image_files
        self.labels = labels
        self.transforms = transforms

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, index):
        return self.transforms(self.image_files[index]), self.labels[index]


def generate_list_labels(data_folder):
    with open("../../ec2/label_dict.json", "r") as fid:
        label_dict = json.load(fid)
    
    image_list, label_list = [], []
    for _label in label_dict:
        label_data = os.path.join(data_folder, _label)
        if not os.path.exists(label_data):
            continue
        for _file in os.listdir(label_data):
            image_list.append(os.path.join(label_data, _file))
            label_list.append(label_dict[_label])

    c = list(zip(image_list, label_list))
    random.shuffle(c)
    image_list, label_list = zip(*c)
    return image_list, label_list, len(label_dict)


def create_data_loaders(data, batch_size):
    transforms = Compose([LoadImage(), EnsureChannelFirst(), 
                         Resize(spatial_size=img_size), ScaleIntensity()])
    
    train_img, train_lbl, num_classes = generate_list_labels(os.path.join(data, 'train'))
    val_img, val_lbl, _ = generate_list_labels(os.path.join(data, 'valid'))
    test_img, test_lbl, _ = generate_list_labels(os.path.join(data, 'test'))

    train_ds = MedNISTDataset(train_img, train_lbl, transforms)
    val_ds = MedNISTDataset(val_img, val_lbl, transforms)
    test_ds = MedNISTDataset(test_img, test_lbl, transforms)
    
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=16)
    val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=True, num_workers=16)
    test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=True, num_workers=16)
    
    return train_loader, test_loader, val_loader, num_classes


train_loader, test_loader, validation_loader, num_classes = create_data_loaders(data_dir, batch_size)
logger.info(f'Classes: {num_classes}, Train: {len(train_loader.dataset)}, Val: {len(validation_loader.dataset)}')

## Step 5: Define Training Functions

Implement test() for validation and train() with early stopping and TensorBoard logging.

In [None]:
def test(model, test_loader, criterion, device, model_name):
    running_loss = 0
    running_corrects = 0
    
    with torch.no_grad():
        for inputs, labels in tqdm(test_loader, desc="Validation"):
            inputs, labels = inputs.to(device), labels.to(device)
            
            if model_name == 'DenseNet121':
                outputs = model(inputs[:, :, :, :, 0])
                loss = criterion(outputs, labels)
                _, preds = torch.max(outputs, 1)
            else:
                outputs = model(inputs)
                loss = criterion(outputs[0], labels)
                _, preds = torch.max(outputs[0], 1)
                
            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data).item()

    total_loss = running_loss / len(test_loader.dataset)
    total_acc = running_corrects / len(test_loader.dataset)
    logger.info(f"Val Loss: {total_loss:.4f}, Val Acc: {total_acc:.4f}")
    return total_loss, total_acc


def train(model, train_loader, validation_loader, criterion, optimizer, device, 
          model_name, num_epochs, val_interval, early_stopping_rounds, model_dir, output_dir):
    best_loss = 1e6
    loss_counter = 0
    writer = SummaryWriter(log_dir=os.path.join(output_dir, model_name, 'logs'))

    for epoch in range(num_epochs):
        model.train()
        running_loss, running_corrects, running_samples = 0.0, 0, 0
        print(f"\nEpoch {epoch}/{num_epochs}")

        for step, (inputs, labels) in enumerate(tqdm(train_loader)):
            inputs, labels = inputs.to(device), labels.to(device)
            
            if model_name == 'DenseNet121':
                outputs = model(inputs[:, :, :, :, 0])
                loss = criterion(outputs, labels)
                _, preds = torch.max(outputs, 1)
            else:
                outputs = model(inputs)
                loss = criterion(outputs[0], labels)
                _, preds = torch.max(outputs[0], 1)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data).item()
            running_samples += len(inputs)
            
            if running_samples % 4000 == 0:
                accuracy = running_corrects / running_samples
                logger.info(f"[{running_samples}/{len(train_loader.dataset)}] Loss: {loss.item():.2f}, Acc: {accuracy:.4f}")
                writer.add_scalar('train/accuracy', accuracy, epoch * len(train_loader) + step)

            writer.add_scalar('train/loss', loss.item(), epoch * len(train_loader) + step)
            
        epoch_loss = running_loss / running_samples
        epoch_acc = running_corrects / running_samples

        if epoch % val_interval == 0:
            model.eval()
            val_loss, val_acc = test(model, validation_loader, criterion, device, model_name)
            
            if val_loss < best_loss:
                best_loss = val_loss
                torch.save(model.state_dict(), os.path.join(model_dir, model_name, f'best_model_{epoch}.pth'))
                logger.info(f"âœ“ Saved best model (loss: {best_loss:.4f})")
            else:
                loss_counter += 1
                
            logger.info(f'Epoch {epoch}: Train Loss={epoch_loss:.4f}, Train Acc={epoch_acc:.4f}, Val Loss={val_loss:.4f}')
            writer.add_scalar('val/loss', val_loss, epoch)
            writer.add_scalar('val/accuracy', val_acc, epoch)
            
        if loss_counter == early_stopping_rounds:
            logger.info('Early stopping triggered')
            break
            
    writer.close()
    return model

## Step 6: Train Model

Initialize model, optimizer, and loss function. Run training loop.

In [None]:
# Initialize
model = ModelDef(num_classes, model_name).get_model().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

logger.info(f"Starting training: {model_name} with {num_classes} classes")

# Train
model = train(model, train_loader, validation_loader, criterion, optimizer, device,
             model_name, num_epochs, val_interval, early_stopping_rounds, model_dir, output_dir)

## Step 7: Save Model

Save final trained model weights to disk.

In [None]:
out_dir = os.path.join(output_dir, model_name)
os.makedirs(out_dir, exist_ok=True)
torch.save(model.state_dict(), os.path.join(out_dir, "model.pth"))
logger.info(f"Model saved to {os.path.join(out_dir, 'model.pth')}")

## View TensorBoard

```bash
tensorboard --logdir=./logs
```