In [2]:
from torch.utils.data import DataLoader, random_split

def create_data_loaders(dataset_path, batch_size, val_split=0.2, seed=42):
    torch.manual_seed(seed)

    transform = transforms.Compose([
        transforms.Resize((224, 224)), 
        transforms.ToTensor(),        
        # transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        # transforms.Resize((256, 256))
    ])
    
    dataset = datasets.ImageFolder(root=dataset_path, transform=transform)
    val_size = int(len(dataset) * val_split)
    train_size = len(dataset) - val_size
    train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
    
    print(f"Train size: {train_size}, Validation size: {val_size}")
    print(f"Class names: {dataset.classes}")
    
    return train_loader, val_loader

In [27]:
import torch
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
from tqdm import tqdm

def load_fine_tuned_model(model_name, model_path, num_classes=16):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    if model_name == "VGG11":
        model = VGG11(num_classes=num_classes, pretrained=False)
    elif model_name == "VGG16":
        model = VGG16Model(num_classes=num_classes, pretrained=False)
    elif model_name == "ResNet50":
        model = ResNet50Model(num_classes=num_classes, pretrained=False)
    elif model_name == "DenseNet121":
        model = DenseNet121Model(num_classes=num_classes, pretrained=False)
    elif model_name == "ViT":
        model = ViTModel(num_classes=num_classes, pretrained=False)
    elif model_name == "ViTsmall":
        model = SmallViTModel(num_classes=num_classes, pretrained=False)    
    else:
        raise ValueError(f"Unsupported model name: {model_name}")

    model.load_state_dict(torch.load(model_path, map_location=device))
    model = model.to(device)
    model.eval()
    print(f'{model_name} model loaded and ready for evaluation.')
    return model

**Teacher Models**

In [3]:
import torch
from torch import nn
from torchvision.models import vit_b_16

class ViTModel(nn.Module): 
    def __init__(self, num_classes, pretrained=True):
        super(ViTModel, self).__init__()
        self.model = vit_b_16(pretrained=pretrained)
        
        for param in self.model.parameters():
            param.requires_grad = False
            
        in_features = self.model.heads.head.in_features
        self.model.heads.head = nn.Linear(in_features, num_classes)
        for param in self.model.heads.head.parameters():
            param.requires_grad = True

    def forward(self, x):
        return self.model(x)

    def get_model(self):
        return self.model

# num_classes = 10 
# vit_model = ViTModel(num_classes=num_classes, pretrained=True)
# model_instance = vit_model.get_model()
# print(model_instance)

In [4]:
import torch
import torchvision.models as models
import torch.nn as nn

class VGG16Model(nn.Module):
    def __init__(self, num_classes, pretrained=True):
        super(VGG16Model, self).__init__()
        self.model = models.vgg16(pretrained=pretrained)

        for param in self.model.features.parameters():
            param.requires_grad = False
        in_features = self.model.classifier[6].in_features
        self.model.classifier[6] = nn.Linear(in_features, num_classes)
        for param in self.model.classifier[6].parameters():
            param.requires_grad = True

    def forward(self, x):
        return self.model(x)
        
    def get_model(self):
        return self.model
        
# num_classes = 10  
# vgg_model = VGG16Model(num_classes=num_classes, pretrained=True)
# model_instance = vgg_model.get_model()
# print(model_instance)

**Student Model**

In [5]:
import torch
import torch.nn as nn
import timm

class SmallViTModel(nn.Module):
    def __init__(self, num_classes, pretrained=False):
        super(SmallViTModel, self).__init__()
        
        # Create a ViT model using a variant that works for 224x224 images
        # Ensure the model matches the number of classes used in the teacher models
        self.model = timm.create_model('vit_small_patch16_224', pretrained=pretrained)

        # Freeze backbone parameters
        for param in self.model.parameters():
            param.requires_grad = False

        # Modify the classifier to match num_classes (ensure it's consistent with the teacher model)
        in_features = self.model.head.in_features
        self.model.reset_classifier(num_classes=num_classes)
        
        # Ensure the new classifier is trainable
        for param in self.model.head.parameters():
            param.requires_grad = True

    def forward(self, x):
        return self.model(x)

    def get_model(self):
        return self.model


**Fine Tuning**

In [9]:
import torch
import torch.optim as optim
import torch.nn as nn
import os
import zipfile
from tqdm import tqdm
from torchvision import datasets, transforms

