In [None]:
# # IMPORTANT: RUN THIS CELL IN ORDER TO IMPORT YOUR KAGGLE DATA SOURCES,
# # THEN FEEL FREE TO DELETE THIS CELL.
# # NOTE: THIS NOTEBOOK ENVIRONMENT DIFFERS FROM KAGGLE'S PYTHON
# # ENVIRONMENT SO THERE MAY BE MISSING LIBRARIES USED BY YOUR
# # NOTEBOOK.
# import kagglehub
# ifigotin_imagenetmini_1000_path = kagglehub.dataset_download('ifigotin/imagenetmini-1000')
# sibgatulislam_distilled_resnet101_pytorch_default_1_path = kagglehub.model_download('sibgatulislam/distilled_resnet101/PyTorch/default/1')

# print('Data source import complete.')

In [None]:
# !ls -a /root/.cache/kagglehub/datasets/ifigotin/imagenetmini-1000/versions/1/imagenet-mini

In [None]:
# best_model_path = "/root/.cache/kagglehub/models/sibgatulislam/distilled_resnet101/PyTorch/default/1/best_model.pth"
# dataset_path = "/root/.cache/kagglehub/datasets/ifigotin/imagenetmini-1000/versions/1/imagenet-mini"

In [None]:
!pip install tqdm -q
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import tqdm

In [None]:
# Get weights
weights101 = models.ResNet101_Weights.DEFAULT
transforms = weights101.transforms()

dataset_path = "/kaggle/input/imagenetmini-1000/imagenet-mini"

# Transform and load your dataset
train_dataset = datasets.ImageFolder(
    dataset_path + '/train',
    transform=transforms
)

# DataLoader for the validation set (adjust the path accordingly)
val_dataset = datasets.ImageFolder(
    dataset_path + '/val',
    transform=transforms
)

# Define device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# Define temperature and alpha for distillation
TEMPERATURE = 3.0
ALPHA = 0.95

train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)

# Load Teacher and Student Models
teacher_model = models.resnet101(
    weights=weights101
).to(device)

# print("======TEACHER======")
# print(teacher_model)

student_model = models.resnet34(weights=None).to(device)

# print("======STUDENT======")
# print(student_model)

# Freeze teacher model parameters
for param in teacher_model.parameters():
    param.requires_grad = False

# Define loss functions
criterion = nn.CrossEntropyLoss()
distillation_loss = nn.KLDivLoss(reduction='batchmean')

# Define optimizer for the student model
# optimizer = optim.Adam(student_model.parameters(), lr=0.001)
optimizer = optim.SGD(student_model.parameters(), lr=0.1, momentum=0.9, nesterov=True)


In [None]:
def accuracy(output, target, topk=(1,)):
    """Computes the top-k accuracy for the specified values of k."""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

# Knowledge Distillation Function
def distillation_loss_fn(student_logits, teacher_logits, labels):
    # Regular cross-entropy loss
    loss = criterion(student_logits, labels)

    # Softened outputs for distillation
    distillation_loss_value = distillation_loss(
        torch.nn.functional.log_softmax(student_logits / TEMPERATURE, dim=1),
        torch.nn.functional.softmax(teacher_logits / TEMPERATURE, dim=1)
    ) * (TEMPERATURE ** 2)

    # Weighted sum of both losses
    return ALPHA * distillation_loss_value + (1 - ALPHA) * loss

# Function to evaluate model accuracy
def evaluate_model(model, dataloader, device):
    model.eval()
    correct_top1 = 0
    correct_top5 = 0
    total = 0
    print("total batches: ", len(dataloader))

    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)

            # Top-1 accuracy
            _, predicted = outputs.max(1)
            correct_top1 += (predicted == labels).sum().item()

            # Top-5 accuracy
            _, top5_pred = outputs.topk(5, dim=1)
            correct_top5 += (top5_pred == labels.view(-1, 1)).sum().item()

            total += labels.size(0)

    top1_accuracy = 100 * correct_top1 / total
    top5_accuracy = 100 * correct_top5 / total
    return top1_accuracy, top5_accuracy

val_loader = DataLoader(
    val_dataset,
    batch_size=128,
    shuffle=False
)

# Training Loop
def train_student(teacher_model, student_model, train_dataloader, val_loader optimizer, epochs=10):
    student_model.train()
    teacher_model.eval()

    best_loss = float('inf')
    best_acc = 50.0
    
    best_acc1 = 50.0
    best_acc5 = 50.0
    
    val_acc1 = 0.0
    val_acc5 = 0.0


    history = {
        'loss': [],
        'acc1': [],
        'acc5': []
    }

    for epoch in range(epochs):
        epoch_loss = 0.0
        epoch_acc1 = 0.0
        epoch_acc5 = 0.0
#         pbar = tqdm.tqdm(dataloader, desc=f'Epoch {epoch+1}/{epochs}')

        for i, (images, labels) in enumerate(train_dataloader):
            inputs, labels = images.to(device), labels.to(device)

            # Get teacher predictions (detach to avoid computing gradients on teacher)
            with torch.no_grad():
                teacher_logits = teacher_model(inputs)

            # Get student predictions
            student_logits = student_model(inputs)

            # Compute loss
            loss = distillation_loss_fn(student_logits, teacher_logits, labels)

            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            acc1, acc5 = accuracy(student_logits, labels, (1, 5))

            # Update running metrics
            batch_size = images.size(0)
            epoch_loss += loss.item() * batch_size
            epoch_acc1 += acc1.item() * batch_size
            epoch_acc5 += acc5.item() * batch_size

