# Loading Libraries, Models, Datasets

In [1]:
!pip install timm



In [2]:
import os
import math
import torch
import random
import shutil
import zipfile
import numpy as np
from tqdm import tqdm
from copy import deepcopy
from PIL import Image
from pathlib import Path
import matplotlib.pyplot as plt
from google.colab import drive,files


import torch.nn as nn
import torch.optim as optim
from timm import create_model
import torch.nn.functional as F
from torchvision.models import resnet50
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10, ImageFolder
from torch.utils.data import TensorDataset, DataLoader

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def create_vit_model(num_classes=10):
    """Create a ViT model with a dynamic number of output classes."""
    model = create_model('vit_base_patch16_224', pretrained=True, num_classes=num_classes)

    #Freeze all layers in the VIT model
    for param in model.parameters():
        param.requires_grad = False

    model.head = torch.nn.Linear(model.head.in_features, num_classes)  # Replace the classifier head

    # Get features model (backbone)
    # print(nn.Sequential(*list(model.children())))
    # print("--------------------------------(--------------------------------------)------------")
    vit_features_model = nn.Sequential(*list(model.children())[:-1])  # Exclude the classifier head
    # print(vit_features_model)
    # print("-----------------------------------(--------------------------------------)---------")

    return model, vit_features_model

def create_resnet_model(num_classes=10):
    """Create a RESNET model with a dynamic number of output classes."""
    model = resnet50(pretrained=True)

    # Freeze all layers in the RESNET model
    for param in model.parameters():
        param.requires_grad = False

    num_features = model.fc.in_features       # Get the number of input features of the last layer
    model.fc = nn.Linear(num_features, num_classes)      # Replace the last fully connected layer with a new one with `num_classes` outputs

    # Get features model (backbone)
    # print(nn.Sequential(*list(model.children())))
    # print("-----------------------(--------------------------------------)---------------")

    features_model = nn.Sequential(*list(model.children())[:-2])  # Extract all layers except the last fully connected layer and the global adaptive average pool 2d layer
    # print(features_model)
    # print("----------------------------(--------------------------------------)----------")
    return model, features_model


# def get_model(name):
#   if name == 'student':
#     resnet_model, resnet_features_model = create_resnet_model()
#     return resnet_model, resnet_features_model
#   elif name == 'teacher':
#     vit_model, vit_features_model = create_vit_model()
#     return vit_model, vit_features_model
#   else:
#     raise ValueError(f"Unsupported model name: {name}")

def get_model(name):
  if name == 'student':
    # Student is now ViT
    vit_model, vit_features_model = create_vit_model()
    return vit_model, vit_features_model
  
  elif name == 'teacher':
    # Teacher is now ResNet
    resnet_model, resnet_features_model = create_resnet_model()
    return resnet_model, resnet_features_model
  
  else:
    raise ValueError(f"Unsupported model name: {name}")



def finetune_model(model, train_loader, num_epochs=10, alpha=1e-3):
    """Fine-tune the model on the training data."""
    print("Finetuning the model on this data")
    model.train()
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=alpha)
    num_batches = len(train_loader)

    for epoch in range(num_epochs):
        running_loss = 0.0
        # running_accuracy = 0.0
        correct = 0
        total = 0
        for batch_index, (inputs, labels) in enumerate(train_loader):

            batch_size = inputs.size(0)
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

            # Calculate accuracy
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)

            if batch_index % 300 == 0 or batch_index == 750:
                print(f'Batch {batch_index + 1}/{num_batches} - Loss: {loss.item():.4f} Accuracy: {(correct/total) * 100:.2f}%')

        epoch_loss = running_loss / len(train_loader)
        epoch_accuracy =  correct / total
        print(f'------------------------->Epoch [{epoch + 1}/{num_epochs}], Loss: {epoch_loss:.4f} Accuracy : {epoch_accuracy * 100:.2f}%')

    return epoch_loss, epoch_accuracy

def evaluate_model(model, data_loader,device):
    """Evaluate the models using a basic testing loop"""
    print("Evaluating the model")
    model.to(device)
    model.eval()
    correct = 0
    total = 0
    num_batches = len(data_loader)  # Get total number of batches

    with torch.no_grad():
        for batch_index, (images, labels) in enumerate(data_loader):

            batch_size = images.size(0)
            if batch_size < 32:
              print(f"skipping last batch of size {batch_size} --- gives shape error")
              break

            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)

            # Print progress every 10 batches (you can adjust this as needed)
            if batch_index % 10 == 0:
                print(f'Batch {batch_index + 1}/{num_batches} - Accuracy: { (correct/total) * 100:.2f}%')

    accuracy = correct / total
    print(f'Final Accuracy: {accuracy * 100:.2f}%')
    return accuracy

In [3]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # works for all models
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.2435, 0.2616))  # CIFAR-10 mean and std
])

mean, std = (0.4914, 0.4822, 0.4465), (0.247, 0.2435, 0.2616)

train_dataset = CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_dataset = CIFAR10(root='./data', train=False, download=True, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# print(f'Total number of images in CIFAR 10 test_loader: {len(test_dataset)}')

cifar10_classes = [ 'airplane', 'automobile',  'ship', 'truck' ,'bird', 'cat', 'deer', 'dog', 'frog', 'horse']
num_classes = len(cifar10_classes)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:01<00:00, 103265683.04it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


# Knowledge Distillation Utils

In [5]:
def get_features_and_logits(model,feature_model,inputs,student=True, crd=False):
    #will return the same logits as long as model.eval()  is set (i.e. no learning is happening hence same LOGITs)
    if not student:
      model.eval()
      feature_model.eval()
      # vit_transform = transforms.Compose([
      #     transforms.Resize((224, 224)),
          # transforms.ToTensor()
      # ])
      # inputs = vit_transform(inputs)
      features = feature_model(inputs)
      # if crd:
      return features,None    #teacher logits NOT required in CRD / Hint
      # logits = model.classifier(features.view(features.size(0), -1))  #to save on time (as teacher doesnt need to LEARN)
    else:
      model.train()
      features = feature_model(inputs)
      logits = model(inputs) #actual student model does need to LEARN hence repeating it twice

    # print("feature shape",features.shape,"logits shape",logits.shape)
    return features, logits

def get_channel_num(model,name):


    if name == 'resnet':
      return 2048
      channel_nums = []
      for layer in model.modules():        # find all Conv2d layers in the model
          if isinstance(layer, nn.Conv2d):
              channel_nums.append(layer.out_channels)

      # Check if there are enough Conv2d layers
      if len(channel_nums) < 2:
          raise ValueError("Model has fewer than two Conv2d layers.")

      # print(channel_nums,name)
      return channel_nums[-2]      # Return the output channels of the second-to-last Conv2d layer

    elif name == 'vit':
        return 150528
        # For ViT, return the embedding dimension (projection size)
        for layer in model.modules():
            # print(layer)
            if isinstance(layer, nn.Linear):
                # Typically, the first Linear layer defines the embedding size
                return layer.out_features

        raise ValueError("No Linear layers found in the ViT model.")

    else:
        raise ValueError(f"Unsupported model type: {name}")




In [None]:
student_model, features_student = get_model('student')
teacher_model, features_teacher = get_model('teacher')
x,y = get_channel_num(student_model,"vit") , get_channel_num(teacher_model,"resnet")
x,"resnet", y, "vit"

In [6]:
def train_student(student, teacher, train_loader, kd, num_epochs=10, alpha=1e-3,  features_student=None, features_teacher=None):
    """Train the model on the training data."""
    student.train()
    teacher.eval()
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(student.parameters(), lr=alpha)
    num_batches = len(train_loader)

    if kd is None:
      print("Training the student on this data independently")
    else:
      print("Distilling the knowledge and training the student on this data",kd.__class__.__name__)

    for epoch in range(num_epochs):
        running_loss = 0.0
        running_accuracy = 0.0
        correct = 0
        total = 0
        for batch_index, (inputs, labels) in enumerate(train_loader):

            batch_size = inputs.size(0)
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()

            if kd is None:      #independent student
              outputs = student(inputs)
              distillation_loss = 0.0
            else:
              if kd.__class__.__name__ in [ 'LogitMatching' , 'LabelSmoothingRegularization', 'DecoupledLogitMatching'] :

                  #get teacher's logits and student's logits
                  with torch.no_grad():
                    teacher_outputs = teacher(inputs)
                  outputs = student(inputs)

                  if kd.__class__.__name__ == 'LogitMatching':
                    distillation_loss = kd(outputs, teacher_outputs)
                  elif kd.__class__.__name__ == 'LabelSmoothingRegularization':
                   distillation_loss = kd(outputs,labels)
                  else:     #DecoupledLogitMatching
                    distillation_loss = kd(outputs, teacher_outputs, labels)

              elif kd.__class__.__name__ == 'FeatureMatching':
                  #get features of student , teacher and logits of student
                  feat_s, outputs = get_features_and_logits(student,features_student,inputs)
                  feat_t, _ = get_features_and_logits(teacher,features_teacher,inputs,student=False)
                  distillation_loss = kd(feat_s, feat_t)

            classification_loss = criterion(outputs, labels)
            total_loss = classification_loss + distillation_loss
            total_loss.backward()
            optimizer.step()

            running_loss += total_loss.item()
            # Calculate accuracy
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)

            if batch_index % 200 == 0 or batch_index == 750:
                print(f'Batch {batch_index + 1}/{num_batches} - Loss: {total_loss.item():.4f} Accuracy: {(correct/total) * 100:.2f}%')

        epoch_loss = running_loss / len(train_loader)
        epoch_accuracy = correct/total
        print(f'------------------------->Epoch [{epoch + 1}/{num_epochs}], Loss: {epoch_loss:.4f} Accuracy : {epoch_accuracy * 100:.2f}%')

    return epoch_loss, epoch_accuracy

In [7]:
def plot_accuracies(accuracies, model_labels, title='Model Accuracies'):
    """
    Plots the accuracies of different models.

    Parameters:
    accuracies (list of float): A list of accuracy values for each model.
    model_labels (list of str): A list of labels corresponding to each model.

    """
    # Ensure the inputs are valid
    if len(accuracies) != len(model_labels):
        raise ValueError("The number of accuracies must match the number of model labels.")

    # Create the plot
    plt.figure(figsize=(8, 6))
    plt.bar(model_labels, accuracies)

    # Adding title and labels
    plt.title(title)
    plt.xlabel('Models')
    plt.ylabel('Accuracy')

    # Optionally, add accuracy values on top of the bars
    for i, acc in enumerate(accuracies):
        plt.text(i, acc, f'{acc:.2f}', ha='center', va='bottom')

    # Show the plot
    plt.xticks(rotation=45)  # Rotate x labels for better visibility
    plt.tight_layout()
    plt.show()


In [11]:
class CIFAR10IdxSample(CIFAR10):          #loaded for Knowledge Distillation with CRD
	def __init__(self, root, train=True, transform=None, target_transform=None, download=False, n=4096, mode='exact', percent=1.0):
		super().__init__(root=root, train=train, download=download,transform=transform, target_transform=target_transform)
		self.n = n
		self.mode = mode

		num_classes = 10
		num_samples = len(self.data)
		labels = self.targets

		self.cls_positive = [[] for _ in range(num_classes)]
		for i in range(num_samples):
			self.cls_positive[labels[i]].append(i)

		self.cls_negative = [[] for _ in range(num_classes)]
		for i in range(num_classes):
			for j in range(num_classes):
				if j == i:
					continue
				self.cls_negative[i].extend(self.cls_positive[j])

		self.cls_positive = [np.asarray(self.cls_positive[i]) for i in range(num_classes)]
		self.cls_negative = [np.asarray(self.cls_negative[i]) for i in range(num_classes)]

		if 0 < percent < 1:
			num = int(len(self.cls_negative[0]) * percent)
			self.cls_negative = [np.random.permutation(self.cls_negative[i])[0:num]
								 for i in range(num_classes)]

		self.cls_positive = np.asarray(self.cls_positive)
		self.cls_negative = np.asarray(self.cls_negative)

	def __getitem__(self, index):
		img, target = self.data[index], self.targets[index]

		img = Image.fromarray(img)
		if self.transform is not None:
			img = self.transform(img)

		if self.target_transform is not None:
			target = self.target_transform(target)

		if self.mode == 'exact':
			pos_idx = index
		elif self.mode == 'relax':
			pos_idx = np.random.choice(self.cls_positive[target], 1)[0]
		else:
			raise NotImplementedError(self.mode)
		replace = True if self.n > len(self.cls_negative[target]) else False
		neg_idx = np.random.choice(self.cls_negative[target], self.n, replace=replace)
		sample_idx = np.hstack((np.asarray([pos_idx]), neg_idx))

		return img, target, index, sample_idx

# Logit Matching and Contrastive

In [12]:
class LogitMatching(nn.Module):  #smoothes teacher logits as well (going from Peaky dist to Uniform dist by increasing Temperature T)
	def __init__(self, T=3.0):
		super(LogitMatching, self).__init__()
		self.T = T      #Temperature for Soft Targets ... lower the T, less the smoothing, less the error

	def forward(self, out_s, out_t):
    # loss = F.mse_loss(out_s, out_t)
		loss = self.T * self.T * F.kl_div(
                                        F.log_softmax(out_s/self.T, dim=1),
                                        F.softmax(out_t/self.T, dim=1),
                                        reduction='batchmean'
          )  #cross entropy loss of softened probabilities

		return loss

class CRD(nn.Module):
	'''
	Contrastive Representation Distillation
	https://openreview.net/pdf?id=SkgpBJrtvS

	includes two symmetric parts:
	(a) using teacher as anchor, choose positive and negatives over the student side
	(b) using student as anchor, choose positive and negatives over the teacher side

	Args:
		s_dim: the dimension of student's feature
		t_dim: the dimension of teacher's feature
		feat_dim: the dimension of the projection space
		nce_n: number of negatives paired with each positive
		nce_t: the temperature
		nce_mom: the momentum for updating the memory buffer
		n_data: the number of samples in the training set, which is the M in Eq.(19)
	'''
	def __init__(self, s_dim, t_dim, n_data, feat_dim=128, nce_n=4096, nce_t=0.1, nce_mom=0.5):
		super(CRD, self).__init__()
		self.embed_s = Embed(s_dim, feat_dim)
		self.embed_t = Embed(t_dim, feat_dim)
		self.contrast = ContrastMemory(feat_dim, n_data, nce_n, nce_t, nce_mom)
		self.criterion_s = ContrastLoss(n_data)
		self.criterion_t = ContrastLoss(n_data)

	def forward(self, feat_s, feat_t, idx, sample_idx):
		feat_s = self.embed_s(feat_s)
		feat_t = self.embed_t(feat_t)
		out_s, out_t = self.contrast(feat_s, feat_t, idx, sample_idx)
		loss_s = self.criterion_s(out_s)
		loss_t = self.criterion_t(out_t)
		loss = loss_s + loss_t

		return loss


class Embed(nn.Module):
	def __init__(self, in_dim, out_dim):
		super(Embed, self).__init__()
		self.linear = nn.Linear(in_dim, out_dim)

	def forward(self, x):
		x = x.view(x.size(0), -1).to(self.linear.weight.device)
	 	# x = x.view(x.size(0), -1).to(self.linear.weight.device)
		x = self.linear(x)
		x = F.normalize(x, p=2, dim=1)

		return x


class ContrastLoss(nn.Module):
	'''
	contrastive loss, corresponding to Eq.(18)
	'''
	def __init__(self, n_data, eps=1e-7):
		super(ContrastLoss, self).__init__()
		self.n_data = n_data
		self.eps = eps

	def forward(self, x):
		bs = x.size(0)
		N  = x.size(1) - 1
		M  = float(self.n_data)

		# loss for positive pair
		pos_pair = x.select(1, 0)
		log_pos  = torch.div(pos_pair, pos_pair.add(N / M + self.eps)).log_()

		# loss for negative pair
		neg_pair = x.narrow(1, 1, N)
		log_neg  = torch.div(neg_pair.clone().fill_(N / M), neg_pair.add(N / M + self.eps)).log_()

		loss = -(log_pos.sum() + log_neg.sum()) / bs

		return loss


class ContrastMemory(nn.Module):
    def __init__(self, feat_dim, n_data, nce_n, nce_t, nce_mom):
        super(ContrastMemory, self).__init__()
        self.N = nce_n
        self.T = nce_t
        self.momentum = nce_mom
        self.Z_t = None
        self.Z_s = None

        stdv = 1. / math.sqrt(feat_dim / 3.)
        self.register_buffer('memory_t', torch.rand(n_data, feat_dim).mul_(2 * stdv).add_(-stdv))
        self.register_buffer('memory_s', torch.rand(n_data, feat_dim).mul_(2 * stdv).add_(-stdv))

    def forward(self, feat_s, feat_t, idx, sample_idx):
        bs = feat_s.size(0)
        feat_dim = self.memory_s.size(1)
        n_data = self.memory_s.size(0)
        sample_idx = sample_idx.to(self.memory_s.device)

        # using teacher as anchor
        weight_s = torch.index_select(self.memory_s, 0, sample_idx.view(-1)).detach()
        weight_s = weight_s.view(bs, self.N + 1, feat_dim)
        out_t = torch.bmm(weight_s, feat_t.view(bs, feat_dim, 1)).squeeze().contiguous()
        out_t = torch.exp(out_t / self.T)

        # using student as anchor
        weight_t = torch.index_select(self.memory_t, 0, sample_idx.view(-1)).detach()
        weight_t = weight_t.view(bs, self.N + 1, feat_dim)
        out_s = torch.bmm(weight_t, feat_s.view(bs, feat_dim, 1)).squeeze().contiguous()
        out_s = torch.exp(out_s / self.T)

        # set Z if haven't been set yet
        if self.Z_t is None:
            self.Z_t = (out_t.mean() * n_data).detach().item()
        if self.Z_s is None:
            self.Z_s = (out_s.mean() * n_data).detach().item()

        out_t = torch.div(out_t, self.Z_t)
        out_s = torch.div(out_s, self.Z_s)

        # update memory
        with torch.no_grad():
            idx = idx.to(self.memory_t.device)
            pos_mem_t = torch.index_select(self.memory_t, 0, idx.view(-1))
            pos_mem_t.mul_(self.momentum)
            pos_mem_t.add_(torch.mul(feat_t, 1 - self.momentum))
            pos_mem_t = F.normalize(pos_mem_t, p=2, dim=1)
            self.memory_t.index_copy_(0, idx, pos_mem_t)

            pos_mem_s = torch.index_select(self.memory_s, 0, idx.view(-1))
            pos_mem_s.mul_(self.momentum)
            pos_mem_s.add_(torch.mul(feat_s, 1 - self.momentum))
            pos_mem_s = F.normalize(pos_mem_s, p=2, dim=1)
            self.memory_s.index_copy_(0, idx, pos_mem_s)

        return out_s, out_t


