In [0]:
!pip install torchinfo

In [0]:
import torch
import torch.nn as nn
from torchinfo import summary
from torchvision import datasets
import torchvision.transforms as T
from torch.utils.data import DataLoader
import os
import numpy as np
import random
import torch.optim as optim



In [0]:
def set_seed(seed=42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    random.seed(seed)
    np.random.seed(seed)

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [0]:
def prepare_models(num_classes):
    from torchvision.models import resnet34, ResNet34_Weights, shufflenet_v2_x0_5, ShuffleNet_V2_X0_5_Weights

    set_seed()

    # Initializing student model
    student_model = shufflenet_v2_x0_5(weights=ShuffleNet_V2_X0_5_Weights.DEFAULT)
    student_model.fc = nn.Linear(student_model.fc.in_features, num_classes)

    # Initializing teacher model
    teacher_model = resnet34(weights=ResNet34_Weights.DEFAULT)
    teacher_model.fc = nn.Linear(teacher_model.fc.in_features, num_classes)

    return teacher_model, student_model

def get_cifar100_dataloaders(batch_size=256):
    mean = (0.5070751592371323, 0.48654887331495095, 0.4409178433670343)
    std = (0.26733428587941854, 0.25643846292120615, 0.2761504713263903)
    cache = '/dbfs/cache'

    transform_train = T.Compose([
        T.ToTensor(),
        T.RandomCrop(32, padding=4),
        T.RandomHorizontalFlip(),
        T.Normalize(mean, std)
    ])
    transform_test = T.Compose([
        T.ToTensor(),
        T.Normalize(mean, std)
    ])

    # Creating datasets
    train_set = datasets.CIFAR100(root=cache, train=True, download=True, transform=transform_train)
    test_set = datasets.CIFAR100(root=cache, train=False, download=True, transform=transform_test)
    
    # Creating dataloaders
    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=os.cpu_count())
    test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=os.cpu_count())

    return train_loader, test_loader


In [0]:
def train(model, train_loader, test_loader, model_path, load=False):

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.0001, weight_decay=0.001)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.1)

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model = model.to(device)

    epochs = 100
    avg_losses, avg_accuracies = [], []

    if os.path.exists(model_path) and load:
        # Loading weights
        try:
            model.load_state_dict(torch.load(model_path))
            print("Model loaded")
        except:
            raise Exception("Model not found")
        
        # Getting base accuracy
        model.eval()
        base_accuracy = 0
        with torch.no_grad():
            for i, data in enumerate(test_loader):
                inputs, labels = data
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                _, predicted = torch.max(outputs.data, 1)
                base_accuracy += (predicted == labels).sum().item()

        base_accuracy = (base_accuracy / len(test_loader.dataset)) * 100
        avg_accuracies.append(base_accuracy)
    else:
        os.makedirs(os.path.dirname(model_path), exist_ok=True)
        print("Model not found")

    # Model training
    for epoch in range(epochs):
        # Training
        model.train()
        running_loss = 0.0
        for i, data in enumerate(train_loader):
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        scheduler.step()
        
        # Testing
        model.eval()
        running_accuracy = 0.0
        with torch.no_grad():
            for i, data in enumerate(test_loader):
                inputs, labels = data
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                _, predicted = torch.max(outputs.data, 1)
                running_accuracy += (predicted == labels).sum().item()

        avg_loss = running_loss / len(train_loader)
        avg_accuracy = (running_accuracy / len(test_loader.dataset)) * 100

        avg_losses.append(avg_loss)
        avg_accuracies.append(avg_accuracy)

        print(f'Epoch [{epoch + 1}/ {epochs}]: Loss: {avg_loss:.4f} | Accuracy: {avg_accuracy:.2f}')   

        if len(avg_accuracies) > 1 and avg_accuracy > max(avg_accuracies[:-1]):
            torch.save(model.state_dict(), model_path)
            print("Model saved")
        elif len(avg_accuracies) == 1:
            torch.save(model.state_dict(), model_path)
            print("Model saved")

    return avg_losses, avg_accuracies



In [0]:
teacher_model, student_model = prepare_models(num_classes=100)
train_loader, test_loader = get_cifar100_dataloaders()

In [0]:
model_path = '/dbfs/research2/best_model.pth'
avg_losses, avg_accuracies = train(teacher_model, train_loader, test_loader, model_path)


In [0]:
def train_student_with_kd(student, teacher, train_loader, test_loader, device, epochs=30):
    """Train student with knowledge distillation"""
    kd_criterion = KnowledgeDistillationLoss(alpha=0.7, temperature=4.0)
    optimizer = torch.optim.Adam(student.parameters(), lr=0.001)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
    
    teacher.eval()  # Teacher in eval mode
    
    print("Training student with knowledge distillation...")
    best_acc = 0
    
    for epoch in range(epochs):
        student.train()
        running_loss = 0.0
        running_kd_loss = 0.0
        running_ce_loss = 0.0
        
        for batch_idx, (data, labels) in enumerate(train_loader):
            data, labels = data.to(device), labels.to(device)
            
            optimizer.zero_grad()
            
            # Get predictions
            with torch.no_grad():
                teacher_logits = teacher(data)
            student_logits = student(data)
            
            # Compute KD loss
            total_loss, kd_loss, ce_loss = kd_criterion(student_logits, teacher_logits, labels)
            
            total_loss.backward()
            optimizer.step()
            
            running_loss += total_loss.item()
            running_kd_loss += kd_loss.item()
            running_ce_loss += ce_loss.item()
            
            if batch_idx % 100 == 0:
                print(f'KD Epoch {epoch}, Batch {batch_idx}, Total: {total_loss.item():.4f}, '
                      f'KD: {kd_loss.item():.4f}, CE: {ce_loss.item():.4f}')
        
        scheduler.step()
        
        # Validation
        if epoch % 5 == 0:
            acc = evaluate_model(student, test_loader, device)
            print(f'Student Epoch {epoch}, Test Accuracy: {acc:.2f}%')
            if acc > best_acc:
                best_acc = acc
                torch.save(student.state_dict(), 'student_kd_model.pth')
    
    print(f'Student KD training complete. Best accuracy: {best_acc:.2f}%')
    return best_acc