#             # Update progress bar
#             pbar.set_postfix({
#                 'loss': f'{loss.item():.4f}',
#                 'acc1': f'{acc1.item():.2f}%',
#                 'acc5': f'{acc5.item():.2f}%'
#             })

            if i%100 == 0:
                print(
                    f"epoch: {epoch+1} || batch/batch_size: {i+1}/{batch_size} || loss: {loss.item():.4f} || acc1: {acc1.item():.2f} || acc5: {acc5.item():.2f}"
                )

        # Calculate epoch metrics
        num_samples = len(train_loader.dataset)
        epoch_loss /= num_samples
        epoch_acc1 /= num_samples
        epoch_acc5 /= num_samples

        # Update history
        history['loss'].append(epoch_loss)
        history['acc1'].append(epoch_acc1)
        history['acc5'].append(epoch_acc5)

        # Print epoch summary
        print(f'\nEpoch {epoch+1}/{epochs}:')
        print(f'Loss: {epoch_loss:.4f}')
        print(f'Accuracy@1: {epoch_acc1:.2f}%')
        print(f'Accuracy@5: {epoch_acc5:.2f}%')
        
        if (epoch+1)%10 == 0:
            top1_acc, top5_acc = evaluate_model(
                student_model, 
                val_loader, 
                device
            )
            
            val_acc1 = top1_acc
            val_acc5 = top5_acc
            
            torch.save({
                'epoch': epoch,
                'model_state_dict': student_model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'acc@1': top1_acc,
                'acc@5': top5_acc,
            }, f'./best_model-{epoch+1}.pth')

            
            print(f"Student Model with Distillation - Top-1 Accuracy: {top1_acc:.2f}%, Top-5 Accuracy: {top5_acc:.2f}%")
            
        # Save best model
        if epoch_loss < best_loss and best_acc1 < val_acc1 and best_acc5 < val_acc5:
            best_acc1 = val_acc1
            best_acc5 = val_acc5
            best_loss = epoch_loss
            torch.save({
                'epoch': epoch,
                'model_state_dict': student_model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': best_loss,
                'best_acc@1': best_acc1,
                'best_acc@5': best_acc5,
                'accuracy': epoch_acc1
            }, './best_model.pth')
            
    return student_model

In [None]:
# Train the student model
trained_student_model = train_student(teacher_model, student_model, train_loader, val_loader optimizer, epochs=50)

In [None]:
# trained_student_model = models.resnet34(weights=None)
# trained_student_model.load_state_dict(torch.load(best_model_path, weights_only=True)["model_state_dict"])
# trained_student_model.to(device)

In [None]:
top1_teacher, top5_teacher = evaluate_model(teacher_model, val_loader, device)
print(f"Teacher Model - Top-1 Accuracy: {top1_teacher:.2f}%, Top-5 Accuracy: {top5_teacher:.2f}%")

# Train student_model here without distillation, then evaluate
student_model_wokd = models.resnet34(weights=models.ResNet34_Weights.DEFAULT).to(device)
top1_student_baseline, top5_student_baseline = evaluate_model(student_model_wokd, val_loader, device)
print(f"Student Model Baseline - Top-1 Accuracy: {top1_student_baseline:.2f}%, Top-5 Accuracy: {top5_student_baseline:.2f}%")

# Train the student model with knowledge distillation and evaluate
# (Use the previously provided code for distillation training)
top1_student_distilled, top5_student_distilled = evaluate_model(trained_student_model, val_loader, device)
print(f"Student Model with Distillation - Top-1 Accuracy: {top1_student_distilled:.2f}%, Top-5 Accuracy: {top5_student_distilled:.2f}%")

In [None]:
def tr_evaluate_model(model, dataloader, device):
    model.eval()
    correct_top1 = 0
    correct_top5 = 0
    total = 0
    print("total batches: ", len(dataloader))

    with torch.no_grad():
        for i, (inputs, labels) in enumerate(dataloader):
            if (i+1) %32 == 0:
                print(f"batch: {i+1}")
                print("label size: ", labels.size(0))
                break

            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)

            # Top-1 accuracy
            _, predicted = outputs.max(1)
            correct_top1 += (predicted == labels).sum().item()

            # Top-5 accuracy
            _, top5_pred = outputs.topk(5, dim=1)
            correct_top5 += (top5_pred == labels.view(-1, 1)).sum().item()

            total += labels.size(0)

    top1_accuracy = 100 * correct_top1 / total
    top5_accuracy = 100 * correct_top5 / total
    return top1_accuracy, top5_accuracy


In [None]:
# top1_student_distilled, top5_student_distilled = tr_evaluate_model(trained_student_model, DataLoader(train_dataset, batch_size=512, shuffle=True), device)
# print(f"Student Model with Distillation - Top-1 Accuracy: {top1_student_distilled:.2f}%, Top-5 Accuracy: {top5_student_distilled:.2f}%")

In [None]:
torch.cuda.empty_cache()