# Homework 3: Knowledge Distillation for AI Dermatologist

## CS 4774 Machine Learning - University of Virginia

In this notebook, you'll implement knowledge distillation to improve your skin disease classifier by learning from **MedSigLIP** (from Google), a powerful medical imaging model.

**Key Requirements:**
- Student model must be < **25 MB** on disk
- Use MedSigLIP as frozen teacher model (inference only)
- Implement temperature-scaled knowledge distillation following Hinton et al. (2015)

**Recommended Starting Point:** Use ShuffleNetV2 for your student model (~5 MB)

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset
from torchvision import transforms
from PIL import Image
import os
import requests
from tqdm import tqdm
from sklearn.metrics import f1_score, accuracy_score

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

Using device: cpu


In [None]:
# =============================
# CONFIGURATION - Change these values to tune your model
# =============================

# Dataset Configuration
DATASET_PATH = 'train_dataset'
NUM_CLASSES = 10

# Image Processing
IMAGE_SIZE = 224  # Image dimensions (224x224)
NORMALIZE_MEAN = [0.485, 0.456, 0.406]  # ImageNet mean
NORMALIZE_STD = [0.229, 0.224, 0.225]   # ImageNet std

# Training Parameters
BATCH_SIZE = 32
NUM_EPOCHS = 10
LEARNING_RATE = 1e-3
NUM_WORKERS = 2

# Data Split
TRAIN_SPLIT = 0.9
VAL_SPLIT = 0.1

# Knowledge Distillation Parameters
TEMPERATURE = 4.0   # Temperature for softening distributions
ALPHA = 0.3         # Weight for hard loss (1-alpha for soft loss)

# Model Configuration
TEACHER_MODEL_NAME = "google/medsiglip-448"
STUDENT_MODEL_PATH = "student_model_hw3.pt"

# Server Configuration
SERVER_URL = 'http://hadi.cs.virginia.edu:8000'
MY_TOKEN = 'your_token_here'  # Replace with your actual token

print("Configuration loaded ‚úì")

## Step 1: Load Data (Same as HW1)

In [None]:
# Define dataset class
class SkinDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.classes = sorted([d for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, d))])
        self.class_to_idx = {cls_name: idx for idx, cls_name in enumerate(self.classes)}
        self.image_paths = []
        self.labels = []
        valid_exts = ('.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff', '.jfif')
        for cls_name in self.classes:
            cls_dir = os.path.join(root_dir, cls_name)
            for fname in os.listdir(cls_dir):
                if fname.lower().endswith(valid_exts):
                    self.image_paths.append(os.path.join(cls_dir, fname))
                    self.labels.append(self.class_to_idx[cls_name])
    def __len__(self):
        return len(self.image_paths)
    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx]).convert('RGB')
        label = self.labels[idx]
        if self.transform:
            image = self.transform(image)
        return image, label

# Load data with image size
# Training transform (Do not change)
train_transform = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=NORMALIZE_MEAN, std=NORMALIZE_STD)
])

# Validation transform (Do not change)
val_transform = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=NORMALIZE_MEAN, std=NORMALIZE_STD)
])

dataset = SkinDataset(DATASET_PATH, transform=train_transform)
print(f'Dataset loaded with {len(dataset)} images and {len(dataset.classes)} classes')

## Step 2: Load Teacher Model (MedSigLIP from Google)

**Important:** Load the pre-trained MedSigLIP model for inference only. Do NOT fine-tune it.

In [None]:
# Load MedSigLIP teacher model
from transformers import AutoModel, AutoProcessor

def load_teacher_model():
    """Load MedSigLIP-448 from HuggingFace."""

    print("Loading MedSigLIP-448 teacher model...")
    
    teacher_model = AutoModel.from_pretrained(TEACHER_MODEL_NAME, trust_remote_code=True)
    processor = AutoProcessor.from_pretrained(TEACHER_MODEL_NAME, trust_remote_code=True)
    
    teacher_model = teacher_model.to(device)
    teacher_model.eval()
    
    # Freeze all parameters
    for param in teacher_model.parameters():
        param.requires_grad = False
    
    print("‚úÖ MedSigLIP loaded successfully!")
    return teacher_model, processor

# Load teacher
teacher_model, teacher_processor = load_teacher_model()

# Define student model: ShuffleNetV2 (Recommended, ~5MB)
from torchvision.models import shufflenet_v2_x0_5

def create_student_shufflenet(num_classes):
    """Create a ShuffleNetV2 student model (~5 MB)."""
    model = shufflenet_v2_x0_5(pretrained=False)
    # Replace final classifier
    model.fc = nn.Linear(model.fc.in_features, num_classes)
    return model

# Create student model
student_model = create_student_shufflenet(num_classes=NUM_CLASSES).to(device)

print(f'Student model created with {sum(p.numel() for p in student_model.parameters()):,} parameters')

## Step 3: Define Distillation Loss

Implement the knowledge distillation loss following Hinton et al. (2015):
- **Hard loss**: Cross-entropy with ground truth labels
- **Soft loss**: KL divergence between teacher and student soft predictions
- **Temperature scaling**: Soften distributions for better knowledge transfer

