# Knowledge Distillation Project

## Install Dependencies and Set Environment

In [None]:
# installing necessary libraries
!pip install pytorch-lightning wandb torchvision timm detectors torchmetrics --quiet

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m51.6/51.6 kB[0m [31m2.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m69.2/69.2 kB[0m [31m4.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m819.3/819.3 kB[0m [31m21.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m20.0/20.0 MB[0m [31m113.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.4/2.4 MB[0m [31m92.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m616.8/616.8 kB[0m [31m43.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m927.3/927.3 kB[0m [31m53.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m207.3/207.3 kB[0m [31m16.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [None]:
import wandb
wandb.login(relogin=True)

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [None]:
# Standard library imports
import os
from typing import List, Optional
from tqdm import tqdm

# Third-party imports
from sklearn.model_selection import StratifiedShuffleSplit
import numpy as np
from torch.utils.data import Subset
import matplotlib.cm as cm
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from PIL import Image
from sklearn.metrics import classification_report, confusion_matrix, ConfusionMatrixDisplay
from torch.nn import CosineSimilarity, CrossEntropyLoss, KLDivLoss
from torch.utils.data import DataLoader, random_split
from torchmetrics.classification import Accuracy, MulticlassAUROC
from torchvision import datasets, models, transforms
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
import wandb
import timm

# Local application/library-specific imports
import detectors




## Getting the Dataset

- We will be training on CIFAR100 Dataset, which is the dataset that has 100 classes, and 60K images, where 45K for training, 5k for validation and 10k for test.
- Since we will be implementing Teacher-Student Knowledge Distillation, we have to handle transforms and augmentations differently for the teacher and student. Therefore, we have created a wrapper `DualTransformDataset` to handle it appropriately.
- Batch Size - 128 throughout the whole project.

In [None]:
class DualTransformDataset(torch.utils.data.Dataset):
    def __init__(
        self,
        dataset,
        teacher_transforms: Optional[List[transforms.Compose]] = None,
        student_transform: Optional[transforms.Compose] = None
    ):
        """
        Wrapper for dataset that applies transforms to image(s).

        Args:
            dataset (torch.utils.data.Dataset): Base dataset.
            teacher_transforms (list of torchvision.transforms.Compose, optional):
                List of transform pipelines for each teacher model.
            student_transform (torchvision.transforms.Compose, optional):
                Transform pipeline for student model.
        """
        self.dataset = dataset
        self.teacher_transforms = teacher_transforms if teacher_transforms is not None else []
        self.student_transform = student_transform

    def __getitem__(self, idx):
        img, label = self.dataset[idx]
        if isinstance(img, torch.Tensor):
            img = transforms.ToPILImage()(img)

        output = {}
        # Apply each teacher's transform and add to output with unique keys
        for i, teacher_transform in enumerate(self.teacher_transforms):
            output[f'teacher_input_{i}'] = teacher_transform(img)

        # Apply student transform
        if self.student_transform is not None:
            output['student_input'] = self.student_transform(img)

        # Add label
        output['label'] = label

        return output

    def __len__(self):
        return len(self.dataset)


In [None]:
class CIFAR100DataModule(pl.LightningDataModule):
    def __init__(
        self,
        data_dir: str = '/content/data',
        batch_size: int = 128,
        num_workers: int = -1,
        teacher_models: Optional[List[str]] = None,  # List of teacher model types
        student_model: str = 'mobilenet',
        pre_trained: bool = False,
        val_size: int = 5000,
        seed: int = 42
    ):
        """
        DataModule for CIFAR-100 with support for multiple teacher transforms.

        Args:
            data_dir (str): Path to store CIFAR-100 data.
            batch_size (int): Batch size for training and validation.
            num_workers (int): Number of subprocesses for data loading.
            teacher_models (list of str, optional):
                List of model types for teacher transforms (e.g., ['resnet', 'densenet']).
            student_model (str): Model type for student transforms.
            pre_trained (bool): Whether to use normalization for pre-trained models.
            val_size (int): Number of images to use for validation.
            seed (int): Random seed for reproducibility.
        """
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.teacher_models = teacher_models if teacher_models is not None else []
        self.student_model = student_model
        self.val_size = val_size
        self.seed = seed
        self.pre_trained = pre_trained

        # Define transforms for student
        self.student_transform = self._get_transform(self.student_model, self.pre_trained, is_train=True)
        self.student_transform_val = self._get_transform(self.student_model, self.pre_trained, is_train=False)

        # Define transforms for each teacher
        self.teacher_transforms = []
        self.teacher_transforms_val = []
        if self.teacher_models:
            for model_type in self.teacher_models:
                # Applying data augmentations to teachers by setting is_train=False
                teacher_transform = self._get_transform(model_type, self.pre_trained, is_train=False)
                teacher_transform_val = self._get_transform(model_type, self.pre_trained, is_train=False)
                self.teacher_transforms.append(teacher_transform)
                self.teacher_transforms_val.append(teacher_transform_val)

    def _get_transform(self, model_type: str, pre_trained: bool, is_train: bool = True) -> transforms.Compose:
        """Helper method to get transforms based on model type"""
        if pre_trained:
            normalizer = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                           std=[0.229, 0.224, 0.225])
        else:
            normalizer = transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                           std=[0.5, 0.5, 0.5])

        if is_train:
            return transforms.Compose([
                transforms.Resize((256, 256), interpolation=Image.BICUBIC),
                transforms.RandomCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalizer
            ])

        if model_type == 'mobilenet':
            return transforms.Compose([
                transforms.Resize((256, 256), interpolation=Image.BICUBIC),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                normalizer
            ])
        elif model_type == 'resnet':
            return transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.5071, 0.4865, 0.4409],
                                  std=[0.2673, 0.2564, 0.2762])
            ])
        elif model_type == 'densenet':
            return transforms.Compose([
                transforms.Resize((36,36), interpolation=Image.BILINEAR),
                transforms.CenterCrop(32),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.5071,0.4867,0.4408],
                                     std=[0.2675,0.2565,0.2761])
            ])
        elif model_type == 'vit':
            return transforms.Compose([
                transforms.Resize((248, 248), interpolation=Image.BICUBIC),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                     std=[0.5, 0.5, 0.5])
            ])
        else:
            raise ValueError(f"Invalid model type: {model_type}")

    def prepare_data(self):
        """Download CIFAR-100 dataset if not already present."""
        datasets.CIFAR100(root=self.data_dir, train=True, download=True)
        datasets.CIFAR100(root=self.data_dir, train=False, download=True)

    def setup(self, stage: Optional[str] = None):
        if stage in ('fit', 'validate', None):
            # Load the full training dataset without transforms
            base_train_dataset = datasets.CIFAR100(
                root=self.data_dir,
                train=True,
                transform=None  # No transform here as we'll use DualTransformDataset
            )

            # Create stratified split
            targets = base_train_dataset.targets
            strat_split = StratifiedShuffleSplit(
                n_splits=1,
                test_size=self.val_size,
                random_state=self.seed
            )

            for train_idx, val_idx in strat_split.split(np.arange(len(targets)), targets):
                self.train_dataset = DualTransformDataset(
                    Subset(base_train_dataset, train_idx),
                    teacher_transforms=self.teacher_transforms,  # Pass list of teacher transforms
                    student_transform=self.student_transform
                )
                self.val_dataset = DualTransformDataset(
                    Subset(base_train_dataset, val_idx),
                    teacher_transforms=self.teacher_transforms_val,  # Pass list of teacher validation transforms
                    student_transform=self.student_transform_val
                )

        if stage in ('test', None):
            self.test_dataset = DualTransformDataset(
                datasets.CIFAR100(
                    root=self.data_dir,
                    train=False,
                    transform=None
                ),
                teacher_transforms=self.teacher_transforms_val,  # Use validation transforms for teachers
                student_transform=self.student_transform_val
            )

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            pin_memory=True
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True
        )

    def test_dataloader(self):
        return DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True
        )

## Getting Pre-Trained Teacher Model

In [None]:
model = timm.create_model("resnet50_cifar100", pretrained=True)

Downloading: "https://huggingface.co/edadaltocg/resnet50_cifar100/resolve/main/pytorch_model.bin" to /root/.cache/torch/hub/checkpoints/resnet50_cifar100.pth
100%|██████████| 90.7M/90.7M [00:00<00:00, 217MB/s]


In [None]:
model = model.to('cuda' if torch.cuda.is_available() else 'cpu')

## Evaluating Pre-Trained Model

In [None]:
class EvaluationLightningModule(pl.LightningModule):
    def __init__(self, model):
        """
        Initializes the LightningModule for Evaluation.

        Args:
            model (torch.nn.Module): The pre-trained model to evaluate.
        """
        super(EvaluationLightningModule, self).__init__()
        self.model = model

    def forward(self, x):
        return self.model(x)

    def configure_optimizers(self):
        # No optimizer needed for evaluation
        return None

    def validation_step(self, batch, batch_idx):
        images, labels = batch
        images = images.to(self.device)
        labels = labels.to(self.device)
        logits = self.model(images)
        preds = logits.argmax(dim=1)
        acc = (preds == labels).float().mean()
        self.log('val/accuracy', acc, on_step=False, on_epoch=True, prog_bar=True)
        return acc

    def test_step(self, batch, batch_idx):
        images, labels = batch
        images = images.to(self.device)
        labels = labels.to(self.device)
        logits = self.model(images)
        preds = logits.argmax(dim=1)
        acc = (preds == labels).float().mean()
        self.log('test/accuracy', acc, on_step=False, on_epoch=True, prog_bar=True)
        return acc


In [None]:
# Initialize W&B logger for evaluation
wandb_logger_eval = WandbLogger(project='Model-Evaluation-CIFAR100', name='ResNet50-Evaluation')


In [None]:
# Initialize Data Module
data_module = CIFAR100DataModule(data_dir='/content/data', batch_size=128, num_workers=2)


In [None]:
# Initialize PyTorch Lightning Trainer for evaluation with W&B logger
trainer_eval = pl.Trainer(
    devices=1 if torch.cuda.is_available() else None,
    accelerator='gpu' if torch.cuda.is_available() else 'cpu',
    logger=wandb_logger_eval,
    precision=16 if torch.cuda.is_available() else 32,
    log_every_n_steps=50,
)

# Initialize the evaluation Lightning module
eval_pl_model = EvaluationLightningModule(model=model)

# Perform evaluation on the validation set
val_results = trainer_eval.validate(eval_pl_model, datamodule=data_module, verbose=True)

# Extract and print validation accuracy
val_accuracy = val_results[0]['val/accuracy']
print(f"Validation Accuracy: {val_accuracy * 100:.2f}%")


INFO:pytorch_lightning.utilities.rank_zero:Using 16bit Automatic Mixed Precision (AMP)
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs


Files already downloaded and verified
Files already downloaded and verified


/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loggers/wandb.py:396: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Validation: |          | 0/? [00:00<?, ?it/s]

Validation Accuracy: 80.91%


## Model Structure

Below you can find two different DynamicKDLitModel models each for different purposes, overall strucure is similar, but ensemble methods a little different.

### DynamicKDLitModel with weights for each teacher and dynamic alpha

In [None]:
class DynamicKDLitModel(pl.LightningModule):
    def __init__(
        self,
        teacher_models: List[torch.nn.Module],
        student_model: torch.nn.Module,
        temperature: float = 4.0,
        gamma: float = 10.0,
        threshold: float = 0.5,
        learning_rate: float = 1e-3,
        use_soft_loss: bool = True,
        use_hard_loss: bool = True,
        alpha: float = 0.5
    ):
        """
        Initializes the Dynamic Knowledge Distillation Lightning Module with flexible loss components.

        Args:
            teacher_models (list of torch.nn.Module): List of pre-trained teacher models.
                Required if use_soft_loss=True.
            student_model (torch.nn.Module): The student model to be trained.
            temperature (float, optional): Temperature for softening probabilities (only relevant if use_soft_loss=True).
            gamma (float, optional): Scaling factor for sigmoid function in gate computation.
            threshold (float, optional): Threshold for gate activation.
            learning_rate (float, optional): Learning rate for the optimizer. Defaults to 1e-3.
            use_soft_loss (bool, optional): Use KL Divergence loss for KD. Defaults to True.
            use_hard_loss (bool, optional): Use Cross-Entropy loss with ground truth. Defaults to True.
            alpha (float, optional): Base weighting factor between soft and hard losses. Defaults to 0.5.
        """
        super(DynamicKDLitModel, self).__init__()

        # Save hyperparameters for checkpointing, excluding large model objects
        self.save_hyperparameters(ignore=["teacher_models", "student_model"])

        # Configuration flags
        self.use_soft_loss = use_soft_loss
        self.use_hard_loss = use_hard_loss
        self.alpha = alpha
        self.temperature = temperature
        self.gamma = gamma
        self.threshold = threshold
        self.learning_rate = learning_rate

        # Validate loss configuration
        if not (self.use_soft_loss or self.use_hard_loss):
            raise ValueError("At least one of 'use_soft_loss' or 'use_hard_loss' must be True.")

        # Initialize student model
        if student_model is None:
            raise ValueError("A student model must be provided.")
        self.student = student_model

        # Initialize teacher models
        if self.use_soft_loss:
            if not teacher_models or len(teacher_models) == 0:
                raise ValueError("Teacher models must be provided if 'use_soft_loss' is True.")
            self.teachers = teacher_models
            self.num_teachers = len(self.teachers)
            # Freeze teacher parameters
            for teacher in self.teachers:
                teacher.eval()
                for param in teacher.parameters():
                    param.requires_grad = False
            self.kl_div_loss = KLDivLoss(reduction='batchmean')
        else:
            self.teachers = None
            self.num_teachers = 0

        # Initialize loss functions
        if self.use_hard_loss:
            self.ce_loss = CrossEntropyLoss()

        # Initialize metrics
        self.num_classes = 100  # Adjust based on your dataset
        self.val_auroc = MulticlassAUROC(num_classes=self.num_classes, average='macro')
        self.test_auroc = MulticlassAUROC(num_classes=self.num_classes, average='macro')

    def forward(self, x):
        return self.student(x)

    def configure_optimizers(self):
        """
        Configures the optimizer and learning rate scheduler.
        """
        # Optimizer: AdamW with adaptive learning rate and weight decay
        optimizer = torch.optim.AdamW(
            self.student.parameters(),
            lr=self.learning_rate,
            weight_decay=1e-4  # Regularization strength
        )

        # Scheduler: Cosine Annealing LR
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer,
            T_max=100,
            eta_min=1e-6  # Minimum learning rate after annealing
        )

        return [optimizer], [scheduler]

    def compute_weights(self, teacher_logits: List[torch.Tensor], labels: torch.Tensor) -> torch.Tensor:
        """
        Computes dynamic weights for each teacher based on their cross-entropy loss with the ground truth
        using an exponential-based weighting scheme.

        Args:
            teacher_logits (List[torch.Tensor]): List of logits from each teacher. Each tensor has shape (batch_size, C).
            labels (torch.Tensor): Ground truth labels. Shape: (batch_size,).

        Returns:
            torch.Tensor: Weights for each teacher. Shape: (batch_size, K).
        """
        # Determine the number of teachers (K)
        K = len(teacher_logits)

        # Compute cross-entropy loss for each teacher without reduction to keep per-sample losses
        ce_losses = []
        for logit in teacher_logits:
            ce = F.cross_entropy(logit, labels, reduction='none')  # Shape: (batch_size,)
            ce_losses.append(ce)

        # Stack losses to form a tensor of shape (batch_size, K)
        ce_losses = torch.stack(ce_losses, dim=1)  # Shape: (batch_size, K)

        # Compute exponential of losses
        exp_losses = torch.exp(ce_losses)  # Shape: (batch_size, K)

        # Sum of exponentials across teachers for each sample
        sum_exp_losses = torch.sum(exp_losses, dim=1, keepdim=True)  # Shape: (batch_size, 1)

        # Compute weights using the formula:
        # w_KD^k = (1 / (K - 1)) * [1 - exp(L_CE^k) / sum_j exp(L_CE^j)]
        weights = (1.0 / (K - 1)) * (1.0 - (exp_losses / sum_exp_losses))  # Shape: (batch_size, K)

        return weights  # Shape: (batch_size, K)


    def aggregate_teacher_logits(self, teacher_logits: List[torch.Tensor]) -> torch.Tensor:
        """
        Aggregates logits from multiple teacher models by averaging their softened probabilities.

        Args:
            teacher_logits (list of torch.Tensor): List of logits from each teacher. Each tensor has shape (batch_size, C).

        Returns:
            torch.Tensor: Averaged softened probabilities. Shape: (batch_size, C).
        """
        # Apply temperature scaling and softmax to each teacher's logits
        teacher_softmax = [F.softmax(logit / self.temperature, dim=1) for logit in teacher_logits]
        # Average the softened probabilities
        ensemble_probs = torch.mean(torch.stack(teacher_softmax, dim=0), dim=0)  # Shape: (batch_size, C)
        return ensemble_probs

    def compute_gate(self, avg_teacher_conf: torch.Tensor) -> torch.Tensor:
        """
        Computes the gate value based on average teacher confidence.

        Args:
            avg_teacher_conf (torch.Tensor): Average confidence of teachers for each sample. Shape: (batch_size,)

        Returns:
            torch.Tensor: Scalar gate value.
        """
        gate = torch.sigmoid(self.gamma * (avg_teacher_conf - self.threshold))  # Shape: (batch_size,)
        gate_scalar = torch.mean(gate)  # Scalar
        return gate_scalar

    def compute_average_teacher_confidence(self, teacher_logits: List[torch.Tensor], labels: torch.Tensor) -> torch.Tensor:
        """
        Computes the average confidence that teachers have in the true class.

        Args:
            teacher_logits (list of torch.Tensor): List of logits from each teacher. Each tensor has shape (batch_size, C).
            labels (torch.Tensor): Ground truth labels. Shape: (batch_size,)

        Returns:
            torch.Tensor: Average teacher confidence per sample. Shape: (batch_size,)
        """
        teacher_probs = []
        for logit in teacher_logits:
            probs = F.softmax(logit, dim=1)
            true_class_probs = probs.gather(1, labels.view(-1, 1)).squeeze(1)  # Shape: (batch_size,)
            teacher_probs.append(true_class_probs)
        teacher_probs = torch.stack(teacher_probs, dim=1)  # Shape: (batch_size, K)
        avg_teacher_conf = torch.mean(teacher_probs, dim=1)  # Shape: (batch_size,)
        return avg_teacher_conf

    def compute_losses(self, student_logits, labels, teacher_logits):
        """
        Compute the total loss given student predictions, labels, and teacher logits.
        Handles soft (KD) and hard (CE) losses based on configuration and edge cases.

        Args:
            student_logits (torch.Tensor): Logits from the student model. Shape: (batch_size, C)
            labels (torch.Tensor): Ground truth labels. Shape: (batch_size,)
            teacher_logits (list of torch.Tensor): List of logits from each teacher. Each tensor has shape: (batch_size, C)

        Returns:
            torch.Tensor: Total loss.
        """
        loss_total = torch.tensor(0.0, device=self.device)
        log_dict = {}

        # If soft loss is enabled
        if self.use_soft_loss:
            # Compute dynamic weights based on teacher cross-entropy losses
            weights = self.compute_weights(teacher_logits, labels)  # Shape: (batch_size, K)

            # Compute KD loss for each teacher and weight them
            kd_loss = 0.0
            for k in range(self.num_teachers):
                # Teacher logits for teacher k
                z_T_k = teacher_logits[k]  # Shape: (batch_size, C)
                # Student log probabilities
                log_student = F.log_softmax(student_logits / self.temperature, dim=1)
                # Teacher probabilities
                soft_teacher = F.softmax(z_T_k / self.temperature, dim=1)
                # KL Divergence loss
                loss_kd_k = self.kl_div_loss(log_student, soft_teacher) * (self.temperature ** 2) # Already batchmean
                # Weight for teacher k (average over batch)
                weight_k = torch.mean(weights[:, k])
                kd_loss += weight_k * loss_kd_k

            log_dict['loss_soft'] = kd_loss

        # If hard loss is enabled
        if self.use_hard_loss:
            loss_hard = self.ce_loss(student_logits, labels)
            log_dict['loss_hard'] = loss_hard

        # Introduce the Global Confidence Gate
        if self.use_soft_loss and self.use_hard_loss:
            if self.num_teachers > 1:
                # Compute teacher confidence: average probability assigned to the true class across all teachers
                avg_teacher_conf = self.compute_average_teacher_confidence(teacher_logits, labels)  # Shape: (batch_size,)

                # Compute the gate using sigmoid function
                gate = torch.sigmoid(self.gamma * (avg_teacher_conf - self.threshold))  # Shape: (batch_size,)

                # Average gate over the batch to get a scalar
                gate_scalar = torch.mean(gate)  # Scalar

                # Scale the KD loss by the gate and adjust hard loss accordingly
                loss_total = gate_scalar * kd_loss + (1 - gate_scalar) * (self.alpha * loss_hard)

                # Re-log the scaled losses
                log_dict['loss_kd_scaled'] = gate_scalar * kd_loss
                log_dict['loss_hard_scaled'] = (1 - gate_scalar) * (self.alpha * loss_hard)
                log_dict['gate'] = gate_scalar
                log_dict['avg_teacher_conf'] = avg_teacher_conf.mean()
            else:
                # Single teacher: Fixed combination without gate
                loss_total = self.alpha * log_dict['loss_soft'] + (1 - self.alpha) * log_dict['loss_hard']
                log_dict['gate'] = torch.tensor(1.0, device=self.device)  # Gate is effectively fully open
                log_dict['avg_teacher_conf'] = self.compute_average_teacher_confidence(teacher_logits, labels).mean()
        elif self.use_soft_loss:
            # Only soft loss
            loss_total = log_dict['loss_soft']
            log_dict['gate'] = torch.tensor(1.0, device=self.device)  # Gate is fully open
            log_dict['avg_teacher_conf'] = self.compute_average_teacher_confidence(teacher_logits, labels).mean()
        elif self.use_hard_loss:
            # Only hard loss
            loss_total = self.alpha * log_dict['loss_hard']
            log_dict['gate'] = torch.tensor(0.0, device=self.device)  # Gate is fully closed
            log_dict['avg_teacher_conf'] = torch.tensor(0.0, device=self.device)  # No teacher confidence

        # Logging
        self.log_dict({
            'train/loss_soft': log_dict.get('loss_soft', torch.tensor(0.0, device=self.device)),
            'train/loss_hard': log_dict.get('loss_hard', torch.tensor(0.0, device=self.device)),
            'train/gate': log_dict.get('gate', torch.tensor(0.0, device=self.device)),
            'train/loss_total': loss_total,
            'train/avg_teacher_conf': log_dict.get('avg_teacher_conf', torch.tensor(0.0, device=self.device))
        }, on_epoch=True, prog_bar=True)

        return loss_total

    def training_step(self, batch, batch_idx):
        """
        Training step for the Lightning Module.

        Args:
            batch (dict): Batch of data containing teacher inputs, student inputs, and labels.
            batch_idx (int): Batch index.

        Returns:
            torch.Tensor: Total loss.
        """
        labels = batch['label'].to(self.device)

        # Extract student inputs
        student_images = batch['student_input'].to(self.device)
        student_logits = self.student(student_images)

        # Extract teacher inputs and compute teacher logits
        teacher_logits = []
        with torch.no_grad():
            for i in range(self.num_teachers):
                teacher_input_key = f'teacher_input_{i}'
                if teacher_input_key not in batch:
                    raise KeyError(f"Batch is missing key: {teacher_input_key}")
                teacher_images = batch[f'teacher_input_{i}'].to(self.device)
                logits = self.teachers[i](teacher_images)
                teacher_logits.append(logits)

        # Compute total loss
        loss_total = self.compute_losses(student_logits, labels, teacher_logits)

        return loss_total

    def shared_eval_step(self, batch, stage: str):
        """
        A shared method for validation and test steps to avoid code duplication.
        Computes and logs total loss, AUROC, and accuracy.

        Args:
            batch (dict): Batch of data.
            stage (str): 'val' or 'test'.

        Returns:
            torch.Tensor: Accuracy for the batch.
        """
        labels = batch['label'].to(self.device)

        # Extract student inputs
        student_images = batch['student_input'].to(self.device)
        student_logits = self.student(student_images)

        # Extract teacher inputs and compute teacher logits
        teacher_logits = []
        with torch.no_grad():
            for i in range(self.num_teachers):
                teacher_input_key = f'teacher_input_{i}'
                if teacher_input_key not in batch:
                    raise KeyError(f"Batch is missing key: {teacher_input_key}")
                teacher_images = batch[f'teacher_input_{i}'].to(self.device)
                logits = self.teachers[i](teacher_images)
                teacher_logits.append(logits)

        # Compute average teacher confidence
        if self.use_soft_loss and self.num_teachers > 1:
            avg_teacher_conf = self.compute_average_teacher_confidence(teacher_logits, labels)  # Shape: (batch_size,)
            gate = torch.sigmoid(self.gamma * (avg_teacher_conf - self.threshold))  # Shape: (batch_size,)
            gate_scalar = torch.mean(gate)  # Scalar

            # Compute KD loss with weights
            weights = self.compute_weights(teacher_logits, labels)  # Shape: (batch_size, K)
            kd_loss = 0.0
            for k in range(self.num_teachers):
                z_T_k = teacher_logits[k]  # Shape: (batch_size, C)
                log_student = F.log_softmax(student_logits / self.temperature, dim=1)
                soft_teacher = F.softmax(z_T_k / self.temperature, dim=1)
                loss_kd_k = self.kl_div_loss(log_student, soft_teacher)  # Already batchmean
                weight_k = torch.mean(weights[:, k])
                kd_loss += weight_k * loss_kd_k

            # Compute CE loss
            if self.use_hard_loss:
                ce_loss = self.ce_loss(student_logits, labels)
            else:
                ce_loss = torch.tensor(0.0, device=self.device)

            # Combine losses using the gate
            if self.use_hard_loss:
                loss_total = gate_scalar * kd_loss + (1 - gate_scalar) * (self.alpha * ce_loss)
            else:
                loss_total = kd_loss

            # Log scaled losses
            log_dict = {
                f'{stage}/loss_soft': kd_loss,
                f'{stage}/gate': gate_scalar,
                f'{stage}/loss_total': loss_total,
                f'{stage}/avg_teacher_conf': avg_teacher_conf.mean()
            }
            if self.use_hard_loss:
                log_dict[f'{stage}/loss_hard'] = ce_loss
                log_dict[f'{stage}/loss_hard_scaled'] = (1 - gate_scalar) * (self.alpha * ce_loss)

        elif self.use_soft_loss and self.num_teachers == 1:
            # Single teacher: Fixed combination without gate
            kd_loss = self.compute_losses(student_logits, labels, teacher_logits)
            loss_total = self.alpha * F.cross_entropy(teacher_logits[0], labels) + (1 - self.alpha) * self.ce_loss(student_logits, labels)

            log_dict = {
                f'{stage}/loss_soft': kd_loss,
                f'{stage}/loss_hard': self.ce_loss(student_logits, labels),
                f'{stage}/gate': torch.tensor(1.0, device=self.device),
                f'{stage}/loss_total': loss_total,
                f'{stage}/avg_teacher_conf': self.compute_average_teacher_confidence(teacher_logits, labels).mean()
            }

        elif self.use_hard_loss and not self.use_soft_loss:
            # Only hard loss
            ce_loss = self.ce_loss(student_logits, labels)
            loss_total = self.alpha * ce_loss

            log_dict = {
                f'{stage}/loss_hard': ce_loss,
                f'{stage}/gate': torch.tensor(0.0, device=self.device),
                f'{stage}/loss_total': loss_total,
                f'{stage}/avg_teacher_conf': torch.tensor(0.0, device=self.device)
            }

        else:
            # Only soft loss without hard loss
            kd_loss = self.compute_losses(student_logits, labels, teacher_logits)
            loss_total = kd_loss

            log_dict = {
                f'{stage}/loss_soft': kd_loss,
                f'{stage}/gate': torch.tensor(1.0, device=self.device),
                f'{stage}/loss_total': loss_total,
                f'{stage}/avg_teacher_conf': self.compute_average_teacher_confidence(teacher_logits, labels).mean()
            }

        # Logging
        self.log_dict(log_dict, on_epoch=True, prog_bar=True)

        # Compute predictions and metrics
        preds = student_logits.argmax(dim=1)
        acc = (preds == labels).float().mean()

        # Compute AUROC
        if stage == 'val':
            self.val_auroc(student_logits, labels)
            self.log('val/auroc', self.val_auroc, on_epoch=True, prog_bar=True)
        elif stage == 'test':
            self.test_auroc(student_logits, labels)
            self.log('test/auroc', self.test_auroc, on_epoch=True, prog_bar=True)

        # Log accuracy
        self.log(f'{stage}/accuracy', acc, on_epoch=True, prog_bar=True)

        return acc

    def validation_step(self, batch, batch_idx):
        return self.shared_eval_step(batch, 'val')

    def test_step(self, batch, batch_idx):
        return self.shared_eval_step(batch, 'test')

### DynamicKDLitModel with the most confident teacher and it's confidence(probability of the correct class) as alpha

In [None]:
class DynamicKDLitModel(pl.LightningModule):
    def __init__(
        self,
        teacher_models: List[torch.nn.Module],
        student_model: torch.nn.Module,
        temperature: float = 4.0,
        gamma: float = 10.0,
        threshold: float = 0.5,
        learning_rate: float = 1e-3,
        use_soft_loss: bool = True,
        use_hard_loss: bool = True,
        alpha: float = 0.5
    ):
        """
        Initializes the Dynamic Knowledge Distillation Lightning Module.

        Args:
            teacher_models (list of torch.nn.Module): Pre-trained teacher models
                (required if use_soft_loss=True).
            student_model (torch.nn.Module): The student model to be trained.
            temperature (float): Temperature for softening probabilities (KD).
            gamma (float): (Unused now, was for sigmoid gating, kept for compatibility).
            threshold (float): (Unused now, was for gating threshold).
            learning_rate (float): Learning rate.
            use_soft_loss (bool): Whether to use KD (KLDiv).
            use_hard_loss (bool): Whether to use cross-entropy with ground truth.
            alpha (float): (No longer a fixed alpha, but kept for config compatibility
                if you want a fallback or other usage).
        """
        super(DynamicKDLitModel, self).__init__()

        self.save_hyperparameters(ignore=["teacher_models", "student_model"])

        # Flags
        self.use_soft_loss = use_soft_loss
        self.use_hard_loss = use_hard_loss
        self.alpha = alpha  # Not used as a direct weight anymore; we use dynamic alpha
        self.temperature = temperature
        self.gamma = gamma
        self.threshold = threshold
        self.learning_rate = learning_rate

        if not (self.use_soft_loss or self.use_hard_loss):
            raise ValueError("At least one of 'use_soft_loss' or 'use_hard_loss' must be True.")

        # Student model
        if student_model is None:
            raise ValueError("A student model must be provided.")
        self.student = student_model

        # Teacher models
        if self.use_soft_loss:
            if not teacher_models or len(teacher_models) == 0:
                raise ValueError("Teacher models must be provided if 'use_soft_loss' is True.")
            self.teachers = teacher_models
            self.num_teachers = len(self.teachers)
            # Freeze teacher params
            for teacher in self.teachers:
                teacher.eval()
                for param in teacher.parameters():
                    param.requires_grad = False
            self.kl_div_loss = KLDivLoss(reduction='batchmean')
        else:
            self.teachers = None
            self.num_teachers = 0

        # Hard loss
        if self.use_hard_loss:
            self.ce_loss = CrossEntropyLoss()

        # Metrics
        self.num_classes = 100
        self.val_auroc = MulticlassAUROC(num_classes=self.num_classes, average='macro')
        self.test_auroc = MulticlassAUROC(num_classes=self.num_classes, average='macro')

    def forward(self, x):
        return self.student(x)

    def configure_optimizers(self):
        # Optimizer
        optimizer = torch.optim.AdamW(
            self.student.parameters(),
            lr=self.learning_rate,
            weight_decay=1e-4
        )
        # Scheduler
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer,
            T_max=100,
            eta_min=1e-6
        )
        return [optimizer], [scheduler]

    # ---------------------------------------------------------------
    # Returns (best_teacher_logits, best_teacher_idx)
    # where "best" is the teacher with highest probability on correct label.
    # ---------------------------------------------------------------
    def get_most_confident_teacher_logits(self, teacher_logits, labels):
        confidences_per_teacher = []
        for logit in teacher_logits:
            probs = F.softmax(logit, dim=1)
            correct_prob = probs.gather(1, labels.view(-1, 1)).squeeze(1)
            confidences_per_teacher.append(correct_prob)
        confidences_per_teacher = torch.stack(confidences_per_teacher, dim=1)  # (B, num_teachers)

        # Indices of teacher with max confidence per sample
        max_conf_teacher_idx = torch.argmax(confidences_per_teacher, dim=1)  # (B,)

        # Gather the logits from the selected teacher for each sample
        batch_size = labels.size(0)
        out_logits = []
        for i in range(batch_size):
            idx = max_conf_teacher_idx[i]
            out_logits.append(teacher_logits[idx][i].unsqueeze(0))
        out_logits = torch.cat(out_logits, dim=0)  # (B, C)
        return out_logits, max_conf_teacher_idx

    # ---------------------------------------------------------------
    # Instead of returning average, we return the maximum teacher confidence
    # for the correct label per sample.
    # ---------------------------------------------------------------
    def compute_max_teacher_confidence(self, teacher_logits, labels):
        teacher_probs = []
        for logit in teacher_logits:
            probs = F.softmax(logit, dim=1)
            true_class_probs = probs.gather(1, labels.view(-1, 1)).squeeze(1)
            teacher_probs.append(true_class_probs)
        teacher_probs = torch.stack(teacher_probs, dim=1)  # (B, num_teachers)

        # Max across teachers
        max_confidence, _ = torch.max(teacher_probs, dim=1)  # (B,)
        return max_confidence

    # ---------------------------------------------------------------
    # We no longer use gating. Instead, alpha = (mean) max_confidence
    # for the batch. Then total_loss = alpha * KD + (1-alpha) * Hard
    # ---------------------------------------------------------------
    def compute_losses(self, student_logits, labels, teacher_logits):
        device = student_logits.device
        loss_total = torch.tensor(0.0, device=device)
        log_dict = {}

        # 1) If we have KD
        if self.use_soft_loss:
            best_teacher_logits, best_teacher_idx = self.get_most_confident_teacher_logits(teacher_logits, labels)
            log_student = F.log_softmax(student_logits / self.temperature, dim=1)
            soft_teacher = F.softmax(best_teacher_logits / self.temperature, dim=1)
            kd_loss = self.kl_div_loss(log_student, soft_teacher) * (self.temperature ** 2)
            log_dict['loss_soft'] = kd_loss

            # Log which teacher was used
            if self.num_teachers > 1:
                for i in range(self.num_teachers):
                    usage_fraction = (best_teacher_idx == i).float().mean()
                    self.log(f"train/teacher_{i}_usage_fraction", usage_fraction, on_step=False, on_epoch=True)
        else:
            kd_loss = torch.tensor(0.0, device=device)

        # 2) If we have Hard
        if self.use_hard_loss:
            loss_hard = self.ce_loss(student_logits, labels)
            log_dict['loss_hard'] = loss_hard
        else:
            loss_hard = torch.tensor(0.0, device=device)

        # 3) Combine via alpha = average(max_confidence) from the teachers
        #    for the correct label.  If both soft & hard are used, do:
        #
        #      L_total = alpha * KD + (1-alpha) * Hard
        #
        if self.use_soft_loss and self.use_hard_loss:
            # Compute per-sample max confidence
            max_conf = self.compute_max_teacher_confidence(teacher_logits, labels)  # shape: (B,)
            alpha_batch = max_conf  # alpha_i for each sample
            alpha_scalar = alpha_batch.mean()  # average alpha over the batch

            # Weighted combination
            loss_total = alpha_scalar * kd_loss + (1.0 - alpha_scalar) * loss_hard

            # Log alpha, etc.
            log_dict['dynamic_alpha'] = alpha_scalar
            log_dict['max_teacher_conf_mean'] = max_conf.mean()

        elif self.use_soft_loss:
            # Only KD
            loss_total = kd_loss
            log_dict['dynamic_alpha'] = torch.tensor(1.0, device=device)
            log_dict['max_teacher_conf_mean'] = self.compute_max_teacher_confidence(teacher_logits, labels).mean()

        elif self.use_hard_loss:
            # Only hard loss
            loss_total = loss_hard
            log_dict['dynamic_alpha'] = torch.tensor(0.0, device=device)
            log_dict['max_teacher_conf_mean'] = torch.tensor(0.0, device=device)

        # Logging
        # We'll log the final losses and alpha to W&B
        self.log_dict({
            'train/loss_soft': log_dict.get('loss_soft', torch.tensor(0.0, device=device)),
            'train/loss_hard': log_dict.get('loss_hard', torch.tensor(0.0, device=device)),
            'train/loss_total': loss_total,
            'train/dynamic_alpha': log_dict.get('dynamic_alpha', torch.tensor(0.0, device=device)),
            'train/max_teacher_conf_mean': log_dict.get('max_teacher_conf_mean', torch.tensor(0.0, device=device))
        }, on_epoch=True, prog_bar=True)

        return loss_total

    def training_step(self, batch, batch_idx):
        labels = batch['label'].to(self.device)
        student_images = batch['student_input'].to(self.device)
        student_logits = self.student(student_images)

        teacher_logits = []
        with torch.no_grad():
            for i in range(self.num_teachers):
                tkey = f'teacher_input_{i}'
                if tkey not in batch:
                    raise KeyError(f"Missing {tkey} in batch")
                teacher_images = batch[tkey].to(self.device)
                logits = self.teachers[i](teacher_images)
                teacher_logits.append(logits)

        loss_total = self.compute_losses(student_logits, labels, teacher_logits)
        return loss_total

    def shared_eval_step(self, batch, stage: str):
        labels = batch['label'].to(self.device)
        student_images = batch['student_input'].to(self.device)
        student_logits = self.student(student_images)

        teacher_logits = []
        with torch.no_grad():
            for i in range(self.num_teachers):
                tkey = f'teacher_input_{i}'
                if tkey not in batch:
                    raise KeyError(f"Missing {tkey} in batch")
                teacher_images = batch[tkey].to(self.device)
                logits = self.teachers[i](teacher_images)
                teacher_logits.append(logits)

        # Recompute KD/hard in eval for logging
        if self.use_soft_loss:
            best_teacher_logits, best_teacher_idx = self.get_most_confident_teacher_logits(teacher_logits, labels)
            log_student = F.log_softmax(student_logits / self.temperature, dim=1)
            soft_teacher = F.softmax(best_teacher_logits / self.temperature, dim=1)
            kd_loss = self.kl_div_loss(log_student, soft_teacher)

            if self.num_teachers > 1:
                for i in range(self.num_teachers):
                    usage_fraction = (best_teacher_idx == i).float().mean()
                    self.log(f"{stage}/teacher_{i}_usage_fraction", usage_fraction, on_step=False, on_epoch=True)
        else:
            kd_loss = torch.tensor(0.0, device=self.device)

        if self.use_hard_loss:
            ce_loss = self.ce_loss(student_logits, labels)
        else:
            ce_loss = torch.tensor(0.0, device=self.device)

        # alpha = average of max teacher confidence
        if self.use_soft_loss and self.use_hard_loss:
            max_conf = self.compute_max_teacher_confidence(teacher_logits, labels)
            alpha_scalar = max_conf.mean()
            loss_total = alpha_scalar * kd_loss + (1 - alpha_scalar) * ce_loss
        elif self.use_soft_loss and not self.use_hard_loss:
            loss_total = kd_loss
            alpha_scalar = torch.tensor(1.0, device=self.device)
            max_conf = self.compute_max_teacher_confidence(teacher_logits, labels)
        elif self.use_hard_loss and not self.use_soft_loss:
            loss_total = ce_loss
            alpha_scalar = torch.tensor(0.0, device=self.device)
            max_conf = torch.tensor(0.0, device=self.device)
        else:
            loss_total = torch.tensor(0.0, device=self.device)
            alpha_scalar = torch.tensor(0.0, device=self.device)
            max_conf = torch.tensor(0.0, device=self.device)

        log_dict = {
            f'{stage}/loss_soft': kd_loss,
            f'{stage}/loss_hard': ce_loss,
            f'{stage}/loss_total': loss_total,
            f'{stage}/dynamic_alpha': alpha_scalar,
            f'{stage}/max_teacher_conf_mean': max_conf.mean()
        }
        self.log_dict(log_dict, on_epoch=True, prog_bar=True)

        # Accuracy & AUROC
        preds = student_logits.argmax(dim=1)
        acc = (preds == labels).float().mean()
        self.log(f'{stage}/accuracy', acc, on_epoch=True, prog_bar=True)

        if stage == 'val':
            self.val_auroc(student_logits, labels)
            self.log('val/auroc', self.val_auroc, on_epoch=True, prog_bar=True)
        elif stage == 'test':
            self.test_auroc(student_logits, labels)
            self.log('test/auroc', self.test_auroc, on_epoch=True, prog_bar=True)

        return acc

    def validation_step(self, batch, batch_idx):
        return self.shared_eval_step(batch, 'val')

    def test_step(self, batch, batch_idx):
        return self.shared_eval_step(batch, 'test')


## Training the Baseline Student Model

### Training Student Model from Scratch

In [None]:
# Initialize MobileNetV2 as the student model
student_model = timm.create_model('mobilenetv2_100', pretrained=False)  # Untrained
student_model = student_model.to('cuda' if torch.cuda.is_available() else 'cpu')


In [None]:
# Modify the classifier to match CIFAR-100 classes
num_ftrs_student = student_model.get_classifier().in_features
student_model.reset_classifier(num_classes=100)


In [None]:
wandb_logger_kd = WandbLogger(
    project='Knowledge-Distillation-CIFAR100-Baseline-Accuracy-Val',
    name='KD-MobileNetV2-From-Scratch',
    log_model='all',
    resume='allow',
)

# Initialize the Knowledge Distillation Lightning Module with only hard label loss
kd_model_student_mobilnet_untrained = DynamicKDLitModel(
    teacher_models=None,                 # No teacher models needed
    student_model=student_model,        # Student model (e.g., MobileNetV2)
    temperature=None,                      # Not utilized since soft loss is disabled
    gamma=None,                           # Not utilized since soft loss is disabled
    threshold=None,                        # Not utilized since soft loss is disabled
    learning_rate=0.01,
    use_soft_loss=False,                 # Disable soft label loss
    use_hard_loss=True,                  # Enable hard label loss
    alpha=None                             # Not utilized since soft loss is disabled
)


In [None]:
# Initialize PyTorch Lightning Trainer
trainer_kd = pl.Trainer(
    max_epochs=100,
    devices=1 if torch.cuda.is_available() else None,
    accelerator='gpu' if torch.cuda.is_available() else 'cpu',
    logger=wandb_logger_kd,
    precision=16 if torch.cuda.is_available() else 32,
    log_every_n_steps=50,
    callbacks=[
        pl.callbacks.EarlyStopping(monitor='val/accuracy', patience=15, mode='max', verbose=True),
    ],
)

# Initialize the Data Module
data_module = CIFAR100DataModule(data_dir='/content/data', batch_size=128, num_workers=2, model='mobilenet')

/usr/local/lib/python3.10/dist-packages/lightning_fabric/connector.py:571: `precision=16` is supported for historical reasons but its usage is discouraged. Please set your precision to 16-mixed instead!
INFO:pytorch_lightning.utilities.rank_zero:Using 16bit Automatic Mixed Precision (AMP)
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs


In [None]:
trainer_kd.fit(kd_model_student_mobilnet_untrained, datamodule=data_module)


Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to /content/data/cifar-100-python.tar.gz


100%|██████████| 169M/169M [00:05<00:00, 30.0MB/s]


Extracting /content/data/cifar-100-python.tar.gz to /content/data
Files already downloaded and verified


[34m[1mwandb[0m: Currently logged in as: [33maidakhmetov-2115331[0m ([33maidakhmetov-2115331-sapienza-universit-di-roma[0m). Use [1m`wandb login --relogin`[0m to force relogin


/usr/local/lib/python3.10/dist-packages/pytorch_lightning/callbacks/model_checkpoint.py:654: Checkpoint directory /content/drive/MyDrive exists and is not empty.
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name       | Type             | Params | Mode 
--------------------------------------------------------
0 | student    | EfficientNet     | 2.4 M  | train
1 | ce_loss    | CrossEntropyLoss | 0      | train
2 | cos_sim    | CosineSimilarity | 0      | train
3 | val_auroc  | MulticlassAUROC  | 0      | train
4 | test_auroc | MulticlassAUROC  | 0      | train
--------------------------------------------------------
2.4 M     Trainable params
0         Non-trainable params
2.4 M     Total params
9.408     Total estimated model params size (MB)
293       Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]



Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val/accuracy improved. New best score: 0.067


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val/accuracy improved by 0.055 >= min_delta = 0.0. New best score: 0.122


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val/accuracy improved by 0.055 >= min_delta = 0.0. New best score: 0.177


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val/accuracy improved by 0.058 >= min_delta = 0.0. New best score: 0.235


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val/accuracy improved by 0.045 >= min_delta = 0.0. New best score: 0.279


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val/accuracy improved by 0.040 >= min_delta = 0.0. New best score: 0.319


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val/accuracy improved by 0.033 >= min_delta = 0.0. New best score: 0.353


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val/accuracy improved by 0.030 >= min_delta = 0.0. New best score: 0.383


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val/accuracy improved by 0.024 >= min_delta = 0.0. New best score: 0.407


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val/accuracy improved by 0.024 >= min_delta = 0.0. New best score: 0.431


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val/accuracy improved by 0.013 >= min_delta = 0.0. New best score: 0.443


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val/accuracy improved by 0.025 >= min_delta = 0.0. New best score: 0.468


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val/accuracy improved by 0.019 >= min_delta = 0.0. New best score: 0.487


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val/accuracy improved by 0.012 >= min_delta = 0.0. New best score: 0.499


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val/accuracy improved by 0.004 >= min_delta = 0.0. New best score: 0.504


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val/accuracy improved by 0.015 >= min_delta = 0.0. New best score: 0.518


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val/accuracy improved by 0.022 >= min_delta = 0.0. New best score: 0.540


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val/accuracy improved by 0.015 >= min_delta = 0.0. New best score: 0.555


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val/accuracy improved by 0.016 >= min_delta = 0.0. New best score: 0.571


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val/accuracy improved by 0.011 >= min_delta = 0.0. New best score: 0.583


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val/accuracy improved by 0.005 >= min_delta = 0.0. New best score: 0.588


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val/accuracy improved by 0.002 >= min_delta = 0.0. New best score: 0.590


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val/accuracy improved by 0.002 >= min_delta = 0.0. New best score: 0.592


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val/accuracy improved by 0.008 >= min_delta = 0.0. New best score: 0.599


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val/accuracy improved by 0.001 >= min_delta = 0.0. New best score: 0.601


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val/accuracy improved by 0.000 >= min_delta = 0.0. New best score: 0.601


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val/accuracy improved by 0.011 >= min_delta = 0.0. New best score: 0.612


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val/accuracy improved by 0.046 >= min_delta = 0.0. New best score: 0.657


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val/accuracy improved by 0.003 >= min_delta = 0.0. New best score: 0.660


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val/accuracy improved by 0.001 >= min_delta = 0.0. New best score: 0.661


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val/accuracy improved by 0.002 >= min_delta = 0.0. New best score: 0.663


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val/accuracy improved by 0.001 >= min_delta = 0.0. New best score: 0.664


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val/accuracy improved by 0.001 >= min_delta = 0.0. New best score: 0.665


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val/accuracy improved by 0.003 >= min_delta = 0.0. New best score: 0.668


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Monitored metric val/accuracy did not improve in the last 15 records. Best score: 0.668. Signaling Trainer to stop.


## Training with Knowledge Distillation

Initialize teacher names that will be used during training, for a single teacher, just pass one teacher name.

In [None]:
teacher_models = []
# teacher_names = ['resnet50_cifar100', 'densenet121_cifar100']
teacher_names = ['resnet50_cifar100', 'resnet18_cifar100', 'resnet34_cifar100']
# teacher_names = ['vit']

Below code is for loading pre-trained models from huggingface.

In [None]:
for name in teacher_names:
    if name == 'vit':
      teacher = timm.create_model("timm/vit_base_patch16_224.orig_in21k_ft_in1k",
pretrained=False)
      teacher.head = nn.Linear(teacher.head.in_features, 100)
      teacher.load_state_dict(
          torch.hub.load_state_dict_from_url(
              "https://huggingface.co/edadaltocg/vit_base_patch16_224_in21k_ft_cifar100/resolve/main/pytorch_model.bin",
              map_location="cpu",
              file_name="vit_base_patch16_224_in21k_ft_cifar100.pth",
          )
      )
    else:
      try:
          teacher = timm.create_model(name, pretrained=True)
      except:
          raise ValueError(f"Invalid teacher model name: {name}")
    teacher = teacher.to('cuda' if torch.cuda.is_available() else 'cpu')
    teacher.eval()  # Set to evaluation mode
    teacher_models.append(teacher)

Downloading: "https://huggingface.co/edadaltocg/resnet50_cifar100/resolve/main/pytorch_model.bin" to /root/.cache/torch/hub/checkpoints/resnet50_cifar100.pth
100%|██████████| 90.7M/90.7M [00:00<00:00, 160MB/s]
Downloading: "https://huggingface.co/edadaltocg/resnet18_cifar100/resolve/main/pytorch_model.bin" to /root/.cache/torch/hub/checkpoints/resnet18_cifar100.pth
100%|██████████| 42.9M/42.9M [00:02<00:00, 19.2MB/s]
Downloading: "https://huggingface.co/edadaltocg/resnet34_cifar100/resolve/main/pytorch_model.bin" to /root/.cache/torch/hub/checkpoints/resnet34_cifar100.pth
100%|██████████| 81.5M/81.5M [00:00<00:00, 126MB/s]


#### Experiments

In [None]:
def experiment(teacher_models, teacher_names, alpha, threshold, run_id, ckpth_path):
    # Format alpha to avoid decimal points in project name
    alpha_str = f"{alpha}"

    # Define project and run names
    project = f'Knowledge-Distillation-CIFAR100-Multiple-Teachers-three-resnet-alpha_{alpha_str}_{run_id}'
    name = f'KD-Multiple-Teachers-MobilNetV2-Untrained_alpha_{alpha_str}'

    # Initialize WandbLogger with unique project and run names
    wandb_logger_kd = WandbLogger(
        project=project,
        name=name,
        log_model='all',
        resume='allow',
        id=run_id,
    )

    # Initialize MobileNetV2 as the student model
    student_model = timm.create_model('mobilenetv2_100', pretrained=False)  # Untrained
    student_model = student_model.to('cuda' if torch.cuda.is_available() else 'cpu')

    # Modify the classifier to match CIFAR-100 classes
    num_ftrs_student = student_model.get_classifier().in_features
    student_model.reset_classifier(num_classes=100)

    kd_model = DynamicKDLitModel(
        teacher_models=teacher_models,
        student_model=student_model,
        temperature=4.0,
        gamma=5.0,
        threshold=threshold,
        learning_rate=0.01,
        use_soft_loss=True,
        use_hard_loss=False,
        alpha=alpha
    )

    data_module = CIFAR100DataModule(data_dir='/content/data', batch_size=128, num_workers=2, teacher_models=teacher_names, student_model='mobilenet')

    # Trainer setup
    trainer = pl.Trainer(
        max_epochs=150,
        devices=1 if torch.cuda.is_available() else None,
        accelerator='gpu' if torch.cuda.is_available() else 'cpu',
        logger=wandb_logger_kd,
        precision=16 if torch.cuda.is_available() else 32,
        callbacks=[
            pl.callbacks.EarlyStopping(monitor='val/accuracy', patience=30, mode='max', verbose=True),
        ]
    )

    # Decide if we want to start training from a checkpoint
    if ckpth_path:
        trainer.fit(kd_model, datamodule=data_module, ckpt_path=ckpth_path)
    else:
        trainer.fit(kd_model, datamodule=data_module)

Begin training from a checkpoint:

In [None]:
# To continue training you can run the code below
# experiment(teacher_models, ['resnet', 'densenet'], None, None, 'x4b4e1o1', '/content/model (6).ckpt')

Begin training from scratch:

In [None]:
run_id = wandb.util.generate_id()
print(f"Starting experiment with run_id={run_id}")
experiment(teacher_models, ['resnet', 'resnet', 'resnet'], None, None, run_id, '')

/usr/local/lib/python3.10/dist-packages/lightning_fabric/connector.py:572: `precision=16` is supported for historical reasons but its usage is discouraged. Please set your precision to 16-mixed instead!
INFO:pytorch_lightning.utilities.rank_zero:Using 16bit Automatic Mixed Precision (AMP)
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.utilities.rank_zero:You are using a CUDA device ('NVIDIA A100-SXM4-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision


Starting experiment with run_id=19vavgzn
Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to /content/data/cifar-100-python.tar.gz


100%|██████████| 169M/169M [00:13<00:00, 12.9MB/s]


Extracting /content/data/cifar-100-python.tar.gz to /content/data
Files already downloaded and verified


[34m[1mwandb[0m: Currently logged in as: [33maidakhmetov-2115331[0m ([33maidakhmetov-2115331-sapienza-universit-di-roma[0m). Use [1m`wandb login --relogin`[0m to force relogin


INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name        | Type            | Params | Mode 
--------------------------------------------------------
0 | student     | EfficientNet    | 2.4 M  | train
1 | kl_div_loss | KLDivLoss       | 0      | train
2 | val_auroc   | MulticlassAUROC | 0      | train
3 | test_auroc  | MulticlassAUROC | 0      | train
--------------------------------------------------------
2.4 M     Trainable params
0         Non-trainable params
2.4 M     Total params
9.408     Total estimated model params size (MB)
292       Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]



Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val/accuracy improved. New best score: 0.068


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val/accuracy improved by 0.084 >= min_delta = 0.0. New best score: 0.152


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val/accuracy improved by 0.062 >= min_delta = 0.0. New best score: 0.215


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val/accuracy improved by 0.072 >= min_delta = 0.0. New best score: 0.287


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val/accuracy improved by 0.065 >= min_delta = 0.0. New best score: 0.352


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val/accuracy improved by 0.056 >= min_delta = 0.0. New best score: 0.408


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val/accuracy improved by 0.011 >= min_delta = 0.0. New best score: 0.419


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val/accuracy improved by 0.079 >= min_delta = 0.0. New best score: 0.498


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val/accuracy improved by 0.003 >= min_delta = 0.0. New best score: 0.501


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val/accuracy improved by 0.037 >= min_delta = 0.0. New best score: 0.538


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val/accuracy improved by 0.004 >= min_delta = 0.0. New best score: 0.543


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val/accuracy improved by 0.030 >= min_delta = 0.0. New best score: 0.573


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val/accuracy improved by 0.024 >= min_delta = 0.0. New best score: 0.597


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val/accuracy improved by 0.017 >= min_delta = 0.0. New best score: 0.613


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val/accuracy improved by 0.004 >= min_delta = 0.0. New best score: 0.617


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val/accuracy improved by 0.005 >= min_delta = 0.0. New best score: 0.622


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val/accuracy improved by 0.009 >= min_delta = 0.0. New best score: 0.631


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val/accuracy improved by 0.011 >= min_delta = 0.0. New best score: 0.643


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val/accuracy improved by 0.004 >= min_delta = 0.0. New best score: 0.646


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val/accuracy improved by 0.009 >= min_delta = 0.0. New best score: 0.655


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val/accuracy improved by 0.000 >= min_delta = 0.0. New best score: 0.655


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val/accuracy improved by 0.005 >= min_delta = 0.0. New best score: 0.660


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val/accuracy improved by 0.008 >= min_delta = 0.0. New best score: 0.667


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val/accuracy improved by 0.007 >= min_delta = 0.0. New best score: 0.675


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val/accuracy improved by 0.009 >= min_delta = 0.0. New best score: 0.683


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val/accuracy improved by 0.006 >= min_delta = 0.0. New best score: 0.689


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val/accuracy improved by 0.005 >= min_delta = 0.0. New best score: 0.694


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val/accuracy improved by 0.004 >= min_delta = 0.0. New best score: 0.698


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val/accuracy improved by 0.003 >= min_delta = 0.0. New best score: 0.701


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val/accuracy improved by 0.002 >= min_delta = 0.0. New best score: 0.703


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val/accuracy improved by 0.002 >= min_delta = 0.0. New best score: 0.705


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val/accuracy improved by 0.008 >= min_delta = 0.0. New best score: 0.714


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val/accuracy improved by 0.001 >= min_delta = 0.0. New best score: 0.714


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val/accuracy improved by 0.006 >= min_delta = 0.0. New best score: 0.720


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val/accuracy improved by 0.000 >= min_delta = 0.0. New best score: 0.720


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val/accuracy improved by 0.004 >= min_delta = 0.0. New best score: 0.724


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val/accuracy improved by 0.006 >= min_delta = 0.0. New best score: 0.731


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val/accuracy improved by 0.002 >= min_delta = 0.0. New best score: 0.732


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val/accuracy improved by 0.004 >= min_delta = 0.0. New best score: 0.736


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val/accuracy improved by 0.000 >= min_delta = 0.0. New best score: 0.736


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val/accuracy improved by 0.001 >= min_delta = 0.0. New best score: 0.737


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:
Detected KeyboardInterrupt, attempting graceful shutdown ...


NameError: name 'exit' is not defined

If you want to finish runtime right after training ends, to not waste gpu resources, run below code.

In [None]:
from google.colab import runtime
runtime.unassign()

## Testing Models

In order to test model locally without `weights&biases`, first download artifact of the best model on the validation set.

In [None]:
import wandb
run = wandb.init()
artifact = run.use_artifact('aidakhmetov-2115331-sapienza-universit-di-roma/Knowledge-Distillation-CIFAR100-Multiple-Teachers_alpha_1_9tucinkr/model-9tucinkr:v49', type='model')
artifact_dir = artifact.download()

[34m[1mwandb[0m:   1 of 1 files downloaded.  


Then load a checkpoint

In [None]:
checkpoint = torch.load('/content/artifacts/model-9tucinkr:v49/model.ckpt', map_location=torch.device('cpu'))
print(checkpoint.keys())

dict_keys(['epoch', 'global_step', 'pytorch-lightning_version', 'state_dict', 'loops', 'callbacks', 'optimizer_states', 'lr_schedulers', 'MixedPrecision', 'hparams_name', 'hyper_parameters'])


  checkpoint = torch.load('/content/artifacts/model-9tucinkr:v49/model.ckpt', map_location=torch.device('cpu'))


Since, all weights and biases from `weights&biases` are starting from `student.`, we have to replace it with an empty string. Probably there is an easier method for loading a model from weights and biases, but we could not find it and did some brute-forcing.

In [None]:
new_state_dict = {}
for k, v in checkpoint['state_dict'].items():
    name = k.replace('student.', '')
    new_state_dict[name] = v


In [None]:
student_model = timm.create_model('mobilenetv2_100', pretrained=False)  # Untrained
num_ftrs_student = student_model.get_classifier().in_features
student_model.reset_classifier(num_classes=100)
student_model.load_state_dict(new_state_dict, strict=True)
student_model = student_model.to('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
data_module = CIFAR100DataModule(data_dir='/content/data', batch_size=128, num_workers=2, teacher_models=None, student_model='mobilenet')
data_module.prepare_data()
data_module.setup(stage='test')
test_loader = data_module.test_dataloader()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Initialize metrics
test_loss = 0.0
correct = 0
total = 0

# Define the loss function
criterion = nn.CrossEntropyLoss()

# Iterate over the test DataLoader
for batch in tqdm(test_loader, desc="Testing"):
    # Move data to the device
    data = batch['student_input'].to(device)
    target = batch['label'].to(device)

    # Forward pass
    outputs = student_model(data)

    # Compute loss
    loss = criterion(outputs, target)
    test_loss += loss.item() * data.size(0)  # Accumulate loss

    # Compute predictions
    _, predicted = torch.max(outputs, 1)

    # Update correct and total counts
    correct += (predicted == target).sum().item()
    total += target.size(0)

# Compute average loss and accuracy
average_loss = test_loss / total
accuracy = correct / total

print(f"Test Loss: {average_loss:.4f}, Test Accuracy: {accuracy*100:.2f}%")


Files already downloaded and verified
Files already downloaded and verified


Testing: 100%|██████████| 79/79 [01:49<00:00,  1.39s/it]

Test Loss: 1.1681, Test Accuracy: 71.37%