def fine_tune_model(model_class, dataset_path, output_path, num_classes, epochs=10, batch_size=32, learning_rate=0.001):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    train_loader, val_loader = create_data_loaders(dataset_path, batch_size)
    model = model_class(num_classes=num_classes, pretrained=True)
    model = model.to(device)
    criterion = nn.CrossEntropyLoss()

    # Optimizer configuration for ViT and ResNet
    if isinstance(model, ViTModel):  # For ViT
        optimizer = optim.Adam(model.get_model().heads.head.parameters(), lr=learning_rate)
    elif isinstance(model, SmallViTModel):  # For SmallViTModel
        optimizer = optim.Adam(model.get_model().get_classifier().parameters(), lr=learning_rate)
    # elif isinstance(model, ResNet50Model):  # For ResNet
    #     optimizer = optim.Adam(model.get_model().fc.parameters(), lr=learning_rate)  # Use 'fc' for ResNet
    else:  # For other models (e.g., DenseNet, etc.)
        optimizer = optim.Adam(model.get_model().classifier.parameters(), lr=learning_rate)

    for epoch in range(epochs):
        model.train()
        train_loss = 0.0
        train_correct = 0
        total_samples = 0

        print(f"Epoch {epoch+1}/{epochs}: Training started...")
        for images, labels in tqdm(train_loader, desc="Training", unit="batch"):
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * images.size(0)
            _, preds = torch.max(outputs, 1)
            train_correct += torch.sum(preds == labels)
            total_samples += images.size(0)

        train_loss_avg = train_loss / total_samples
        train_accuracy = (train_correct.double() / total_samples) * 100
        model.eval()
        val_loss = 0.0
        val_correct = 0
        with torch.no_grad():
            print("Validation started...")
            for images, labels in tqdm(val_loader, desc="Validating", unit="batch"):
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                val_loss += loss.item() * images.size(0)
                _, preds = torch.max(outputs, 1)
                val_correct += torch.sum(preds == labels)

        val_loss_avg = val_loss / len(val_loader.dataset)
        val_accuracy = (val_correct.double() / len(val_loader.dataset)) * 100
        print(f"Epoch {epoch+1}/{epochs}: "
              f"Training Accuracy: {train_accuracy:.4f}, "
              f"Training Loss: {train_loss_avg:.4f} | "
              f"Validation Accuracy: {val_accuracy:.4f}, "
              f"Validation Loss: {val_loss_avg:.4f}")

    model_name = model_class.__name__.lower()
    model_path = os.path.join(output_path, f"{model_name}_fine_tuned_{num_classes}_classes.pth")
    torch.save(model.state_dict(), model_path)

    # torch.save(model.get_model().state_dict(), model_path)
    print(f"Model saved at {model_path}")
    
    zip_path = model_path.replace(".pth", ".zip")
    with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
        zipf.write(model_path, os.path.basename(model_path))
    # os.remove(model_path)
    print(f"Model saved and zipped to {zip_path}")

In [8]:
dataset_path = "/kaggle/input/cue-conflict-splitdata/train" 
output_path = "/kaggle/working/FinetunedModels"
os.makedirs(output_path, exist_ok=True)
model_class = SmallViTModel
fine_tune_model(model_class, dataset_path, output_path, num_classes=16, epochs=2, batch_size=128, learning_rate=0.001)

Train size: 820, Validation size: 204
Class names: ['airplane', 'bear', 'bicycle', 'bird', 'boat', 'bottle', 'car', 'cat', 'chair', 'clock', 'dog', 'elephant', 'keyboard', 'knife', 'oven', 'truck']
Epoch 1/2: Training started...


Training: 100%|██████████| 7/7 [00:04<00:00,  1.57batch/s]


Validation started...


Validating: 100%|██████████| 2/2 [00:01<00:00,  1.19batch/s]


Epoch 1/2: Training Accuracy: 6.0976, Training Loss: 2.8025 | Validation Accuracy: 8.3333, Validation Loss: 2.7480
Epoch 2/2: Training started...


Training: 100%|██████████| 7/7 [00:02<00:00,  2.69batch/s]


Validation started...


Validating: 100%|██████████| 2/2 [00:01<00:00,  1.90batch/s]


Epoch 2/2: Training Accuracy: 12.9268, Training Loss: 2.6774 | Validation Accuracy: 12.2549, Validation Loss: 2.7249
Model saved at /kaggle/working/FinetunedModels/smallvitmodel_fine_tuned_16_classes.pth
Model saved and zipped to /kaggle/working/FinetunedModels/smallvitmodel_fine_tuned_16_classes.zip


In [10]:
dataset_path = "/kaggle/input/labellledimagenet/train" 
output_path = "/kaggle/working/FinetunedModelsCC"  
os.makedirs(output_path, exist_ok=True)
model_class = VGG16Model  
fine_tune_model(model_class, dataset_path, output_path, num_classes=16, epochs=5, batch_size=128, learning_rate=0.001)

Train size: 4819, Validation size: 1204
Class names: ['airplane', 'bear', 'bicycle', 'bird', 'boat', 'bottle', 'car', 'cat', 'chair', 'clock', 'dog', 'elephant', 'keyboard', 'knife', 'oven', 'truck']


Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth
100%|██████████| 528M/528M [00:02<00:00, 198MB/s] 


Epoch 1/5: Training started...


Training: 100%|██████████| 38/38 [00:28<00:00,  1.32batch/s]


Validation started...


Validating: 100%|██████████| 10/10 [00:07<00:00,  1.33batch/s]


Epoch 1/5: Training Accuracy: 82.2370, Training Loss: 0.7517 | Validation Accuracy: 96.4286, Validation Loss: 0.1306
Epoch 2/5: Training started...