In [13]:
train_crd_transform = transforms.Compose([
    transforms.Resize((224, 224)),   # <-- add this
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std)
])

train_crd_dataset = CIFAR10IdxSample

train_crd_loader = torch.utils.data.DataLoader(
			train_crd_dataset(
              root  = './datasets',
						  transform = train_crd_transform,
						  train     = True,
						  download  = True,
						  n         = 4096,
						  mode      = 'exact'
      ),
			batch_size=32, shuffle=True, num_workers=4, pin_memory=True)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./datasets/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:11<00:00, 14751851.82it/s]


Extracting ./datasets/cifar-10-python.tar.gz to ./datasets


In [14]:
def train_student_crd(student, features_student, teacher, features_teacher, train_loader, kd, num_epochs=10, alpha=1e-3):
    """Train the model on the training data for CRD """
    print("Distilling the knowledge and training the student on this data for CRD")
    student.train()
    teacher.eval()
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(student.parameters(), lr=alpha)
    num_batches = len(train_loader)

    for epoch in range(num_epochs):
      running_loss = 0.0
      correct = 0
      total = 0
      for batch_index, (img, target, idx, sample_idx ) in enumerate(train_loader, start=1):
        # print(img.shape, target.shape, idx.shape, sample_idx.shape)
        img = img.to(device)
        target = target.to(device)
        idx = idx.to(device)
        sample_idx = sample_idx.to(device)

        #get features and logits
        # print(img.shape)
        feat_s, logit_s = get_features_and_logits(student,features_student,img)
        feat_t, _ = get_features_and_logits(teacher,features_teacher,img,student=False, crd=True)


        feat_s = feat_s.to(device)
        feat_t = feat_t.to(device)
        logit_s = logit_s.to(device)


        #get total loss
        kd_lambda = 1   # 0.2   #for CRD
        # print(feat_t.shape, feat_s.shape)
        kd_loss  = kd(feat_s, feat_t, idx, sample_idx) * kd_lambda
        cls_loss = criterion(logit_s, target)
        total_loss = cls_loss + kd_loss
        # print("total loss",total_loss, "crd + cls", kd_loss , cls_loss)


        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

        running_loss += total_loss.item()

        # # Calculate accuracy
        # _, predicted = torch.max(logit_s, 1)
        # correct += (predicted == labels).sum().item()
        # total += labels.size(0)

        if batch_index % 200 == 0 or batch_index in [1,750]:
            print(f'Batch {batch_index }/{num_batches} - Loss: {total_loss.item():.4f} ')
            # Accuracy: {(correct/total) * 100:.2f}%')

      epoch_loss = running_loss / len(train_loader)
      # epoch_accuracy = correct / total
      print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {epoch_loss:.4f}')
    # return epoch_loss, epoch_accuracy
    return epoch_loss


# criterion_CRD  = CRD(s_dim=get_channel_num(student_model,'vit'), t_dim=get_channel_num(teacher_model,'resnet'), n_data=len(train_loader.dataset))
# train_student_crd(student_model, features_student, teacher_model, features_teacher, train_crd_loader, criterion_CRD, num_epochs=5)

# Texture Bias Dataset

In [5]:
# Define the dataset path
base_dir = '/kaggle/input/texture-bias-dataset/'  # Update this path based on your Kaggle dataset location

# Function to get all partition paths
def get_partition_paths(base_dir):
    # Navigate to the deeper subdirectory containing images and labels.txt
    return [os.path.join(base_dir, f"partition{i}", f"partition{i}") for i in range(1, 11)]

# Function to load images and labels from a single partition
def load_partition_data(partition_path):
    images = []
    labels = []
    
    # Path to the labels.txt file
    label_file = os.path.join(partition_path, 'labels.txt')
    
    # Read the labels.txt file
    with open(label_file, 'r') as f:
        for line in f:
            img_name, label = line.strip().split()  # Split filename and label
            img_path = os.path.join(partition_path, img_name)  # Full path to image
            label = int(label)  # Convert label to integer
            
            # Load the image
            img = Image.open(img_path).convert('RGB')  # Ensure RGB format
            
            images.append((img_path, img))  # Store tuple of (image path, image object)
            labels.append(label)
    
    return images, labels

# Load data from all partitions
all_images = []  # To store tuples of (image path, image data)
all_labels = []  # To store corresponding labels

for partition_path in get_partition_paths(base_dir):
    images, labels = load_partition_data(partition_path)
    all_images.extend(images)
    all_labels.extend(labels)

# Print some debug info
print(f"Loaded {len(all_images)} images and {len(all_labels)} labels.")
print(f"Example Image Path: {all_images[0][0]}, Label: {all_labels[0]}")

# # Preprocessing for PyTorch
# transform = transforms.Compose([
#     transforms.Resize((32, 32)),  # CIFAR-10 images are 32x32
#     transforms.ToTensor(),       # Convert image to tensor
# ])

# Apply transformations to images
processed_images = [transform(img[1]) for img in all_images]  # img[1] contains the actual image data
all_labels = torch.tensor(all_labels)  # Convert labels to tensor

# Create a PyTorch DataLoader
texture_bias_dataset = TensorDataset(torch.stack(processed_images), all_labels)
texture_bias_loader = DataLoader(texture_bias_dataset, batch_size=32, shuffle=False)

Loaded 10000 images and 10000 labels.
Example Image Path: /kaggle/input/texture-bias-dataset/partition1/partition1/img_0000.png, Label: 3


# Finetuning Teacher(Resnet50) on Cifar10 dataset


In [15]:
# 1. Get the teacher model (ResNet) and its features
teacher_model, features_teacher = get_model('teacher')

# Move the teacher model to the correct device
teacher_model = teacher_model.to(device)


# 2. Finetune the teacher on CIFAR-10
teacher_loss, teacher_accuracy = finetune_model(
    teacher_model, 
    train_loader,     # CIFAR-10 train_loader
    num_epochs=10, 
    alpha=1e-3
)

print(f"Teacher finetuned on CIFAR-10 with final loss={teacher_loss:.4f}, accuracy={teacher_accuracy*100:.2f}%")

torch.save(teacher_model.state_dict(), 'teacher_model_weights_after_finetuning.pth')

print("Teacher model weights saved successfully!")






Finetuning the model on this data
Batch 1/1563 - Loss: 2.3532 Accuracy: 12.50%
Batch 301/1563 - Loss: 1.1407 Accuracy: 65.32%
Batch 601/1563 - Loss: 0.7251 Accuracy: 70.28%
Batch 751/1563 - Loss: 0.6909 Accuracy: 71.39%
Batch 901/1563 - Loss: 1.1071 Accuracy: 71.98%
Batch 1201/1563 - Loss: 0.8286 Accuracy: 73.11%
Batch 1501/1563 - Loss: 0.7037 Accuracy: 73.95%
------------------------->Epoch [1/10], Loss: 0.7639 Accuracy : 74.09%
Batch 1/1563 - Loss: 0.6780 Accuracy: 75.00%
Batch 301/1563 - Loss: 0.4301 Accuracy: 78.04%
Batch 601/1563 - Loss: 1.0053 Accuracy: 78.02%
Batch 751/1563 - Loss: 0.5594 Accuracy: 78.00%
Batch 901/1563 - Loss: 0.8027 Accuracy: 78.04%
Batch 1201/1563 - Loss: 0.3043 Accuracy: 78.23%
Batch 1501/1563 - Loss: 0.6332 Accuracy: 78.32%
------------------------->Epoch [2/10], Loss: 0.6251 Accuracy : 78.38%
Batch 1/1563 - Loss: 0.6996 Accuracy: 68.75%
Batch 301/1563 - Loss: 0.4643 Accuracy: 79.51%
Batch 601/1563 - Loss: 0.7940 Accuracy: 79.00%
Batch 751/1563 - Loss: 0.47

# Evaluating Teacher on Texture Bias Dataset

In [6]:
# ---------------code for loading wieghts of the finetuned teacher model--------------------------
# # Load the state_dict into the model
# 1. Get the teacher model (ResNet) and its features
teacher_model, features_teacher = get_model('teacher')

# Move the teacher model to the correct device
teacher_model = teacher_model.to(device)

teacher_model.load_state_dict(torch.load('/kaggle/input/after-finetuning-teacher/teacher_model_weights_after_finetuning.pth'))

# Set the model to evaluation mode if you're using it for inference
teacher_model.eval()

print("Teacher model weights loaded successfully!")

Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 204MB/s]
  teacher_model.load_state_dict(torch.load('/kaggle/input/after-finetuning-teacher/teacher_model_weights_after_finetuning.pth'))


Teacher model weights loaded successfully!


In [16]:
# ---------------code for loading wieghts of the finetuned teacher model--------------------------
# # Load the state_dict into the model
# teacher_model.load_state_dict(torch.load('teacher_model_weights_after_finetuning.pth'))

# # Set the model to evaluation mode if you're using it for inference
# teacher_model.eval()

# print("Teacher model weights loaded successfully!")


texture_teacher_acc = evaluate_model(teacher_model, texture_bias_loader, device)
print(f"Teacher accuracy on Texture-Bias dataset = {texture_teacher_acc*100:.2f}%")


Evaluating the model
Batch 1/313 - Accuracy: 12.50%
Batch 11/313 - Accuracy: 17.90%
Batch 21/313 - Accuracy: 15.92%
Batch 31/313 - Accuracy: 16.23%
Batch 41/313 - Accuracy: 15.47%
Batch 51/313 - Accuracy: 15.07%
Batch 61/313 - Accuracy: 14.75%
Batch 71/313 - Accuracy: 14.92%
Batch 81/313 - Accuracy: 15.08%
Batch 91/313 - Accuracy: 15.08%
Batch 101/313 - Accuracy: 14.82%
Batch 111/313 - Accuracy: 14.86%
Batch 121/313 - Accuracy: 14.90%
Batch 131/313 - Accuracy: 15.12%
Batch 141/313 - Accuracy: 15.18%
Batch 151/313 - Accuracy: 15.00%
Batch 161/313 - Accuracy: 15.06%
Batch 171/313 - Accuracy: 15.08%
Batch 181/313 - Accuracy: 15.18%
Batch 191/313 - Accuracy: 15.09%
Batch 201/313 - Accuracy: 15.02%
Batch 211/313 - Accuracy: 15.20%
Batch 221/313 - Accuracy: 15.19%
Batch 231/313 - Accuracy: 15.12%
Batch 241/313 - Accuracy: 15.22%
Batch 251/313 - Accuracy: 15.35%
Batch 261/313 - Accuracy: 15.16%
Batch 271/313 - Accuracy: 15.16%
Batch 281/313 - Accuracy: 15.11%
Batch 291/313 - Accuracy: 15.05%


# Knowledge Distillation (Logit matching)

In [17]:
# Suppose you define kd = LogitMatching(T=3.0)
student_lm, features_lm = get_model('student')  # ViT

student_lm = student_lm.to(device)

logit_matching_loss = LogitMatching(T=3.0)       # or whatever T you want

# distill knowledge
train_student(
    student=student_lm,
    teacher=teacher_model,
    train_loader=train_loader,  # still CIFAR-10
    kd=logit_matching_loss,
    num_epochs=10,
    alpha=1e-3,
    features_student=features_lm,
    features_teacher=features_teacher
)

# Save the teacher model weights
torch.save(teacher_model.state_dict(), 'teacher_model_weights.pth')

# Save the student model weights
torch.save(student_lm.state_dict(), 'student_model_weights.pth')

print("Teacher and student model weights saved successfully!")

# Optionally save the features (if they need to be reused)
torch.save(features_teacher, 'features_teacher.pth')
torch.save(features_lm, 'features_student.pth')

print("Features saved successfully!")




Distilling the knowledge and training the student on this data LogitMatching
Batch 1/1563 - Loss: 12.1133 Accuracy: 6.25%
Batch 201/1563 - Loss: 1.3886 Accuracy: 88.60%
Batch 401/1563 - Loss: 1.9707 Accuracy: 90.65%
Batch 601/1563 - Loss: 1.5134 Accuracy: 91.44%
Batch 751/1563 - Loss: 1.6779 Accuracy: 91.75%
Batch 801/1563 - Loss: 1.4647 Accuracy: 91.84%
Batch 1001/1563 - Loss: 1.5533 Accuracy: 92.12%
Batch 1201/1563 - Loss: 1.3349 Accuracy: 92.26%
Batch 1401/1563 - Loss: 1.7039 Accuracy: 92.41%
------------------------->Epoch [1/10], Loss: 1.7205 Accuracy : 92.50%
Batch 1/1563 - Loss: 1.6260 Accuracy: 93.75%
Batch 201/1563 - Loss: 1.3967 Accuracy: 93.72%
Batch 401/1563 - Loss: 1.4239 Accuracy: 93.76%
Batch 601/1563 - Loss: 1.2717 Accuracy: 93.50%
Batch 751/1563 - Loss: 1.2137 Accuracy: 93.49%
Batch 801/1563 - Loss: 1.5362 Accuracy: 93.47%
Batch 1001/1563 - Loss: 1.5717 Accuracy: 93.46%
Batch 1201/1563 - Loss: 1.1261 Accuracy: 93.40%
Batch 1401/1563 - Loss: 2.0002 Accuracy: 93.41%
----

# Evaluating Logit matching student on Texture Bias dataset

In [16]:
# Reload the teacher model
teacher_model, features_teacher = get_model('teacher')  # Initialize the teacher architecture
teacher_model.load_state_dict(torch.load('/kaggle/input/logit-matching/teacher_model_weights.pth'))
teacher_model.eval()  # Set to evaluation mode

# # Reload the student model
student_lm, features_lm = get_model('student')  # Initialize the student architecture
student_lm.load_state_dict(torch.load('/kaggle/input/logit-matching/student_model_weights.pth'))
student_lm.eval()  # Set to evaluation mode

# # Optionally load the features
features_teacher = torch.load('/kaggle/input/logit-matching/features_teacher.pth')
features_lm = torch.load('/kaggle/input/logit-matching/features_student.pth')

# print("Teacher and student models with features loaded successfully!")

  teacher_model.load_state_dict(torch.load('/kaggle/input/logit-matching/teacher_model_weights.pth'))


model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

  student_lm.load_state_dict(torch.load('/kaggle/input/logit-matching/student_model_weights.pth'))
  features_teacher = torch.load('/kaggle/input/logit-matching/features_teacher.pth')
  features_lm = torch.load('/kaggle/input/logit-matching/features_student.pth')


In [17]:
test_student_lm_acc = evaluate_model(student_lm, test_loader, device)
print(f"Student (LogitMatching) accuracy on Cifar 10 test set = {test_student_lm_acc*100:.2f}%")

Evaluating the model
Batch 1/313 - Accuracy: 100.00%
Batch 11/313 - Accuracy: 93.47%
Batch 21/313 - Accuracy: 93.45%
Batch 31/313 - Accuracy: 93.55%
Batch 41/313 - Accuracy: 93.60%
Batch 51/313 - Accuracy: 93.26%
Batch 61/313 - Accuracy: 93.29%
Batch 71/313 - Accuracy: 93.22%
Batch 81/313 - Accuracy: 93.13%
Batch 91/313 - Accuracy: 93.17%
Batch 101/313 - Accuracy: 93.01%
Batch 111/313 - Accuracy: 92.91%
Batch 121/313 - Accuracy: 92.95%
Batch 131/313 - Accuracy: 92.82%
Batch 141/313 - Accuracy: 93.00%
Batch 151/313 - Accuracy: 92.82%
Batch 161/313 - Accuracy: 92.76%
Batch 171/313 - Accuracy: 92.82%
Batch 181/313 - Accuracy: 92.83%
Batch 191/313 - Accuracy: 92.82%
Batch 201/313 - Accuracy: 92.79%
Batch 211/313 - Accuracy: 92.65%
Batch 221/313 - Accuracy: 92.56%
Batch 231/313 - Accuracy: 92.71%
Batch 241/313 - Accuracy: 92.58%
Batch 251/313 - Accuracy: 92.62%
Batch 261/313 - Accuracy: 92.60%
Batch 271/313 - Accuracy: 92.60%
Batch 281/313 - Accuracy: 92.64%
Batch 291/313 - Accuracy: 92.62%

In [18]:

# Evaluate on texture-bias
texture_student_lm_acc = evaluate_model(student_lm, texture_bias_loader, device)
print(f"Student (LogitMatching) accuracy on Texture-Bias = {texture_student_lm_acc*100:.2f}%")

Evaluating the model
Batch 1/313 - Accuracy: 37.50%
Batch 11/313 - Accuracy: 33.24%
Batch 21/313 - Accuracy: 32.89%
Batch 31/313 - Accuracy: 33.17%
Batch 41/313 - Accuracy: 33.84%
Batch 51/313 - Accuracy: 33.82%
Batch 61/313 - Accuracy: 33.86%
Batch 71/313 - Accuracy: 33.54%
Batch 81/313 - Accuracy: 33.49%
Batch 91/313 - Accuracy: 33.76%
Batch 101/313 - Accuracy: 33.69%
Batch 111/313 - Accuracy: 33.87%
Batch 121/313 - Accuracy: 34.01%
Batch 131/313 - Accuracy: 34.16%
Batch 141/313 - Accuracy: 34.40%
Batch 151/313 - Accuracy: 34.60%
Batch 161/313 - Accuracy: 34.78%
Batch 171/313 - Accuracy: 34.83%
Batch 181/313 - Accuracy: 34.88%
Batch 191/313 - Accuracy: 34.80%
Batch 201/313 - Accuracy: 34.70%
Batch 211/313 - Accuracy: 34.58%
Batch 221/313 - Accuracy: 34.39%
Batch 231/313 - Accuracy: 34.40%
Batch 241/313 - Accuracy: 34.30%
Batch 251/313 - Accuracy: 34.31%
Batch 261/313 - Accuracy: 34.23%
Batch 271/313 - Accuracy: 34.12%
Batch 281/313 - Accuracy: 34.26%
Batch 291/313 - Accuracy: 34.05%


