In [1]:
# IMPORTANT: RUN THIS CELL IN ORDER TO IMPORT YOUR KAGGLE DATA SOURCES,
# THEN FEEL FREE TO DELETE THIS CELL.
# NOTE: THIS NOTEBOOK ENVIRONMENT DIFFERS FROM KAGGLE'S PYTHON
# ENVIRONMENT SO THERE MAY BE MISSING LIBRARIES USED BY YOUR
# NOTEBOOK.
# import kagglehub
# ifigotin_imagenetmini_1000_path = kagglehub.dataset_download('ifigotin/imagenetmini-1000')

# print('Data source import complete.')


Downloading from https://www.kaggle.com/api/v1/datasets/download/ifigotin/imagenetmini-1000?dataset_version_number=1...


100%|██████████| 3.92G/3.92G [02:21<00:00, 29.8MB/s]

Extracting files...





Data source import complete.


In [2]:
# !ls -a /root/.cache/kagglehub/datasets/ifigotin/imagenetmini-1000/versions/1/imagenet-mini

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder

# best_model_path = "/root/.cache/kagglehub/models/sibgatulislam/distilled_resnet101/PyTorch/default/1/best_model.pth"
dataset_path = "/kaggle/input/imagenetmini-1000/imagenet-mini"

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder

class KnowledgeDistillationLoss(nn.Module):
    def __init__(self, temperature=3.0, alpha=0.85):
        """
        Custom loss for Knowledge Distillation

        Args:
            temperature (float): Temperature for softening logits
            alpha (float): Balance between soft and hard targets
        """
        super().__init__()
        self.temperature = temperature
        self.alpha = alpha
        self.ce_loss = nn.CrossEntropyLoss()
        self.kl_div_loss = nn.KLDivLoss(reduction='batchmean')

    def forward(self, student_logits, teacher_logits, labels):
        """
        Compute knowledge distillation loss

        Args:
            student_logits (torch.Tensor): Logits from student model
            teacher_logits (torch.Tensor): Logits from teacher model
            labels (torch.Tensor): Ground truth labels

        Returns:
            torch.Tensor: Combined loss
        """
        # Soft targets with temperature scaling
        soft_target_loss = self.kl_div_loss(
            torch.log_softmax(student_logits / self.temperature, dim=1),
            torch.softmax(teacher_logits / self.temperature, dim=1)
        ) * (self.temperature ** 2)

        # Hard target loss
        hard_target_loss = self.ce_loss(student_logits, labels)

        # Combined loss
        return self.alpha * soft_target_loss + (1 - self.alpha) * hard_target_loss

def create_model(num_classes, is_teacher=False):
    """
    Create ResNet model (teacher or student)

    Args:
        num_classes (int): Number of output classes
        is_teacher (bool): Whether to create teacher or student model

    Returns:
        nn.Module: Configured ResNet model
    """
    if is_teacher:
        model = models.resnet101(weights="IMAGENET1K_V2")
        model.fc = nn.Linear(model.fc.in_features, num_classes)
    else:
        model = models.resnet34(weights=None)
        model.fc = nn.Linear(model.fc.in_features, num_classes)

    return model

def accuracy(output, target, topk=(1,)):
    """
    Computes the top-k accuracy for the specified values of k

    Args:
        output (torch.Tensor): Model predictions
        target (torch.Tensor): Ground truth labels
        topk (tuple): Values of k for top-k accuracy

    Returns:
        list: List of top-k accuracies
    """

    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

def train_knowledge_distillation(train_loader, test_loader, num_classes, device='cuda'):
    """
    Perform Knowledge Distillation training

    Args:
        train_loader (DataLoader): Training data loader
        test_loader (DataLoader): Validation data loader
        num_classes (int): Number of output classes
        device (str): Computing device
    """
    # Initialize teacher and student models
    teacher_model = create_model(num_classes, is_teacher=True).to(device)
    student_model = create_model(num_classes, is_teacher=False).to(device)

    # Freeze teacher model parameters
    for param in teacher_model.parameters():
        param.requires_grad = False

    # Set up optimizer and loss
    # optimizer = optim.Adam(student_model.parameters(), lr=1e-4)
    optimizer = optim.SGD(student_model.parameters(), lr=0.001, momentum=0.9, nesterov=False, weight_decay=2e-05)
    kd_loss_fn = KnowledgeDistillationLoss()

    # Training loop
    num_epochs = 100
    best_accuracy = 0.0

    for epoch in range(num_epochs):
        student_model.train()
        teacher_model.eval()
        total_loss = 0.0
        i = 0

        for i, (inputs, labels) in enumerate(train_loader):
            inputs, labels = inputs.to(device), labels.to(device)

            # Zero gradients
            optimizer.zero_grad()

            # Forward pass
            with torch.no_grad():
                teacher_logits = teacher_model(inputs)

            student_logits = student_model(inputs)

            # Compute loss
            loss = kd_loss_fn(student_logits, teacher_logits, labels)

            acc1, acc5 = accuracy(student_logits, labels, (1, 5))

            # Backward pass and optimize
            loss.backward()
            optimizer.step()

            if i%50 == 0:
              print(f"epoch: {epoch} || batch: {i}, loss: {loss.item()}, acc@1: {acc1.item()}, acc@5: {acc5.item()}")

            total_loss += loss.item()

        # Validation
        student_model.eval()
        correct = 0
        total = 0

        if (epoch) % 10 == 0:
          with torch.no_grad():
            for inputs, labels in test_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = student_model(inputs)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

                t_outputs = teacher_model(inputs)
                _, t_predicted = torch.max(t_outputs.data, 1)
                t_correct += (t_predicted == labels).sum().item()

            v_accuracy = 100 * correct / total
            t_accuracy = 100 * t_correct / total

            print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss/len(train_loader):.4f}, Accuracy: {v_accuracy:.2f}%, Teacher Accuracy: {t_accuracy:.2f}%')

            # Save best model
            if v_accuracy > best_accuracy:
                best_accuracy = v_accuracy
                torch.save(student_model.state_dict(), 'best_student_model.pth')
                
                
def main():
    weights101 = models.ResNet101_Weights.IMAGENET1K_V2
    transform = weights101.transforms()

    # Data transformations
    # transform = transforms.Compose([
    #     transforms.Resize((224, 224)),
    #     transforms.ToTensor(),
    #     transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    # ])

    # Load dataset (replace with your Kaggle dataset path)
    train_dataset = ImageFolder(
        dataset_path + '/train',
        transform=transform
    )

    test_dataset = ImageFolder(
      dataset_path + '/val',
      transform=transform
    )

    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=512, shuffle=False)

    # Number of classes in your dataset
    num_classes = len(train_dataset.classes)

    # Perform knowledge distillation
    train_knowledge_distillation(train_loader, test_loader, num_classes)

In [None]:
main()