Training: 100%|██████████| 38/38 [00:18<00:00,  2.09batch/s]


Validation started...


Validating: 100%|██████████| 10/10 [00:05<00:00,  1.95batch/s]


Epoch 2/5: Training Accuracy: 96.5346, Training Loss: 0.1493 | Validation Accuracy: 95.3488, Validation Loss: 0.2568
Epoch 3/5: Training started...


Training: 100%|██████████| 38/38 [00:18<00:00,  2.05batch/s]


Validation started...


Validating: 100%|██████████| 10/10 [00:05<00:00,  1.93batch/s]


Epoch 3/5: Training Accuracy: 97.7589, Training Loss: 0.1181 | Validation Accuracy: 95.5980, Validation Loss: 0.3114
Epoch 4/5: Training started...


Training: 100%|██████████| 38/38 [00:18<00:00,  2.03batch/s]


Validation started...


Validating: 100%|██████████| 10/10 [00:05<00:00,  1.97batch/s]


Epoch 4/5: Training Accuracy: 98.2984, Training Loss: 0.1213 | Validation Accuracy: 95.5150, Validation Loss: 0.4698
Epoch 5/5: Training started...


Training: 100%|██████████| 38/38 [00:18<00:00,  2.03batch/s]


Validation started...


Validating: 100%|██████████| 10/10 [00:05<00:00,  1.87batch/s]


Epoch 5/5: Training Accuracy: 97.8004, Training Loss: 0.1709 | Validation Accuracy: 93.6877, Validation Loss: 0.8586
Model saved at /kaggle/working/FinetunedModelsCC/vgg16model_fine_tuned_16_classes.pth
Model saved and zipped to /kaggle/working/FinetunedModelsCC/vgg16model_fine_tuned_16_classes.zip


In [11]:
dataset_path = "/kaggle/input/cue-conflict-splitdata/train" 
output_path = "/kaggle/working/FinetunedModels"
os.makedirs(output_path, exist_ok=True)
model_class = ViTModel 
fine_tune_model(model_class, dataset_path, output_path, num_classes=16, epochs=5, batch_size=128, learning_rate=0.001)

Train size: 820, Validation size: 204
Class names: ['airplane', 'bear', 'bicycle', 'bird', 'boat', 'bottle', 'car', 'cat', 'chair', 'clock', 'dog', 'elephant', 'keyboard', 'knife', 'oven', 'truck']


Downloading: "https://download.pytorch.org/models/vit_b_16-c867db91.pth" to /root/.cache/torch/hub/checkpoints/vit_b_16-c867db91.pth
100%|██████████| 330M/330M [00:01<00:00, 226MB/s] 


Epoch 1/5: Training started...


Training: 100%|██████████| 7/7 [00:05<00:00,  1.18batch/s]


Validation started...


Validating: 100%|██████████| 2/2 [00:01<00:00,  1.02batch/s]


Epoch 1/5: Training Accuracy: 11.4634, Training Loss: 2.7557 | Validation Accuracy: 22.0588, Validation Loss: 2.5802
Epoch 2/5: Training started...


Training: 100%|██████████| 7/7 [00:05<00:00,  1.20batch/s]


Validation started...


Validating: 100%|██████████| 2/2 [00:01<00:00,  1.04batch/s]


Epoch 2/5: Training Accuracy: 36.0976, Training Loss: 2.3423 | Validation Accuracy: 33.8235, Validation Loss: 2.3797
Epoch 3/5: Training started...


Training: 100%|██████████| 7/7 [00:05<00:00,  1.21batch/s]


Validation started...


Validating: 100%|██████████| 2/2 [00:01<00:00,  1.02batch/s]


Epoch 3/5: Training Accuracy: 53.0488, Training Loss: 2.0276 | Validation Accuracy: 39.7059, Validation Loss: 2.2170
Epoch 4/5: Training started...


Training: 100%|██████████| 7/7 [00:05<00:00,  1.21batch/s]


Validation started...


Validating: 100%|██████████| 2/2 [00:01<00:00,  1.03batch/s]


Epoch 4/5: Training Accuracy: 63.9024, Training Loss: 1.7726 | Validation Accuracy: 42.6471, Validation Loss: 2.0902
Epoch 5/5: Training started...


Training: 100%|██████████| 7/7 [00:05<00:00,  1.20batch/s]


Validation started...


Validating: 100%|██████████| 2/2 [00:01<00:00,  1.01batch/s]


Epoch 5/5: Training Accuracy: 69.7561, Training Loss: 1.5711 | Validation Accuracy: 43.6275, Validation Loss: 1.9823
Model saved at /kaggle/working/FinetunedModels/vitmodel_fine_tuned_16_classes.pth
Model saved and zipped to /kaggle/working/FinetunedModels/vitmodel_fine_tuned_16_classes.zip


In [22]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from tqdm import tqdm
import zipfile
import os

def logit_matching_loss(student_logits, teacher_logits, student_temperature=1.0, teacher_temperature=1.0):
    student_logits = student_logits - student_logits.mean(dim=-1, keepdim=True)
    teacher_logits = teacher_logits - teacher_logits.mean(dim=-1, keepdim=True)
    student_logits = student_logits / student_temperature
    teacher_logits = teacher_logits / teacher_temperature
    loss = F.mse_loss(student_logits, teacher_logits)
    return loss