# Shape Bias Dataset

In [7]:
# Path to the dataset and labels file
from torch.utils.data import Dataset, DataLoader

data_dir = '/kaggle/input/shape-bias-dataset/shapes/'
labels_file = os.path.join(data_dir, 'labels.txt')

# Custom dataset for Shape-Bias Dataset
class ShapeBiasDataset(Dataset):
    def __init__(self, data_dir, labels_file, transform=None):
        self.data_dir = data_dir
        self.transform = transform
        
        # Read labels file
        self.data = []
        with open(labels_file, 'r') as f:
            for line in f:
                img_name, label = line.strip().split()
                # Replace .pt with .png dynamically and apply a shift of +1
                img_name = img_name.replace('.pt', '.png')
                img_number = int(img_name.split('_')[1].split('.')[0]) + 1  # Shift the index by +1
                img_name = f"img_{img_number}.png"
                self.data.append((img_name, int(label)))
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        img_name, label = self.data[idx]
        img_path = os.path.join(self.data_dir, img_name)
        
        # Load image
        image = Image.open(img_path).convert('RGB')  # Ensure RGB format
        if self.transform:
            image = self.transform(image)
        
        return image, label

# # Define the custom transformations
# transform = transforms.Compose([
#     transforms.Resize((224, 224)),  # Resize images for compatibility with larger models (e.g., ResNet)
#     transforms.ToTensor(),         # Convert image to tensor
#     transforms.Normalize(          # Normalize using CIFAR-10 statistics
#         mean=(0.4914, 0.4822, 0.4465), 
#         std=(0.247, 0.2435, 0.2616)
#     ),
# ])

# Create the dataset and DataLoader with the custom transformations
shape_bias_dataset = ShapeBiasDataset(data_dir=data_dir, labels_file=labels_file, transform=transform)
shape_bias_loader = DataLoader(shape_bias_dataset, batch_size=32, shuffle=False)


# Evaluating Logit matching student on Shape Bias dataset

In [29]:
# Evaluate on shape-bias
shape_student_lm_acc = evaluate_model(student_lm, shape_bias_loader, device)
print(f"Student (LogitMatching) accuracy on Shape-Bias = {shape_student_lm_acc*100:.2f}%")

Evaluating the model
Batch 1/313 - Accuracy: 9.38%
Batch 11/313 - Accuracy: 16.76%
Batch 21/313 - Accuracy: 16.82%
Batch 31/313 - Accuracy: 17.64%
Batch 41/313 - Accuracy: 16.54%
Batch 51/313 - Accuracy: 17.10%
Batch 61/313 - Accuracy: 17.06%
Batch 71/313 - Accuracy: 17.17%
Batch 81/313 - Accuracy: 16.98%
Batch 91/313 - Accuracy: 16.79%
Batch 101/313 - Accuracy: 16.80%
Batch 111/313 - Accuracy: 16.67%
Batch 121/313 - Accuracy: 16.30%
Batch 131/313 - Accuracy: 16.17%
Batch 141/313 - Accuracy: 16.25%
Batch 151/313 - Accuracy: 16.27%
Batch 161/313 - Accuracy: 16.44%
Batch 171/313 - Accuracy: 16.37%
Batch 181/313 - Accuracy: 16.42%
Batch 191/313 - Accuracy: 16.41%
Batch 201/313 - Accuracy: 16.34%
Batch 211/313 - Accuracy: 16.41%
Batch 221/313 - Accuracy: 16.37%
Batch 231/313 - Accuracy: 16.44%
Batch 241/313 - Accuracy: 16.34%
Batch 251/313 - Accuracy: 16.28%
Batch 261/313 - Accuracy: 16.22%
Batch 271/313 - Accuracy: 16.24%
Batch 281/313 - Accuracy: 16.19%
Batch 291/313 - Accuracy: 16.19%
B

# Super Pixelated Dataset

In [8]:
# Path to the dataset and labels file
data_dir = '/kaggle/input/superpixelated-dataset/slic/'
labels_file = os.path.join(data_dir, 'labels.txt')  # Ensure the labels.txt file exists

# Custom dataset for the Superpixelated Dataset
class SuperpixelDataset(Dataset):
    def __init__(self, data_dir, labels_file, transform=None):
        self.data_dir = data_dir
        self.transform = transform
        
        # Read labels file
        self.data = []
        with open(labels_file, 'r') as f:
            for line in f:
                img_name, label = line.strip().split()
                # Replace .pt with .png dynamically and remove leading zeros from the image name
                img_name = img_name.replace('.pt', '.png')
                img_number = int(img_name.split('_')[1].split('.')[0])  # Extract numeric part
                img_name = f"img_{img_number}.png"  # Reformat without leading zeros
                self.data.append((img_name, int(label)))
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        img_name, label = self.data[idx]
        img_path = os.path.join(self.data_dir, img_name)
        
        # Load image
        image = Image.open(img_path).convert('RGB')  # Ensure RGB format
        if self.transform:
            image = self.transform(image)
        
        return image, label

# # Define transformations (use the same transformations for consistency)
# transform = transforms.Compose([
#     transforms.Resize((224, 224)),  # Resize images for compatibility with the model
#     transforms.ToTensor(),         # Convert image to tensor
#     transforms.Normalize(          # Normalize using CIFAR-10 statistics
#         mean=(0.4914, 0.4822, 0.4465), 
#         std=(0.247, 0.2435, 0.2616)
#     ),
# ])

# Create the dataset and DataLoader
superpixel_dataset = SuperpixelDataset(data_dir=data_dir, labels_file=labels_file, transform=transform)
superpixel_loader = DataLoader(superpixel_dataset, batch_size=32, shuffle=False)


# Evaluating Logit matching student on SuperPixelated dataset

In [35]:
# Evaluate on super pixelated dataset
super_pixelated_student_lm_acc = evaluate_model(student_lm, superpixel_loader, device)
print(f"Student (LogitMatching) accuracy on Super Pixelated = {super_pixelated_student_lm_acc*100:.2f}%")

Evaluating the model
Batch 1/313 - Accuracy: 56.25%
Batch 11/313 - Accuracy: 39.77%
Batch 21/313 - Accuracy: 38.84%
Batch 31/313 - Accuracy: 37.40%
Batch 41/313 - Accuracy: 38.34%
Batch 51/313 - Accuracy: 37.87%
Batch 61/313 - Accuracy: 38.11%
Batch 71/313 - Accuracy: 37.50%
Batch 81/313 - Accuracy: 37.73%
Batch 91/313 - Accuracy: 38.05%
Batch 101/313 - Accuracy: 38.15%
Batch 111/313 - Accuracy: 38.12%
Batch 121/313 - Accuracy: 38.02%
Batch 131/313 - Accuracy: 38.38%
Batch 141/313 - Accuracy: 38.59%
Batch 151/313 - Accuracy: 38.53%
Batch 161/313 - Accuracy: 38.76%
Batch 171/313 - Accuracy: 38.65%
Batch 181/313 - Accuracy: 38.48%
Batch 191/313 - Accuracy: 38.24%
Batch 201/313 - Accuracy: 38.37%
Batch 211/313 - Accuracy: 38.39%
Batch 221/313 - Accuracy: 38.46%
Batch 231/313 - Accuracy: 38.47%
Batch 241/313 - Accuracy: 38.43%
Batch 251/313 - Accuracy: 38.43%
Batch 261/313 - Accuracy: 38.41%
Batch 271/313 - Accuracy: 38.21%
Batch 281/313 - Accuracy: 38.31%
Batch 291/313 - Accuracy: 38.28%


# Scrambled Dataset

In [9]:
# Path to the dataset and labels file
data_dir = '/kaggle/input/scrambled-dataset/images/'
labels_file = '/kaggle/input/scrambled-dataset/labels.txt'

# Custom Dataset for Scrambled Dataset
class ScrambledDataset(Dataset):
    def __init__(self, data_dir, labels_file, transform=None):
        self.data_dir = data_dir
        self.transform = transform
        
        # Read labels file
        self.data = []
        with open(labels_file, 'r') as f:
            for line in f:
                img_name, label = line.strip().split()
                self.data.append((img_name, int(label)))
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        img_name, label = self.data[idx]
        img_path = os.path.join(self.data_dir, img_name)
        
        # Load tensor image
        image = torch.load(img_path)  # Load the .pt tensor file
        
        # Apply transformations
        if self.transform:
            image = self.transform(image)
        
        return image, label

# Define the transformation to resize tensors to 224x224
def tensor_transform(tensor):
    # Ensure tensor is 4D (batch dimension needed for interpolate)
    if tensor.dim() == 3:  # [C, H, W]
        tensor = tensor.unsqueeze(0)  # Add batch dimension -> [1, C, H, W]
    
    # Resize to 224x224
    resized_tensor = F.interpolate(tensor, size=(224, 224), mode='bilinear', align_corners=False)
    
    # Remove batch dimension if added
    return resized_tensor.squeeze(0)  # Back to [C, H, W]

# Create the dataset and DataLoader
scrambled_dataset = ScrambledDataset(data_dir=data_dir, labels_file=labels_file, transform=tensor_transform)
scrambled_loader = DataLoader(scrambled_dataset, batch_size=32, shuffle=False)

# Evaluating Logit matching student on Scrambled dataset

In [37]:
# Evaluate on srambled dataset
scrambled_student_lm_acc = evaluate_model(student_lm, scrambled_loader, device)
print(f"Student (LogitMatching) accuracy on Scrambled = {scrambled_student_lm_acc*100:.2f}%")

Evaluating the model


  image = torch.load(img_path)  # Load the .pt tensor file


Batch 1/313 - Accuracy: 28.12%
Batch 11/313 - Accuracy: 23.30%
Batch 21/313 - Accuracy: 22.17%
Batch 31/313 - Accuracy: 23.29%
Batch 41/313 - Accuracy: 23.78%
Batch 51/313 - Accuracy: 23.04%
Batch 61/313 - Accuracy: 23.57%
Batch 71/313 - Accuracy: 23.20%
Batch 81/313 - Accuracy: 23.15%
Batch 91/313 - Accuracy: 22.97%
Batch 101/313 - Accuracy: 23.33%
Batch 111/313 - Accuracy: 23.37%
Batch 121/313 - Accuracy: 22.93%
Batch 131/313 - Accuracy: 22.73%
Batch 141/313 - Accuracy: 22.67%
Batch 151/313 - Accuracy: 22.89%
Batch 161/313 - Accuracy: 22.96%
Batch 171/313 - Accuracy: 22.99%
Batch 181/313 - Accuracy: 22.88%
Batch 191/313 - Accuracy: 22.99%
Batch 201/313 - Accuracy: 22.96%
Batch 211/313 - Accuracy: 22.90%
Batch 221/313 - Accuracy: 22.96%
Batch 231/313 - Accuracy: 22.89%
Batch 241/313 - Accuracy: 22.91%
Batch 251/313 - Accuracy: 23.00%
Batch 261/313 - Accuracy: 23.14%
Batch 271/313 - Accuracy: 23.27%
Batch 281/313 - Accuracy: 23.24%
Batch 291/313 - Accuracy: 23.40%
Batch 301/313 - Accur

# Noisy Dataset

In [10]:
# Path to the dataset and labels file
data_dir = '/kaggle/input/noisy-dataset/images/'
labels_file = '/kaggle/input/noisy-dataset/labels.txt'

# Custom Dataset for Noisy Dataset
class NoisyDataset(Dataset):
    def __init__(self, data_dir, labels_file, transform=None):
        self.data_dir = data_dir
        self.transform = transform
        
        # Read labels file
        self.data = []
        with open(labels_file, 'r') as f:
            for line in f:
                img_name, label = line.strip().split()
                self.data.append((img_name, int(label)))
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        img_name, label = self.data[idx]
        img_path = os.path.join(self.data_dir, img_name)
        
        # Load tensor image
        image = torch.load(img_path)  # Load the .pt tensor file
        
        # Apply transformations
        if self.transform:
            image = self.transform(image)
        
        return image, label

# Define the transformation to resize tensors to 224x224
def tensor_transform(tensor):
    # Ensure tensor is 4D (batch dimension needed for interpolate)
    if tensor.dim() == 3:  # [C, H, W]
        tensor = tensor.unsqueeze(0)  # Add batch dimension -> [1, C, H, W]
    
    # Resize to 224x224
    resized_tensor = F.interpolate(tensor, size=(224, 224), mode='bilinear', align_corners=False)
    
    # Remove batch dimension if added
    return resized_tensor.squeeze(0)  # Back to [C, H, W]

# Create the dataset and DataLoader
noisy_dataset = NoisyDataset(data_dir=data_dir, labels_file=labels_file, transform=tensor_transform)
noisy_loader = DataLoader(noisy_dataset, batch_size=32, shuffle=False)


# Evaluating Logit matching student on Noisy dataset

In [39]:
# Evaluate on srambled dataset
noisy_student_lm_acc = evaluate_model(student_lm, noisy_loader, device)
print(f"Student (LogitMatching) accuracy on Noisy = {noisy_student_lm_acc*100:.2f}%")

Evaluating the model


  image = torch.load(img_path)  # Load the .pt tensor file


Batch 1/313 - Accuracy: 59.38%
Batch 11/313 - Accuracy: 60.51%
Batch 21/313 - Accuracy: 60.57%
Batch 31/313 - Accuracy: 61.19%
Batch 41/313 - Accuracy: 61.36%
Batch 51/313 - Accuracy: 62.07%
Batch 61/313 - Accuracy: 62.04%
Batch 71/313 - Accuracy: 61.09%
Batch 81/313 - Accuracy: 61.11%
Batch 91/313 - Accuracy: 61.37%
Batch 101/313 - Accuracy: 62.00%
Batch 111/313 - Accuracy: 62.47%
Batch 121/313 - Accuracy: 62.53%
Batch 131/313 - Accuracy: 62.31%
Batch 141/313 - Accuracy: 62.70%
Batch 151/313 - Accuracy: 62.56%
Batch 161/313 - Accuracy: 62.38%
Batch 171/313 - Accuracy: 62.46%
Batch 181/313 - Accuracy: 62.40%
Batch 191/313 - Accuracy: 62.03%
Batch 201/313 - Accuracy: 61.85%
Batch 211/313 - Accuracy: 61.61%
Batch 221/313 - Accuracy: 61.48%
Batch 231/313 - Accuracy: 61.72%
Batch 241/313 - Accuracy: 61.61%
Batch 251/313 - Accuracy: 61.74%
Batch 261/313 - Accuracy: 61.55%
Batch 271/313 - Accuracy: 61.39%
Batch 281/313 - Accuracy: 61.38%
Batch 291/313 - Accuracy: 61.33%
Batch 301/313 - Accur

In [18]:
def get_features_and_logits(model, feature_model, inputs, student=True, crd=False):
    if not student:
        model.eval()
        feature_model.eval()
        features = feature_model(inputs)              # (B, 2048, 7, 7)
        features = features.view(features.size(0), -1)  # (B, 100352)
        return features, None
    else:
        model.train()
        features = feature_model(inputs)               # if ViT also is truncated, might be (B, 768, 196)
        features = features.view(features.size(0), -1) # (B, 150528)
        logits = model(inputs)  # The student still does the full forward for classification
        return features, logits


# Knowledge Distillation (CRD)

In [48]:
# create CRD module
num_data = len(train_dataset)  # total CIFAR-10 training samples
s_dim = get_channel_num(student_lm, "vit")         # 150528 or so
t_dim = get_channel_num(teacher_model, "resnet")   # 2048

print(s_dim)
print(t_dim)

crd_module = CRD(
    s_dim=s_dim, 
    t_dim=100352,
    n_data=num_data,
    feat_dim=128,    # or 256 etc., your choice
    nce_n=4096,
    nce_t=0.1,
    nce_mom=0.5
)

# new student for CRD
student_crd, features_crd = get_model('student')
student_crd = student_crd.to(device)

# Train with CRD
train_student_crd(
    student_crd, 
    features_crd, 
    teacher_model, 
    features_teacher, 
    train_crd_loader,  # your special loader with CIFAR10IdxSample
    kd=crd_module, 
    num_epochs=5, 
    alpha=1e-3
)



150528
2048
Distilling the knowledge and training the student on this data for CRD
Batch 1/1563 - Loss: 21.9496 
Batch 200/1563 - Loss: 19.6246 
Batch 400/1563 - Loss: 19.5598 
Batch 600/1563 - Loss: 19.6651 
Batch 750/1563 - Loss: 19.6888 
Batch 800/1563 - Loss: 19.4584 
Batch 1000/1563 - Loss: 19.4957 
Batch 1200/1563 - Loss: 20.0994 
Batch 1400/1563 - Loss: 19.7063 
Epoch [1/5], Loss: 19.7503
Batch 1/1563 - Loss: 19.5297 
Batch 200/1563 - Loss: 19.2098 
Batch 400/1563 - Loss: 19.6271 
Batch 600/1563 - Loss: 19.2578 
Batch 750/1563 - Loss: 19.8765 
Batch 800/1563 - Loss: 19.7548 
Batch 1000/1563 - Loss: 19.4758 
Batch 1200/1563 - Loss: 19.3368 
Batch 1400/1563 - Loss: 19.6965 
Epoch [2/5], Loss: 19.5663
Batch 1/1563 - Loss: 19.3237 
Batch 200/1563 - Loss: 19.5865 
Batch 400/1563 - Loss: 19.5719 
Batch 600/1563 - Loss: 19.4839 
Batch 750/1563 - Loss: 19.4115 
Batch 800/1563 - Loss: 18.8573 
Batch 1000/1563 - Loss: 19.7717 
Batch 1200/1563 - Loss: 19.1155 
Batch 1400/1563 - Loss: 19.27

19.4645045262762

In [49]:
# Path to save the weights and features
student_weights_path = "student_crd_weights.pth"
teacher_weights_path = "teacher_crd_weights.pth"
student_features_path = "student_crd_features.pth"
teacher_features_path = "teacher_crd_features.pth"

# Save the weights of the student and teacher models
torch.save(student_crd.state_dict(), student_weights_path)
torch.save(teacher_model.state_dict(), teacher_weights_path)

# Save the features (if they are used later, e.g., for distillation or further training)
torch.save(features_crd, student_features_path)
torch.save(features_teacher, teacher_features_path)

print("Student and teacher weights and features saved successfully!")


Student and teacher weights and features saved successfully!


# Evaluating CRD student on Texture bias dataset

