In [1]:
from torch.utils.data import DataLoader, random_split

def create_data_loaders(dataset_path, batch_size, val_split=0.2, seed=42):
    torch.manual_seed(seed)

    transform = transforms.Compose([
        transforms.Resize((224, 224)), 
        transforms.ToTensor(),        
        # transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        # transforms.Resize((256, 256))
    ])
    
    dataset = datasets.ImageFolder(root=dataset_path, transform=transform)
    val_size = int(len(dataset) * val_split)
    train_size = len(dataset) - val_size
    train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
    
    print(f"Train size: {train_size}, Validation size: {val_size}")
    print(f"Class names: {dataset.classes}")
    
    return train_loader, val_loader

In [2]:
import torch
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
from tqdm import tqdm

def load_fine_tuned_model(model_name, model_path, num_classes=16):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    if model_name == "VGG11":
        model = VGG11(num_classes=num_classes, pretrained=False)
    elif model_name == "VGG16":
        model = VGG16Model(num_classes=num_classes, pretrained=False)
    elif model_name == "ResNet50":
        model = ResNet50Model(num_classes=num_classes, pretrained=False)
    elif model_name == "DenseNet121":
        model = DenseNet121Model(num_classes=num_classes, pretrained=False)
    elif model_name == "ViT":
        model = ViTModel(num_classes=num_classes, pretrained=False)
    elif model_name == "ViTsmall":
        model = SmallViTModel(num_classes=num_classes, pretrained=False)    
    else:
        raise ValueError(f"Unsupported model name: {model_name}")

    model.load_state_dict(torch.load(model_path, map_location=device))
    model = model.to(device)
    model.eval()
    print(f'{model_name} model loaded and ready for evaluation.')
    return model

**Teacher Models**

In [3]:
import torch
from torch import nn
from torchvision.models import vit_b_16

class ViTModel(nn.Module): 
    def __init__(self, num_classes, pretrained=True):
        super(ViTModel, self).__init__()
        self.model = vit_b_16(pretrained=pretrained)
        
        for param in self.model.parameters():
            param.requires_grad = False
            
        in_features = self.model.heads.head.in_features
        self.model.heads.head = nn.Linear(in_features, num_classes)
        for param in self.model.heads.head.parameters():
            param.requires_grad = True

    def forward(self, x):
        return self.model(x)

    def get_model(self):
        return self.model

# num_classes = 10 
# vit_model = ViTModel(num_classes=num_classes, pretrained=True)
# model_instance = vit_model.get_model()
# print(model_instance)

In [4]:
import torch
import torchvision.models as models
import torch.nn as nn

class VGG16Model(nn.Module):
    def __init__(self, num_classes, pretrained=True):
        super(VGG16Model, self).__init__()
        self.model = models.vgg16(pretrained=pretrained)

        for param in self.model.features.parameters():
            param.requires_grad = False
        in_features = self.model.classifier[6].in_features
        self.model.classifier[6] = nn.Linear(in_features, num_classes)
        for param in self.model.classifier[6].parameters():
            param.requires_grad = True

    def forward(self, x):
        return self.model(x)
        
    def get_model(self):
        return self.model


# num_classes = 10  
# vgg_model = VGG16Model(num_classes=num_classes, pretrained=True)
# model_instance = vgg_model.get_model()
# print(model_instance)

**Student Model**

In [5]:
import torch
import torch.nn as nn
import timm

class SmallViTModel(nn.Module):
    def __init__(self, num_classes, pretrained=False):
        super(SmallViTModel, self).__init__()
        
        # Create a ViT model using a variant that works for 224x224 images
        # Ensure the model matches the number of classes used in the teacher models
        self.model = timm.create_model('vit_small_patch16_224', pretrained=pretrained)

        # Freeze backbone parameters
        for param in self.model.parameters():
            param.requires_grad = True

        # Modify the classifier to match num_classes (ensure it's consistent with the teacher model)
        in_features = self.model.head.in_features
        self.model.reset_classifier(num_classes=num_classes)
        
        # Ensure the new classifier is trainable
        for param in self.model.head.parameters():
            param.requires_grad = True

    def forward(self, x):
        return self.model(x)

    def get_model(self):
        return self.model


**Fine Tuning**

In [6]:
import torch
import torch.optim as optim
import torch.nn as nn
import os
import zipfile
from tqdm import tqdm
from torchvision import datasets, transforms

def fine_tune_model(model_class, dataset_path, output_path, num_classes, epochs=10, batch_size=32, learning_rate=0.001):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    train_loader, val_loader = create_data_loaders(dataset_path, batch_size)
    model = model_class(num_classes=num_classes, pretrained=True)
    # model = model_class(num_classes=num_classes, pretrained=False)
    model = model.to(device)
    criterion = nn.CrossEntropyLoss()

    # Optimizer configuration for ViT and ResNet
    if isinstance(model, ViTModel):  # For ViT
        optimizer = optim.Adam(model.get_model().heads.head.parameters(), lr=learning_rate)
    elif isinstance(model, SmallViTModel):  # For SmallViTModel
        optimizer = optim.Adam(model.get_model().get_classifier().parameters(), lr=learning_rate)
    # elif isinstance(model, ResNet50Model):  # For ResNet
    #     optimizer = optim.Adam(model.get_model().fc.parameters(), lr=learning_rate)  # Use 'fc' for ResNet
    else:  # For other models (e.g., DenseNet, etc.)
        optimizer = optim.Adam(model.get_model().classifier.parameters(), lr=learning_rate)

    for epoch in range(epochs):
        model.train()
        train_loss = 0.0
        train_correct = 0
        total_samples = 0

        print(f"Epoch {epoch+1}/{epochs}: Training started...")
        for images, labels in tqdm(train_loader, desc="Training", unit="batch"):
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * images.size(0)
            _, preds = torch.max(outputs, 1)
            train_correct += torch.sum(preds == labels)
            total_samples += images.size(0)

        train_loss_avg = train_loss / total_samples
        train_accuracy = (train_correct.double() / total_samples) * 100
        model.eval()
        val_loss = 0.0
        val_correct = 0
        with torch.no_grad():
            print("Validation started...")
            for images, labels in tqdm(val_loader, desc="Validating", unit="batch"):
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                val_loss += loss.item() * images.size(0)
                _, preds = torch.max(outputs, 1)
                val_correct += torch.sum(preds == labels)

        val_loss_avg = val_loss / len(val_loader.dataset)
        val_accuracy = (val_correct.double() / len(val_loader.dataset)) * 100
        print(f"Epoch {epoch+1}/{epochs}: "
              f"Training Accuracy: {train_accuracy:.4f}, "
              f"Training Loss: {train_loss_avg:.4f} | "
              f"Validation Accuracy: {val_accuracy:.4f}, "
              f"Validation Loss: {val_loss_avg:.4f}")

    model_name = model_class.__name__.lower()
    model_path = os.path.join(output_path, f"{model_name}_fine_tuned_{num_classes}_classes.pth")
    torch.save(model.state_dict(), model_path)

    # torch.save(model.get_model().state_dict(), model_path)
    print(f"Model saved at {model_path}")
    
    zip_path = model_path.replace(".pth", ".zip")
    with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
        zipf.write(model_path, os.path.basename(model_path))
    # os.remove(model_path)
    print(f"Model saved and zipped to {zip_path}")