def distill(model_class, teacher_model, teacher_model_name, train_loader, val_loader, num_classes, 
            epochs=10, batch_size=32, learning_rate=0.001, 
            student_temperature=2.0, teacher_temperature=2.0, alpha=0.5, save_dir="/kaggle/working/DistilledModels"):

    student_model = model_class(num_classes=num_classes)
    optimizer = optim.Adam(student_model.parameters(), lr=learning_rate)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    student_model = student_model.to(device)
    teacher_model = teacher_model.to(device)
    teacher_model.eval()  
    criterion = nn.CrossEntropyLoss()

    for epoch in range(epochs):
        student_model.train()
        running_loss = 0.0

        with tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}", unit="batch") as pbar:
            for images, labels in pbar:
                images, labels = images.to(device), labels.to(device)
                student_outputs = student_model(images)
                with torch.no_grad():
                    teacher_outputs = teacher_model(images)
                ce_loss = criterion(student_outputs, labels)
                distillation_loss = logit_matching_loss(student_outputs, teacher_outputs, 
                                                        student_temperature=student_temperature,
                                                        teacher_temperature=teacher_temperature)

                total_loss = (1 - alpha) * ce_loss + alpha * distillation_loss
                optimizer.zero_grad()
                total_loss.backward()
                optimizer.step()
                running_loss += total_loss.item()

                pbar.set_postfix(loss=running_loss / len(pbar))  

        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")
        student_model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = student_model(images)
                _, predicted = torch.max(outputs, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        accuracy = 100 * correct / total
        print(f"Validation Accuracy at epoch {epoch+1}: {accuracy:.2f}%")

    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    model_filename = f"{teacher_model_name}_distilled_model.pth" 
    model_path = os.path.join(save_dir, model_filename)
    torch.save(student_model.state_dict(), model_path)
    print(f"Distilled model saved to {model_path}")

    zip_filename = model_filename.replace(".pth", ".zip")
    zip_path = os.path.join(save_dir, zip_filename)
    with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
        zipf.write(model_path, os.path.basename(model_path))
        
    print(f"Distilled model zip file saved to {zip_path}")
    return student_model

dataset_path = "/kaggle/input/labellledimagenet/train"
batch_size = 128

train_loader, val_loader = create_data_loaders(dataset_path, batch_size)
teacher_model_name = "VGG16"  
teacher_model_path = "/kaggle/working/FinetunedModelsCC/vgg16model_fine_tuned_16_classes.pth"  
teacher_model = load_fine_tuned_model(teacher_model_name, teacher_model_path, num_classes=16)
student_model_class = SmallViTModel

distilled_model = distill(
    model_class=student_model_class, 
    teacher_model=teacher_model, 
    teacher_model_name=teacher_model_name,
    train_loader=train_loader,
    val_loader=val_loader,
    num_classes=16,
    epochs=8,
    batch_size=32,
    learning_rate=1e-4,
    student_temperature=1.0,
    teacher_temperature=1.0,
    alpha=0.5
)

Train size: 4819, Validation size: 1204
Class names: ['airplane', 'bear', 'bicycle', 'bird', 'boat', 'bottle', 'car', 'cat', 'chair', 'clock', 'dog', 'elephant', 'keyboard', 'knife', 'oven', 'truck']


  model.load_state_dict(torch.load(model_path, map_location=device))


VGG16 model loaded and ready for evaluation.


Epoch 1/8: 100%|██████████| 38/38 [00:21<00:00,  1.79batch/s, loss=419]

Epoch 1/8, Loss: 418.6106928775185





Validation Accuracy at epoch 1: 18.85%


Epoch 2/8: 100%|██████████| 38/38 [00:21<00:00,  1.79batch/s, loss=412]

Epoch 2/8, Loss: 411.6780925549959





Validation Accuracy at epoch 2: 31.23%


Epoch 3/8: 100%|██████████| 38/38 [00:21<00:00,  1.79batch/s, loss=405]

Epoch 3/8, Loss: 404.75341074090255





Validation Accuracy at epoch 3: 39.78%


Epoch 4/8: 100%|██████████| 38/38 [00:21<00:00,  1.79batch/s, loss=399]

Epoch 4/8, Loss: 398.9636021664268





Validation Accuracy at epoch 4: 45.35%


Epoch 5/8: 100%|██████████| 38/38 [00:21<00:00,  1.78batch/s, loss=393]

Epoch 5/8, Loss: 393.1274470279091





Validation Accuracy at epoch 5: 48.59%


Epoch 6/8: 100%|██████████| 38/38 [00:21<00:00,  1.79batch/s, loss=388]

Epoch 6/8, Loss: 387.51149388363484





Validation Accuracy at epoch 6: 49.92%


Epoch 7/8: 100%|██████████| 38/38 [00:21<00:00,  1.78batch/s, loss=382]

Epoch 7/8, Loss: 381.97849394145766





Validation Accuracy at epoch 7: 50.75%


Epoch 8/8: 100%|██████████| 38/38 [00:21<00:00,  1.79batch/s, loss=377] 

Epoch 8/8, Loss: 376.99485136333266





Validation Accuracy at epoch 8: 51.16%
Distilled model saved to /kaggle/working/DistilledModels/VGG16_distilled_model.pth
Distilled model zip file saved to /kaggle/working/DistilledModels/VGG16_distilled_model.zip


In [30]:
model_name = "ViTsmall"  
model_path = "/kaggle/working/DistilledModels/VGG16_distilled_model.pth" 
eval_dataset_path1 = "/kaggle/input/labellledimagenet/test" 
eval_dataset_path2 = "/kaggle/input/cue-conflict-splitdata/test" 
model = load_fine_tuned_model(model_name=model_name, model_path=model_path, num_classes=16)
evaluate_model(model, dataset_path=eval_dataset_path1, batch_size=32)
evaluate_cue_conflict_dataset(model, eval_dataset_path2, batch_size=32)

  model.load_state_dict(torch.load(model_path, map_location=device))


ViTsmall model loaded and ready for evaluation.
Evaluating on dataset with 1513 samples and classes: ['airplane', 'bear', 'bicycle', 'bird', 'boat', 'bottle', 'car', 'cat', 'chair', 'clock', 'dog', 'elephant', 'keyboard', 'knife', 'oven', 'truck']


Evaluating: 100%|██████████| 48/48 [00:05<00:00,  8.54it/s]


Accuracy on the evaluation dataset: 53.0734 %
Evaluating on dataset with 256 samples and classes: ['airplane', 'bear', 'bicycle', 'bird', 'boat', 'bottle', 'car', 'cat', 'chair', 'clock', 'dog', 'elephant', 'keyboard', 'knife', 'oven', 'truck']


Evaluating: 100%|██████████| 8/8 [00:01<00:00,  5.32it/s]


Final Results:
  Shape Accuracy: 0.0586
  Texture Accuracy: 0.0664
  Cue Accuracy: 0.1250
  Shape Bias: 0.4688
  Texture Bias: 0.5312





In [31]:
dataset_path = "/kaggle/input/labellledimagenet/train"
batch_size = 128

train_loader, val_loader = create_data_loaders(dataset_path, batch_size)
teacher_model_name = "ViT"  
teacher_model_path = "/kaggle/working/FinetunedModels/vitmodel_fine_tuned_16_classes.pth"  
teacher_model = load_fine_tuned_model(teacher_model_name, teacher_model_path, num_classes=16)
student_model_class = SmallViTModel

distilled_model = distill(
    model_class=student_model_class, 
    teacher_model=teacher_model, 
    teacher_model_name=teacher_model_name,
    train_loader=train_loader,
    val_loader=val_loader,
    num_classes=16,
    epochs=8,
    batch_size=32,
    learning_rate=1e-4,
    student_temperature=1.0,
    teacher_temperature=1.0,
    alpha=0.5
)

Train size: 4819, Validation size: 1204
Class names: ['airplane', 'bear', 'bicycle', 'bird', 'boat', 'bottle', 'car', 'cat', 'chair', 'clock', 'dog', 'elephant', 'keyboard', 'knife', 'oven', 'truck']


  model.load_state_dict(torch.load(model_path, map_location=device))


ViT model loaded and ready for evaluation.


Epoch 1/8: 100%|██████████| 38/38 [00:37<00:00,  1.02batch/s, loss=1.45] 

Epoch 1/8, Loss: 1.4545410024492365





Validation Accuracy at epoch 1: 40.37%


Epoch 2/8: 100%|██████████| 38/38 [00:37<00:00,  1.01batch/s, loss=1.25] 

Epoch 2/8, Loss: 1.2458825864289935





Validation Accuracy at epoch 2: 58.55%


Epoch 3/8: 100%|██████████| 38/38 [00:37<00:00,  1.01batch/s, loss=1.18] 

Epoch 3/8, Loss: 1.1766453542207416





Validation Accuracy at epoch 3: 58.97%


Epoch 4/8: 100%|██████████| 38/38 [00:37<00:00,  1.01batch/s, loss=1.16] 

Epoch 4/8, Loss: 1.1560211840428805





Validation Accuracy at epoch 4: 58.97%


Epoch 5/8: 100%|██████████| 38/38 [00:37<00:00,  1.01batch/s, loss=1.15] 

Epoch 5/8, Loss: 1.1489908005061902





Validation Accuracy at epoch 5: 59.05%


Epoch 6/8: 100%|██████████| 38/38 [00:37<00:00,  1.01batch/s, loss=1.14] 

Epoch 6/8, Loss: 1.1437419935276634





Validation Accuracy at epoch 6: 59.05%


Epoch 7/8: 100%|██████████| 38/38 [00:37<00:00,  1.01batch/s, loss=1.14] 

Epoch 7/8, Loss: 1.1400173212352551





Validation Accuracy at epoch 7: 59.05%


Epoch 8/8: 100%|██████████| 38/38 [00:37<00:00,  1.01batch/s, loss=1.14] 

Epoch 8/8, Loss: 1.1377655236344588





Validation Accuracy at epoch 8: 59.05%
Distilled model saved to /kaggle/working/DistilledModels/ViT_distilled_model.pth
Distilled model zip file saved to /kaggle/working/DistilledModels/ViT_distilled_model.zip


In [32]:
model_name = "ViTsmall"  
model_path = "/kaggle/working/DistilledModels/ViT_distilled_model.pth" 
eval_dataset_path1 = "/kaggle/input/labellledimagenet/test" 
eval_dataset_path2 = "/kaggle/input/cue-conflict-splitdata/test" 
model = load_fine_tuned_model(model_name=model_name, model_path=model_path, num_classes=16)
evaluate_model(model, dataset_path=eval_dataset_path1, batch_size=32)
evaluate_cue_conflict_dataset(model, eval_dataset_path2, batch_size=32)

  model.load_state_dict(torch.load(model_path, map_location=device))


ViTsmall model loaded and ready for evaluation.
Evaluating on dataset with 1513 samples and classes: ['airplane', 'bear', 'bicycle', 'bird', 'boat', 'bottle', 'car', 'cat', 'chair', 'clock', 'dog', 'elephant', 'keyboard', 'knife', 'oven', 'truck']


Evaluating: 100%|██████████| 48/48 [00:05<00:00,  8.60it/s]


Accuracy on the evaluation dataset: 57.1051 %
Evaluating on dataset with 256 samples and classes: ['airplane', 'bear', 'bicycle', 'bird', 'boat', 'bottle', 'car', 'cat', 'chair', 'clock', 'dog', 'elephant', 'keyboard', 'knife', 'oven', 'truck']


Evaluating: 100%|██████████| 8/8 [00:00<00:00,  9.01it/s]


Final Results:
  Shape Accuracy: 0.0586
  Texture Accuracy: 0.0586
  Cue Accuracy: 0.1172
  Shape Bias: 0.5000
  Texture Bias: 0.5000





**Weighted Ensembling**

In [34]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from tqdm import tqdm
import zipfile
import os

def logit_matching_loss(student_logits, ensembled_teacher_logits, student_temperature=1.0, teacher_temperature=1.0):
    """Improved KL divergence loss with temperature scaling"""
    student_scaled = student_logits / student_temperature
    teacher_scaled = ensembled_teacher_logits / teacher_temperature
    
    student_log_probs = F.log_softmax(student_scaled, dim=-1)
    teacher_probs = F.softmax(teacher_scaled, dim=-1)
    
    return F.kl_div(student_log_probs, teacher_probs, reduction='batchmean') * (student_temperature ** 2)

def weighted_ensemble_logits(teacher_logits_list, confidence_list):
    """Enhanced ensemble weighting with entropy-based confidence"""
    stacked_logits = torch.stack(teacher_logits_list, dim=0)
    
    # Calculate entropy-based confidence
    probs = [F.softmax(logits, dim=-1) for logits in teacher_logits_list]
    entropies = [-torch.sum(p * torch.log(p + 1e-10), dim=-1) for p in probs]
    confidence_scores = [1.0 / (entropy + 1e-10) for entropy in entropies]
    
    # Normalize confidence scores
    stacked_conf = torch.stack(confidence_scores, dim=0).unsqueeze(-1)
    normalized_conf = (stacked_conf - stacked_conf.mean()) / (stacked_conf.std() + 1e-6)
    weights = F.softmax(normalized_conf * 1.5, dim=0)
    
    return (stacked_logits * weights).sum(dim=0)

def distill_with_confidence_aware_ensemble(model_class, teacher_models, teacher_model_names, train_loader, val_loader, 
                                         num_classes, epochs=10, batch_size=32, learning_rate=0.001, 
                                         student_temperature=4.0, teacher_temperature=4.0, alpha=0.7,
                                         save_dir='/kaggle/working/'):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    student_model = model_class(num_classes=num_classes).to(device)
    teacher_models = [model.to(device).eval() for model in teacher_models]
    
    # Improved optimizer and learning rate scheduler
    optimizer = optim.AdamW(student_model.parameters(), lr=learning_rate, weight_decay=0.01)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    criterion = nn.CrossEntropyLoss()

    best_val_acc = 0
    for epoch in range(epochs):
        student_model.train()
        running_loss = 0.0
        
        with tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}") as pbar:
            for batch_idx, (images, labels) in enumerate(pbar):
                images, labels = images.to(device), labels.to(device)
                
                student_outputs = student_model(images)
                teacher_logits_list = []
                confidence_list = []
                
                with torch.no_grad():
                    for teacher in teacher_models:
                        logits = teacher(images)
                        probs = F.softmax(logits, dim=-1)
                        confidence = probs.max(dim=-1)[0]
                        teacher_logits_list.append(logits)
                        confidence_list.append(confidence)

                ensemble_logits = weighted_ensemble_logits(teacher_logits_list, confidence_list)
                
                ce_loss = criterion(student_outputs, labels)
                distill_loss = logit_matching_loss(student_outputs, ensemble_logits,
                                                 student_temperature, teacher_temperature)
                
                # Dynamic loss weighting
                loss = (1 - alpha) * ce_loss + alpha * distill_loss
                
                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(student_model.parameters(), max_norm=1.0)
                optimizer.step()
                
                running_loss += loss.item()
                pbar.set_postfix(loss=running_loss/(batch_idx + 1))

        scheduler.step()
        
        # Validation
        student_model.eval()
        val_loss = 0.0
        correct = 0
        total = 0
        
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = student_model(images)
                _, predicted = outputs.max(1)
                total += labels.size(0)
                correct += predicted.eq(labels).sum().item()
                val_loss += criterion(outputs, labels).item()

        accuracy = 100.0 * correct / total
        print(f"Epoch {epoch+1}: Train Loss={running_loss/len(train_loader):.4f}, "
              f"Val Loss={val_loss/len(val_loader):.4f}, Val Acc={accuracy:.2f}%")
        
        # Save best model
        if accuracy > best_val_acc:
            best_val_acc = accuracy
            model_path = os.path.join(save_dir, f"{'+'.join(teacher_model_names)}_distilled_best.pth")
            torch.save(student_model.state_dict(), model_path)
    
    return student_model