In [19]:
# ---------------------loadind crd student weights--------------------------------------
# # Load the student model's weights
student_crd, features_crd = get_model('student')  # Initialize the same architecture
student_crd.load_state_dict(torch.load('/kaggle/input/contrastive-rd/student_crd_weights.pth'))
student_crd = student_crd.to(device)

# # Load the teacher model's weights
teacher_model, features_teacher = get_model('teacher')  # Initialize the same architecture
teacher_model.load_state_dict(torch.load('/kaggle/input/contrastive-rd/teacher_crd_weights.pth'))
teacher_model = teacher_model.to(device)

# # Load the features (if needed)
features_crd = torch.load('/kaggle/input/contrastive-rd/student_crd_features.pth')
features_teacher = torch.load('/kaggle/input/contrastive-rd/teacher_crd_features.pth')

print("Student and teacher weights and features loaded successfully!")


# Evaluate the CRD student on texture-bias
test_student_crd_acc = evaluate_model(student_crd, test_loader, device)
print(f"Student (CRD) accuracy on Texture-Bias = {test_student_crd_acc*100:.2f}%")

  student_crd.load_state_dict(torch.load('/kaggle/input/contrastive-rd/student_crd_weights.pth'))
  teacher_model.load_state_dict(torch.load('/kaggle/input/contrastive-rd/teacher_crd_weights.pth'))
  features_crd = torch.load('/kaggle/input/contrastive-rd/student_crd_features.pth')
  features_teacher = torch.load('/kaggle/input/contrastive-rd/teacher_crd_features.pth')


Student and teacher weights and features loaded successfully!
Evaluating the model
Batch 1/313 - Accuracy: 100.00%
Batch 11/313 - Accuracy: 94.89%
Batch 21/313 - Accuracy: 95.68%
Batch 31/313 - Accuracy: 96.37%
Batch 41/313 - Accuracy: 95.96%
Batch 51/313 - Accuracy: 95.89%
Batch 61/313 - Accuracy: 95.54%
Batch 71/313 - Accuracy: 95.47%
Batch 81/313 - Accuracy: 95.37%
Batch 91/313 - Accuracy: 95.19%
Batch 101/313 - Accuracy: 95.14%
Batch 111/313 - Accuracy: 95.05%
Batch 121/313 - Accuracy: 95.02%
Batch 131/313 - Accuracy: 95.04%
Batch 141/313 - Accuracy: 95.12%
Batch 151/313 - Accuracy: 95.03%
Batch 161/313 - Accuracy: 94.97%
Batch 171/313 - Accuracy: 95.05%
Batch 181/313 - Accuracy: 95.10%
Batch 191/313 - Accuracy: 95.04%
Batch 201/313 - Accuracy: 95.07%
Batch 211/313 - Accuracy: 95.10%
Batch 221/313 - Accuracy: 95.14%
Batch 231/313 - Accuracy: 95.21%
Batch 241/313 - Accuracy: 95.16%
Batch 251/313 - Accuracy: 95.16%
Batch 261/313 - Accuracy: 95.15%
Batch 271/313 - Accuracy: 95.17%
Bat

In [50]:
# ---------------------loadind crd student weights--------------------------------------
# # Load the student model's weights
# student_crd, features_crd = get_model('student')  # Initialize the same architecture
# student_crd.load_state_dict(torch.load(student_weights_path))
# student_crd = student_crd.to(device)

# # Load the teacher model's weights
# teacher_model, features_teacher = get_model('teacher')  # Initialize the same architecture
# teacher_model.load_state_dict(torch.load(teacher_weights_path))
# teacher_model = teacher_model.to(device)

# # Load the features (if needed)
# features_crd = torch.load(student_features_path)
# features_teacher = torch.load(teacher_features_path)

# print("Student and teacher weights and features loaded successfully!")


# Evaluate the CRD student on texture-bias
texture_student_crd_acc = evaluate_model(student_crd, texture_bias_loader, device)
print(f"Student (CRD) accuracy on Texture-Bias = {texture_student_crd_acc*100:.2f}%")

Evaluating the model
Batch 1/313 - Accuracy: 28.12%
Batch 11/313 - Accuracy: 36.36%
Batch 21/313 - Accuracy: 36.46%
Batch 31/313 - Accuracy: 37.80%
Batch 41/313 - Accuracy: 37.88%
Batch 51/313 - Accuracy: 38.48%
Batch 61/313 - Accuracy: 38.52%
Batch 71/313 - Accuracy: 38.56%
Batch 81/313 - Accuracy: 38.04%
Batch 91/313 - Accuracy: 38.36%
Batch 101/313 - Accuracy: 38.09%
Batch 111/313 - Accuracy: 38.09%
Batch 121/313 - Accuracy: 38.40%
Batch 131/313 - Accuracy: 38.45%
Batch 141/313 - Accuracy: 38.52%
Batch 151/313 - Accuracy: 38.74%
Batch 161/313 - Accuracy: 39.09%
Batch 171/313 - Accuracy: 39.24%
Batch 181/313 - Accuracy: 39.28%
Batch 191/313 - Accuracy: 39.14%
Batch 201/313 - Accuracy: 39.30%
Batch 211/313 - Accuracy: 39.28%
Batch 221/313 - Accuracy: 39.10%
Batch 231/313 - Accuracy: 39.18%
Batch 241/313 - Accuracy: 38.76%
Batch 251/313 - Accuracy: 38.68%
Batch 261/313 - Accuracy: 38.69%
Batch 271/313 - Accuracy: 38.50%
Batch 281/313 - Accuracy: 38.52%
Batch 291/313 - Accuracy: 38.50%


# Evaluating CRD student on Shape bias dataset

In [51]:
# Evaluate on shape-bias
shape_student_crd_acc = evaluate_model(student_crd, shape_bias_loader, device)
print(f"Student (CRD) accuracy on Shape-Bias = {shape_student_crd_acc*100:.2f}%")

Evaluating the model
Batch 1/313 - Accuracy: 18.75%
Batch 11/313 - Accuracy: 14.20%
Batch 21/313 - Accuracy: 15.77%
Batch 31/313 - Accuracy: 16.33%
Batch 41/313 - Accuracy: 15.55%
Batch 51/313 - Accuracy: 14.77%
Batch 61/313 - Accuracy: 15.01%
Batch 71/313 - Accuracy: 15.36%
Batch 81/313 - Accuracy: 15.28%
Batch 91/313 - Accuracy: 15.14%
Batch 101/313 - Accuracy: 15.10%
Batch 111/313 - Accuracy: 15.09%
Batch 121/313 - Accuracy: 14.95%
Batch 131/313 - Accuracy: 15.08%
Batch 141/313 - Accuracy: 15.23%
Batch 151/313 - Accuracy: 15.15%
Batch 161/313 - Accuracy: 15.06%
Batch 171/313 - Accuracy: 14.99%
Batch 181/313 - Accuracy: 15.04%
Batch 191/313 - Accuracy: 15.04%
Batch 201/313 - Accuracy: 15.17%
Batch 211/313 - Accuracy: 15.18%
Batch 221/313 - Accuracy: 15.02%
Batch 231/313 - Accuracy: 15.03%
Batch 241/313 - Accuracy: 15.08%
Batch 251/313 - Accuracy: 15.25%
Batch 261/313 - Accuracy: 15.06%
Batch 271/313 - Accuracy: 14.96%
Batch 281/313 - Accuracy: 15.00%
Batch 291/313 - Accuracy: 14.97%


# Evaluating CRD student on Scrambled dataset

In [52]:
# Evaluate on srambled dataset
scrambled_student_crd_acc = evaluate_model(student_crd, scrambled_loader, device)
print(f"Student (CRD) accuracy on Scrambled = {scrambled_student_crd_acc*100:.2f}%")

Evaluating the model


  image = torch.load(img_path)  # Load the .pt tensor file


Batch 1/313 - Accuracy: 28.12%
Batch 11/313 - Accuracy: 22.16%
Batch 21/313 - Accuracy: 20.68%
Batch 31/313 - Accuracy: 20.77%
Batch 41/313 - Accuracy: 20.81%
Batch 51/313 - Accuracy: 20.16%
Batch 61/313 - Accuracy: 20.24%
Batch 71/313 - Accuracy: 20.29%
Batch 81/313 - Accuracy: 19.87%
Batch 91/313 - Accuracy: 19.71%
Batch 101/313 - Accuracy: 19.96%
Batch 111/313 - Accuracy: 19.71%
Batch 121/313 - Accuracy: 19.55%
Batch 131/313 - Accuracy: 19.56%
Batch 141/313 - Accuracy: 19.86%
Batch 151/313 - Accuracy: 20.18%
Batch 161/313 - Accuracy: 20.13%
Batch 171/313 - Accuracy: 20.21%
Batch 181/313 - Accuracy: 20.27%
Batch 191/313 - Accuracy: 20.37%
Batch 201/313 - Accuracy: 20.26%
Batch 211/313 - Accuracy: 20.47%
Batch 221/313 - Accuracy: 20.57%
Batch 231/313 - Accuracy: 20.45%
Batch 241/313 - Accuracy: 20.46%
Batch 251/313 - Accuracy: 20.49%
Batch 261/313 - Accuracy: 20.61%
Batch 271/313 - Accuracy: 20.64%
Batch 281/313 - Accuracy: 20.64%
Batch 291/313 - Accuracy: 20.68%
Batch 301/313 - Accur

# Evaluating CRD student on Noised dataset

In [53]:
# Evaluate on noisy dataset
noisy_student_crd_acc = evaluate_model(student_crd, noisy_loader, device)
print(f"Student (CRD) accuracy on Noisy = {noisy_student_crd_acc*100:.2f}%")

Evaluating the model


  image = torch.load(img_path)  # Load the .pt tensor file


Batch 1/313 - Accuracy: 65.62%
Batch 11/313 - Accuracy: 67.61%
Batch 21/313 - Accuracy: 68.15%
Batch 31/313 - Accuracy: 67.14%
Batch 41/313 - Accuracy: 67.68%
Batch 51/313 - Accuracy: 67.89%
Batch 61/313 - Accuracy: 67.78%
Batch 71/313 - Accuracy: 66.42%
Batch 81/313 - Accuracy: 66.28%
Batch 91/313 - Accuracy: 66.59%
Batch 101/313 - Accuracy: 67.14%
Batch 111/313 - Accuracy: 67.15%
Batch 121/313 - Accuracy: 67.41%
Batch 131/313 - Accuracy: 66.98%
Batch 141/313 - Accuracy: 67.33%
Batch 151/313 - Accuracy: 67.18%
Batch 161/313 - Accuracy: 67.02%
Batch 171/313 - Accuracy: 67.20%
Batch 181/313 - Accuracy: 67.16%
Batch 191/313 - Accuracy: 67.03%
Batch 201/313 - Accuracy: 66.95%
Batch 211/313 - Accuracy: 66.91%
Batch 221/313 - Accuracy: 66.78%
Batch 231/313 - Accuracy: 66.90%
Batch 241/313 - Accuracy: 66.78%
Batch 251/313 - Accuracy: 66.92%
Batch 261/313 - Accuracy: 66.71%
Batch 271/313 - Accuracy: 66.57%
Batch 281/313 - Accuracy: 66.65%
Batch 291/313 - Accuracy: 66.52%
Batch 301/313 - Accur

# Evaluating CRD student on SuperPixelated dataset

In [54]:
# Evaluate on super pixelated dataset
super_pixelated_student_crd_acc = evaluate_model(student_crd, superpixel_loader, device)
print(f"Student (CRD) accuracy on Super Pixelated = {super_pixelated_student_crd_acc*100:.2f}%")

Evaluating the model
Batch 1/313 - Accuracy: 40.62%
Batch 11/313 - Accuracy: 33.52%
Batch 21/313 - Accuracy: 31.40%
Batch 31/313 - Accuracy: 32.06%
Batch 41/313 - Accuracy: 32.93%
Batch 51/313 - Accuracy: 33.21%
Batch 61/313 - Accuracy: 33.40%
Batch 71/313 - Accuracy: 33.23%
Batch 81/313 - Accuracy: 33.29%
Batch 91/313 - Accuracy: 33.62%
Batch 101/313 - Accuracy: 33.66%
Batch 111/313 - Accuracy: 33.92%
Batch 121/313 - Accuracy: 33.65%
Batch 131/313 - Accuracy: 33.73%
Batch 141/313 - Accuracy: 33.71%
Batch 151/313 - Accuracy: 33.77%
Batch 161/313 - Accuracy: 33.77%
Batch 171/313 - Accuracy: 33.79%
Batch 181/313 - Accuracy: 34.00%
Batch 191/313 - Accuracy: 33.98%
Batch 201/313 - Accuracy: 33.99%
Batch 211/313 - Accuracy: 33.83%
Batch 221/313 - Accuracy: 33.85%
Batch 231/313 - Accuracy: 33.83%
Batch 241/313 - Accuracy: 33.73%
Batch 251/313 - Accuracy: 33.69%
Batch 261/313 - Accuracy: 33.47%
Batch 271/313 - Accuracy: 33.18%
Batch 281/313 - Accuracy: 33.39%
Batch 291/313 - Accuracy: 33.41%


# Independent Student 

In [55]:
independent_student, features_ind_student = get_model('student')  # ViT
print("Training an independent student on CIFAR-10, no distillation")

independent_student = independent_student.to(device)

# no KD (kd=None)
ind_loss, ind_acc = finetune_model(
    independent_student, 
    train_loader, 
    num_epochs=5, 
    alpha=1e-3
)

# Path to save the independent student model's weights and features
independent_student_weights_path = "independent_student_weights.pth"
independent_student_features_path = "independent_student_features.pth"

# Save the weights of the independent student model
torch.save(independent_student.state_dict(), independent_student_weights_path)

# Save the features (if they are computed or extracted during training)
torch.save(features_ind_student, independent_student_features_path)

print("Independent student weights and features saved successfully!")


Training an independent student on CIFAR-10, no distillation
Finetuning the model on this data
Batch 1/1563 - Loss: 2.7694 Accuracy: 9.38%
Batch 301/1563 - Loss: 0.3283 Accuracy: 91.68%
Batch 601/1563 - Loss: 0.0501 Accuracy: 93.11%
Batch 751/1563 - Loss: 0.0556 Accuracy: 93.41%
Batch 901/1563 - Loss: 0.0379 Accuracy: 93.79%
Batch 1201/1563 - Loss: 0.0661 Accuracy: 94.16%
Batch 1501/1563 - Loss: 0.2814 Accuracy: 94.34%
------------------------->Epoch [1/5], Loss: 0.1775 Accuracy : 94.35%
Batch 1/1563 - Loss: 0.3177 Accuracy: 90.62%
Batch 301/1563 - Loss: 0.2355 Accuracy: 96.17%
Batch 601/1563 - Loss: 0.0331 Accuracy: 96.08%
Batch 751/1563 - Loss: 0.1093 Accuracy: 95.93%
Batch 901/1563 - Loss: 0.1203 Accuracy: 95.88%
Batch 1201/1563 - Loss: 0.0190 Accuracy: 95.88%
Batch 1501/1563 - Loss: 0.0585 Accuracy: 95.84%
------------------------->Epoch [2/5], Loss: 0.1268 Accuracy : 95.86%
Batch 1/1563 - Loss: 0.1277 Accuracy: 96.88%
Batch 301/1563 - Loss: 0.0892 Accuracy: 96.78%
Batch 601/1563 -

# Evaluating Independent Student on Texture Bias

In [20]:
# -------------------------loading weights of independent -----------------------------
# Load the independent student's weights
independent_student, features_ind_student = get_model('student')  # Initialize the same architecture
independent_student.load_state_dict(torch.load('/kaggle/input/independent/independent_student_weights.pth'))
independent_student = independent_student.to(device)

# # Load the features (if needed)
features_ind_student = torch.load('/kaggle/input/independent/independent_student_features.pth')

print("Independent student weights and features loaded successfully!")

test_ind_acc = evaluate_model(independent_student, test_loader, device)
print(f"Independent Student accuracy on Texture-Bias = {test_ind_acc*100:.2f}%")



  independent_student.load_state_dict(torch.load('/kaggle/input/independent/independent_student_weights.pth'))
  features_ind_student = torch.load('/kaggle/input/independent/independent_student_features.pth')


Independent student weights and features loaded successfully!
Evaluating the model
Batch 1/313 - Accuracy: 100.00%
Batch 11/313 - Accuracy: 95.74%
Batch 21/313 - Accuracy: 95.98%
Batch 31/313 - Accuracy: 96.37%
Batch 41/313 - Accuracy: 95.88%
Batch 51/313 - Accuracy: 95.40%
Batch 61/313 - Accuracy: 95.59%
Batch 71/313 - Accuracy: 95.51%
Batch 81/313 - Accuracy: 95.49%
Batch 91/313 - Accuracy: 95.54%
Batch 101/313 - Accuracy: 95.36%
Batch 111/313 - Accuracy: 95.35%
Batch 121/313 - Accuracy: 95.43%
Batch 131/313 - Accuracy: 95.42%
Batch 141/313 - Accuracy: 95.52%
Batch 151/313 - Accuracy: 95.45%
Batch 161/313 - Accuracy: 95.44%
Batch 171/313 - Accuracy: 95.41%
Batch 181/313 - Accuracy: 95.46%
Batch 191/313 - Accuracy: 95.44%
Batch 201/313 - Accuracy: 95.43%
Batch 211/313 - Accuracy: 95.47%
Batch 221/313 - Accuracy: 95.49%
Batch 231/313 - Accuracy: 95.55%
Batch 241/313 - Accuracy: 95.47%
Batch 251/313 - Accuracy: 95.44%
Batch 261/313 - Accuracy: 95.40%
Batch 271/313 - Accuracy: 95.39%
Bat

In [56]:
# -------------------------loading weights of independent -----------------------------
# Load the independent student's weights
# independent_student, features_ind_student = get_model('student')  # Initialize the same architecture
# independent_student.load_state_dict(torch.load(independent_student_weights_path))
# independent_student = independent_student.to(device)

# # Load the features (if needed)
# features_ind_student = torch.load(independent_student_features_path)

# print("Independent student weights and features loaded successfully!")


texture_ind_acc = evaluate_model(independent_student, texture_bias_loader, device)
print(f"Independent Student accuracy on Texture-Bias = {texture_ind_acc*100:.2f}%")