In [None]:
class DistillationLoss(nn.Module):
    def __init__(self, temperature=4.0, alpha=0.3):
        super().__init__()
        self.temperature = temperature
        self.alpha = alpha
        # TODO: Initialize cross-entropy loss
        self.ce_loss = None  #
    
    def forward(self, student_logits, teacher_logits, labels):
        # TODO: Implement hard loss
        hard_loss = None  # Replace with your implementation
        
        # TODO: Implement soft loss 
        # Hint: Use temperature scaling to soften the distributions
        # Hint: Use F.log_softmax for student and F.softmax for teacher
        # Hint: Use F.kl_div with reduction='batchmean' and multiply by temperature^2
        student_soft = None  # Replace with your implementation
        teacher_soft = None  # Replace with your implementation
        soft_loss = None  # Replace with your implementation
        
        # TODO: Combine hard and soft losses using alpha
        total_loss = None  # Replace with your implementation
        
        return total_loss, hard_loss, soft_loss

# Create an instance of DistillationLoss using config values
distillation_loss = DistillationLoss(temperature=TEMPERATURE, alpha=ALPHA)

## Step 4: Train with Knowledge Distillation

Implement training loop that:
1. Gets teacher's soft predictions (with torch.no_grad())
2. Gets student's predictions
3. Computes distillation loss
4. Updates only student model parameters

In [None]:
# Prepare data loaders
train_size = int(TRAIN_SPLIT * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)

# Setup training
optimizer = optim.Adam(student_model.parameters(), lr=LEARNING_RATE)
criterion = distillation_loss

# Training function
def train_epoch(student, teacher, teacher_proc, dataloader, criterion, optimizer):
    student.train()
    total_loss = 0
    
    for images, labels in tqdm(dataloader, desc='Training'):
        images, labels = images.to(device), labels.to(device)
        
        # Get teacher predictions (no gradients)
        with torch.no_grad():
            # TODO: Process images for MedSigLIP and get teacher logits
            # This requires converting images to PIL format for teacher_processor
            # For now, using student predictions as placeholder
            teacher_logits = student(images).detach()  # REPLACE THIS with actual teacher inference
        
        # Get student predictions
        student_logits = student(images)
        
        # Compute distillation loss
        loss, hard_loss, soft_loss = criterion(student_logits, teacher_logits, labels)
        
        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    return total_loss / len(dataloader)

# Validation function
def validate(student, dataloader):
    student.eval()
    all_preds, all_labels = [], []
    
    with torch.no_grad():
        for images, labels in tqdm(dataloader, desc='Validation'):
            images = images.to(device)
            outputs = student(images)
            preds = torch.argmax(outputs, dim=1)
            
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.numpy())
    
    accuracy = accuracy_score(all_labels, all_preds)
    f1 = f1_score(all_labels, all_preds, average='macro')
    return accuracy, f1

# Training loop
best_f1 = 0

for epoch in range(NUM_EPOCHS):
    print(f'\nEpoch {epoch+1}/{NUM_EPOCHS}')
    
    # Train
    train_loss = train_epoch(student_model, teacher_model, teacher_processor, 
                             train_loader, criterion, optimizer)
    
    # Validate
    val_acc, val_f1 = validate(student_model, val_loader)
    
    print(f'Train Loss: {train_loss:.4f} | Val Acc: {val_acc:.4f} | Val F1: {val_f1:.4f}')
    
    if val_f1 > best_f1:
        best_f1 = val_f1
        print(f'‚úÖ New best F1: {best_f1:.4f}')

print(f'\nTraining complete! Best F1: {best_f1:.4f}')

## Step 5: Save and Submit

Save your student model (< 25 MB) and submit to the HW3 leaderboard.

**Important:** Only submit the student model, NOT the teacher!

In [None]:
# Save student model
student_model.eval()
student_model.cpu()
scripted_model = torch.jit.script(student_model)
scripted_model.save(STUDENT_MODEL_PATH)

# Check model size
import os
size_mb = os.path.getsize(STUDENT_MODEL_PATH) / (1024 * 1024)
print(f'‚úÖ Model saved: {STUDENT_MODEL_PATH}')
print(f'üì¶ Model size: {size_mb:.2f} MB')

if size_mb >= 25.0:
    print('‚ùå WARNING: Model exceeds 25 MB limit!')
else:
    print('‚úÖ Model size is within the 25 MB limit')

# Submit to HW3 leaderboard
def submit_model(token, model_path, server_url):
    """Submit model to the HW3 leaderboard."""
    with open(model_path, 'rb') as f:
        files = {'file': f}
        data = {'token': token}
        response = requests.post(f'{server_url}/submit', data=data, files=files)
        resp_json = response.json()
        if 'message' in resp_json:
            print(f"‚úÖ {resp_json['message']}")
        else:
            print(f"‚ùå {resp_json.get('error', 'Unknown error')}")

# Check submission status
def check_status(token, server_url):
    """Check your submission status."""
    url = f'{server_url}/submission-status/{token}'
    response = requests.get(url)
    
    if response.status_code == 200:
        attempts = response.json()
        for a in attempts:
            score = f"{a['score']:.4f}" if isinstance(a['score'], (float, int)) else "Pending"
            size = f"{a['model_size']:.2f}" if isinstance(a['model_size'], (float, int)) else "N/A"
            print(f"Attempt {a['attempt']}: Score={score}, Size={size} MB, Status={a['status']}")
    else:
        print(f"Error: {response.status_code}")

# Uncomment to submit:
# submit_model(MY_TOKEN, STUDENT_MODEL_PATH, SERVER_URL)
# check_status(MY_TOKEN, SERVER_URL)

print(f'\nüéØ View the HW3 leaderboard at: {SERVER_URL}/leaderboard3')