In [36]:
# Example usage
teacher_model1 = load_fine_tuned_model('VGG16', '/kaggle/working/FinetunedModelsCC/vgg16model_fine_tuned_16_classes.pth', num_classes=16)
teacher_model2 = load_fine_tuned_model('ViT', '/kaggle/working/FinetunedModels/vitmodel_fine_tuned_16_classes.pth', num_classes=16)

teacher_models = [teacher_model1, teacher_model2]
teacher_model_names = ['VGG16', 'ViT']

# Create data loaders
train_loader, val_loader = create_data_loaders('/kaggle/input/labellledimagenet/train', batch_size=32)

# Call the distillation function with multiple teachers
distilled_model = distill_with_confidence_aware_ensemble(
    model_class=SmallViTModel,  # Student model
    teacher_models=teacher_models,
    teacher_model_names=teacher_model_names,
    train_loader=train_loader,
    val_loader=val_loader,
    num_classes=16,
    epochs=8,
    batch_size=32,
    learning_rate=0.001,
    student_temperature=2.0,
    teacher_temperature=2.0,
    alpha=0.5,
    save_dir='/kaggle/working/'
)

  model.load_state_dict(torch.load(model_path, map_location=device))


VGG16 model loaded and ready for evaluation.
ViT model loaded and ready for evaluation.
Train size: 4819, Validation size: 1204
Class names: ['airplane', 'bear', 'bicycle', 'bird', 'boat', 'bottle', 'car', 'cat', 'chair', 'clock', 'dog', 'elephant', 'keyboard', 'knife', 'oven', 'truck']