In [13]:
dataset_path = "/kaggle/input/labellledimagenet/train" 
output_path = "/kaggle/working/FinetunedModels"
os.makedirs(output_path, exist_ok=True)
model_class = SmallViTModel
fine_tune_model(model_class, dataset_path, output_path, num_classes=16, epochs=2, batch_size=128, learning_rate=1e-4)

Train size: 4819, Validation size: 1204
Class names: ['airplane', 'bear', 'bicycle', 'bird', 'boat', 'bottle', 'car', 'cat', 'chair', 'clock', 'dog', 'elephant', 'keyboard', 'knife', 'oven', 'truck']
Epoch 1/2: Training started...


Training: 100%|██████████| 38/38 [00:27<00:00,  1.36batch/s]


Validation started...


Validating: 100%|██████████| 10/10 [00:04<00:00,  2.13batch/s]


Epoch 1/2: Training Accuracy: 16.9122, Training Loss: 2.5553 | Validation Accuracy: 34.8837, Validation Loss: 2.2592
Epoch 2/2: Training started...


Training: 100%|██████████| 38/38 [00:27<00:00,  1.36batch/s]


Validation started...


Validating: 100%|██████████| 10/10 [00:04<00:00,  2.02batch/s]


Epoch 2/2: Training Accuracy: 44.2415, Training Loss: 2.0526 | Validation Accuracy: 49.4186, Validation Loss: 1.8675
Model saved at /kaggle/working/FinetunedModels/smallvitmodel_fine_tuned_16_classes.pth
Model saved and zipped to /kaggle/working/FinetunedModels/smallvitmodel_fine_tuned_16_classes.zip


In [7]:
dataset_path = "/kaggle/input/labellledimagenet/train" 
output_path = "/kaggle/working/FinetunedModelsCC"  
os.makedirs(output_path, exist_ok=True)
model_class = VGG16Model  
fine_tune_model(model_class, dataset_path, output_path, num_classes=16, epochs=1, batch_size=128, learning_rate=1e-4)

Train size: 4819, Validation size: 1204
Class names: ['airplane', 'bear', 'bicycle', 'bird', 'boat', 'bottle', 'car', 'cat', 'chair', 'clock', 'dog', 'elephant', 'keyboard', 'knife', 'oven', 'truck']


Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth
100%|██████████| 528M/528M [00:02<00:00, 233MB/s] 


Epoch 1/1: Training started...


Training: 100%|██████████| 38/38 [00:28<00:00,  1.33batch/s]


Validation started...


Validating: 100%|██████████| 10/10 [00:07<00:00,  1.27batch/s]


Epoch 1/1: Training Accuracy: 84.3121, Training Loss: 0.5777 | Validation Accuracy: 97.0930, Validation Loss: 0.1115
Model saved at /kaggle/working/FinetunedModelsCC/vgg16model_fine_tuned_16_classes.pth
Model saved and zipped to /kaggle/working/FinetunedModelsCC/vgg16model_fine_tuned_16_classes.zip


In [9]:
dataset_path = "/kaggle/input/labellledimagenet/train"   
output_path = "/kaggle/working/FinetunedModels"

os.makedirs(output_path, exist_ok=True)
model_class = ViTModel 
fine_tune_model(model_class, dataset_path, output_path, num_classes=16, epochs=1, batch_size=128, learning_rate=1e-4)

Train size: 4819, Validation size: 1204
Class names: ['airplane', 'bear', 'bicycle', 'bird', 'boat', 'bottle', 'car', 'cat', 'chair', 'clock', 'dog', 'elephant', 'keyboard', 'knife', 'oven', 'truck']
Epoch 1/1: Training started...


Training: 100%|██████████| 38/38 [00:30<00:00,  1.24batch/s]


Validation started...


Validating: 100%|██████████| 10/10 [00:08<00:00,  1.20batch/s]


Epoch 1/1: Training Accuracy: 12.3055, Training Loss: 2.7227 | Validation Accuracy: 21.6777, Validation Loss: 2.6096
Model saved at /kaggle/working/FinetunedModels/vitmodel_fine_tuned_16_classes.pth
Model saved and zipped to /kaggle/working/FinetunedModels/vitmodel_fine_tuned_16_classes.zip


In [24]:
dataset_path = "/kaggle/input/cue-conflict-splitdata/train" 
output_path = "/kaggle/working/FinetunedModels"
os.makedirs(output_path, exist_ok=True)
model_class = ViTModel 
fine_tune_model(model_class, dataset_path, output_path, num_classes=16, epochs=5, batch_size=128, learning_rate=1e-4)

Train size: 705, Validation size: 176
Class names: ['airplane', 'bear', 'bicycle', 'bird', 'boat', 'bottle', 'car', 'cat', 'chair', 'clock', 'dog', 'elephant', 'keyboard', 'knife', 'oven', 'truck']
Epoch 1/5: Training started...


Training: 100%|██████████| 6/6 [00:06<00:00,  1.11s/batch]


Validation started...


Validating: 100%|██████████| 2/2 [00:02<00:00,  1.43s/batch]


Epoch 1/5: Training Accuracy: 5.6738, Training Loss: 2.8552 | Validation Accuracy: 4.5455, Validation Loss: 2.8262
Epoch 2/5: Training started...


Training: 100%|██████████| 6/6 [00:05<00:00,  1.15batch/s]


Validation started...


Validating: 100%|██████████| 2/2 [00:01<00:00,  1.08batch/s]


Epoch 2/5: Training Accuracy: 7.8014, Training Loss: 2.7790 | Validation Accuracy: 8.5227, Validation Loss: 2.7688
Epoch 3/5: Training started...


Training: 100%|██████████| 6/6 [00:05<00:00,  1.15batch/s]


Validation started...


Validating: 100%|██████████| 2/2 [00:01<00:00,  1.07batch/s]


Epoch 3/5: Training Accuracy: 11.9149, Training Loss: 2.7083 | Validation Accuracy: 11.9318, Validation Loss: 2.7136
Epoch 4/5: Training started...


Training: 100%|██████████| 6/6 [00:05<00:00,  1.13batch/s]


Validation started...


Validating: 100%|██████████| 2/2 [00:01<00:00,  1.06batch/s]


Epoch 4/5: Training Accuracy: 15.8865, Training Loss: 2.6396 | Validation Accuracy: 18.7500, Validation Loss: 2.6602
Epoch 5/5: Training started...


Training: 100%|██████████| 6/6 [00:05<00:00,  1.14batch/s]


Validation started...


Validating: 100%|██████████| 2/2 [00:01<00:00,  1.08batch/s]