Evaluating the model
Batch 1/313 - Accuracy: 21.88%
Batch 11/313 - Accuracy: 34.94%
Batch 21/313 - Accuracy: 34.97%
Batch 31/313 - Accuracy: 37.30%
Batch 41/313 - Accuracy: 37.42%
Batch 51/313 - Accuracy: 37.44%
Batch 61/313 - Accuracy: 37.60%
Batch 71/313 - Accuracy: 37.76%
Batch 81/313 - Accuracy: 37.46%
Batch 91/313 - Accuracy: 37.60%
Batch 101/313 - Accuracy: 37.56%
Batch 111/313 - Accuracy: 37.67%
Batch 121/313 - Accuracy: 38.09%
Batch 131/313 - Accuracy: 38.24%
Batch 141/313 - Accuracy: 38.39%
Batch 151/313 - Accuracy: 38.64%
Batch 161/313 - Accuracy: 38.98%
Batch 171/313 - Accuracy: 39.00%
Batch 181/313 - Accuracy: 39.21%
Batch 191/313 - Accuracy: 39.28%
Batch 201/313 - Accuracy: 39.52%
Batch 211/313 - Accuracy: 39.35%
Batch 221/313 - Accuracy: 39.17%
Batch 231/313 - Accuracy: 39.37%
Batch 241/313 - Accuracy: 39.04%
Batch 251/313 - Accuracy: 39.03%
Batch 261/313 - Accuracy: 38.98%
Batch 271/313 - Accuracy: 38.84%
Batch 281/313 - Accuracy: 38.87%
Batch 291/313 - Accuracy: 38.73%


# Evaluating Independent student on Shape bias dataset

In [57]:
# Evaluate on shape-bias
shape_student_ind_acc = evaluate_model(independent_student, shape_bias_loader, device)
print(f"Student (Independent) accuracy on Shape-Bias = {shape_student_ind_acc*100:.2f}%")



Evaluating the model
Batch 1/313 - Accuracy: 9.38%
Batch 11/313 - Accuracy: 13.35%
Batch 21/313 - Accuracy: 15.18%
Batch 31/313 - Accuracy: 16.63%
Batch 41/313 - Accuracy: 15.70%
Batch 51/313 - Accuracy: 15.56%
Batch 61/313 - Accuracy: 15.68%
Batch 71/313 - Accuracy: 16.07%
Batch 81/313 - Accuracy: 16.20%
Batch 91/313 - Accuracy: 16.21%
Batch 101/313 - Accuracy: 16.15%
Batch 111/313 - Accuracy: 16.08%
Batch 121/313 - Accuracy: 15.78%
Batch 131/313 - Accuracy: 16.17%
Batch 141/313 - Accuracy: 16.31%
Batch 151/313 - Accuracy: 16.20%
Batch 161/313 - Accuracy: 16.34%
Batch 171/313 - Accuracy: 16.37%
Batch 181/313 - Accuracy: 16.49%
Batch 191/313 - Accuracy: 16.44%
Batch 201/313 - Accuracy: 16.64%
Batch 211/313 - Accuracy: 16.72%
Batch 221/313 - Accuracy: 16.56%
Batch 231/313 - Accuracy: 16.67%
Batch 241/313 - Accuracy: 16.69%
Batch 251/313 - Accuracy: 16.80%
Batch 261/313 - Accuracy: 16.62%
Batch 271/313 - Accuracy: 16.55%
Batch 281/313 - Accuracy: 16.61%
Batch 291/313 - Accuracy: 16.65%
B

# Evaluating Independent student on Scrambled dataset

In [58]:
# Evaluate on scrambled dataset
scrambled_student_ind_acc = evaluate_model(independent_student, scrambled_loader, device)
print(f"Student (Independent) accuracy on Scrambled = {scrambled_student_ind_acc*100:.2f}%")



Evaluating the model


  image = torch.load(img_path)  # Load the .pt tensor file


Batch 1/313 - Accuracy: 34.38%
Batch 11/313 - Accuracy: 23.58%
Batch 21/313 - Accuracy: 21.73%
Batch 31/313 - Accuracy: 21.77%
Batch 41/313 - Accuracy: 22.03%
Batch 51/313 - Accuracy: 21.32%
Batch 61/313 - Accuracy: 21.52%
Batch 71/313 - Accuracy: 21.39%
Batch 81/313 - Accuracy: 21.22%
Batch 91/313 - Accuracy: 21.15%
Batch 101/313 - Accuracy: 21.50%
Batch 111/313 - Accuracy: 21.65%
Batch 121/313 - Accuracy: 21.46%
Batch 131/313 - Accuracy: 21.18%
Batch 141/313 - Accuracy: 21.43%
Batch 151/313 - Accuracy: 21.61%
Batch 161/313 - Accuracy: 21.53%
Batch 171/313 - Accuracy: 21.71%
Batch 181/313 - Accuracy: 21.74%
Batch 191/313 - Accuracy: 21.81%
Batch 201/313 - Accuracy: 21.80%
Batch 211/313 - Accuracy: 21.95%
Batch 221/313 - Accuracy: 21.97%
Batch 231/313 - Accuracy: 21.88%
Batch 241/313 - Accuracy: 22.00%
Batch 251/313 - Accuracy: 22.01%
Batch 261/313 - Accuracy: 22.09%
Batch 271/313 - Accuracy: 22.06%
Batch 281/313 - Accuracy: 22.06%
Batch 291/313 - Accuracy: 21.98%
Batch 301/313 - Accur

# Evaluating Independent student on Noised dataset

In [59]:
# Evaluate on noisy dataset
noisy_student_ind_acc = evaluate_model(independent_student, noisy_loader, device)
print(f"Student (Independent) accuracy on Noisy = {noisy_student_ind_acc*100:.2f}%")



Evaluating the model


  image = torch.load(img_path)  # Load the .pt tensor file


Batch 1/313 - Accuracy: 75.00%
Batch 11/313 - Accuracy: 69.03%
Batch 21/313 - Accuracy: 68.60%
Batch 31/313 - Accuracy: 67.64%
Batch 41/313 - Accuracy: 68.60%
Batch 51/313 - Accuracy: 68.69%
Batch 61/313 - Accuracy: 68.44%
Batch 71/313 - Accuracy: 67.39%
Batch 81/313 - Accuracy: 67.32%
Batch 91/313 - Accuracy: 67.38%
Batch 101/313 - Accuracy: 67.79%
Batch 111/313 - Accuracy: 67.79%
Batch 121/313 - Accuracy: 67.98%
Batch 131/313 - Accuracy: 67.51%
Batch 141/313 - Accuracy: 67.80%
Batch 151/313 - Accuracy: 67.65%
Batch 161/313 - Accuracy: 67.59%
Batch 171/313 - Accuracy: 67.69%
Batch 181/313 - Accuracy: 67.83%
Batch 191/313 - Accuracy: 67.75%
Batch 201/313 - Accuracy: 67.77%
Batch 211/313 - Accuracy: 67.68%
Batch 221/313 - Accuracy: 67.56%
Batch 231/313 - Accuracy: 67.74%
Batch 241/313 - Accuracy: 67.73%
Batch 251/313 - Accuracy: 67.82%
Batch 261/313 - Accuracy: 67.71%
Batch 271/313 - Accuracy: 67.59%
Batch 281/313 - Accuracy: 67.62%
Batch 291/313 - Accuracy: 67.45%
Batch 301/313 - Accur

# Evaluating Independent student on SuperPixelated dataset

In [60]:
# Evaluate on super pixelated dataset
super_pixelated_student_ind_acc = evaluate_model(independent_student, superpixel_loader, device)
print(f"Student (Independent) accuracy on Super Pixelated = {super_pixelated_student_ind_acc*100:.2f}%")


Evaluating the model
Batch 1/313 - Accuracy: 37.50%
Batch 11/313 - Accuracy: 33.52%
Batch 21/313 - Accuracy: 32.14%
Batch 31/313 - Accuracy: 32.96%
Batch 41/313 - Accuracy: 33.16%
Batch 51/313 - Accuracy: 33.27%
Batch 61/313 - Accuracy: 33.09%
Batch 71/313 - Accuracy: 32.79%
Batch 81/313 - Accuracy: 32.75%
Batch 91/313 - Accuracy: 33.04%
Batch 101/313 - Accuracy: 33.23%
Batch 111/313 - Accuracy: 33.22%
Batch 121/313 - Accuracy: 33.06%
Batch 131/313 - Accuracy: 33.28%
Batch 141/313 - Accuracy: 33.22%
Batch 151/313 - Accuracy: 33.46%
Batch 161/313 - Accuracy: 33.52%
Batch 171/313 - Accuracy: 33.55%
Batch 181/313 - Accuracy: 33.87%
Batch 191/313 - Accuracy: 33.87%
Batch 201/313 - Accuracy: 33.96%
Batch 211/313 - Accuracy: 33.93%
Batch 221/313 - Accuracy: 33.89%
Batch 231/313 - Accuracy: 33.97%
Batch 241/313 - Accuracy: 33.91%
Batch 251/313 - Accuracy: 33.89%
Batch 261/313 - Accuracy: 33.57%
Batch 271/313 - Accuracy: 33.33%
Batch 281/313 - Accuracy: 33.53%
Batch 291/313 - Accuracy: 33.55%


# C2VKD (VLFD)

In [21]:
# Reload the teacher model
teacher_model, features_teacher = get_model('teacher')  # Initialize the teacher architecture
teacher_model.load_state_dict(torch.load('/kaggle/input/teacher-model-weights-and-features/teacher_model_weights.pth'))
teacher_model = teacher_model.to(device)
teacher_model.eval()  # Set to evaluation mode

# Optionally load the features
features_teacher = torch.load('/kaggle/input/teacher-model-weights-and-features/features_teacher.pth')

print("Teacher model with features loaded successfully!")

student_c2vkd, features_c2vkd = get_model('student')  # The truncated ViT
student_c2vkd = student_c2vkd.to(device)


# Make sure this is placed at the global scope (e.g., near your other definitions).
# This is crucial so we can project teacher's global pool (2048-d) to 768-d.
teacher_global_proj = nn.Linear(2048, 768, bias=False).to(device)

  teacher_model.load_state_dict(torch.load('/kaggle/input/teacher-model-weights-and-features/teacher_model_weights.pth'))
  features_teacher = torch.load('/kaggle/input/teacher-model-weights-and-features/features_teacher.pth')


Teacher model with features loaded successfully!


In [22]:
# ===============================
# FIXED C2VKD NOTEBOOK CELL 
# USING student_c2vkd.forward_features
# ===============================

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

########################################
# 1) TEACHER & STUDENT BACKBONE HELPERS
########################################


def teacher_backbone_extract(images):
    """
    Calls the truncated teacher backbone (features_teacher).
    shape => (B, 2048, H, W) for ResNet-50 with layers removed.
    """
    return features_teacher(images)

def student_backbone_extract_c2vkd(images):
    """
    1) We call student_c2vkd.forward_features(images) to get patch embeddings 
       from the ViT (including CLS).
    2) We remove the CLS token => shape (B, 196, 768).
    3) Reshape => (B, 768, 14,14).

    This depends on your timm-based ViT or a custom forward_features method.
    """
    all_tokens = student_c2vkd.forward_features(images)  
    patch_tokens = all_tokens[:, 1:, :]         # remove CLS => (B,196,768)
    B, N, D = patch_tokens.shape
    H = W = int(N**0.5)                         # 14
    feats_4d = patch_tokens.reshape(B, H, W, D).permute(0, 3, 1, 2)  # (B,768,14,14)
    return feats_4d

########################################
# 2) LOSSES FOR C2VKD (VLFD + PDD)
########################################

def global_feature_loss(teacher_feats, student_feats):
    """
    Teacher => (B,2048,Ht,Wt)
    Student => (B,768,Hs,Ws)

    We do global avg pooling, then 
    project teacher's 2048 -> 768 
    so we can do KL on shape (B,768) vs. (B,768).
    """
    # Teacher
    B, Ct, Ht, Wt = teacher_feats.shape
    teacher_global = F.adaptive_avg_pool2d(teacher_feats, (1,1)).view(B, Ct)  # (B, 2048)

    # Student
    B2, Cs, Hs, Ws = student_feats.shape
    student_global = F.adaptive_avg_pool2d(student_feats, (1,1)).view(B2, Cs) # (B, 768)

    # Project teacher => (B,768)
    teacher_global_768 = teacher_global_proj(teacher_global)  # (B,768)

    # Now do KL on (B,768) vs. (B,768)
    return F.kl_div(
        F.log_softmax(student_global, dim=1),
        F.softmax(teacher_global_768, dim=1),
        reduction='batchmean'
    )

def patch_feature_loss(teacher_feats, student_feats):
    """
    MSE alignment of teacher (B,2048,Ht,Wt) vs. student (B,768,Hs,Ws).
    We resize teacher => (Hs,Ws). If channels differ (2048 vs. 768),
    we do a linear projection too.
    """
    B, Ct, Ht, Wt = teacher_feats.shape
    B2, Cs, Hs, Ws = student_feats.shape

    # Resize teacher to match student's (Hs,Ws)
    teacher_resized = F.interpolate(teacher_feats, size=(Hs, Ws), mode='bilinear', align_corners=False)

    if Ct != Cs:
        # linear proj teacher from Ct->Cs
        t_flat = teacher_resized.permute(0,2,3,1).reshape(-1, Ct)
        linear_proj = nn.Linear(Ct, Cs, bias=False).to(teacher_feats.device)
        t_proj = linear_proj(t_flat)     # => (B*Hs*Ws, Cs)
        teacher_4d = t_proj.view(B, Hs, Ws, Cs).permute(0,3,1,2)  # => (B, Cs, Hs,Ws)
        return F.mse_loss(teacher_4d, student_feats)
    else:
        return F.mse_loss(teacher_resized, student_feats)

class LinguisticAlignmentPatchBased(nn.Module):
    """
    Projects a patch-level feature (B, N, C) -> (B, N, proj_dim) then normalizes.
    """
    def __init__(self, in_dim, projection_dim):
        super().__init__()
        self.proj = nn.Linear(in_dim, projection_dim)

    def forward(self, patch_feats):
        # patch_feats: (B, N, in_dim)
        out = self.proj(patch_feats)   # => (B, N, projection_dim)
        out = F.normalize(out, dim=-1)
        return out

def linguistic_feature_loss(
    teacher_feats,   # (B, Ct, Ht, Wt)
    student_feats,   # (B, Cs, Hs, Ws)
    teacher_align_module,
    student_align_module
):
    # 1) If teacher is 7x7 while student is 14x14, unify them
    B, Ct, Ht, Wt = teacher_feats.shape
    B2, Cs, Hs, Ws = student_feats.shape

    if (Ht != Hs) or (Wt != Ws):
        teacher_feats = F.interpolate(
            teacher_feats, size=(Hs, Ws),
            mode='bilinear', align_corners=False
        )
        # Now teacher_feats is (B, Ct, Hs, Ws)
        # shape is consistent with student_feats

    # 2) Flatten teacher => (B, Hs*Ws, Ct)
    teacher_patches = teacher_feats.view(B, Ct, Hs*Ws).permute(0, 2, 1)

    # 3) Flatten student => (B, Hs*Ws, Cs)
    student_patches = student_feats.view(B2, Cs, Hs*Ws).permute(0, 2, 1)

    # 4) Project both
    teacher_proj = teacher_align_module(teacher_patches)   # => (B, Hs*Ws, proj_dim)
    student_proj = student_align_module(student_patches)   # => (B, Hs*Ws, proj_dim)

    # 5) Flatten to (B*N, proj_dim)
    BN_teacher = teacher_proj.view(-1, teacher_proj.size(-1))
    BN_student = student_proj.view(-1, student_proj.size(-1))

    # 6) KL
    return F.kl_div(
        F.log_softmax(BN_student, dim=-1),
        F.softmax(BN_teacher, dim=-1),
        reduction='batchmean'
    )

def pixel_wise_decoupled_loss(teacher_logits, student_logits, labels):
    """
    For classification, teacher_logits & student_logits => (B, num_classes).
    We separate teacher's target vs. non-target logit. Then do MSE vs. student's.
    """
    B, num_classes = teacher_logits.shape
    target_mask = F.one_hot(labels, num_classes).float() 
    non_target_mask = 1 - target_mask

    t_target = (teacher_logits * target_mask).sum(dim=1, keepdim=True)
    t_non    = (teacher_logits * non_target_mask).sum(dim=1, keepdim=True)
    s_target = (student_logits * target_mask).sum(dim=1, keepdim=True)
    s_non    = (student_logits * non_target_mask).sum(dim=1, keepdim=True)

    return F.mse_loss(s_target, t_target) + F.mse_loss(s_non, t_non)

def total_distillation_loss(
    teacher_feats, teacher_logits,
    student_feats, student_logits,
    labels,
    teacher_align_module=None,
    student_align_module=None
):
    """
    Summation of:
      1) global feature loss (teacher->student dimension mismatch fixed)
      2) patch feature loss
      3) linguistic alignment (VLFD)
      4) pixel decoupled (PDD)
    """
    g_loss = global_feature_loss(teacher_feats, student_feats)
    p_loss = patch_feature_loss(teacher_feats, student_feats)
    ling_loss = 0
    if teacher_align_module and student_align_module:
        ling_loss = linguistic_feature_loss(teacher_feats, student_feats, 
                                            teacher_align_module, student_align_module)
    pdd_loss = pixel_wise_decoupled_loss(teacher_logits, student_logits, labels)

    return 0.5*g_loss + 0.3*p_loss + 0.2*ling_loss + pdd_loss


########################################
# 3) C2VKD TRAIN LOOP
########################################
def train_student_c2vkd(
    student_model, teacher_model, 
    train_loader, optimizer,
    teacher_align_module=None, student_align_module=None,
    num_epochs=10, device='cuda'
):
    """
    Classification-based distillation (C2VKD).
    Expects teacher_model => full ResNet for teacher logits,
    teacher_backbone_extract => truncated teacher for 4D feats,
    student_model => ViT for logits,
    student_backbone_extract_c2vkd => 4D patch embeddings from the student's forward_features.
    """
    student_model.train()
    teacher_model.eval()

    for epoch in range(num_epochs):
        total_loss = 0.0
        print(f"Starting Epoch {epoch + 1}/{num_epochs}...")
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            # Teacher forward
            with torch.no_grad():
                teacher_logits = teacher_model(images)           # (B, #classes)
                teacher_feats  = teacher_backbone_extract(images) # (B,2048,Ht,Wt)

            # Student forward
            student_logits = student_model(images)               # (B, #classes)
            student_feats  = student_backbone_extract_c2vkd(images) # (B,768,Hs,Ws)

            # Compute total distillation
            loss = total_distillation_loss(
                teacher_feats, teacher_logits,
                student_feats, student_logits,
                labels,
                teacher_align_module, 
                student_align_module
            )

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        avg_loss = total_loss / len(train_loader)
        print(f"[Epoch {epoch+1}/{num_epochs}] => Loss: {avg_loss:.4f}")