Epoch 1/8: 100%|██████████| 151/151 [00:49<00:00,  3.03it/s, loss=3.99]


Epoch 1: Train Loss=3.9921, Val Loss=1.7168, Val Acc=58.14%


Epoch 2/8: 100%|██████████| 151/151 [00:49<00:00,  3.04it/s, loss=3.66]


Epoch 2: Train Loss=3.6572, Val Loss=1.7391, Val Acc=57.64%


Epoch 3/8: 100%|██████████| 151/151 [00:49<00:00,  3.03it/s, loss=3.59]


Epoch 3: Train Loss=3.5928, Val Loss=1.7189, Val Acc=58.72%


Epoch 4/8: 100%|██████████| 151/151 [00:49<00:00,  3.04it/s, loss=3.55]


Epoch 4: Train Loss=3.5542, Val Loss=1.7016, Val Acc=58.22%


Epoch 5/8: 100%|██████████| 151/151 [00:49<00:00,  3.04it/s, loss=3.52]


Epoch 5: Train Loss=3.5241, Val Loss=1.6905, Val Acc=57.89%


Epoch 6/8: 100%|██████████| 151/151 [00:49<00:00,  3.04it/s, loss=3.5] 


Epoch 6: Train Loss=3.5005, Val Loss=1.6965, Val Acc=58.39%