Epoch 5/5: Training Accuracy: 19.5745, Training Loss: 2.5731 | Validation Accuracy: 22.7273, Validation Loss: 2.6087
Model saved at /kaggle/working/FinetunedModels/vitmodel_fine_tuned_16_classes.pth
Model saved and zipped to /kaggle/working/FinetunedModels/vitmodel_fine_tuned_16_classes.zip


**Shape Biased Student Tuning**

In [11]:
import torch
import torch.optim as optim
import torch.nn as nn
import os
import zipfile
from tqdm import tqdm
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split, ConcatDataset

def create_curriculum_data_loaders(dataset_path_original, dataset_path_shape, batch_size, val_split=0.2, seed=42, curriculum_stage=1):
    """
    Create data loaders for curriculum learning.
    - Stage 1: Only shape-biased data.
    - Stage 2: Mix of shape-biased and original data.
    """
    torch.manual_seed(seed)

    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    dataset_shape = datasets.ImageFolder(root=dataset_path_shape, transform=transform)
    dataset_original = datasets.ImageFolder(root=dataset_path_original, transform=transform)

    if curriculum_stage == 1:
        combined_dataset = dataset_shape
    else:
        combined_dataset = ConcatDataset([dataset_shape, dataset_original])

    val_size = int(len(combined_dataset) * val_split)
    train_size = len(combined_dataset) - val_size
    train_dataset, val_dataset = random_split(combined_dataset, [train_size, val_size])

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

    print(f"Train size: {train_size}, Validation size: {val_size}")
    print(f"Shape Class Names: {dataset_shape.classes}")
    if curriculum_stage > 1:
        print(f"Original Class Names: {dataset_original.classes}")

    return train_loader, val_loader

