# 7. Vision Transformer (ViT) Model

This notebook implements a Vision Transformer (ViT) for pneumonia classification. ViT uses self-attention mechanisms instead of convolutions, representing state-of-the-art in image classification.

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.models as models
import albumentations as A
from albumentations.pytorch import ToTensorV2
import cv2
import numpy as np
from tqdm import tqdm
import os
from config import *

# Set random seeds
torch.manual_seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)

print("="*60)
print("VISION TRANSFORMER (ViT) ARCHITECTURE")
print("="*60)

# Load pretrained ViT-B/16
try:
    vit = models.vit_b_16(pretrained=True)
    
    # Modify final classifier head
    num_features = vit.heads.head.in_features
    vit.heads.head = nn.Linear(num_features, NUM_CLASSES)
    
    # Move to device
    vit = vit.to(DEVICE)
    
    # Count parameters
    total_params = sum(p.numel() for p in vit.parameters())
    trainable_params = sum(p.numel() for p in vit.parameters() if p.requires_grad)
    
    print(f"\n✓ Model: Vision Transformer Base (ViT-B/16)")
    print(f"✓ Pretrained on ImageNet")
    print(f"✓ Modified head: {num_features} → {NUM_CLASSES} classes")
    
    print(f"\n✓ Model Parameters:")
    print(f"  • Total parameters: {total_params:,}")
    print(f"  • Trainable parameters: {trainable_params:,}")
    
    print(f"\n✓ Model loaded and moved to {DEVICE}")
    
except Exception as e:
    print(f"\n✗ Error loading ViT: {e}")
    print("\nNote: ViT requires torchvision >= 0.12.0")
    print("Attempting alternative: Using a simple ViT implementation...")

VISION TRANSFORMER (ViT) ARCHITECTURE


Downloading: "https://download.pytorch.org/models/vit_b_16-c867db91.pth" to C:\Users\steve/.cache\torch\hub\checkpoints\vit_b_16-c867db91.pth
100%|███████████████████████████████████████████████████████████████████████████████| 330M/330M [00:06<00:00, 57.1MB/s]



✓ Model: Vision Transformer Base (ViT-B/16)
✓ Pretrained on ImageNet
✓ Modified head: 768 → 2 classes

✓ Model Parameters:
  • Total parameters: 85,800,194
  • Trainable parameters: 85,800,194

✓ Model loaded and moved to cpu


## Loss Function and Optimizer Setup

Configure weighted loss and optimizer with differential learning rates for ViT.

In [2]:
print("="*60)
print("LOSS FUNCTION & OPTIMIZER SETUP")
print("="*60)

# Load class weights
class_weights = torch.load('class_weights.pt', weights_only=True)
class_weights = class_weights.to(DEVICE)

print(f"\n✓ Class weights loaded:")
print(f"  • Normal (minority): {class_weights[0]:.4f}")
print(f"  • Pneumonia (majority): {class_weights[1]:.4f}")

# Loss function with class weights
criterion = nn.CrossEntropyLoss(weight=class_weights)

# Optimizer with different learning rates for pretrained vs new layers
# ViT typically needs lower learning rates than CNNs
optimizer = optim.Adam([
    {'params': vit.encoder.parameters(), 'lr': LEARNING_RATE * 0.05},  # Very low LR for transformer
    {'params': vit.heads.parameters(), 'lr': LEARNING_RATE * 0.5}      # Moderate LR for new head
])

# Learning rate scheduler
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, 
    mode='min', 
    factor=0.5, 
    patience=3
)

print(f"\n✓ Loss Function: Weighted CrossEntropyLoss")
print(f"✓ Optimizer: Adam with differential learning rates")
print(f"  • Transformer encoder: lr={LEARNING_RATE * 0.05}")
print(f"  • New head: lr={LEARNING_RATE * 0.5}")
print(f"✓ Scheduler: ReduceLROnPlateau (factor=0.5, patience=3)")
print("\nNote: ViT uses lower learning rates than CNNs for stability")

LOSS FUNCTION & OPTIMIZER SETUP

✓ Class weights loaded:
  • Normal (minority): 1.9448
  • Pneumonia (majority): 0.6730

✓ Loss Function: Weighted CrossEntropyLoss
✓ Optimizer: Adam with differential learning rates
  • Transformer encoder: lr=5e-05
  • New head: lr=0.0005
✓ Scheduler: ReduceLROnPlateau (factor=0.5, patience=3)

Note: ViT uses lower learning rates than CNNs for stability


## Training and Validation Functions

Define training and validation functions for Vision Transformer.