Epoch 7/8: 100%|██████████| 151/151 [00:49<00:00,  3.04it/s, loss=3.48]


Epoch 7: Train Loss=3.4809, Val Loss=1.6863, Val Acc=58.39%


Epoch 8/8: 100%|██████████| 151/151 [00:49<00:00,  3.04it/s, loss=3.47]


Epoch 8: Train Loss=3.4688, Val Loss=1.6827, Val Acc=58.39%


In [38]:
model_name = "ViTsmall"  
model_path = "/kaggle/working/VGG16+ViT_distilled_best.pth" 
eval_dataset_path1 = "/kaggle/input/labellledimagenet/test" 
eval_dataset_path2 = "/kaggle/input/cue-conflict-splitdata/test" 
model = load_fine_tuned_model(model_name=model_name, model_path=model_path, num_classes=16)
evaluate_model(model, dataset_path=eval_dataset_path1, batch_size=32)
evaluate_cue_conflict_dataset(model, eval_dataset_path2, batch_size=32)

  model.load_state_dict(torch.load(model_path, map_location=device))


ViTsmall model loaded and ready for evaluation.
Evaluating on dataset with 1513 samples and classes: ['airplane', 'bear', 'bicycle', 'bird', 'boat', 'bottle', 'car', 'cat', 'chair', 'clock', 'dog', 'elephant', 'keyboard', 'knife', 'oven', 'truck']


