# Homework 3: Knowledge Distillation for AI Dermatologist

## CS 4774 Machine Learning - University of Virginia

In this notebook, you'll implement knowledge distillation to improve your skin disease classifier by learning from **MedSigLIP** (from Google), a powerful medical imaging model.

**Key Requirements:**
- Student model must be < **25 MB** on disk
- Use MedSigLIP as frozen teacher model (inference only)
- Implement temperature-scaled knowledge distillation following Hinton et al. (2015)

**Recommended Starting Point:** Use ShuffleNetV2 for your student model (~5 MB)

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset
from torchvision import transforms
from PIL import Image
import os
import requests
from tqdm import tqdm
from sklearn.metrics import f1_score, accuracy_score

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

Using device: cpu


In [None]:
# HuggingFace Login - Run this cell first!
from huggingface_hub import login

# Option 1: Interactive login (will prompt for token)
login()

# Option 2: Direct login (replace with your token)
# login(token="hf_YOUR_TOKEN_HERE")

# HuggingFace Login - Run this cell first!
from huggingface_hub import login

# Option 1: Interactive login (will prompt for token)
login()

# Option 2: Direct login (paste your token here)
# login(token="hf_YOUR_TOKEN_HERE")

print("‚úÖ Logged in to HuggingFace")

In [None]:
# =============================
# CONFIGURATION - Change these values to tune your model
# =============================

# Dataset Configuration
DATASET_PATH = 'train_dataset'
NUM_CLASSES = 10

# Image Processing
IMAGE_SIZE = 224  # Image dimensions (224x224)
NORMALIZE_MEAN = [0.485, 0.456, 0.406]  # ImageNet mean
NORMALIZE_STD = [0.229, 0.224, 0.225]   # ImageNet std

# Training Parameters
BATCH_SIZE = 32
NUM_EPOCHS = 10
LEARNING_RATE = 1e-3
NUM_WORKERS = 2

# Data Split
TRAIN_SPLIT = 0.9
VAL_SPLIT = 0.1

# Knowledge Distillation Parameters
TEMPERATURE = 4.0   # Temperature for softening distributions
ALPHA = 0.3         # Weight for hard loss (1-alpha for soft loss)

# Model Configuration
TEACHER_MODEL_NAME = "google/medsiglip-448"
STUDENT_MODEL_PATH = "student_model_hw3.pt"

# Server Configuration
SERVER_URL = 'http://hadi.cs.virginia.edu:8000'
MY_TOKEN = 'your_token_here'  # Replace with your actual token

print("Configuration loaded ‚úì")

## Part 1: Class and Function Definitions

All classes and functions are defined here. Run these cells first before executing the main workflow.

In [None]:
# ============================================================
# CLASS DEFINITIONS
# ============================================================

class SkinDataset(Dataset):
    """Custom dataset for loading skin disease images."""
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.classes = sorted([d for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, d))])
        self.class_to_idx = {cls_name: idx for idx, cls_name in enumerate(self.classes)}
        self.image_paths = []
        self.labels = []
        valid_exts = ('.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff', '.jfif')
        for cls_name in self.classes:
            cls_dir = os.path.join(root_dir, cls_name)
            for fname in os.listdir(cls_dir):
                if fname.lower().endswith(valid_exts):
                    self.image_paths.append(os.path.join(cls_dir, fname))
                    self.labels.append(self.class_to_idx[cls_name])
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx]).convert('RGB')
        label = self.labels[idx]
        if self.transform:
            image = self.transform(image)
        return image, label


class DistillationLoss(nn.Module):
    """Knowledge distillation loss combining hard and soft losses."""
    def __init__(self, temperature=4.0, alpha=0.3):
        super().__init__()
        self.temperature = temperature
        self.alpha = alpha
        # TODO: Initialize cross-entropy loss
        self.ce_loss = None  # Replace with nn.CrossEntropyLoss()
    
    def forward(self, student_logits, teacher_logits, labels):
        # TODO: Implement hard loss
        hard_loss = None  # Replace with your implementation
        
        # TODO: Implement soft loss 
        # Hint: Use temperature scaling to soften the distributions
        # Hint: Use F.log_softmax for student and F.softmax for teacher
        # Hint: Use F.kl_div with reduction='batchmean' and multiply by temperature^2
        student_soft = None  # Replace with your implementation
        teacher_soft = None  # Replace with your implementation
        soft_loss = None  # Replace with your implementation
        
        # TODO: Combine hard and soft losses using alpha
        total_loss = None  # Replace with your implementation
        
        return total_loss, hard_loss, soft_loss