In [20]:
# from torch.utils.data import Subset, DataLoader
# import numpy as np

# # Limit the dataset to 1000 samples
# subset_indices = np.random.choice(len(train_dataset), 1000, replace=False)  # Randomly select 1000 indices
# subset_train_dataset = Subset(train_dataset, subset_indices)

# # Create a DataLoader for the subset
# subset_train_loader = DataLoader(
#     subset_train_dataset, 
#     batch_size=32,  # Adjust batch size as needed
#     shuffle=True, 
#     num_workers=4  # Adjust num_workers based on your system
# )

# # Define alignment modules
# teacher_align_module = LinguisticAlignmentPatchBased(in_dim=2048, projection_dim=512).to(device)
# student_align_module = LinguisticAlignmentPatchBased(in_dim=768, projection_dim=512).to(device)

# # Optimizer
# optimizer = optim.Adam(student_c2vkd.parameters(), lr=1e-3)

# # Run training on the small subset
# train_student_c2vkd(
#     student_model=student_c2vkd,
#     teacher_model=teacher_model,
#     train_loader=subset_train_loader,
#     optimizer=optimizer,
#     teacher_align_module=teacher_align_module,
#     student_align_module=student_align_module,
#     num_epochs=10,  # Keep fewer epochs for quick testing
#     device=device
# )


In [21]:
########################################
# 4) EXAMPLE USAGE
########################################

# We assume you already have:
#  teacher_model, features_teacher = get_model('teacher')
#  student_c2vkd, features_c2vkd  = get_model('student')
#  teacher_model is loaded with fine-tuned weights
#  The "features_teacher" and "features_c2vkd" are your truncated networks
#  We do timm's forward_features => student_c2vkd.forward_features(...) inside our helper

# Create alignment modules
teacher_align_module = LinguisticAlignmentPatchBased(in_dim=2048, projection_dim=512).to(device)
student_align_module = LinguisticAlignmentPatchBased(in_dim=768,  projection_dim=512).to(device)

optimizer = optim.Adam(student_c2vkd.parameters(), lr=1e-3)

train_student_c2vkd(
    student_model=student_c2vkd,
    teacher_model=teacher_model,
    train_loader=train_loader,
    optimizer=optimizer,
    teacher_align_module=teacher_align_module,
    student_align_module=student_align_module,
    num_epochs=5,
    device=device
)


# Save the student's weights
torch.save(student_c2vkd.state_dict(), 'student_c2vkd_model_weights.pth')

# Save the student's features (if needed for truncated forward)
torch.save(features_c2vkd.state_dict(), 'student_c2vkd_features.pth')

print("Student model weights and features saved successfully!")



# Evaluate on texture, shape, etc.
# evaluate_model(student_c2vkd, texture_bias_loader, device)
# ...

   

Starting Epoch 1/5...
[Epoch 1/5] => Loss: 59.3676
Starting Epoch 2/5...
[Epoch 2/5] => Loss: 39.3139
Starting Epoch 3/5...
[Epoch 3/5] => Loss: 37.9903
Starting Epoch 4/5...
[Epoch 4/5] => Loss: 36.8878
Starting Epoch 5/5...
[Epoch 5/5] => Loss: 36.3656
Student model weights and features saved successfully!


In [24]:
# need to save this student and it's features too
# then load and evaluate of texture bias and then so on

# Initialize the student model and features
student_c2vkd, features_c2vkd = get_model('student')  # Ensure same architecture as saved model
student_c2vkd = student_c2vkd.to(device)

# # Load the student's weights
student_c2vkd.load_state_dict(torch.load('/kaggle/input/ctovkd/student_c2vkd_model_weights.pth'))
student_c2vkd.eval()  # Set to evaluation mode

# # Optionally, load the features
features_c2vkd.load_state_dict(torch.load('/kaggle/input/ctovkd/student_c2vkd_features.pth'))

print("Student model weights and features loaded successfully!")

test_c2vkd_acc = evaluate_model(student_c2vkd, test_loader, device)
print(f"Student (C2VKD) accuracy on Texture-Bias = {test_c2vkd_acc*100:.2f}%")



  student_c2vkd.load_state_dict(torch.load('/kaggle/input/ctovkd/student_c2vkd_model_weights.pth'))
  features_c2vkd.load_state_dict(torch.load('/kaggle/input/ctovkd/student_c2vkd_features.pth'))


Student model weights and features loaded successfully!
Evaluating the model
Batch 1/313 - Accuracy: 96.88%
Batch 11/313 - Accuracy: 88.64%
Batch 21/313 - Accuracy: 88.84%
Batch 31/313 - Accuracy: 88.31%
Batch 41/313 - Accuracy: 88.11%
Batch 51/313 - Accuracy: 87.87%
Batch 61/313 - Accuracy: 88.01%
Batch 71/313 - Accuracy: 87.90%
Batch 81/313 - Accuracy: 87.81%
Batch 91/313 - Accuracy: 87.71%
Batch 101/313 - Accuracy: 87.62%
Batch 111/313 - Accuracy: 87.73%
Batch 121/313 - Accuracy: 87.63%
Batch 131/313 - Accuracy: 87.64%
Batch 141/313 - Accuracy: 87.94%
Batch 151/313 - Accuracy: 87.85%
Batch 161/313 - Accuracy: 87.99%
Batch 171/313 - Accuracy: 88.03%
Batch 181/313 - Accuracy: 87.97%
Batch 191/313 - Accuracy: 87.94%
Batch 201/313 - Accuracy: 88.08%
Batch 211/313 - Accuracy: 87.90%
Batch 221/313 - Accuracy: 87.92%
Batch 231/313 - Accuracy: 88.15%
Batch 241/313 - Accuracy: 88.14%
Batch 251/313 - Accuracy: 88.07%
Batch 261/313 - Accuracy: 88.04%
Batch 271/313 - Accuracy: 87.97%
Batch 281/

In [22]:
# need to save this student and it's features too
# then load and evaluate of texture bias and then so on

# Initialize the student model and features
# student_c2vkd, features_c2vkd = get_model('student')  # Ensure same architecture as saved model
# student_c2vkd = student_c2vkd.to(device)

# # Load the student's weights
# student_c2vkd.load_state_dict(torch.load('student_c2vkd_model_weights.pth'))
# student_c2vkd.eval()  # Set to evaluation mode

# # Optionally, load the features
# features_c2vkd.load_state_dict(torch.load('student_c2vkd_features.pth'))

# print("Student model weights and features loaded successfully!")


# Evaluate on texture-bias dataset
texture_c2vkd_acc = evaluate_model(student_c2vkd, texture_bias_loader, device)
print(f"Student (C2VKD) accuracy on Texture-Bias = {texture_c2vkd_acc*100:.2f}%")





Evaluating the model
Batch 1/313 - Accuracy: 46.88%
Batch 11/313 - Accuracy: 32.10%
Batch 21/313 - Accuracy: 30.36%
Batch 31/313 - Accuracy: 30.54%
Batch 41/313 - Accuracy: 30.72%
Batch 51/313 - Accuracy: 30.45%
Batch 61/313 - Accuracy: 30.28%
Batch 71/313 - Accuracy: 30.68%
Batch 81/313 - Accuracy: 30.63%
Batch 91/313 - Accuracy: 30.80%
Batch 101/313 - Accuracy: 30.41%
Batch 111/313 - Accuracy: 30.63%
Batch 121/313 - Accuracy: 31.12%
Batch 131/313 - Accuracy: 31.25%
Batch 141/313 - Accuracy: 31.43%
Batch 151/313 - Accuracy: 31.52%
Batch 161/313 - Accuracy: 31.44%
Batch 171/313 - Accuracy: 31.63%
Batch 181/313 - Accuracy: 31.68%
Batch 191/313 - Accuracy: 31.48%
Batch 201/313 - Accuracy: 31.39%
Batch 211/313 - Accuracy: 31.26%
Batch 221/313 - Accuracy: 31.17%
Batch 231/313 - Accuracy: 31.29%
Batch 241/313 - Accuracy: 31.11%
Batch 251/313 - Accuracy: 31.09%
Batch 261/313 - Accuracy: 31.06%
Batch 271/313 - Accuracy: 30.90%
Batch 281/313 - Accuracy: 30.95%
Batch 291/313 - Accuracy: 30.91%


In [23]:
# Evaluate on shape-bias dataset
shape_c2vkd_acc = evaluate_model(student_c2vkd, shape_bias_loader, device)
print(f"Student (C2VKD) accuracy on Shape-Bias = {shape_c2vkd_acc*100:.2f}%")


Evaluating the model
Batch 1/313 - Accuracy: 9.38%
Batch 11/313 - Accuracy: 15.06%
Batch 21/313 - Accuracy: 15.33%
Batch 31/313 - Accuracy: 15.93%
Batch 41/313 - Accuracy: 16.69%
Batch 51/313 - Accuracy: 16.48%
Batch 61/313 - Accuracy: 16.19%
Batch 71/313 - Accuracy: 16.33%
Batch 81/313 - Accuracy: 16.20%
Batch 91/313 - Accuracy: 16.14%
Batch 101/313 - Accuracy: 16.09%
Batch 111/313 - Accuracy: 15.68%
Batch 121/313 - Accuracy: 15.39%
Batch 131/313 - Accuracy: 15.31%
Batch 141/313 - Accuracy: 15.38%
Batch 151/313 - Accuracy: 15.62%
Batch 161/313 - Accuracy: 15.95%
Batch 171/313 - Accuracy: 15.94%
Batch 181/313 - Accuracy: 15.80%
Batch 191/313 - Accuracy: 15.77%
Batch 201/313 - Accuracy: 15.70%
Batch 211/313 - Accuracy: 15.68%
Batch 221/313 - Accuracy: 15.77%
Batch 231/313 - Accuracy: 15.69%
Batch 241/313 - Accuracy: 15.64%
Batch 251/313 - Accuracy: 15.70%
Batch 261/313 - Accuracy: 15.72%
Batch 271/313 - Accuracy: 15.82%
Batch 281/313 - Accuracy: 15.87%
Batch 291/313 - Accuracy: 15.85%
B

In [24]:
# Evaluate on scrambled dataset
scrambled_c2vkd_acc = evaluate_model(student_c2vkd, scrambled_loader, device)
print(f"Student (C2VKD) accuracy on Scrambled = {scrambled_c2vkd_acc*100:.2f}%")


Evaluating the model


  image = torch.load(img_path)  # Load the .pt tensor file


Batch 1/313 - Accuracy: 31.25%
Batch 11/313 - Accuracy: 23.58%
Batch 21/313 - Accuracy: 23.66%
Batch 31/313 - Accuracy: 23.89%
Batch 41/313 - Accuracy: 24.16%
Batch 51/313 - Accuracy: 23.28%
Batch 61/313 - Accuracy: 24.03%
Batch 71/313 - Accuracy: 24.08%
Batch 81/313 - Accuracy: 23.65%
Batch 91/313 - Accuracy: 23.87%
Batch 101/313 - Accuracy: 23.92%
Batch 111/313 - Accuracy: 24.16%
Batch 121/313 - Accuracy: 23.76%
Batch 131/313 - Accuracy: 23.78%
Batch 141/313 - Accuracy: 23.76%
Batch 151/313 - Accuracy: 23.97%
Batch 161/313 - Accuracy: 24.11%
Batch 171/313 - Accuracy: 23.99%
Batch 181/313 - Accuracy: 24.02%
Batch 191/313 - Accuracy: 24.03%
Batch 201/313 - Accuracy: 23.96%
Batch 211/313 - Accuracy: 23.92%
Batch 221/313 - Accuracy: 24.05%
Batch 231/313 - Accuracy: 24.07%
Batch 241/313 - Accuracy: 24.08%
Batch 251/313 - Accuracy: 24.17%
Batch 261/313 - Accuracy: 24.26%
Batch 271/313 - Accuracy: 24.19%
Batch 281/313 - Accuracy: 24.34%
Batch 291/313 - Accuracy: 24.45%
Batch 301/313 - Accur

In [25]:
# Evaluate on noisy dataset
noisy_c2vkd_acc = evaluate_model(student_c2vkd, noisy_loader, device)
print(f"Student (C2VKD) accuracy on Noisy = {noisy_c2vkd_acc*100:.2f}%")


Evaluating the model


  image = torch.load(img_path)  # Load the .pt tensor file


Batch 1/313 - Accuracy: 59.38%
Batch 11/313 - Accuracy: 53.98%
Batch 21/313 - Accuracy: 55.80%
Batch 31/313 - Accuracy: 56.05%
Batch 41/313 - Accuracy: 55.64%
Batch 51/313 - Accuracy: 56.50%
Batch 61/313 - Accuracy: 57.02%
Batch 71/313 - Accuracy: 55.55%
Batch 81/313 - Accuracy: 55.44%
Batch 91/313 - Accuracy: 55.87%
Batch 101/313 - Accuracy: 56.50%
Batch 111/313 - Accuracy: 56.98%
Batch 121/313 - Accuracy: 57.33%
Batch 131/313 - Accuracy: 57.16%
Batch 141/313 - Accuracy: 57.71%
Batch 151/313 - Accuracy: 57.35%
Batch 161/313 - Accuracy: 57.10%
Batch 171/313 - Accuracy: 57.13%
Batch 181/313 - Accuracy: 57.08%
Batch 191/313 - Accuracy: 56.87%
Batch 201/313 - Accuracy: 56.72%
Batch 211/313 - Accuracy: 56.69%
Batch 221/313 - Accuracy: 56.72%
Batch 231/313 - Accuracy: 56.78%
Batch 241/313 - Accuracy: 56.78%
Batch 251/313 - Accuracy: 56.80%
Batch 261/313 - Accuracy: 56.60%
Batch 271/313 - Accuracy: 56.58%
Batch 281/313 - Accuracy: 56.62%
Batch 291/313 - Accuracy: 56.57%
Batch 301/313 - Accur

In [26]:
# Evaluate on super-pixelated dataset
super_pixelated_c2vkd_acc = evaluate_model(student_c2vkd, superpixel_loader, device)
print(f"Student (C2VKD) accuracy on Super Pixelated = {super_pixelated_c2vkd_acc*100:.2f}%")


Evaluating the model
Batch 1/313 - Accuracy: 40.62%
Batch 11/313 - Accuracy: 30.97%
Batch 21/313 - Accuracy: 30.80%
Batch 31/313 - Accuracy: 29.64%
Batch 41/313 - Accuracy: 30.87%
Batch 51/313 - Accuracy: 30.21%
Batch 61/313 - Accuracy: 29.76%
Batch 71/313 - Accuracy: 29.97%
Batch 81/313 - Accuracy: 29.90%
Batch 91/313 - Accuracy: 30.73%
Batch 101/313 - Accuracy: 30.66%
Batch 111/313 - Accuracy: 30.63%
Batch 121/313 - Accuracy: 30.91%
Batch 131/313 - Accuracy: 31.01%
Batch 141/313 - Accuracy: 31.29%
Batch 151/313 - Accuracy: 31.39%
Batch 161/313 - Accuracy: 31.56%
Batch 171/313 - Accuracy: 31.69%
Batch 181/313 - Accuracy: 31.66%
Batch 191/313 - Accuracy: 31.82%
Batch 201/313 - Accuracy: 31.78%
Batch 211/313 - Accuracy: 31.68%
Batch 221/313 - Accuracy: 31.70%
Batch 231/313 - Accuracy: 31.76%
Batch 241/313 - Accuracy: 31.66%
Batch 251/313 - Accuracy: 31.72%
Batch 261/313 - Accuracy: 31.56%
Batch 271/313 - Accuracy: 31.52%
Batch 281/313 - Accuracy: 31.53%
Batch 291/313 - Accuracy: 31.59%


# CSKD

In [25]:
import torch.nn as nn

class ViTWithDenseOutputs(nn.Module):
    def __init__(self, base_model, num_classes=10):
        super(ViTWithDenseOutputs, self).__init__()
        self.base_model = base_model
        self.head = nn.Linear(base_model.head.in_features, num_classes)
        self.dense_head = nn.Linear(base_model.head.in_features, num_classes)

    def forward(self, x):
        # Get all tokens ([CLS] + patch tokens)
        all_tokens = self.base_model.forward_features(x)  
        # For vit_base_patch16_224, shape = [B, 197, C] if no distillation

        # The first token is CLS, the rest are patch tokens:
        cls_token = all_tokens[:, 0, :]      # shape [B, C]
        patch_tokens = all_tokens[:, 1:, :]  # shape [B, 196, C]

        # Global logits from CLS
        global_logits = self.head(cls_token)

        # Dense logits from patch tokens
        dense_logits = self.dense_head(patch_tokens)  # shape [B, 196, num_classes]

        return global_logits, global_logits, dense_logits




def create_vit_model(num_classes=10):
    """Create a ViT model with a dynamic number of output classes."""
    base_model = create_model('vit_base_patch16_224', pretrained=True, num_classes=num_classes)

    # Freeze all layers in the ViT model
    for param in base_model.parameters():
        param.requires_grad = False

    # Replace the classifier head with a custom one
    base_model.head = torch.nn.Linear(base_model.head.in_features, num_classes)

    # Wrap the model to include dense outputs
    vit_model = ViTWithDenseOutputs(base_model, num_classes=num_classes)
    return vit_model


In [26]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