Evaluating: 100%|██████████| 48/48 [00:05<00:00,  8.65it/s]


Accuracy on the evaluation dataset: 55.7171 %
Evaluating on dataset with 256 samples and classes: ['airplane', 'bear', 'bicycle', 'bird', 'boat', 'bottle', 'car', 'cat', 'chair', 'clock', 'dog', 'elephant', 'keyboard', 'knife', 'oven', 'truck']


Evaluating: 100%|██████████| 8/8 [00:00<00:00,  9.25it/s]


Final Results:
  Shape Accuracy: 0.0586
  Texture Accuracy: 0.0586
  Cue Accuracy: 0.1172
  Shape Bias: 0.5000
  Texture Bias: 0.5000





**Evaluation**

In [24]:

def evaluate_model(model, dataset_path, batch_size=32):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    transform = transforms.Compose([
        transforms.Resize((224, 224)), 
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])
    
    eval_dataset = datasets.ImageFolder(root=dataset_path, transform=transform)
    eval_loader = DataLoader(eval_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
    
    print(f"Evaluating on dataset with {len(eval_dataset)} samples and classes: {eval_dataset.classes}")

    total_samples = 0
    correct_predictions = 0
    
    with torch.no_grad():
        for images, labels in tqdm(eval_loader, desc="Evaluating"):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, preds = torch.max(outputs, 1)
            
            correct_predictions += torch.sum(preds == labels).item()
            total_samples += labels.size(0)
    
    accuracy = (correct_predictions / total_samples) * 100
    print(f"Accuracy on the evaluation dataset: {accuracy:.4f} %")
    
    return accuracy

In [23]:
import os
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
import re

def strip_label_suffix(label):
    return re.sub(r'\d+$', '', label)

def evaluate_cue_conflict_dataset(model, dataset_path, batch_size=32):

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        # transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    eval_dataset = datasets.ImageFolder(root=dataset_path, transform=transform)
    eval_loader = DataLoader(eval_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

    print(f"Evaluating on dataset with {len(eval_dataset)} samples and classes: {eval_dataset.classes}")

    shape_correct = 0
    texture_correct = 0
    total_samples = 0

    cue_correct = 0
    
    with torch.no_grad():
        for images, labels in tqdm(eval_loader, desc="Evaluating"):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, preds = torch.max(outputs, 1)
    
            for i, pred in enumerate(preds):
                filename = eval_dataset.samples[total_samples + i][0]
                shape_texture = os.path.basename(filename).split("-")
                shape_label = strip_label_suffix(shape_texture[0]) 
                texture_label = strip_label_suffix(shape_texture[1][:-4])  
                pred_class = eval_dataset.classes[pred]
    
                if pred_class == shape_label:
                    shape_correct += 1
                    cue_correct += 1  # Count once for shape match
                elif pred_class == texture_label:  # Use elif to avoid double-counting
                    texture_correct += 1
                    cue_correct += 1
            total_samples += len(preds)
    
    shape_accuracy = shape_correct / total_samples if total_samples > 0 else 0
    texture_accuracy = texture_correct / total_samples if total_samples > 0 else 0
    cue_accuracy = cue_correct / total_samples if total_samples > 0 else 0
    shape_bias = shape_accuracy / cue_accuracy if cue_accuracy > 0 else 0
    texture_bias = 1 - shape_bias


    print("\nFinal Results:")
    print(f"  Shape Accuracy: {shape_accuracy:.4f}")
    print(f"  Texture Accuracy: {texture_accuracy:.4f}")
    print(f"  Cue Accuracy: {cue_accuracy:.4f}")
    print(f"  Shape Bias: {shape_bias:.4f}")
    print(f"  Texture Bias: {texture_bias:.4f}")

    # return {
    #     "Shape Accuracy": shape_accuracy,
    #     "Texture Accuracy": texture_accuracy,
    #     "Cue Accuracy": cue_accuracy,
    #     "Shape Bias": shape_bias,
    #     "Texture Bias": texture_bias
    # }

In [None]:
model_name = "ViT"  
model_path = "/kaggle/working/FinetunedModels/vitmodel_fine_tuned_16_classes.pth" 
# eval_dataset_path1 = "/kaggle/input/labelledimagenet" 
eval_dataset_path2 = "/kaggle/input/cue-conflict-splitdata/test" 
model = load_fine_tuned_model(model_name=model_name, model_path=model_path, num_classes=16)
# evaluate_model(model, dataset_path=eval_dataset_path1, batch_size=32)
evaluate_cue_conflict_dataset(model, eval_dataset_path2, batch_size=32)