print("‚úì Classes defined")

In [None]:
# ============================================================
# DATA TRANSFORM FUNCTIONS
# ============================================================

def create_transforms():
    """Create training and validation transforms."""
    train_transform = transforms.Compose([
        transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize(mean=NORMALIZE_MEAN, std=NORMALIZE_STD)
    ])
    
    val_transform = transforms.Compose([
        transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize(mean=NORMALIZE_MEAN, std=NORMALIZE_STD)
    ])
    
    return train_transform, val_transform

print("‚úì Transform functions defined")

In [None]:
# ============================================================
# MODEL LOADING FUNCTIONS
# ============================================================

from transformers import AutoModel, AutoProcessor
from torchvision.models import shufflenet_v2_x0_5

def load_teacher_model():
    """Load MedSigLIP-448 teacher model from HuggingFace."""
    print("Loading MedSigLIP-448 teacher model...")
    
    teacher_model = AutoModel.from_pretrained(TEACHER_MODEL_NAME, trust_remote_code=True)
    processor = AutoProcessor.from_pretrained(TEACHER_MODEL_NAME, trust_remote_code=True)
    
    teacher_model = teacher_model.to(device)
    teacher_model.eval()
    
    # Freeze all parameters
    for param in teacher_model.parameters():
        param.requires_grad = False
    
    print("‚úÖ MedSigLIP loaded successfully!")
    return teacher_model, processor


def create_student_shufflenet(num_classes):
    """Create a ShuffleNetV2 student model (~5 MB)."""
    model = shufflenet_v2_x0_5(pretrained=False)
    # Replace final classifier
    model.fc = nn.Linear(model.fc.in_features, num_classes)
    return model

print("‚úì Model functions defined")

In [None]:
# ============================================================
# TRAINING FUNCTIONS
# ============================================================

def train_epoch(student, teacher, teacher_proc, dataloader, criterion, optimizer):
    """Train for one epoch using knowledge distillation."""
    student.train()
    total_loss = 0
    
    for images, labels in tqdm(dataloader, desc='Training'):
        images, labels = images.to(device), labels.to(device)
        
        # Get teacher predictions (no gradients)
        with torch.no_grad():
            # TODO: Process images for MedSigLIP and get teacher logits
            # This requires converting images to PIL format for teacher_processor
            # For now, using student predictions as placeholder
            teacher_logits = student(images).detach()  # REPLACE THIS with actual teacher inference
        
        # Get student predictions
        student_logits = student(images)
        
        # Compute distillation loss
        loss, hard_loss, soft_loss = criterion(student_logits, teacher_logits, labels)
        
        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    return total_loss / len(dataloader)


def validate(student, dataloader):
    """Validate the student model."""
    student.eval()
    all_preds, all_labels = [], []
    
    with torch.no_grad():
        for images, labels in tqdm(dataloader, desc='Validation'):
            images = images.to(device)
            outputs = student(images)
            preds = torch.argmax(outputs, dim=1)
            
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.numpy())
    
    accuracy = accuracy_score(all_labels, all_preds)
    f1 = f1_score(all_labels, all_preds, average='macro')
    return accuracy, f1

print("‚úì Training functions defined")

In [None]:
# ============================================================
# SUBMISSION FUNCTIONS
# ============================================================

def submit_model(token, model_path, server_url):
    """Submit model to the HW3 leaderboard."""
    with open(model_path, 'rb') as f:
        files = {'file': f}
        data = {'token': token}
        response = requests.post(f'{server_url}/submit', data=data, files=files)
        resp_json = response.json()
        if 'message' in resp_json:
            print(f"‚úÖ {resp_json['message']}")
        else:
            print(f"‚ùå {resp_json.get('error', 'Unknown error')}")