class CSKDLoss(nn.Module):
    """
    Cumulative Spatial Knowledge Distillation Loss
    """
    def __init__(self, cfg, criterion, teacher):
        super(CSKDLoss, self).__init__()
        self.cfg = cfg  # Configuration object
        self.criterion = criterion  # Base criterion (e.g., CrossEntropyLoss)
        self.teacher = teacher  # Teacher model

    def forward(self, inputs, outputs, labels, epoch, max_epoch):
        # If the model outputs multiple tensors (e.g., logits and dense logits)
        if not isinstance(outputs, torch.Tensor):
            outputs, stu_deit_logits, stu_dense_logits = outputs
        else:
            raise ValueError("Expected multiple outputs including dense logits from the student model.")

        # Ensure global logits match the teacher's logits in shape
        stu_deit_logits = F.log_softmax(stu_deit_logits, dim=-1)

        # Base classification loss
        loss_base = self.criterion(outputs, labels)

        # No distillation loss if configured
        if self.cfg.deit_loss_type == "none":
            return loss_base

        # Teacher predictions (global and dense logits)
        with torch.no_grad():
            tea_dense_logits, tea_global_logits = self.teacher(inputs)

        # DeiT-based distillation loss
        loss_deit = self.get_loss_deit(stu_deit_logits, tea_global_logits)

        # Cumulative spatial knowledge distillation loss
        loss_cskd = self.get_loss_cskd(
            stu_dense_logits, tea_dense_logits, tea_global_logits, epoch, max_epoch
        )

        # Weighted combination of losses
        alpha = self.cfg.deit_alpha
        total_loss = (
            loss_base * (1 - alpha)
            + loss_deit * alpha
            + loss_cskd * self.cfg.cksd_loss_weight
        )
        return total_loss
        
    def align_stu_logits(self, stu_dense_logits):
        """
        Align student logits to match teacher logits' spatial resolution.
        """
        N, M, C = stu_dense_logits.shape  # Batch size, number of patches, num_classes
    
        # Dynamically compute H and W based on the number of patches
        H = int(math.sqrt(M)) if math.sqrt(M).is_integer() else int(math.ceil(math.sqrt(M)))
        W = int(M / H)
    
        if H * W != M:
            raise ValueError(f"Number of patches {M} cannot be reshaped into a valid grid (H={H}, W={W}).")
    
        # Reshape and align
        stu_dense_logits = stu_dense_logits.permute(0, 2, 1).reshape(N, C, H, W)
        stu_dense_logits = F.avg_pool2d(stu_dense_logits, kernel_size=2, stride=2)
        return stu_dense_logits



    def get_decay_ratio(self, epoch, max_epoch):
        """
        Compute the decay ratio for combining dense and global teacher logits.
        """
        x = epoch / max_epoch
        if self.cfg.cskd_decay_func == "linear":
            return 1 - x
        elif self.cfg.cskd_decay_func == "x2":
            return (1 - x) ** 2
        elif self.cfg.cskd_decay_func == "cos":
            return math.cos(math.pi * 0.5 * x)
        else:
            raise NotImplementedError(f"Decay function '{self.cfg.cskd_decay_func}' not implemented.")

    def get_loss_deit(self, stu_deit_logits, tea_global_logits):
        """
        Compute the DeiT distillation loss.
        """
        if self.cfg.deit_loss_type == "soft":
            T = self.cfg.deit_tau  # Temperature
            loss_deit = F.kl_div(
                F.log_softmax(stu_deit_logits / T, dim=1),
                F.log_softmax(tea_global_logits / T, dim=1),
                reduction="sum",
                log_target=True,
            ) * (T * T) / stu_deit_logits.numel()
        elif self.cfg.deit_loss_type == "hard":
            loss_deit = F.cross_entropy(
                stu_deit_logits, tea_global_logits.argmax(dim=1)
            )
        else:
            raise NotImplementedError(f"DeiT loss type '{self.cfg.deit_loss_type}' not implemented.")
        return loss_deit

    def get_loss_cskd(self, stu_dense_logits, tea_dense_logits, tea_global_logits, epoch, max_epoch):
        """
        Compute the cumulative spatial knowledge distillation loss.
        """
        # Align student logits to teacher's spatial resolution
        stu_dense_logits = self.align_stu_logits(stu_dense_logits)

        # Compute the decay ratio
        decay_ratio = self.get_decay_ratio(epoch, max_epoch)

        # Weighted combination of dense and global teacher logits
        N, C = tea_global_logits.shape
        teacher_logits = (
            decay_ratio * tea_dense_logits
            + (1 - decay_ratio) * tea_global_logits.reshape(N, C, 1, 1)
        )

        # CSKD loss
        if self.cfg.deit_loss_type == "hard":
            loss_cskd = F.cross_entropy(
                stu_dense_logits, teacher_logits.argmax(dim=1)
            )
        elif self.cfg.deit_loss_type == "soft":
            T = self.cfg.deit_tau  # Temperature
            loss_cskd = F.kl_div(
                F.log_softmax(stu_dense_logits / T, dim=1),
                F.log_softmax(teacher_logits / T, dim=1),
                reduction="sum",
                log_target=True,
            ) * (T * T) / stu_dense_logits.size(0)
        else:
            raise NotImplementedError(f"CSKD loss type '{self.cfg.cskd_loss_type}' not implemented.")
        return loss_cskd


In [27]:
def train_student_with_cskd(student, teacher, train_loader, cfg):
    """
    Train a student model using CSKD with configurations from a ConfigBase class.

    Args:
        student: The student model (e.g., Vision Transformer).
        teacher: The teacher model (e.g., ResNet).
        train_loader: DataLoader for training data.
        cfg: Configuration object from a ConfigBase class instance.
    """
    student.train()
    teacher.eval()

    # Initialize optimizer and loss functions from config
    optimizer = torch.optim.Adam(student.parameters(), lr=cfg.learning_rate)
    criterion = nn.CrossEntropyLoss()
    cskd_loss = CSKDLoss(cfg, criterion, teacher)

    for epoch in range(cfg.num_epochs):
        running_loss = 0.0
        for inputs, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{cfg.num_epochs}"):
            inputs, labels = inputs.to(cfg.device), labels.to(cfg.device)
            optimizer.zero_grad()

            # Forward pass
            outputs = student(inputs)

            # Compute CSKD loss
            loss = cskd_loss(inputs, outputs, labels, epoch, cfg.num_epochs)

            # Backpropagation
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f"Epoch [{epoch+1}/{cfg.num_epochs}], Loss: {running_loss / len(train_loader):.4f}")


In [28]:
from functools import lru_cache
from hashlib import md5
from pprint import pprint

class ConfigBase:
    @classmethod
    @lru_cache(maxsize=1)
    def to_dict(cls):
        keys = dir(cls)
        hp_dict = {}
        for key in keys:
            value = getattr(cls, key)
            if (
                not key.startswith("__")
                and not key.endswith("__")
                and not callable(value)
            ):
                hp_dict[key] = value
        return hp_dict

    @classmethod
    @lru_cache(maxsize=1)
    def to_md5(cls):
        hp_dict = cls.to_dict()
        return md5(f"{hp_dict}".encode("utf-8")).hexdigest()

    @classmethod
    @lru_cache(maxsize=1)
    def instance(cls):
        return cls()

    @classmethod
    def print(cls):
        print(f"hash (md5): {cls.to_md5()}")
        pprint(cls.to_dict())


In [29]:
class MyConfig(ConfigBase):
    # General settings
    seed = 42
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # Dataset settings
    dataset_path = "./data"
    batch_size = 32
    num_workers = 4

    # Training settings
    learning_rate = 0.001
    num_epochs = 1
    deit_alpha = 0.5
    cksd_loss_weight = 1.0
    deit_loss_type = "soft"  # or "hard"
    deit_tau = 2.0
    cskd_decay_func = "linear"  # "x2" or "cos"

    # Model settings
    model_name = "vit_base_patch16_224"
    teacher_model_name = "resnet50"


In [30]:
# Load configuration
cfg = MyConfig.instance()

# Print the configuration
cfg.print()

# Access configuration parameters
device = cfg.device
batch_size = cfg.batch_size
learning_rate = cfg.learning_rate

print(f"Running on device: {device}")


hash (md5): 1db02d5f2b9fb6deb889c8c31bba3e3c
{'batch_size': 32,
 'cksd_loss_weight': 1.0,
 'cskd_decay_func': 'linear',
 'dataset_path': './data',
 'deit_alpha': 0.5,
 'deit_loss_type': 'soft',
 'deit_tau': 2.0,
 'device': 'cuda',
 'learning_rate': 0.001,
 'model_name': 'vit_base_patch16_224',
 'num_epochs': 1,
 'num_workers': 4,
 'seed': 42,
 'teacher_model_name': 'resnet50'}
Running on device: cuda


In [31]:
import torchvision.models as models

class ResNetWithDenseOutputs(nn.Module):
    def __init__(self, base_model, num_classes=10):
        super().__init__()
        self.base_model = nn.Sequential(*list(base_model.children())[:-2]) 
          # remove avgpool + fc
        self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = base_model.fc            # the standard classification fc for global logits

        # NEW: 1×1 conv to map 2048 → num_classes
        self.dense_head = nn.Conv2d(2048, num_classes, kernel_size=1, stride=1)

    def forward(self, x):
        # features => [B, 2048, 7, 7]
        features = self.base_model(x)

        # Global logits => [B, 10]
        global_logits = self.fc(
            self.global_pool(features).squeeze(-1).squeeze(-1)
        )

        # Dense logits => [B, 10, 7, 7]
        dense_logits = self.dense_head(features)

        return dense_logits, global_logits

        return features, global_logits
def create_resnet_model(num_classes=10):
    base_model = models.resnet50(pretrained=True)
    base_model.fc = nn.Linear(base_model.fc.in_features, num_classes)  # Update the classification head
    return ResNetWithDenseOutputs(base_model)

def get_model(name):
    if name == 'student':
        vit_model = create_vit_model()
        return vit_model
    elif name == 'teacher':
        resnet_model = create_resnet_model(num_classes=10)
        return resnet_model
    else:
        raise ValueError(f"Unsupported model name: {name}")


In [136]:
def finetune_model(model, train_loader, num_epochs=10, alpha=1e-3):
    """
    Fine-tune the model on the training data.
    If model(inputs) returns (dense_logits, global_logits), we only
    use the global_logits for cross-entropy.
    """
    print("Finetuning the model on this data")
    model.train()
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=alpha)
    num_batches = len(train_loader)

    for epoch in range(num_epochs):
        running_loss = 0.0
        correct = 0
        total = 0

        for batch_index, (inputs, labels) in enumerate(train_loader):
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            # `model(inputs)` may return a tuple. Let's unpack it:
            outputs_tuple = model(inputs)

            # If the teacher model returns (dense_logits, global_logits):
            #   outputs = outputs_tuple[0]  # dense_logits
            #   global_logits = outputs_tuple[1]
            # If you want to train on global logits, do:
            if isinstance(outputs_tuple, tuple):
                # Suppose the second element is the global logits
                outputs = outputs_tuple[1]
            else:
                # otherwise, if the model returns a single tensor
                outputs = outputs_tuple
            
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            # Calculate accuracy
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)

            # Print info periodically
            if batch_index % 300 == 0 or batch_index == 750:
                print(f'Batch {batch_index + 1}/{num_batches} - '
                      f'Loss: {loss.item():.4f} Accuracy: {(correct/total) * 100:.2f}%')

        epoch_loss = running_loss / len(train_loader)
        epoch_accuracy = correct / total
        print(f'------------------------->Epoch [{epoch + 1}/{num_epochs}], '
              f'Loss: {epoch_loss:.4f} Accuracy : {epoch_accuracy * 100:.2f}%')

    return epoch_loss, epoch_accuracy


In [137]:
# Create the teacher model
teacher_model = get_model('teacher').to(cfg.device)

# 2. Finetune the teacher on CIFAR-10
teacher_loss, teacher_accuracy = finetune_model(
    teacher_model, 
    train_loader,     # CIFAR-10 train_loader
    num_epochs=10, 
    alpha=1e-3
)

print(f"Teacher finetuned on CIFAR-10 with final loss={teacher_loss:.4f}, accuracy={teacher_accuracy*100:.2f}%")

torch.save(teacher_model.state_dict(), 'teacher_model_cskd_weights_after_finetuning.pth')


print("Teacher model weights saved successfully!")



Finetuning the model on this data
Batch 1/1563 - Loss: 2.2401 Accuracy: 15.62%
Batch 301/1563 - Loss: 1.7760 Accuracy: 49.87%
Batch 601/1563 - Loss: 0.8437 Accuracy: 57.11%
Batch 751/1563 - Loss: 0.7100 Accuracy: 59.60%
Batch 901/1563 - Loss: 0.8096 Accuracy: 61.55%
Batch 1201/1563 - Loss: 0.5885 Accuracy: 64.45%
Batch 1501/1563 - Loss: 0.6351 Accuracy: 66.52%
------------------------->Epoch [1/10], Loss: 0.9409 Accuracy : 66.92%
Batch 1/1563 - Loss: 0.6178 Accuracy: 84.38%
Batch 301/1563 - Loss: 0.5414 Accuracy: 78.85%
Batch 601/1563 - Loss: 0.4257 Accuracy: 79.61%
Batch 751/1563 - Loss: 0.3871 Accuracy: 79.56%
Batch 901/1563 - Loss: 0.5903 Accuracy: 79.82%
Batch 1201/1563 - Loss: 0.4419 Accuracy: 80.41%
Batch 1501/1563 - Loss: 0.4739 Accuracy: 80.86%
------------------------->Epoch [2/10], Loss: 0.5563 Accuracy : 80.94%
Batch 1/1563 - Loss: 0.3159 Accuracy: 90.62%
Batch 301/1563 - Loss: 0.5521 Accuracy: 84.74%
Batch 601/1563 - Loss: 0.4509 Accuracy: 84.93%
Batch 751/1563 - Loss: 0.37

In [152]:
from torch.utils.data import Subset
import torch

# Get the original dataset from the DataLoader
train_dataset = train_loader.dataset

# Create a subset of the dataset with only the first 1000 images
subset_indices = torch.arange(5000)
train_subset = Subset(train_dataset, subset_indices)

# Create a DataLoader for the subset
subset_loader = torch.utils.data.DataLoader(train_subset, batch_size=cfg.batch_size, shuffle=True)

# Train the student model using the subset DataLoader
student_model_test = get_model('student')
student_model_test = student_model_test.to(device)

train_student_with_cskd(student_model_test, teacher_model, subset_loader, cfg)

Epoch 1/1: 100%|██████████| 157/157 [01:10<00:00,  2.22it/s]

Epoch [1/1], Loss: 6.0566





In [153]:
# Initialize models using config
# student_model_cskd = create_model(cfg.model_name, pretrained=True, num_classes=10).to(cfg.device)

student_model_cskd = get_model('student')
student_model_cskd = student_model_cskd.to(device)

train_student_with_cskd(student_model_cskd, teacher_model, train_loader, cfg)

# Define paths to save the model weights and features
weights_save_path = "student_model_cskd_weights.pth"
features_save_path = "student_model_cskd_features.pth"

# Save the model weights
torch.save(student_model_cskd.state_dict(), weights_save_path)
print(f"Model weights saved to {weights_save_path}")



Epoch 1/1: 100%|██████████| 1563/1563 [12:34<00:00,  2.07it/s]


Epoch [1/1], Loss: 1.7197
Model weights saved to student_model_cskd_weights.pth


In [32]:
def evaluate_cskd_model(model, data_loader, device):
    """Evaluate the models using a basic testing loop"""
    print("Evaluating the model")
    model.to(device)
    model.eval()
    correct = 0
    total = 0
    num_batches = len(data_loader)  # Get total number of batches

    with torch.no_grad():
        for batch_index, (images, labels) in enumerate(data_loader):

            batch_size = images.size(0)
            if batch_size < 32:
                print(f"Skipping last batch of size {batch_size} --- gives shape error")
                break

            images, labels = images.to(device), labels.to(device)
            outputs = model(images)

            # Handle tuple outputs
            if isinstance(outputs, tuple):
                outputs = outputs[0]  # Extract global_logits

            _, predicted = torch.max(outputs, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)

            # Print progress every 10 batches
            if batch_index % 10 == 0:
                print(f'Batch {batch_index + 1}/{num_batches} - Accuracy: { (correct/total) * 100:.2f}%')

    accuracy = correct / total
    print(f'Final Accuracy: {accuracy * 100:.2f}%')
    return accuracy


In [33]:
teacher_model = get_model('teacher').to(cfg.device)

student_model_cskd = get_model('student')
student_model_cskd = student_model_cskd.to(device)

student_model_cskd.load_state_dict(torch.load('/kaggle/input/cskd-distillation/student_model_cskd_weights.pth'))
teacher_model.load_state_dict(torch.load('/kaggle/input/cskd-distillation/teacher_model_cskd_weights_after_finetuning.pth'))


print("Student model weights and features loaded successfully!")



test_cskd_acc = evaluate_cskd_model(student_model_cskd, test_loader, device)
print(f"Student (CSKD) accuracy on test data = {test_cskd_acc*100:.2f}%")



  student_model_cskd.load_state_dict(torch.load('/kaggle/input/cskd-distillation/student_model_cskd_weights.pth'))
  teacher_model.load_state_dict(torch.load('/kaggle/input/cskd-distillation/teacher_model_cskd_weights_after_finetuning.pth'))


Student model weights and features loaded successfully!
Evaluating the model
Batch 1/313 - Accuracy: 100.00%
Batch 11/313 - Accuracy: 94.89%
Batch 21/313 - Accuracy: 95.09%
Batch 31/313 - Accuracy: 95.46%
Batch 41/313 - Accuracy: 95.43%
Batch 51/313 - Accuracy: 95.22%
Batch 61/313 - Accuracy: 95.34%
Batch 71/313 - Accuracy: 95.33%
Batch 81/313 - Accuracy: 95.25%
Batch 91/313 - Accuracy: 95.05%
Batch 101/313 - Accuracy: 94.93%
Batch 111/313 - Accuracy: 94.88%
Batch 121/313 - Accuracy: 94.89%
Batch 131/313 - Accuracy: 94.99%
Batch 141/313 - Accuracy: 95.17%
Batch 151/313 - Accuracy: 94.99%
Batch 161/313 - Accuracy: 94.95%
Batch 171/313 - Accuracy: 94.99%
Batch 181/313 - Accuracy: 95.04%
Batch 191/313 - Accuracy: 95.08%
Batch 201/313 - Accuracy: 95.12%
Batch 211/313 - Accuracy: 95.10%
Batch 221/313 - Accuracy: 95.19%
Batch 231/313 - Accuracy: 95.25%
Batch 241/313 - Accuracy: 95.20%
Batch 251/313 - Accuracy: 95.17%
Batch 261/313 - Accuracy: 95.15%
Batch 271/313 - Accuracy: 95.18%
Batch 281

In [161]:
# Evaluate on texture-bias dataset
texture_cskd_acc = evaluate_cskd_model(student_model_cskd, texture_bias_loader, device)
print(f"Student (CSKD) accuracy on Texture-Bias = {texture_cskd_acc*100:.2f}%")