def fine_tune_model_curriculum(model_class, dataset_path_original, dataset_path_shape, output_path, num_classes, epochs=10, batch_size=32, learning_rate=0.001):
    """
    Fine-tune the model using curriculum learning.
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    curriculum_epochs = [int(epochs * 0.5), epochs]  

    model = model_class(num_classes=num_classes, pretrained=True)
    model = model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.get_model().heads.head.parameters(), lr=learning_rate)

    for epoch in range(epochs):
        curriculum_stage = 1 if epoch < curriculum_epochs[0] else 2
        train_loader, val_loader = create_curriculum_data_loaders(
            dataset_path_original, dataset_path_shape, batch_size, curriculum_stage=curriculum_stage
        )

        model.train()
        train_loss = 0.0
        train_correct = 0
        total_samples = 0

        print(f"Epoch {epoch + 1}/{epochs}: Training started (Curriculum Stage {curriculum_stage})...")
        for images, labels in tqdm(train_loader, desc="Training", unit="batch"):
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * images.size(0)
            _, preds = torch.max(outputs, 1)
            train_correct += torch.sum(preds == labels)
            total_samples += images.size(0)

        train_loss_avg = train_loss / total_samples
        train_accuracy = (train_correct.double() / total_samples) * 100

        model.eval()
        val_loss = 0.0
        val_correct = 0
        with torch.no_grad():
            print("Validation started...")
            for images, labels in tqdm(val_loader, desc="Validating", unit="batch"):
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                val_loss += loss.item() * images.size(0)
                _, preds = torch.max(outputs, 1)
                val_correct += torch.sum(preds == labels)

        val_loss_avg = val_loss / len(val_loader.dataset)
        val_accuracy = (val_correct.double() / len(val_loader.dataset)) * 100
        print(f"Epoch {epoch + 1}/{epochs}: "
              f"Training Accuracy: {train_accuracy:.4f}, "
              f"Training Loss: {train_loss_avg:.4f} | "
              f"Validation Accuracy: {val_accuracy:.4f}, "
              f"Validation Loss: {val_loss_avg:.4f}")

    model_name = model_class.__name__.lower()
    model_path = os.path.join(output_path, f"{model_name}_fine_tuned_{num_classes}_classes.pth")
    torch.save(model.state_dict(), model_path)

    zip_path = model_path.replace(".pth", ".zip")
    with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
        zipf.write(model_path, os.path.basename(model_path))

    print(f"Model saved and zipped to {zip_path}")

dataset_path_original = "/kaggle/input/labellledimagenet/train"
dataset_path_shape = "/kaggle/input/cue-conflict-splitdata/train"
output_path = "/kaggle/working/FinetunedModels"
os.makedirs(output_path, exist_ok=True)
model_class = ViTModel  
fine_tune_model_curriculum(model_class, dataset_path_original, dataset_path_shape, output_path, num_classes=16, epochs=5, batch_size=128, learning_rate=1e-4)


Train size: 820, Validation size: 204
Shape Class Names: ['airplane', 'bear', 'bicycle', 'bird', 'boat', 'bottle', 'car', 'cat', 'chair', 'clock', 'dog', 'elephant', 'keyboard', 'knife', 'oven', 'truck']
Epoch 1/5: Training started (Curriculum Stage 1)...


Training:  29%|██▊       | 2/7 [00:03<00:09,  1.91s/batch]


KeyboardInterrupt: 

**Basic LM**

In [12]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from tqdm import tqdm
import zipfile
import os

def logit_matching_loss(student_logits, teacher_logits, student_temperature=1.0, teacher_temperature=1.0):
    student_logits = student_logits - student_logits.mean(dim=-1, keepdim=True)
    teacher_logits = teacher_logits - teacher_logits.mean(dim=-1, keepdim=True)
    student_logits = student_logits / student_temperature
    teacher_logits = teacher_logits / teacher_temperature
    loss = F.mse_loss(student_logits, teacher_logits)
    return loss

def distill(model_class, teacher_model, teacher_model_name, train_loader, val_loader, num_classes, 
            epochs=10, batch_size=32, learning_rate=0.001, 
            student_temperature=2.0, teacher_temperature=2.0, alpha=0.5, save_dir="/kaggle/working/DistilledModels"):

    student_model = model_class(num_classes=num_classes)
    optimizer = optim.Adam(student_model.parameters(), lr=learning_rate)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    student_model = student_model.to(device)
    teacher_model = teacher_model.to(device)
    teacher_model.eval()  
    criterion = nn.CrossEntropyLoss()

    for epoch in range(epochs):
        student_model.train()
        running_loss = 0.0

        with tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}", unit="batch") as pbar:
            for images, labels in pbar:
                images, labels = images.to(device), labels.to(device)
                student_outputs = student_model(images)
                with torch.no_grad():
                    teacher_outputs = teacher_model(images)
                ce_loss = criterion(student_outputs, labels)
                distillation_loss = logit_matching_loss(student_outputs, teacher_outputs, 
                                                        student_temperature=student_temperature,
                                                        teacher_temperature=teacher_temperature)

                total_loss = (1 - alpha) * ce_loss + alpha * distillation_loss
                optimizer.zero_grad()
                total_loss.backward()
                optimizer.step()
                running_loss += total_loss.item()

                pbar.set_postfix(loss=running_loss / len(pbar))  

        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")
        student_model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = student_model(images)
                _, predicted = torch.max(outputs, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        accuracy = 100 * correct / total
        print(f"Validation Accuracy at epoch {epoch+1}: {accuracy:.2f}%")

    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    model_filename = f"{teacher_model_name}_distilled_model.pth" 
    model_path = os.path.join(save_dir, model_filename)
    torch.save(student_model.state_dict(), model_path)
    print(f"Distilled model saved to {model_path}")

    zip_filename = model_filename.replace(".pth", ".zip")
    zip_path = os.path.join(save_dir, zip_filename)
    with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
        zipf.write(model_path, os.path.basename(model_path))
        
    print(f"Distilled model zip file saved to {zip_path}")
    return student_model


In [13]:
dataset_path = "/kaggle/input/labellledimagenet/train"
batch_size = 128

train_loader, val_loader = create_data_loaders(dataset_path, batch_size)
teacher_model_name = "VGG16"  
teacher_model_path = "/kaggle/working/FinetunedModelsCC/vgg16model_fine_tuned_16_classes.pth"  
teacher_model = load_fine_tuned_model(teacher_model_name, teacher_model_path, num_classes=16)
student_model_class = SmallViTModel

distilled_model = distill(
    model_class=student_model_class, 
    teacher_model=teacher_model, 
    teacher_model_name=teacher_model_name,
    train_loader=train_loader,
    val_loader=val_loader,
    num_classes=16,
    epochs=8,
    batch_size=32,
    learning_rate=1e-4,
    student_temperature=1.0,
    teacher_temperature=1.0,
    alpha=0.5
)

Train size: 4819, Validation size: 1204
Class names: ['airplane', 'bear', 'bicycle', 'bird', 'boat', 'bottle', 'car', 'cat', 'chair', 'clock', 'dog', 'elephant', 'keyboard', 'knife', 'oven', 'truck']


  model.load_state_dict(torch.load(model_path, map_location=device))


KeyboardInterrupt: 

In [30]:
model_name = "ViTsmall"  
model_path = "/kaggle/working/DistilledModels/VGG16_distilled_model.pth" 
eval_dataset_path1 = "/kaggle/input/labellledimagenet/test" 
eval_dataset_path2 = "/kaggle/input/cue-conflict-splitdata/test" 
model = load_fine_tuned_model(model_name=model_name, model_path=model_path, num_classes=16)
evaluate_model(model, dataset_path=eval_dataset_path1, batch_size=32)
evaluate_cue_conflict_dataset(model, eval_dataset_path2, batch_size=32)

  model.load_state_dict(torch.load(model_path, map_location=device))


ViTsmall model loaded and ready for evaluation.
Evaluating on dataset with 1513 samples and classes: ['airplane', 'bear', 'bicycle', 'bird', 'boat', 'bottle', 'car', 'cat', 'chair', 'clock', 'dog', 'elephant', 'keyboard', 'knife', 'oven', 'truck']


Evaluating: 100%|██████████| 48/48 [00:05<00:00,  8.54it/s]


Accuracy on the evaluation dataset: 53.0734 %
Evaluating on dataset with 256 samples and classes: ['airplane', 'bear', 'bicycle', 'bird', 'boat', 'bottle', 'car', 'cat', 'chair', 'clock', 'dog', 'elephant', 'keyboard', 'knife', 'oven', 'truck']


Evaluating: 100%|██████████| 8/8 [00:01<00:00,  5.32it/s]


Final Results:
  Shape Accuracy: 0.0586
  Texture Accuracy: 0.0664
  Cue Accuracy: 0.1250
  Shape Bias: 0.4688
  Texture Bias: 0.5312





In [31]:
dataset_path = "/kaggle/input/labellledimagenet/train"
batch_size = 128

train_loader, val_loader = create_data_loaders(dataset_path, batch_size)
teacher_model_name = "ViT"  
teacher_model_path = "/kaggle/working/FinetunedModels/vitmodel_fine_tuned_16_classes.pth"  
teacher_model = load_fine_tuned_model(teacher_model_name, teacher_model_path, num_classes=16)
student_model_class = SmallViTModel

distilled_model = distill(
    model_class=student_model_class, 
    teacher_model=teacher_model, 
    teacher_model_name=teacher_model_name,
    train_loader=train_loader,
    val_loader=val_loader,
    num_classes=16,
    epochs=8,
    batch_size=32,
    learning_rate=1e-4,
    student_temperature=1.0,
    teacher_temperature=1.0,
    alpha=0.5
)

Train size: 4819, Validation size: 1204
Class names: ['airplane', 'bear', 'bicycle', 'bird', 'boat', 'bottle', 'car', 'cat', 'chair', 'clock', 'dog', 'elephant', 'keyboard', 'knife', 'oven', 'truck']


  model.load_state_dict(torch.load(model_path, map_location=device))


ViT model loaded and ready for evaluation.


Epoch 1/8: 100%|██████████| 38/38 [00:37<00:00,  1.02batch/s, loss=1.45] 

Epoch 1/8, Loss: 1.4545410024492365





Validation Accuracy at epoch 1: 40.37%


Epoch 2/8: 100%|██████████| 38/38 [00:37<00:00,  1.01batch/s, loss=1.25] 

Epoch 2/8, Loss: 1.2458825864289935





Validation Accuracy at epoch 2: 58.55%


Epoch 3/8: 100%|██████████| 38/38 [00:37<00:00,  1.01batch/s, loss=1.18] 

Epoch 3/8, Loss: 1.1766453542207416





Validation Accuracy at epoch 3: 58.97%


Epoch 4/8: 100%|██████████| 38/38 [00:37<00:00,  1.01batch/s, loss=1.16] 

Epoch 4/8, Loss: 1.1560211840428805





Validation Accuracy at epoch 4: 58.97%


Epoch 5/8: 100%|██████████| 38/38 [00:37<00:00,  1.01batch/s, loss=1.15] 

Epoch 5/8, Loss: 1.1489908005061902





Validation Accuracy at epoch 5: 59.05%


Epoch 6/8: 100%|██████████| 38/38 [00:37<00:00,  1.01batch/s, loss=1.14] 

Epoch 6/8, Loss: 1.1437419935276634





Validation Accuracy at epoch 6: 59.05%


Epoch 7/8: 100%|██████████| 38/38 [00:37<00:00,  1.01batch/s, loss=1.14] 

Epoch 7/8, Loss: 1.1400173212352551





Validation Accuracy at epoch 7: 59.05%


Epoch 8/8: 100%|██████████| 38/38 [00:37<00:00,  1.01batch/s, loss=1.14] 

Epoch 8/8, Loss: 1.1377655236344588





Validation Accuracy at epoch 8: 59.05%
Distilled model saved to /kaggle/working/DistilledModels/ViT_distilled_model.pth
Distilled model zip file saved to /kaggle/working/DistilledModels/ViT_distilled_model.zip


In [32]:
model_name = "ViTsmall"  
model_path = "/kaggle/working/DistilledModels/ViT_distilled_model.pth" 
eval_dataset_path1 = "/kaggle/input/labellledimagenet/test" 
eval_dataset_path2 = "/kaggle/input/cue-conflict-splitdata/test" 
model = load_fine_tuned_model(model_name=model_name, model_path=model_path, num_classes=16)
evaluate_model(model, dataset_path=eval_dataset_path1, batch_size=32)
evaluate_cue_conflict_dataset(model, eval_dataset_path2, batch_size=32)

  model.load_state_dict(torch.load(model_path, map_location=device))


ViTsmall model loaded and ready for evaluation.
Evaluating on dataset with 1513 samples and classes: ['airplane', 'bear', 'bicycle', 'bird', 'boat', 'bottle', 'car', 'cat', 'chair', 'clock', 'dog', 'elephant', 'keyboard', 'knife', 'oven', 'truck']


Evaluating: 100%|██████████| 48/48 [00:05<00:00,  8.60it/s]


Accuracy on the evaluation dataset: 57.1051 %
Evaluating on dataset with 256 samples and classes: ['airplane', 'bear', 'bicycle', 'bird', 'boat', 'bottle', 'car', 'cat', 'chair', 'clock', 'dog', 'elephant', 'keyboard', 'knife', 'oven', 'truck']


Evaluating: 100%|██████████| 8/8 [00:00<00:00,  9.01it/s]


Final Results:
  Shape Accuracy: 0.0586
  Texture Accuracy: 0.0586
  Cue Accuracy: 0.1172
  Shape Bias: 0.5000
  Texture Bias: 0.5000





**Weighted Ensembling**

In [14]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from tqdm import tqdm
import zipfile
import os

def logit_matching_loss(student_logits, ensembled_teacher_logits, student_temperature=1.0, teacher_temperature=1.0):
    """Improved KL divergence loss with temperature scaling"""
    student_scaled = student_logits / student_temperature
    teacher_scaled = ensembled_teacher_logits / teacher_temperature
    
    student_log_probs = F.log_softmax(student_scaled, dim=-1)
    teacher_probs = F.softmax(teacher_scaled, dim=-1)
    
    return F.kl_div(student_log_probs, teacher_probs, reduction='batchmean') * (student_temperature ** 2)

def weighted_ensemble_logits(teacher_logits_list, confidence_list):
    """Enhanced ensemble weighting with entropy-based confidence"""
    stacked_logits = torch.stack(teacher_logits_list, dim=0)
    
    # Calculate entropy-based confidence
    probs = [F.softmax(logits, dim=-1) for logits in teacher_logits_list]
    entropies = [-torch.sum(p * torch.log(p + 1e-10), dim=-1) for p in probs]
    confidence_scores = [1.0 / (entropy + 1e-10) for entropy in entropies]
    
    # Normalize confidence scores
    stacked_conf = torch.stack(confidence_scores, dim=0).unsqueeze(-1)
    normalized_conf = (stacked_conf - stacked_conf.mean()) / (stacked_conf.std() + 1e-6)
    weights = F.softmax(normalized_conf * 1.5, dim=0)
    
    return (stacked_logits * weights).sum(dim=0)

def distill_with_confidence_aware_ensemble(model_class, teacher_models, teacher_model_names, train_loader, val_loader, 
                                         num_classes, epochs=10, batch_size=32, learning_rate=0.001, 
                                         student_temperature=4.0, teacher_temperature=4.0, alpha=0.7,
                                         save_dir='/kaggle/working/'):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    student_model = model_class(num_classes=num_classes).to(device)
    teacher_models = [model.to(device).eval() for model in teacher_models]
    
    # Improved optimizer and learning rate scheduler
    optimizer = optim.AdamW(student_model.parameters(), lr=learning_rate, weight_decay=0.01)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    criterion = nn.CrossEntropyLoss()

    best_val_acc = 0
    for epoch in range(epochs):
        student_model.train()
        running_loss = 0.0
        
        with tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}") as pbar:
            for batch_idx, (images, labels) in enumerate(pbar):
                images, labels = images.to(device), labels.to(device)
                
                student_outputs = student_model(images)
                teacher_logits_list = []
                confidence_list = []
                
                with torch.no_grad():
                    for teacher in teacher_models:
                        logits = teacher(images)
                        probs = F.softmax(logits, dim=-1)
                        confidence = probs.max(dim=-1)[0]
                        teacher_logits_list.append(logits)
                        confidence_list.append(confidence)

                ensemble_logits = weighted_ensemble_logits(teacher_logits_list, confidence_list)
                
                ce_loss = criterion(student_outputs, labels)
                distill_loss = logit_matching_loss(student_outputs, ensemble_logits,
                                                 student_temperature, teacher_temperature)
                
                # Dynamic loss weighting
                loss = (1 - alpha) * ce_loss + alpha * distill_loss
                
                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(student_model.parameters(), max_norm=1.0)
                optimizer.step()
                
                running_loss += loss.item()
                pbar.set_postfix(loss=running_loss/(batch_idx + 1))

        scheduler.step()
        
        # Validation
        student_model.eval()
        val_loss = 0.0
        correct = 0
        total = 0
        
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = student_model(images)
                _, predicted = outputs.max(1)
                total += labels.size(0)
                correct += predicted.eq(labels).sum().item()
                val_loss += criterion(outputs, labels).item()

        accuracy = 100.0 * correct / total
        print(f"Epoch {epoch+1}: Train Loss={running_loss/len(train_loader):.4f}, "
              f"Val Loss={val_loss/len(val_loader):.4f}, Val Acc={accuracy:.2f}%")
        
        # Save best model
        if accuracy > best_val_acc:
            best_val_acc = accuracy
            model_path = os.path.join(save_dir, f"{'+'.join(teacher_model_names)}_distilled_best.pth")
            torch.save(student_model.state_dict(), model_path)
    
    return student_model

In [36]:
# Example usage
teacher_model1 = load_fine_tuned_model('VGG16', '/kaggle/working/FinetunedModelsCC/vgg16model_fine_tuned_16_classes.pth', num_classes=16)
teacher_model2 = load_fine_tuned_model('ViT', '/kaggle/working/FinetunedModels/vitmodel_fine_tuned_16_classes.pth', num_classes=16)

teacher_models = [teacher_model1, teacher_model2]
teacher_model_names = ['VGG16', 'ViT']

# Create data loaders
train_loader, val_loader = create_data_loaders('/kaggle/input/labellledimagenet/train', batch_size=32)

# Call the distillation function with multiple teachers
distilled_model = distill_with_confidence_aware_ensemble(
    model_class=SmallViTModel,  # Student model
    teacher_models=teacher_models,
    teacher_model_names=teacher_model_names,
    train_loader=train_loader,
    val_loader=val_loader,
    num_classes=16,
    epochs=8,
    batch_size=32,
    learning_rate=0.001,
    student_temperature=2.0,
    teacher_temperature=2.0,
    alpha=0.5,
    save_dir='/kaggle/working/'
)

  model.load_state_dict(torch.load(model_path, map_location=device))


VGG16 model loaded and ready for evaluation.
ViT model loaded and ready for evaluation.
Train size: 4819, Validation size: 1204
Class names: ['airplane', 'bear', 'bicycle', 'bird', 'boat', 'bottle', 'car', 'cat', 'chair', 'clock', 'dog', 'elephant', 'keyboard', 'knife', 'oven', 'truck']


Epoch 1/8: 100%|██████████| 151/151 [00:49<00:00,  3.03it/s, loss=3.99]


Epoch 1: Train Loss=3.9921, Val Loss=1.7168, Val Acc=58.14%


Epoch 2/8: 100%|██████████| 151/151 [00:49<00:00,  3.04it/s, loss=3.66]


Epoch 2: Train Loss=3.6572, Val Loss=1.7391, Val Acc=57.64%


Epoch 3/8: 100%|██████████| 151/151 [00:49<00:00,  3.03it/s, loss=3.59]


Epoch 3: Train Loss=3.5928, Val Loss=1.7189, Val Acc=58.72%


Epoch 4/8: 100%|██████████| 151/151 [00:49<00:00,  3.04it/s, loss=3.55]


Epoch 4: Train Loss=3.5542, Val Loss=1.7016, Val Acc=58.22%


Epoch 5/8: 100%|██████████| 151/151 [00:49<00:00,  3.04it/s, loss=3.52]


Epoch 5: Train Loss=3.5241, Val Loss=1.6905, Val Acc=57.89%


Epoch 6/8: 100%|██████████| 151/151 [00:49<00:00,  3.04it/s, loss=3.5] 


Epoch 6: Train Loss=3.5005, Val Loss=1.6965, Val Acc=58.39%


Epoch 7/8: 100%|██████████| 151/151 [00:49<00:00,  3.04it/s, loss=3.48]


Epoch 7: Train Loss=3.4809, Val Loss=1.6863, Val Acc=58.39%


Epoch 8/8: 100%|██████████| 151/151 [00:49<00:00,  3.04it/s, loss=3.47]


Epoch 8: Train Loss=3.4688, Val Loss=1.6827, Val Acc=58.39%


In [38]:
model_name = "ViTsmall"  
model_path = "/kaggle/working/VGG16+ViT_distilled_best.pth" 
eval_dataset_path1 = "/kaggle/input/labellledimagenet/test" 
eval_dataset_path2 = "/kaggle/input/cue-conflict-splitdata/test" 
model = load_fine_tuned_model(model_name=model_name, model_path=model_path, num_classes=16)
evaluate_model(model, dataset_path=eval_dataset_path1, batch_size=32)
evaluate_cue_conflict_dataset(model, eval_dataset_path2, batch_size=32)

  model.load_state_dict(torch.load(model_path, map_location=device))


ViTsmall model loaded and ready for evaluation.
Evaluating on dataset with 1513 samples and classes: ['airplane', 'bear', 'bicycle', 'bird', 'boat', 'bottle', 'car', 'cat', 'chair', 'clock', 'dog', 'elephant', 'keyboard', 'knife', 'oven', 'truck']


Evaluating: 100%|██████████| 48/48 [00:05<00:00,  8.65it/s]


Accuracy on the evaluation dataset: 55.7171 %
Evaluating on dataset with 256 samples and classes: ['airplane', 'bear', 'bicycle', 'bird', 'boat', 'bottle', 'car', 'cat', 'chair', 'clock', 'dog', 'elephant', 'keyboard', 'knife', 'oven', 'truck']


Evaluating: 100%|██████████| 8/8 [00:00<00:00,  9.25it/s]


Final Results:
  Shape Accuracy: 0.0586
  Texture Accuracy: 0.0586
  Cue Accuracy: 0.1172
  Shape Bias: 0.5000
  Texture Bias: 0.5000





**Evaluation**

In [14]:

def evaluate_model(model, dataset_path, batch_size=32):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    transform = transforms.Compose([
        transforms.Resize((224, 224)), 
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])
    
    eval_dataset = datasets.ImageFolder(root=dataset_path, transform=transform)
    eval_loader = DataLoader(eval_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
    
    print(f"Evaluating on dataset with {len(eval_dataset)} samples and classes: {eval_dataset.classes}")

    total_samples = 0
    correct_predictions = 0
    
    with torch.no_grad():
        for images, labels in tqdm(eval_loader, desc="Evaluating"):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, preds = torch.max(outputs, 1)
            
            correct_predictions += torch.sum(preds == labels).item()
            total_samples += labels.size(0)
    
    accuracy = (correct_predictions / total_samples) * 100
    print(f"Accuracy on the evaluation dataset: {accuracy:.4f} %")
    
    return accuracy

In [15]:
import os
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
import re

def strip_label_suffix(label):
    return re.sub(r'\d+$', '', label)

def evaluate_cue_conflict_dataset(model, dataset_path, batch_size=32):

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        # transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    eval_dataset = datasets.ImageFolder(root=dataset_path, transform=transform)
    eval_loader = DataLoader(eval_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

    print(f"Evaluating on dataset with {len(eval_dataset)} samples and classes: {eval_dataset.classes}")

    shape_correct = 0
    texture_correct = 0
    total_samples = 0

    cue_correct = 0
    
    with torch.no_grad():
        for images, labels in tqdm(eval_loader, desc="Evaluating"):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, preds = torch.max(outputs, 1)
    
            for i, pred in enumerate(preds):
                filename = eval_dataset.samples[total_samples + i][0]
                shape_texture = os.path.basename(filename).split("-")
                shape_label = strip_label_suffix(shape_texture[0]) 
                texture_label = strip_label_suffix(shape_texture[1][:-4])  
                pred_class = eval_dataset.classes[pred]
    
                if pred_class == shape_label:
                    shape_correct += 1
                    cue_correct += 1  # Count once for shape match
                elif pred_class == texture_label:  # Use elif to avoid double-counting
                    texture_correct += 1
                    cue_correct += 1
            total_samples += len(preds)
    
    shape_accuracy = shape_correct / total_samples if total_samples > 0 else 0
    texture_accuracy = texture_correct / total_samples if total_samples > 0 else 0
    cue_accuracy = cue_correct / total_samples if total_samples > 0 else 0
    shape_bias = shape_accuracy / cue_accuracy if cue_accuracy > 0 else 0
    texture_bias = 1 - shape_bias


    print("\nFinal Results:")
    print(f"  Shape Accuracy: {shape_accuracy:.4f}")
    print(f"  Texture Accuracy: {texture_accuracy:.4f}")
    print(f"  Cue Accuracy: {cue_accuracy:.4f}")
    print(f"  Shape Bias: {shape_bias:.4f}")
    print(f"  Texture Bias: {texture_bias:.4f}")

    # return {
    #     "Shape Accuracy": shape_accuracy,
    #     "Texture Accuracy": texture_accuracy,
    #     "Cue Accuracy": cue_accuracy,
    #     "Shape Bias": shape_bias,
    #     "Texture Bias": texture_bias
    # }

In [None]:
model_name = "ViT"  
model_path = "/kaggle/working/FinetunedModels/vitmodel_fine_tuned_16_classes.pth" 
# eval_dataset_path1 = "/kaggle/input/labelledimagenet" 
eval_dataset_path2 = "/kaggle/input/cue-conflict-splitdata/test" 
model = load_fine_tuned_model(model_name=model_name, model_path=model_path, num_classes=16)
# evaluate_model(model, dataset_path=eval_dataset_path1, batch_size=32)
evaluate_cue_conflict_dataset(model, eval_dataset_path2, batch_size=32)

**Distilling Intermediate Loss**

In [28]:
import torch
import torch.nn as nn
import timm
from torchvision import models

class FeatureExtractor(nn.Module):
    def __init__(self, model, layers_to_extract):
        super(FeatureExtractor, self).__init__()
        self.model = model
        self.layers_to_extract = layers_to_extract
        
    def forward(self, x):
        features = []
        for name, module in self.model.named_children():
            try:
                x = module(x)
                if name in self.layers_to_extract:
                    print(f"Extracted features from layer {name}, shape: {x.shape}")
                    features.append(x)
            except Exception as e:
                print(f"Error in layer {name}: {str(e)}")
                raise
        return features

class VGG16FeatureExtractor(FeatureExtractor):
    def __init__(self, model, layers_to_extract=['features']):
        super(VGG16FeatureExtractor, self).__init__(model, layers_to_extract)
        
class ViTFeatureExtractor(FeatureExtractor):
    def __init__(self, model, layers_to_extract=['blocks']):
        super(ViTFeatureExtractor, self).__init__(model, layers_to_extract)

def intermediate_distillation_loss(student_features, teacher_features, alpha=0.5, temperature=2.0):
    total_loss = 0.0
    for idx, (student_feature, teacher_feature) in enumerate(zip(student_features, teacher_features)):
        try:
            print(f"\nFeature pair {idx}:")
            print(f"Student feature shape: {student_feature.shape}")
            print(f"Teacher feature shape: {teacher_feature.shape}")
            
            student_feature = student_feature / student_feature.norm(dim=1, keepdim=True)
            teacher_feature = teacher_feature / teacher_feature.norm(dim=1, keepdim=True)
            
            if student_feature.shape != teacher_feature.shape:
                print(f"Resizing teacher features from {teacher_feature.shape} to match student shape {student_feature.shape}")
                if len(student_feature.shape) == 4:  
                    teacher_feature = nn.functional.interpolate(
                        teacher_feature,
                        size=student_feature.shape[2:],
                        mode='bilinear',
                        align_corners=False
                    )
                elif len(student_feature.shape) == 3: 
                    teacher_feature = teacher_feature[:, :student_feature.shape[1], :]
                    if teacher_feature.shape[-1] != student_feature.shape[-1]:
                        projection = nn.Linear(teacher_feature.shape[-1], student_feature.shape[-1]).to(teacher_feature.device)
                        teacher_feature = projection(teacher_feature)
            
            student_feature = student_feature / temperature
            teacher_feature = teacher_feature / temperature
            distill_loss = nn.functional.mse_loss(student_feature, teacher_feature)
            total_loss += distill_loss * alpha
            print(f"Distillation loss for feature pair {idx}: {distill_loss.item()}")
            
        except Exception as e:
            print(f"Error in distillation loss calculation for feature pair {idx}: {str(e)}")
            raise
            
    return total_loss

def intermediate_distillation_loss(student_features, teacher_features, alpha=0.5, temperature=2.0):
    total_loss = torch.tensor(0.0).to(student_features[0].device)  
    
    for idx, (student_feature, teacher_feature) in enumerate(zip(student_features, teacher_features)):
        try:
            print(f"\nFeature pair {idx}:")
            print(f"Student feature shape: {student_feature.shape}")
            print(f"Teacher feature shape: {teacher_feature.shape}")
            if not isinstance(student_feature, torch.Tensor):
                print(f"Warning: Converting student feature to tensor")
                student_feature = torch.tensor(student_feature)
            if not isinstance(teacher_feature, torch.Tensor):
                print(f"Warning: Converting teacher feature to tensor")
                teacher_feature = torch.tensor(teacher_feature)
            student_feature = student_feature.to(total_loss.device)
            teacher_feature = teacher_feature.to(total_loss.device)
            
            # student_feature = student_feature / (student_feature.norm(dim=1, keepdim=True) + 1e-6)
            # teacher_feature = teacher_feature / (teacher_feature.norm(dim=1, keepdim=True) + 1e-6)
            
            if student_feature.shape != teacher_feature.shape:
                print(f"Resizing teacher features from {teacher_feature.shape} to match student shape {student_feature.shape}")
                if len(student_feature.shape) == 4: 
                    teacher_feature = nn.functional.interpolate(
                        teacher_feature,
                        size=student_feature.shape[2:],
                        mode='bilinear',
                        align_corners=False
                    )
                elif len(student_feature.shape) == 3: 
                    teacher_feature = teacher_feature[:, :student_feature.shape[1], :]
                    if teacher_feature.shape[-1] != student_feature.shape[-1]:
                        projection = nn.Linear(teacher_feature.shape[-1], student_feature.shape[-1]).to(teacher_feature.device)
                        teacher_feature = projection(teacher_feature)
            
            student_feature = student_feature / temperature
            teacher_feature = teacher_feature / temperature
            
            distill_loss = nn.functional.mse_loss(student_feature, teacher_feature)
            total_loss += distill_loss * alpha
            print(f"Distillation loss for feature pair {idx}: {distill_loss.item()}")
            
        except Exception as e:
            print(f"Error in distillation loss calculation for feature pair {idx}: {str(e)}")
            print(f"Student feature type: {type(student_feature)}")
            print(f"Teacher feature type: {type(teacher_feature)}")
            raise
            
    return total_loss

def compute_accuracy(output, labels):
    _, predicted = torch.max(output, 1)
    correct = (predicted == labels).sum().item()  
    accuracy = correct / labels.size(0)  
    return accuracy

def distill_with_intermediate_distillation(student_model, teacher_models, train_loader, val_loader, epochs=10, 
                                           learning_rate=0.001, alpha=0.5, temperature=2.0, device='cuda'):
    print(f"\nStarting distillation process:")
    print(f"Device: {device}")
    print(f"Learning rate: {learning_rate}")
    print(f"Alpha: {alpha}")
    print(f"Temperature: {temperature}")
    
    student_model.to(device)
    
    optimizer = torch.optim.Adam(student_model.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=2)
    
    teacher_extractors = []
    for idx, teacher_model in enumerate(teacher_models):
        print(f"\nInitializing teacher model {idx + 1}")
        teacher_model.eval()
        teacher_model.to(device)
        
        if isinstance(teacher_model, models.vgg.VGG):
            extractor = VGG16FeatureExtractor(teacher_model)
            print("Using VGG16 feature extractor")
        elif 'ViT' in str(type(teacher_model)):
            extractor = ViTFeatureExtractor(teacher_model)
            print("Using ViT feature extractor")
        else:
            extractor = FeatureExtractor(teacher_model, ['features', 'blocks'])
            print("Using default feature extractor")
            
        teacher_extractors.append(extractor)
    
    student_model.train()
    best_val_loss = float('inf')
    
    for epoch in range(epochs):
        print(f"\nEpoch [{epoch + 1}/{epochs}]")
        total_train_loss = torch.tensor(0.0).to(device)  # Initialize as tensor
        total_train_accuracy = 0.0
        num_batches = 0
        
        for batch_idx, (inputs, labels) in enumerate(train_loader):
            try:
                inputs = inputs.to(device)
                labels = labels.to(device).long()
                
                print(f"\nBatch {batch_idx + 1}:")
                print(f"Input shape: {inputs.shape}")
                print(f"Labels shape: {labels.shape}")
                optimizer.zero_grad()
                
                student_features = student_model(inputs)
                if not isinstance(student_features, list):
                    student_features = [student_features]
                
                teacher_features_list = []
                for teacher_idx, extractor in enumerate(teacher_extractors):
                    print(f"\nExtracting features from teacher {teacher_idx + 1}")
                    with torch.no_grad():
                        teacher_features = extractor(inputs)
                    teacher_features_list.append(teacher_features)
                
                distill_loss = torch.tensor(0.0).to(device)  # Initialize as tensor
                for teacher_idx, teacher_features in enumerate(teacher_features_list):
                    print(f"\nComputing distillation loss for teacher {teacher_idx + 1}")
                    current_distill_loss = intermediate_distillation_loss(
                        student_features, 
                        teacher_features, 
                        alpha, 
                        temperature
                    )
                    distill_loss += current_distill_loss
                
                student_output = student_features[-1]
                classification_loss = nn.CrossEntropyLoss()(student_output, labels)
                
                total_loss = distill_loss + classification_loss
                print(f"\nLosses for batch {batch_idx + 1}:")
                print(f"Distillation loss: {distill_loss.item():.4f}")
                print(f"Classification loss: {classification_loss.item():.4f}")
                print(f"Total loss: {total_loss.item():.4f}")
                
                total_loss.backward()
                optimizer.step()
                
                total_train_loss += total_loss
                total_train_accuracy += compute_accuracy(student_output, labels)
                num_batches += 1
                
            except Exception as e:
                print(f"Error in batch {batch_idx + 1}: {str(e)}")
                continue
        
        avg_train_loss = (total_train_loss / num_batches).item()
        avg_train_accuracy = (total_train_accuracy / num_batches) * 100  
        print(f"\nEpoch {epoch + 1} average training loss: {avg_train_loss:.4f}")
        print(f"Epoch {epoch + 1} average training accuracy: {avg_train_accuracy:.2f}%")
        
        student_model.eval()
        val_loss = 0.0
        val_accuracy = 0.0
        num_val_batches = 0
        
        with torch.no_grad():
            for val_inputs, val_labels in val_loader:
                try:
                    val_inputs = val_inputs.to(device)
                    val_labels = val_labels.to(device).long()
                    
                    val_student_features = student_model(val_inputs)
                    if not isinstance(val_student_features, list):
                        val_student_features = [val_student_features]
                    
                    val_distill_loss = 0.0
                    for extractor in teacher_extractors:
                        teacher_features = extractor(val_inputs)
                        val_distill_loss += intermediate_distillation_loss(
                            val_student_features,
                            teacher_features,
                            alpha,
                            temperature
                        )
                    
                    val_classification_loss = nn.CrossEntropyLoss()(val_student_features[-1], val_labels)
                    val_total_loss = val_distill_loss + val_classification_loss
                    val_loss += val_total_loss.item()
                    
                    val_accuracy += compute_accuracy(val_student_features[-1], val_labels)
                    num_val_batches += 1
                    
                except Exception as e:
                    print(f"Error in validation batch: {str(e)}")
                    continue
        
        avg_val_loss = val_loss / num_val_batches
        avg_val_accuracy = (val_accuracy / num_val_batches) * 100  # Percentage
        print(f"Validation loss: {avg_val_loss:.4f}")
        print(f"Validation accuracy: {avg_val_accuracy:.2f}%")
        scheduler.step(avg_val_loss)
        
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(student_model.state_dict(), 'best_student_model.pth')
            print("Saved best model checkpoint")
        
        student_model.train()
    
    return student_model


train_loader, val_loader = create_data_loaders('/kaggle/input/labellledimagenet/train', batch_size=16)
teacher_model1 = load_fine_tuned_model('VGG16', '/kaggle/working/FinetunedModelsCC/vgg16model_fine_tuned_16_classes.pth', num_classes=16)
teacher_model2 = load_fine_tuned_model('ViT', '/kaggle/working/FinetunedModels/vitmodel_fine_tuned_16_classes.pth', num_classes=16)
student_model = SmallViTModel(num_classes=16, pretrained=False)

device = 'cuda' if torch.cuda.is_available() else 'cpu'

distilled_model = distill_with_intermediate_distillation(
    student_model=student_model,
    teacher_models=[teacher_model1, teacher_model2],
    train_loader=train_loader,
    val_loader=val_loader,
    epochs=8,
    learning_rate=0.001,
    alpha=0.5,
    temperature=2.0,
    device=device 
)

Train size: 4819, Validation size: 1204
Class names: ['airplane', 'bear', 'bicycle', 'bird', 'boat', 'bottle', 'car', 'cat', 'chair', 'clock', 'dog', 'elephant', 'keyboard', 'knife', 'oven', 'truck']


  model.load_state_dict(torch.load(model_path, map_location=device))


VGG16 model loaded and ready for evaluation.
ViT model loaded and ready for evaluation.

Starting distillation process:
Device: cuda
Learning rate: 0.001
Alpha: 0.5
Temperature: 2.0

Initializing teacher model 1
Using default feature extractor

Initializing teacher model 2
Using ViT feature extractor

Epoch [1/8]

Batch 1:
Input shape: torch.Size([16, 3, 224, 224])
Labels shape: torch.Size([16])

Extracting features from teacher 1

Extracting features from teacher 2

Computing distillation loss for teacher 1

Computing distillation loss for teacher 2

Losses for batch 1:
Distillation loss: 0.0000
Classification loss: 3.0366
Total loss: 3.0366

Batch 2:
Input shape: torch.Size([16, 3, 224, 224])
Labels shape: torch.Size([16])

Extracting features from teacher 1

Extracting features from teacher 2

Computing distillation loss for teacher 1

Computing distillation loss for teacher 2

Losses for batch 2:
Distillation loss: 0.0000
Classification loss: 3.1179
Total loss: 3.1179

Batch 3:
Inp

KeyboardInterrupt: 