def check_status(token, server_url):
    """Check your submission status."""
    url = f'{server_url}/submission-status/{token}'
    response = requests.get(url)
    
    if response.status_code == 200:
        attempts = response.json()
        for a in attempts:
            score = f"{a['score']:.4f}" if isinstance(a['score'], (float, int)) else "Pending"
            size = f"{a['model_size']:.2f}" if isinstance(a['model_size'], (float, int)) else "N/A"
            print(f"Attempt {a['attempt']}: Score={score}, Size={size} MB, Status={a['status']}")
    else:
        print(f"Error: {response.status_code}")

print("‚úì Submission functions defined")

In [None]:
# ============================================================
# QUICK TEST PIPELINE (for rapid testing)
# ============================================================

def quick_test_pipeline(
    # Quick test parameters
    num_samples=500,      # Use only 500 samples total
    num_epochs=2,         # Just 2 epochs for quick testing
    
    # Use config defaults for everything else
    dataset_path=DATASET_PATH,
    num_classes=NUM_CLASSES,
    image_size=IMAGE_SIZE,
    normalize_mean=NORMALIZE_MEAN,
    normalize_std=NORMALIZE_STD,
    batch_size=BATCH_SIZE,
    learning_rate=LEARNING_RATE,
    num_workers=NUM_WORKERS,
    train_split=0.8,      # 80/20 split for quick test
    temperature=TEMPERATURE,
    alpha=ALPHA,
    teacher_model_name=TEACHER_MODEL_NAME,
    student_model_path="quick_test_model.pt"
):
    """
    Quick test pipeline for rapid iteration and debugging.
    Uses a small subset of data and fewer epochs.
    
    Args:
        num_samples: Total number of samples to use (default: 500)
        num_epochs: Number of training epochs (default: 2)
        (other args same as run_training_pipeline)
    
    Returns:
        dict: Results including best_f1, model_path, and model_size_mb
    """
    print("="*70)
    print("QUICK TEST PIPELINE (Small Dataset)")
    print("="*70)
    print(f"Using {num_samples} samples, {num_epochs} epochs")
    print("="*70)
    
    # Load full dataset first
    print("\n[1/5] Loading dataset...")
    train_transform, val_transform = create_transforms()
    full_dataset = SkinDataset(dataset_path, transform=train_transform)
    
    # Create a small subset
    import random
    indices = list(range(len(full_dataset)))
    random.shuffle(indices)
    subset_indices = indices[:num_samples]
    
    from torch.utils.data import Subset
    dataset = Subset(full_dataset, subset_indices)
    
    print(f'‚úÖ Dataset subset created: {len(dataset)} images (from {len(full_dataset)} total)')
    
    # Load models
    print("\n[2/5] Loading models...")
    teacher_model, teacher_processor = load_teacher_model()
    student_model = create_student_shufflenet(num_classes=num_classes).to(device)
    print(f'‚úÖ Student model created: {sum(p.numel() for p in student_model.parameters()):,} parameters')
    
    # Setup training
    print("\n[3/5] Setting up training...")
    train_size = int(train_split * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    
    criterion = DistillationLoss(temperature=temperature, alpha=alpha)
    optimizer = optim.Adam(student_model.parameters(), lr=learning_rate)
    
    print(f'‚úÖ Setup complete: Train={len(train_dataset)} | Val={len(val_dataset)}')
    
    # Train model
    print(f"\n[4/5] Training model ({num_epochs} epochs)...")
    print("-"*70)
    best_f1 = 0
    
    for epoch in range(num_epochs):
        print(f'\nEpoch {epoch+1}/{num_epochs}')
        print('='*50)
        
        train_loss = train_epoch(student_model, teacher_model, teacher_processor, 
                                 train_loader, criterion, optimizer)
        val_acc, val_f1 = validate(student_model, val_loader)
        
        print(f'Train Loss: {train_loss:.4f} | Val Acc: {val_acc:.4f} | Val F1: {val_f1:.4f}')
        
        if val_f1 > best_f1:
            best_f1 = val_f1
            print(f'‚úÖ New best F1: {best_f1:.4f}')
    
    print(f'\n{"="*50}')
    print(f'Quick test complete! Best F1: {best_f1:.4f}')
    print(f'{"="*50}')
    
    # Save model
    print("\n[5/5] Saving model...")
    student_model.eval()
    student_model.cpu()
    scripted_model = torch.jit.script(student_model)
    scripted_model.save(student_model_path)
    
    size_mb = os.path.getsize(student_model_path) / (1024 * 1024)
    print(f'‚úÖ Model saved: {student_model_path}')
    print(f'üì¶ Model size: {size_mb:.2f} MB')
    
    print("\n" + "="*70)
    print("‚ö° Quick test finished! Use run_training_pipeline() for full training.")
    print("="*70)
    
    return {
        'best_f1': best_f1,
        'model_path': student_model_path,
        'model_size_mb': size_mb,
        'train_size': len(train_dataset),
        'val_size': len(val_dataset),
        'num_samples': num_samples,
        'num_epochs': num_epochs
    }

print("‚úì Quick test pipeline function defined")

---

## Part 2B: Full Training Pipeline

Run this cell for the complete training with all data and epochs.

In [None]:
# Quick test with small dataset - runs in a few minutes
test_results = quick_test_pipeline(
    num_samples=500,    # Use only 500 images
    num_epochs=2        # Just 2 epochs
)

# Display results
print("\n" + "="*70)
print("QUICK TEST RESULTS")
print("="*70)
print(f"Best F1 Score: {test_results['best_f1']:.4f}")
print(f"Model Size: {test_results['model_size_mb']:.2f} MB")
print(f"Samples Used: {test_results['num_samples']}")
print(f"Train/Val: {test_results['train_size']}/{test_results['val_size']}")
print("="*70)
print("\n‚úÖ If this worked, you're ready to run the full pipeline!")

---

## Part 2A: Quick Test (Optional)

Run this cell first to quickly test that everything works before running the full training pipeline. This uses only 500 samples and 2 epochs, so it completes in a few minutes.

In [None]:
# ============================================================
# MAIN TRAINING PIPELINE
# ============================================================

def run_training_pipeline(
    # Dataset parameters
    dataset_path='train_dataset',
    num_classes=10,
    
    # Image processing parameters
    image_size=224,
    normalize_mean=[0.485, 0.456, 0.406],
    normalize_std=[0.229, 0.224, 0.225],
    
    # Training parameters
    batch_size=32,
    num_epochs=10,
    learning_rate=1e-3,
    num_workers=2,
    
    # Data split parameters
    train_split=0.9,
    
    # Distillation parameters
    temperature=4.0,
    alpha=0.3,
    
    # Model parameters
    teacher_model_name="google/medsiglip-448",
    student_model_path="student_model_hw3.pt",
    
    # Submission parameters
    submit=False,
    my_token='your_token_here',
    server_url='http://hadi.cs.virginia.edu:8000'
):
    """
    Complete training pipeline for knowledge distillation.
    
    Args:
        dataset_path: Path to training dataset directory
        num_classes: Number of output classes
        image_size: Size to resize images to (square)
        normalize_mean: Mean values for normalization
        normalize_std: Std values for normalization
        batch_size: Batch size for training
        num_epochs: Number of training epochs
        learning_rate: Learning rate for optimizer
        num_workers: Number of workers for data loading
        train_split: Fraction of data to use for training (rest is validation)
        temperature: Temperature for knowledge distillation
        alpha: Weight for hard loss (1-alpha for soft loss)
        teacher_model_name: HuggingFace model name for teacher
        student_model_path: Path to save student model
        submit: Whether to submit model to leaderboard (default: False)
        my_token: Token for leaderboard submission
        server_url: Server URL for submission
    
    Returns:
        dict: Results including best_f1, model_path, and model_size_mb
    """
    print("="*70)
    print("KNOWLEDGE DISTILLATION TRAINING PIPELINE")
    print("="*70)
    
    # ==================== STEP 1: LOAD DATASET ====================
    print("\n[1/5] Loading dataset...")
    train_transform, val_transform = create_transforms()
    dataset = SkinDataset(dataset_path, transform=train_transform)
    print(f'‚úÖ Dataset loaded: {len(dataset)} images, {len(dataset.classes)} classes')
    
    # ==================== STEP 2: LOAD MODELS ====================
    print("\n[2/5] Loading models...")
    teacher_model, teacher_processor = load_teacher_model()
    student_model = create_student_shufflenet(num_classes=num_classes).to(device)
    print(f'‚úÖ Student model created: {sum(p.numel() for p in student_model.parameters()):,} parameters')
    
    # ==================== STEP 3: SETUP TRAINING ====================
    print("\n[3/5] Setting up training...")
    train_size = int(train_split * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    
    criterion = DistillationLoss(temperature=temperature, alpha=alpha)
    optimizer = optim.Adam(student_model.parameters(), lr=learning_rate)
    
    print(f'‚úÖ Setup complete: Train={len(train_dataset)} | Val={len(val_dataset)}')
    
    # ==================== STEP 4: TRAIN MODEL ====================
    print("\n[4/5] Training model...")
    print("-"*70)
    best_f1 = 0
    
    for epoch in range(num_epochs):
        print(f'\nEpoch {epoch+1}/{num_epochs}')
        print('='*50)
        
        # Train for one epoch
        train_loss = train_epoch(student_model, teacher_model, teacher_processor, 
                                 train_loader, criterion, optimizer)
        
        # Validate
        val_acc, val_f1 = validate(student_model, val_loader)
        
        print(f'Train Loss: {train_loss:.4f} | Val Acc: {val_acc:.4f} | Val F1: {val_f1:.4f}')
        
        # Save best model
        if val_f1 > best_f1:
            best_f1 = val_f1
            print(f'‚úÖ New best F1: {best_f1:.4f}')
    
    print(f'\n{"="*50}')
    print(f'Training complete! Best F1: {best_f1:.4f}')
    print(f'{"="*50}')
    
    # ==================== STEP 5: SAVE MODEL ====================
    print("\n[5/5] Saving model...")
    student_model.eval()
    student_model.cpu()
    scripted_model = torch.jit.script(student_model)
    scripted_model.save(student_model_path)
    
    size_mb = os.path.getsize(student_model_path) / (1024 * 1024)
    print(f'‚úÖ Model saved: {student_model_path}')
    print(f'üì¶ Model size: {size_mb:.2f} MB')
    
    if size_mb >= 25.0:
        print('‚ùå WARNING: Model exceeds 25 MB limit!')
    else:
        print('‚úÖ Model size is within the 25 MB limit')
    
    # ==================== OPTIONAL: SUBMIT ====================
    if submit:
        print("\n[BONUS] Submitting model to leaderboard...")
        submit_model(my_token, student_model_path, server_url)
        check_status(my_token, server_url)
    else:
        print(f'\nüí° To submit, set submit=True or run manually:')
        print(f'   submit_model("{my_token}", "{student_model_path}", "{server_url}")')
    
    print(f'\nüéØ View the HW3 leaderboard at: {server_url}/leaderboard3')
    print("="*70)
    
    # Return results
    return {
        'best_f1': best_f1,
        'model_path': student_model_path,
        'model_size_mb': size_mb,
        'train_size': len(train_dataset),
        'val_size': len(val_dataset)
    }

print("‚úì Main pipeline function defined")

---

## Part 2: Run Training Pipeline

Simply call the `run_training_pipeline()` function with your desired parameters.
All parameters use the configuration values defined in Cell 2 by default.

In [None]:
# Run the complete training pipeline using config values
results = run_training_pipeline(
    # Dataset parameters
    dataset_path=DATASET_PATH,
    num_classes=NUM_CLASSES,
    
    # Image processing parameters
    image_size=IMAGE_SIZE,
    normalize_mean=NORMALIZE_MEAN,
    normalize_std=NORMALIZE_STD,
    
    # Training parameters
    batch_size=BATCH_SIZE,
    num_epochs=NUM_EPOCHS,
    learning_rate=LEARNING_RATE,
    num_workers=NUM_WORKERS,
    
    # Data split parameters
    train_split=TRAIN_SPLIT,
    
    # Distillation parameters
    temperature=TEMPERATURE,
    alpha=ALPHA,
    
    # Model parameters
    teacher_model_name=TEACHER_MODEL_NAME,
    student_model_path=STUDENT_MODEL_PATH,
    
    # Submission parameters (set submit=True to auto-submit)
    submit=False,  # Change to True to submit automatically
    my_token=MY_TOKEN,
    server_url=SERVER_URL
)

# Display final results
print("\n" + "="*70)
print("FINAL RESULTS")
print("="*70)
print(f"Best F1 Score: {results['best_f1']:.4f}")
print(f"Model Path: {results['model_path']}")
print(f"Model Size: {results['model_size_mb']:.2f} MB")
print(f"Training Set: {results['train_size']} samples")
print(f"Validation Set: {results['val_size']} samples")
print("="*70)