Evaluating the model
Batch 1/313 - Accuracy: 34.38%
Batch 11/313 - Accuracy: 36.93%
Batch 21/313 - Accuracy: 39.14%
Batch 31/313 - Accuracy: 40.12%
Batch 41/313 - Accuracy: 40.32%
Batch 51/313 - Accuracy: 41.24%
Batch 61/313 - Accuracy: 41.29%
Batch 71/313 - Accuracy: 41.55%
Batch 81/313 - Accuracy: 40.93%
Batch 91/313 - Accuracy: 40.76%
Batch 101/313 - Accuracy: 40.32%
Batch 111/313 - Accuracy: 40.40%
Batch 121/313 - Accuracy: 40.70%
Batch 131/313 - Accuracy: 40.65%
Batch 141/313 - Accuracy: 40.89%
Batch 151/313 - Accuracy: 40.79%
Batch 161/313 - Accuracy: 40.97%
Batch 171/313 - Accuracy: 40.95%
Batch 181/313 - Accuracy: 40.71%
Batch 191/313 - Accuracy: 40.76%
Batch 201/313 - Accuracy: 40.97%
Batch 211/313 - Accuracy: 40.89%
Batch 221/313 - Accuracy: 40.71%
Batch 231/313 - Accuracy: 40.91%
Batch 241/313 - Accuracy: 40.68%
Batch 251/313 - Accuracy: 40.74%
Batch 261/313 - Accuracy: 40.76%
Batch 271/313 - Accuracy: 40.56%
Batch 281/313 - Accuracy: 40.54%
Batch 291/313 - Accuracy: 40.52%


In [162]:
# Evaluate on shape-bias dataset
shape_cskd_acc = evaluate_cskd_model(student_model_cskd, shape_bias_loader, device)
print(f"Student (CSKD) accuracy on Shape-Bias = {shape_cskd_acc*100:.2f}%")

Evaluating the model
Batch 1/313 - Accuracy: 15.62%
Batch 11/313 - Accuracy: 16.19%
Batch 21/313 - Accuracy: 17.71%
Batch 31/313 - Accuracy: 19.25%
Batch 41/313 - Accuracy: 18.06%
Batch 51/313 - Accuracy: 17.22%
Batch 61/313 - Accuracy: 17.16%
Batch 71/313 - Accuracy: 17.25%
Batch 81/313 - Accuracy: 17.01%
Batch 91/313 - Accuracy: 16.79%
Batch 101/313 - Accuracy: 16.74%
Batch 111/313 - Accuracy: 16.55%
Batch 121/313 - Accuracy: 16.24%
Batch 131/313 - Accuracy: 16.70%
Batch 141/313 - Accuracy: 16.71%
Batch 151/313 - Accuracy: 16.72%
Batch 161/313 - Accuracy: 16.71%
Batch 171/313 - Accuracy: 16.89%
Batch 181/313 - Accuracy: 16.95%
Batch 191/313 - Accuracy: 17.03%
Batch 201/313 - Accuracy: 17.06%
Batch 211/313 - Accuracy: 17.15%
Batch 221/313 - Accuracy: 16.97%
Batch 231/313 - Accuracy: 17.00%
Batch 241/313 - Accuracy: 17.01%
Batch 251/313 - Accuracy: 17.11%
Batch 261/313 - Accuracy: 16.97%
Batch 271/313 - Accuracy: 16.85%
Batch 281/313 - Accuracy: 16.90%
Batch 291/313 - Accuracy: 16.92%


In [163]:
# Evaluate on scrambled dataset
scrambled_cskd_acc = evaluate_cskd_model(student_model_cskd, scrambled_loader, device)
print(f"Student (CSKD) accuracy on Scrambled = {scrambled_cskd_acc*100:.2f}%")

Evaluating the model


  image = torch.load(img_path)  # Load the .pt tensor file


Batch 1/313 - Accuracy: 31.25%
Batch 11/313 - Accuracy: 24.15%
Batch 21/313 - Accuracy: 22.47%
Batch 31/313 - Accuracy: 22.08%
Batch 41/313 - Accuracy: 22.18%
Batch 51/313 - Accuracy: 21.38%
Batch 61/313 - Accuracy: 20.54%
Batch 71/313 - Accuracy: 20.47%
Batch 81/313 - Accuracy: 20.49%
Batch 91/313 - Accuracy: 20.36%
Batch 101/313 - Accuracy: 20.58%
Batch 111/313 - Accuracy: 20.47%
Batch 121/313 - Accuracy: 20.14%
Batch 131/313 - Accuracy: 20.16%
Batch 141/313 - Accuracy: 20.48%
Batch 151/313 - Accuracy: 20.84%
Batch 161/313 - Accuracy: 20.87%
Batch 171/313 - Accuracy: 21.13%
Batch 181/313 - Accuracy: 21.22%
Batch 191/313 - Accuracy: 21.42%
Batch 201/313 - Accuracy: 21.47%
Batch 211/313 - Accuracy: 21.53%
Batch 221/313 - Accuracy: 21.56%
Batch 231/313 - Accuracy: 21.48%
Batch 241/313 - Accuracy: 21.50%
Batch 251/313 - Accuracy: 21.51%
Batch 261/313 - Accuracy: 21.55%
Batch 271/313 - Accuracy: 21.63%
Batch 281/313 - Accuracy: 21.70%
Batch 291/313 - Accuracy: 21.74%
Batch 301/313 - Accur

In [164]:
# Evaluate on noisy dataset
noisy_cskd_acc = evaluate_cskd_model(student_model_cskd, noisy_loader, device)
print(f"Student (CSKD) accuracy on Noisy = {noisy_cskd_acc*100:.2f}%")


Evaluating the model


  image = torch.load(img_path)  # Load the .pt tensor file


Batch 1/313 - Accuracy: 68.75%
Batch 11/313 - Accuracy: 69.60%
Batch 21/313 - Accuracy: 71.28%
Batch 31/313 - Accuracy: 70.77%
Batch 41/313 - Accuracy: 70.81%
Batch 51/313 - Accuracy: 71.08%
Batch 61/313 - Accuracy: 70.54%
Batch 71/313 - Accuracy: 69.63%
Batch 81/313 - Accuracy: 69.48%
Batch 91/313 - Accuracy: 69.78%
Batch 101/313 - Accuracy: 70.39%
Batch 111/313 - Accuracy: 70.64%
Batch 121/313 - Accuracy: 70.74%
Batch 131/313 - Accuracy: 70.30%
Batch 141/313 - Accuracy: 70.70%
Batch 151/313 - Accuracy: 70.51%
Batch 161/313 - Accuracy: 70.42%
Batch 171/313 - Accuracy: 70.43%
Batch 181/313 - Accuracy: 70.46%
Batch 191/313 - Accuracy: 70.44%
Batch 201/313 - Accuracy: 70.40%
Batch 211/313 - Accuracy: 70.26%
Batch 221/313 - Accuracy: 70.12%
Batch 231/313 - Accuracy: 70.17%
Batch 241/313 - Accuracy: 70.10%
Batch 251/313 - Accuracy: 70.23%
Batch 261/313 - Accuracy: 70.08%
Batch 271/313 - Accuracy: 70.01%
Batch 281/313 - Accuracy: 70.04%
Batch 291/313 - Accuracy: 69.88%
Batch 301/313 - Accur

In [165]:
# Evaluate on super-pixelated dataset
super_pixelated_cskd_acc = evaluate_cskd_model(student_model_cskd, superpixel_loader, device)
print(f"Student (CSKD) accuracy on Super Pixelated = {super_pixelated_cskd_acc*100:.2f}%")

Evaluating the model
Batch 1/313 - Accuracy: 46.88%
Batch 11/313 - Accuracy: 38.35%
Batch 21/313 - Accuracy: 36.90%
Batch 31/313 - Accuracy: 36.90%
Batch 41/313 - Accuracy: 37.20%
Batch 51/313 - Accuracy: 37.56%
Batch 61/313 - Accuracy: 37.19%
Batch 71/313 - Accuracy: 37.06%
Batch 81/313 - Accuracy: 37.15%
Batch 91/313 - Accuracy: 37.71%
Batch 101/313 - Accuracy: 37.84%
Batch 111/313 - Accuracy: 38.37%
Batch 121/313 - Accuracy: 38.09%
Batch 131/313 - Accuracy: 38.22%
Batch 141/313 - Accuracy: 38.14%
Batch 151/313 - Accuracy: 38.43%
Batch 161/313 - Accuracy: 38.47%
Batch 171/313 - Accuracy: 38.36%
Batch 181/313 - Accuracy: 38.36%
Batch 191/313 - Accuracy: 38.35%
Batch 201/313 - Accuracy: 38.46%
Batch 211/313 - Accuracy: 38.33%
Batch 221/313 - Accuracy: 38.24%
Batch 231/313 - Accuracy: 38.34%
Batch 241/313 - Accuracy: 38.28%
Batch 251/313 - Accuracy: 38.23%
Batch 261/313 - Accuracy: 37.94%
Batch 271/313 - Accuracy: 37.65%
Batch 281/313 - Accuracy: 37.88%
Batch 291/313 - Accuracy: 37.85%


# Evaluating Resnet (Teaher)

In [11]:
test_teacher_acc = evaluate_model(teacher_model, test_loader, device)
print(f"Teacher accuracy on Texture-Bias dataset = {test_teacher_acc*100:.2f}%")


Evaluating the model
Batch 1/313 - Accuracy: 87.50%
Batch 11/313 - Accuracy: 80.97%
Batch 21/313 - Accuracy: 81.25%
Batch 31/313 - Accuracy: 81.25%
Batch 41/313 - Accuracy: 81.25%
Batch 51/313 - Accuracy: 81.25%
Batch 61/313 - Accuracy: 81.25%
Batch 71/313 - Accuracy: 81.16%
Batch 81/313 - Accuracy: 80.83%
Batch 91/313 - Accuracy: 80.63%
Batch 101/313 - Accuracy: 80.63%
Batch 111/313 - Accuracy: 80.43%
Batch 121/313 - Accuracy: 80.60%
Batch 131/313 - Accuracy: 80.77%
Batch 141/313 - Accuracy: 81.07%
Batch 151/313 - Accuracy: 81.08%
Batch 161/313 - Accuracy: 81.23%
Batch 171/313 - Accuracy: 81.29%
Batch 181/313 - Accuracy: 81.35%
Batch 191/313 - Accuracy: 81.36%
Batch 201/313 - Accuracy: 81.55%
Batch 211/313 - Accuracy: 81.61%
Batch 221/313 - Accuracy: 81.50%
Batch 231/313 - Accuracy: 81.70%
Batch 241/313 - Accuracy: 81.77%
Batch 251/313 - Accuracy: 81.80%
Batch 261/313 - Accuracy: 81.76%
Batch 271/313 - Accuracy: 81.76%
Batch 281/313 - Accuracy: 81.76%
Batch 291/313 - Accuracy: 81.83%


In [12]:
# Evaluate on shape-bias dataset
shape_teacher_acc = evaluate_model(teacher_model, shape_bias_loader, device)
print(f"Student (CSKD) accuracy on Shape-Bias = {shape_teacher_acc*100:.2f}%")

Evaluating the model
Batch 1/313 - Accuracy: 12.50%
Batch 11/313 - Accuracy: 10.80%
Batch 21/313 - Accuracy: 10.42%
Batch 31/313 - Accuracy: 9.88%
Batch 41/313 - Accuracy: 9.68%
Batch 51/313 - Accuracy: 9.19%
Batch 61/313 - Accuracy: 9.58%
Batch 71/313 - Accuracy: 9.64%
Batch 81/313 - Accuracy: 9.68%
Batch 91/313 - Accuracy: 9.86%
Batch 101/313 - Accuracy: 9.81%
Batch 111/313 - Accuracy: 9.71%
Batch 121/313 - Accuracy: 9.53%
Batch 131/313 - Accuracy: 9.64%
Batch 141/313 - Accuracy: 9.73%
Batch 151/313 - Accuracy: 9.60%
Batch 161/313 - Accuracy: 9.55%
Batch 171/313 - Accuracy: 9.61%
Batch 181/313 - Accuracy: 9.62%
Batch 191/313 - Accuracy: 9.57%
Batch 201/313 - Accuracy: 9.64%
Batch 211/313 - Accuracy: 9.75%
Batch 221/313 - Accuracy: 9.79%
Batch 231/313 - Accuracy: 9.78%
Batch 241/313 - Accuracy: 9.79%
Batch 251/313 - Accuracy: 9.85%
Batch 261/313 - Accuracy: 9.69%
Batch 271/313 - Accuracy: 9.63%
Batch 281/313 - Accuracy: 9.63%
Batch 291/313 - Accuracy: 9.59%
Batch 301/313 - Accuracy: 9

In [13]:
# Evaluate on scrambled dataset
scrambled_teacher_acc = evaluate_model(teacher_model, scrambled_loader, device)
print(f"Student (CSKD) accuracy on Scrambled = {scrambled_teacher_acc*100:.2f}%")


Evaluating the model


  image = torch.load(img_path)  # Load the .pt tensor file


Batch 1/313 - Accuracy: 21.88%
Batch 11/313 - Accuracy: 24.72%
Batch 21/313 - Accuracy: 24.85%
Batch 31/313 - Accuracy: 24.90%
Batch 41/313 - Accuracy: 24.54%
Batch 51/313 - Accuracy: 23.96%
Batch 61/313 - Accuracy: 24.44%
Batch 71/313 - Accuracy: 24.21%
Batch 81/313 - Accuracy: 24.50%
Batch 91/313 - Accuracy: 24.38%
Batch 101/313 - Accuracy: 24.35%
Batch 111/313 - Accuracy: 24.52%
Batch 121/313 - Accuracy: 24.46%
Batch 131/313 - Accuracy: 24.28%
Batch 141/313 - Accuracy: 24.18%
Batch 151/313 - Accuracy: 24.23%
Batch 161/313 - Accuracy: 24.20%
Batch 171/313 - Accuracy: 24.10%
Batch 181/313 - Accuracy: 24.12%
Batch 191/313 - Accuracy: 24.20%
Batch 201/313 - Accuracy: 24.28%
Batch 211/313 - Accuracy: 24.32%
Batch 221/313 - Accuracy: 24.28%
Batch 231/313 - Accuracy: 24.23%
Batch 241/313 - Accuracy: 24.44%
Batch 251/313 - Accuracy: 24.38%
Batch 261/313 - Accuracy: 24.19%
Batch 271/313 - Accuracy: 24.04%
Batch 281/313 - Accuracy: 24.07%
Batch 291/313 - Accuracy: 24.12%
Batch 301/313 - Accur

In [14]:
# Evaluate on noisy dataset
noisy_teacher_acc = evaluate_model(teacher_model, noisy_loader, device)
print(f"Student (CSKD) accuracy on Noisy = {noisy_teacher_acc*100:.2f}%")

Evaluating the model


  image = torch.load(img_path)  # Load the .pt tensor file


Batch 1/313 - Accuracy: 18.75%
Batch 11/313 - Accuracy: 16.48%
Batch 21/313 - Accuracy: 15.62%
Batch 31/313 - Accuracy: 15.42%
Batch 41/313 - Accuracy: 15.85%
Batch 51/313 - Accuracy: 15.26%
Batch 61/313 - Accuracy: 15.52%
Batch 71/313 - Accuracy: 16.15%
Batch 81/313 - Accuracy: 16.24%
Batch 91/313 - Accuracy: 16.24%
Batch 101/313 - Accuracy: 16.31%
Batch 111/313 - Accuracy: 16.27%
Batch 121/313 - Accuracy: 16.14%
Batch 131/313 - Accuracy: 16.03%
Batch 141/313 - Accuracy: 16.16%
Batch 151/313 - Accuracy: 15.98%
Batch 161/313 - Accuracy: 15.74%
Batch 171/313 - Accuracy: 15.83%
Batch 181/313 - Accuracy: 15.68%
Batch 191/313 - Accuracy: 15.61%
Batch 201/313 - Accuracy: 15.72%
Batch 211/313 - Accuracy: 15.74%
Batch 221/313 - Accuracy: 15.79%
Batch 231/313 - Accuracy: 15.83%
Batch 241/313 - Accuracy: 15.77%
Batch 251/313 - Accuracy: 15.95%
Batch 261/313 - Accuracy: 15.78%
Batch 271/313 - Accuracy: 15.72%
Batch 281/313 - Accuracy: 15.69%
Batch 291/313 - Accuracy: 15.61%
Batch 301/313 - Accur

In [15]:
# Evaluate on super-pixelated dataset
superpixelated_teacher_acc = evaluate_model(teacher_model, superpixel_loader, device)
print(f"Student (CSKD) accuracy on Super Pixelated = {superpixelated_teacher_acc*100:.2f}%")

Evaluating the model
Batch 1/313 - Accuracy: 12.50%
Batch 11/313 - Accuracy: 15.06%
Batch 21/313 - Accuracy: 13.99%
Batch 31/313 - Accuracy: 13.61%
Batch 41/313 - Accuracy: 13.11%
Batch 51/313 - Accuracy: 12.56%
Batch 61/313 - Accuracy: 12.91%
Batch 71/313 - Accuracy: 12.90%
Batch 81/313 - Accuracy: 12.89%
Batch 91/313 - Accuracy: 13.19%
Batch 101/313 - Accuracy: 13.15%
Batch 111/313 - Accuracy: 13.29%
Batch 121/313 - Accuracy: 13.07%
Batch 131/313 - Accuracy: 13.00%
Batch 141/313 - Accuracy: 12.94%
Batch 151/313 - Accuracy: 12.83%
Batch 161/313 - Accuracy: 12.75%
Batch 171/313 - Accuracy: 12.88%
Batch 181/313 - Accuracy: 12.86%
Batch 191/313 - Accuracy: 12.89%
Batch 201/313 - Accuracy: 12.87%
Batch 211/313 - Accuracy: 12.99%
Batch 221/313 - Accuracy: 13.02%
Batch 231/313 - Accuracy: 12.96%
Batch 241/313 - Accuracy: 12.94%
Batch 251/313 - Accuracy: 13.02%
Batch 261/313 - Accuracy: 12.85%
Batch 271/313 - Accuracy: 12.78%
Batch 281/313 - Accuracy: 12.89%
Batch 291/313 - Accuracy: 12.89%