In [3]:
def train_one_epoch(model, train_loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    pbar = tqdm(train_loader, desc='Training', leave=False)
    for images, labels in pbar:
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        
        pbar.set_postfix({'loss': f'{loss.item():.4f}'})
    
    epoch_loss = running_loss / len(train_loader)
    epoch_acc = 100 * correct / total
    return epoch_loss, epoch_acc

def validate(model, val_loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    epoch_loss = running_loss / len(val_loader)
    epoch_acc = 100 * correct / total
    return epoch_loss, epoch_acc

print("="*60)
print("TRAINING FUNCTIONS DEFINED")
print("="*60)
print("\n✓ train_one_epoch() - Trains model for one epoch")
print("✓ validate() - Evaluates model on validation set")
print("\n✓ Functions ready for training loop")

TRAINING FUNCTIONS DEFINED

✓ train_one_epoch() - Trains model for one epoch
✓ validate() - Evaluates model on validation set

✓ Functions ready for training loop


## Training Loop with Early Stopping

Train Vision Transformer with early stopping and save the best model.

In [None]:
from torch.utils.data import DataLoader
import albumentations as A
from albumentations.pytorch import ToTensorV2
import cv2

# Recreate transforms
train_transform = A.Compose([
    A.Resize(IMAGE_SIZE, IMAGE_SIZE),
    A.HorizontalFlip(p=0.5),
    A.Rotate(limit=10, p=0.5),
    A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.05, rotate_limit=10, p=0.5),
    A.RandomBrightnessContrast(brightness_limit=0.1, contrast_limit=0.1, p=0.5),
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ToTensorV2()
])

val_transform = A.Compose([
    A.Resize(IMAGE_SIZE, IMAGE_SIZE),
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ToTensorV2()
])

# Custom Dataset class
class ChestXrayDataset(torch.utils.data.Dataset):
    def __init__(self, root_dir, split='train', transform=None):
        self.root_dir = root_dir
        self.split = split
        self.transform = transform
        self.images = []
        self.labels = []
        
        for class_idx, class_name in enumerate(CLASS_NAMES):
            class_path = os.path.join(root_dir, split, class_name)
            if os.path.exists(class_path):
                for img_name in os.listdir(class_path):
                    if img_name.endswith(('.jpeg', '.jpg', '.png')):
                        self.images.append(os.path.join(class_path, img_name))
                        self.labels.append(class_idx)
        
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        img_path = self.images[idx]
        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        label = self.labels[idx]
        
        if self.transform:
            augmented = self.transform(image=image)
            image = augmented['image']
        
        return image, label

# Recreate datasets and dataloaders
train_dataset = ChestXrayDataset(DATASET_PATH, split='train', transform=train_transform)
val_dataset = ChestXrayDataset(DATASET_PATH, split='val', transform=val_transform)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

print("="*60)
print("TRAINING VISION TRANSFORMER")
print("="*60)

best_val_loss = float('inf')
patience_counter = 0
train_losses = []
val_losses = []
train_accs = []
val_accs = []

print(f"\nTraining for {NUM_EPOCHS} epochs...")
print(f"Early stopping patience: {PATIENCE} epochs")
print("Note: ViT may take longer per epoch due to larger model size\n")

for epoch in range(NUM_EPOCHS):
    print(f"Epoch [{epoch+1}/{NUM_EPOCHS}]")
    
    # Train
    train_loss, train_acc = train_one_epoch(vit, train_loader, criterion, optimizer, DEVICE)
    train_losses.append(train_loss)
    train_accs.append(train_acc)
    
    # Validate
    val_loss, val_acc = validate(vit, val_loader, criterion, DEVICE)
    val_losses.append(val_loss)
    val_accs.append(val_acc)
    
    # Update scheduler
    scheduler.step(val_loss)
    
    print(f"  Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
    print(f"  Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")
    
    # Save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        patience_counter = 0
        torch.save({
            'epoch': epoch,
            'model_state_dict': vit.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': val_loss,
            'val_acc': val_acc,
        }, f'{MODEL_SAVE_DIR}/vit_best.pth')
        print(f"  ✓ Best model saved! (Val Loss: {val_loss:.4f})")
    else:
        patience_counter += 1
        print(f"  No improvement ({patience_counter}/{PATIENCE})")
    
    # Early stopping
    if patience_counter >= PATIENCE:
        print(f"\n⚠ Early stopping triggered after {epoch+1} epochs")
        break
    
    print()

print("="*60)
print("TRAINING COMPLETE")
print("="*60)
print(f"\n✓ Best validation loss: {best_val_loss:.4f}")
print(f"✓ Best model saved to: {MODEL_SAVE_DIR}/vit_best.pth")

  original_init(self, **validated_kwargs)


TRAINING VISION TRANSFORMER

Training for 10 epochs...
Early stopping patience: 5 epochs
Note: ViT may take longer per epoch due to larger model size

Epoch [1/10]


Training:  31%|█████████████████▊                                        | 50/163 [17:30<39:14, 20.84s/it, loss=0.0452]