#For accessing datasets and checkpoints in drive

In [None]:
import os
import re

PROJECT_PATH = '/content/CatDogCNN2'
DATA_PATH = f'{PROJECT_PATH}/pets0'
CKPT_PATH = f'{PROJECT_PATH}/mods'

os.makedirs(DATA_PATH, exist_ok=True)
os.makedirs(CKPT_PATH, exist_ok=True)

print("Project folders created:")
print(PROJECT_PATH)
print(DATA_PATH)
print(CKPT_PATH)

# =========================================
# 3. Download DATASET (pets0)
# =========================================
# Function to extract file ID from a Google Drive URL
def get_drive_id(url):
    match = re.search(r'file/d/([a-zA-Z0-9_-]+)', url)
    if match:
        return match.group(1)
    return None

dataset_url = 'https://drive.google.com/file/d/1dUoT7hrUgzPOhvbt2qOwbvfaFIsXBijA/view?usp=drive_link'
dataset_id = get_drive_id(dataset_url)
dataset_zip = f"{PROJECT_PATH}/pets0.zip"

if dataset_id:
    !gdown --id {dataset_id} -O {dataset_zip}
    !unzip -q {dataset_zip} -d {DATA_PATH}
    !rm {dataset_zip}
    print("Dataset ready!")
else:
    print("Could not extract dataset ID from the URL.")


# =========================================
# 4. Download CHECKPOINTS (mods)
# =========================================
checkpoint_url = 'https://drive.google.com/file/d/1cnI2-TXblTxdMwAFnic-QEJ-EiwX-QeD/view?usp=drive_link'
checkpoint_id = get_drive_id(checkpoint_url)
checkpoint_zip = f"{PROJECT_PATH}/mods.zip"

if checkpoint_id:
    !gdown --id {checkpoint_id} -O {checkpoint_zip}
    !unzip -q {checkpoint_zip} -d {CKPT_PATH}
    !rm {checkpoint_zip}
    print("Checkpoints ready!")
else:
    print("Could not extract checkpoint ID from the URL.")


Project folders created:
/content/CatDogCNN2
/content/CatDogCNN2/pets0
/content/CatDogCNN2/mods
Downloading...
From (original): https://drive.google.com/uc?id=1dUoT7hrUgzPOhvbt2qOwbvfaFIsXBijA
From (redirected): https://drive.google.com/uc?id=1dUoT7hrUgzPOhvbt2qOwbvfaFIsXBijA&confirm=t&uuid=aa16f096-9bb6-4a4e-b13b-5cd49a84c229
To: /content/CatDogCNN2/pets0.zip
100% 807M/807M [00:15<00:00, 53.1MB/s]
Dataset ready!
Downloading...
From (original): https://drive.google.com/uc?id=1cnI2-TXblTxdMwAFnic-QEJ-EiwX-QeD
From (redirected): https://drive.google.com/uc?id=1cnI2-TXblTxdMwAFnic-QEJ-EiwX-QeD&confirm=t&uuid=486b0809-f263-4841-bca3-56ce9de9baab
To: /content/CatDogCNN2/mods.zip
100% 283M/283M [00:09<00:00, 31.0MB/s]
Checkpoints ready!


Please pass the corresponding file paths to run the functions below.

#SimCLR Pretraining


In [None]:
!pip install lightly


Collecting lightly
  Downloading lightly-1.5.22-py3-none-any.whl.metadata (38 kB)
Collecting hydra-core>=1.0.0 (from lightly)
  Downloading hydra_core-1.3.2-py3-none-any.whl.metadata (5.5 kB)
Collecting lightly_utils~=0.0.0 (from lightly)
  Downloading lightly_utils-0.0.2-py3-none-any.whl.metadata (1.4 kB)
Collecting pytorch_lightning>=1.0.4 (from lightly)
  Downloading pytorch_lightning-2.6.0-py3-none-any.whl.metadata (21 kB)
Collecting aenum>=3.1.11 (from lightly)
  Downloading aenum-3.1.16-py3-none-any.whl.metadata (3.8 kB)
Collecting torchmetrics>0.7.0 (from pytorch_lightning>=1.0.4->lightly)
  Downloading torchmetrics-1.8.2-py3-none-any.whl.metadata (22 kB)
Collecting lightning-utilities>=0.10.0 (from pytorch_lightning>=1.0.4->lightly)
  Downloading lightning_utilities-0.15.2-py3-none-any.whl.metadata (5.7 kB)
Downloading lightly-1.5.22-py3-none-any.whl (859 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m859.3/859.3 kB[0m [31m27.0 MB/s[0m eta [36m0:00:00[0m

In [None]:
import torch
import torch.nn as nn
from torchvision.models import mobilenet_v2
from lightly.transforms import SimCLRTransform
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from lightly.loss import NTXentLoss

# Define the SimCLRModel class globally as it's a core component
class SimCLRModel(nn.Module):
    def __init__(self, backbone, feature_dim=128):
        super().__init__()
        self.backbone = backbone
        self.projection_head = nn.Sequential(
            nn.Linear(1280, 256),
            nn.ReLU(),
            nn.Linear(256, feature_dim)
        )

    def forward(self, x):
        features = self.backbone(x)
        projections = self.projection_head(features)
        return projections

def load_simclr_model(feature_dim=128, pretrained_path=None):
    """
    Loads a MobileNetV2 backbone and constructs the SimCLR model.
    Optionally loads pretrained weights.
    """
    # Load MobileNetV2 without pretrained weights and remove classifier
    base_model = mobilenet_v2(weights=None)

    # Remove the classification head and keep only the feature extractor
    backbone = nn.Sequential(
        base_model.features,
        nn.AdaptiveAvgPool2d(1),  # Ensure consistent output shape
        nn.Flatten(),             # Shape: [B, 1280]
    )

    model = SimCLRModel(backbone, feature_dim=feature_dim)

    if pretrained_path:
        print(f"Loading pretrained model from {pretrained_path}")
        # Use map_location='cpu' to load onto CPU first, then move to device
        student_state_dict = torch.load(pretrained_path, map_location='cpu')
        model.load_state_dict(student_state_dict, strict=True)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    print("SimCLR model loaded successfully.")
    return model, device

# Custom dataset to return 2 views for SimCLR
class SimCLRDataset(ImageFolder):
    def __init__(self, root, simclr_transform): # Changed `transform` to `simclr_transform` and removed passing it to super().__init__
        super().__init__(root) # Initialize ImageFolder without a transform
        self.simclr_transform = simclr_transform # Store the SimCLRTransform separately

    def __getitem__(self, index):
        # ImageFolder returns (image, label), we only need the image for pretraining
        # Here, 'sample' will be the raw PIL image loaded by ImageFolder (since no transform was passed to super)
        sample, _ = super().__getitem__(index)
        xi, xj = self.simclr_transform(sample) # Apply SimCLR transform to get two views from the raw PIL image
        return xi, xj

def load_simclr_data(data_path, batch_size=64, input_size=224, num_workers=2):
    """
    Loads the dataset and creates a DataLoader for SimCLR pretraining.
    """
    simclr_transform = SimCLRTransform(input_size=input_size)
    # Pass simclr_transform explicitly to our custom SimCLRDataset
    dataset = SimCLRDataset(root=data_path, simclr_transform=simclr_transform)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, drop_last=True)
    num_imgs = len(dataset)
    print(f"Number of images in the dataset: {num_imgs}")
    return dataloader, num_imgs

def train_simclr_model(model, dataloader, device, epochs=20, lr=3e-4, save_best_path=None, save_epoch_path=None, start_epoch=0, initial_mloss=float('inf')):
    """
    Trains the SimCLR model.
    """
    criterion = NTXentLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    mloss = initial_mloss # Keep track of the minimum loss for saving the best model

    print("Starting SimCLR training...")
    for epoch in range(start_epoch, start_epoch + epochs):
        running_loss = 0.0
        total_batches = 0
        for views in dataloader:
            view1, view2 = views[0].to(device), views[1].to(device)

            optimizer.zero_grad()
            z1 = model(view1)
            z2 = model(view2)
            loss = criterion(z1, z2)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            total_batches += 1

        avg_epoch_loss = running_loss / total_batches if total_batches > 0 else 0.0

        print(f"Epoch {epoch + 1}: Loss = {avg_epoch_loss:.4f}")

        # Save model if current loss is the best so far
        if save_best_path and avg_epoch_loss < mloss:
            mloss = avg_epoch_loss
            torch.save(model.state_dict(), save_best_path)
            print(f"Saved best model with loss {mloss:.4f} at epoch {epoch + 1}")

        # Save model periodically (e.g., every 2 epochs or as specified)
        if save_epoch_path and (epoch + 1) % 2 == 0:
            torch.save(model.state_dict(), save_epoch_path)
            print(f"Saved model checkpoint at epoch {epoch + 1}")

    print("SimCLR training finished.")
    return model

##Load the model (and optionally resume training)

In [None]:
# To start a new training run:
model, device = load_simclr_model(feature_dim=128)

# To resume training from a saved checkpoint:
pretrained_model_path = '/content/drive/MyDrive/mods/mobilenet_sim_10000_epoch2.pth'
model, device = load_simclr_model(feature_dim=128, pretrained_path=pretrained_model_path)

##Load the data

In [None]:
data_path = '/content/drive/MyDrive/pets0/unlabeled_train'
dataloader, num_imgs = load_simclr_data(data_path, batch_size=64, input_size=224, num_workers=2)

##Train the model

In [None]:
# Example for starting a new training from scratch (adjust epochs, lr, and save paths)
model = train_simclr_model(
         model, dataloader, device, epochs=50, lr=3e-4,
         save_best_path='/content/drive/MyDrive/mods/mobilenet_sim_10000_best.pth',
         save_epoch_path='/content/drive/MyDrive/mods/mobilenet_sim_10000_epoch.pth',
         start_epoch=0 # For new training
         )



In [None]:
# Example for continuing training from a checkpoint (adjust start_epoch and initial_mloss based on previous runs)
    model = train_simclr_model(
        model, dataloader, device, epochs=10, lr=3e-4,
        save_best_path='/content/drive/MyDrive/mods/mobilenet_sim_10000_best_cont.pth',
        save_epoch_path='/content/drive/MyDrive/mods/mobilenet_sim_10000_epoch_cont.pth',
        start_epoch=68, # Assuming previous training ended at epoch 67 (58 + 10 epochs = 68)
        initial_mloss=1.97 # Based on the mloss from the last training run
    )

#Finetune the simCLR model

In [None]:
import torch
import torch.nn as nn
from torchvision.models import mobilenet_v2
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
from sklearn.metrics import classification_report, confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np

def finetune_model(model, train_loader, val_loader, optimizer, criterion, num_epochs, device, save_best=None, save_epoch=None, start_epoch=0, init_acc = 0.0):
    best_val_acc = init_acc

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * images.size(0)
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

        epoch_loss = running_loss / total
        epoch_acc = correct / total

        # Validation
        model.eval()
        val_running_loss = 0.0
        val_correct = 0
        val_total = 0
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)

                val_running_loss += loss.item() * images.size(0)
                _, preds = torch.max(outputs, 1)
                val_correct += (preds == labels).sum().item()
                val_total += labels.size(0)

        val_loss = val_running_loss / val_total
        val_acc = val_correct / val_total

        print(f'Epoch {start_epoch + epoch + 1}: '
              f'Train Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f} | '
              f'Val Loss: {val_loss:.4f} Acc: {val_acc:.4f}')

        if save_epoch and (epoch+1)%2==0:
            torch.save(model.state_dict(), save_epoch)
            print(f"Saved model checkpoint at epoch {start_epoch + epoch + 1}")

        if save_best:
            if val_acc > best_val_acc:
                best_val_acc = val_acc
                torch.save(model.state_dict(), save_best)
                print(f"Saved best model with Val Acc: {best_val_acc:.4f} at epoch {start_epoch + epoch + 1}")

    return best_val_acc

In [None]:
def run_finetuning_workflow(pretrained_simclr_path=None, num_epochs_initial=10, best_save=None,
                            save_epoch=None, start_epoch=0, init_acc = 0.0, finetuned_path=None):

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Create a standard MobileNetV2 model
    student_finetune = mobilenet_v2(weights=None)
    num_ftrs = student_finetune.classifier[1].in_features
    student_finetune.classifier[1] = nn.Linear(num_ftrs, 2) # 2 classes: cat, dog

    if finetuned_path:
        print(f"Continuing finetuning from {finetuned_path}")
        student_finetune.load_state_dict(torch.load(finetuned_path, map_location=device))
        print("Previous finetuned model loaded successfully.")
    elif pretrained_simclr_path:
        print(f"Starting new finetuning using SimCLR backbone from {pretrained_simclr_path}")
        # Load the pretrained SimCLR model state dict
        simclr_state_dict = torch.load(pretrained_simclr_path, map_location=device)

        # Filter the state dict to keep only the backbone weights
        backbone_state_dict = {}
        for k, v in simclr_state_dict.items():
            # Keys in the saved SimCLRModel state dict for the backbone start with 'backbone.0.'
            if k.startswith('backbone.0.'):
                backbone_state_dict[k.replace('backbone.0.', 'features.')] = v
            # Also handle the case if the state dict keys were just 'backbone.' without the '0.'
            elif k.startswith('backbone.'):
                backbone_state_dict[k.replace('backbone.', 'features.')] = v
            # Handle projection head weights if directly loading the SimCLRModel's state_dict
            elif k.startswith('projection_head.'):
                # These are not needed for finetuning the classification head, so we ignore them
                pass


        # Load the backbone weights into the standard MobileNetV2 model
        # Use strict=False because we are not loading the classifier weights
        student_finetune.load_state_dict(backbone_state_dict, strict=False)
        print("Pretrained SimCLR backbone loaded and classifier replaced.")
    else:
        raise ValueError("Either pretrained_simclr_path must be provided for new finetuning, or cont=True and finetuned_path must be provided for resuming.")

    student_finetune = student_finetune.to(device)

    # Define transforms for finetuning
    finetune_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
        transforms.RandomRotation(15),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
    ])

    val_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.CenterCrop(192),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
    ])

    # Create datasets and dataloaders for finetuning
    finetune_dataset_labeled = datasets.ImageFolder('/content/drive/MyDrive/pets0/finetune_train', transform=finetune_transform)
    val_dataset_labeled = datasets.ImageFolder('/content/drive/MyDrive/pets0/val', transform=val_transform)

    finetune_loader_labeled = DataLoader(finetune_dataset_labeled, batch_size=64, shuffle=True, num_workers=2)
    val_loader_labeled = DataLoader(val_dataset_labeled, batch_size=64, shuffle=False, num_workers=2)

    print(f"Number of samples in finetune dataset: {len(finetune_dataset_labeled)}")
    print(f"Number of samples in validation dataset: {len(val_dataset_labeled)}")

    # Define optimizer and loss function for finetuning
    optimizer_finetune = torch.optim.Adam(student_finetune.parameters(), lr=1e-4) # Start with a lower learning rate
    criterion_finetune = nn.CrossEntropyLoss()

    # Initial finetuning
    finetune_model(
        student_finetune,
        finetune_loader_labeled,
        val_loader_labeled,
        optimizer_finetune,
        criterion_finetune,
        num_epochs=num_epochs_initial,
        device=device,
        save_best=best_save,
        save_epoch=save_epoch,
        start_epoch=start_epoch,
        init_acc=init_acc
    )

    print(f"Finetuning complete. Best model saved to {best_save}")

### Finetuning Run

In [None]:
# Example for starting a new finetuning run:
run_finetuning_workflow(
    pretrained_simclr_path='/content/drive/MyDrive/mods/mobilenet_sim_6000.pth',
    num_epochs_initial=10,
    best_save='/content/drive/MyDrive/mods/student_finetuned_6000_new.pth',
    save_epoch='/content/drive/MyDrive/mods/student_finetuned_6000_epoch_new.pth'
)



In [None]:
# Example for continuing finetuning from a saved finetuned model:
run_finetuning_workflow(

     finetuned_path='/content/drive/MyDrive/mods/student_finetuned_6000_epoch_new.pth', # Path to a previously finetuned model
     num_epochs_initial=5,
     best_save='/content/drive/MyDrive/mods/student_finetuned_6000_cont_best.pth',
     save_epoch='/content/drive/MyDrive/mods/student_finetuned_6000_cont_epoch.pth',
     start_epoch=10, # If previous run had 10 epochs, start from 10
     init_acc=0.7 # Initial best accuracy from the previous finetuning run
 )

#Test the models

In [None]:
import torch
import torch.nn as nn
from torchvision import models
from torchvision.models import mobilenet_v2
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
from sklearn.metrics import classification_report, confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np

test_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
test_dataset = datasets.ImageFolder('/content/drive/MyDrive/CatDogCNN/pets0/test', transform=test_transform)
print(f"Number of images in the test dataset: {len(test_dataset)}")
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

def test_model(model, test_loader, device, test_dataset_classes):
    model.eval()
    correct = 0
    total = 0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            _, preds = torch.max(outputs, 1)

            correct += (preds == labels).sum().item()
            total += labels.size(0)

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    accuracy = correct / total
    print(f"Test Accuracy: {accuracy * 100:.2f}%")

    print(classification_report(all_labels, all_preds, target_names=test_dataset_classes))

    cm_finetuned = confusion_matrix(all_labels, all_preds)
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm_finetuned, annot=True, fmt='d', xticklabels=test_dataset_classes, yticklabels=test_dataset_classes)
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Confusion Matrix for Finetuned Student Model')
    plt.show()
    return np.array(all_preds), np.array(all_labels)


Number of images in the test dataset: 5000


**Test the students (finetuned/ distilled)**

In [None]:
def load_student_model(model_path=None):
  student = models.mobilenet_v2(weights=None)
  num_classes = 2
  student.classifier[1] = nn.Linear(student.last_channel, num_classes)
  if model_path:
      student_state_dict = torch.load(model_path, map_location='cpu')
      student.load_state_dict(student_state_dict, strict=True)
  student_test.eval()
  print("MobileNetV2 student loaded for tesing.")
  return student

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
best_student_finetuned_path = '/content/drive/MyDrive/CatDogCNN/mods/student_finetuned_3000.pth'
student_test = load_student_model(best_student_finetuned_path)
test_model(student_test, test_loader, device, test_dataset.classes)

**Test the teacher (finetuned) - only 1 checkpoint**

In [None]:
# Define the path to your saved finetuned teacher model checkpoint
finetuned_checkpoint_path = '/content/drive/MyDrive/mods/resnet_finetune_only.pth'

# Load a standard ResNet50 model structure
teacher_model = models.resnet50(weights=None) # Load without pretrained ImageNet weights initially

# Modify the final fully connected layer to match the number of classes
num_ftrs = teacher_model.fc.in_features
num_classes = 2  # Your model was finetuned for 2 classes (Cat/Dog)
teacher_model.fc = nn.Linear(num_ftrs, num_classes)


# Load the state dictionary from the saved finetuned teacher model checkpoint
# Using map_location='cpu' to load onto CPU first is safer, then move to device
teacher_state_dict = torch.load(finetuned_checkpoint_path, map_location='cpu')

# Load the state dictionary into the standard ResNet50 model
# This should now work because the model structure matches the saved state_dict
teacher_model.load_state_dict(teacher_state_dict)

# Set the teacher model to evaluation mode
teacher_model.eval()

# Determine the device based on CUDA availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Move the teacher model to the device
teacher_model = teacher_model.to(device)

print("Finetuned teacher model loaded correctly for testing.")

test_model(teacher_model, test_loader, device, test_dataset.classes)

Finetuned teacher model loaded correctly for testing.


#Compact KD function

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, ConcatDataset

def kd_loss(student_logits, teacher_logits, T):
    """KL divergence loss for soft logits."""
    p_s = F.log_softmax(student_logits / T, dim=1)
    p_t = F.softmax(teacher_logits / T, dim=1)
    return F.kl_div(p_s, p_t, reduction='batchmean') * (T * T)

def train_distillation_epoch(student_model, teacher_model, dataloader, criterion_ce, criterion_kd, optimizer, T, device, alpha):
    student_model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for inputs, labels in dataloader:
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()

        with torch.no_grad():
            teacher_logits = teacher_model(inputs)
        student_logits = student_model(inputs)

        labeled_mask = (labels != -1)
        unlabeled_mask = (labels == -1)

        ce_loss = criterion_ce(student_logits[labeled_mask], labels[labeled_mask]) if labeled_mask.sum() > 0 else 0
        kd_loss_val = criterion_kd(student_logits[unlabeled_mask], teacher_logits[unlabeled_mask], T) if unlabeled_mask.sum() > 0 else 0

        loss = (1 - alpha) * ce_loss + alpha * kd_loss_val

        loss.backward()
        optimizer.step()

        running_loss += loss.item() * inputs.size(0)
        if labeled_mask.sum() > 0:
            _, preds = torch.max(student_logits[labeled_mask], 1)
            correct += (preds == labels[labeled_mask]).sum().item()
            total += labeled_mask.sum().item()

    epoch_loss = running_loss / len(dataloader.dataset)
    epoch_acc = correct / total if total > 0 else 0.0
    return epoch_loss, epoch_acc

def validate(model, dataloader, criterion):
    model.eval()
    running_loss = 0
    correct = 0
    total = 0

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    with torch.no_grad():
        for imgs, labels in dataloader:
            imgs, labels = imgs.to(device), labels.to(device)
            outputs = model(imgs)
            loss = criterion(outputs, labels)

            running_loss += loss.item() * imgs.size(0)
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

    val_loss = running_loss / total
    val_acc = correct / total
    return val_loss, val_acc

def run_distillation(num_epochs, student_model, teacher_model_path, labeled_dir,
                     unlabeled_dir, val_dir, img_size=224, batch_size=64, learning_rate=3e-4,
                     T=5.0, alpha=0.7, start = 0, save_path=None, save_best=None, device=None):
    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # --- Teacher Model Setup ---
    teacher_model = models.resnet50(weights=None)
    num_ftrs_teacher = teacher_model.fc.in_features
    teacher_model.fc = nn.Linear(num_ftrs_teacher, 2)
    teacher_state_dict = torch.load(teacher_model_path, map_location='cpu')
    teacher_model.load_state_dict(teacher_state_dict)
    teacher_model.eval()
    for param in teacher_model.parameters():
        param.requires_grad = False
    teacher_model = teacher_model.to(device)

    # --- Student Model Setup ---
    # Assuming student model is already defined and passed as an argument
    student_model = student_model.to(device)
    optimizer_student = torch.optim.Adam(student_model.parameters(), lr=learning_rate)
    criterion_ce = nn.CrossEntropyLoss()

    # --- DataLoaders Setup ---
    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(img_size),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(0.2, 0.2, 0.2, 0.1),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
    val_transform = transforms.Compose([
        transforms.Resize(img_size),
        transforms.CenterCrop(img_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])

    labeled_dataset = datasets.ImageFolder(labeled_dir, transform=train_transform)
    unlabeled_dataset = datasets.ImageFolder(unlabeled_dir, transform=train_transform)
    unlabeled_dataset.samples = [(path, -1) for (path, _) in unlabeled_dataset.samples]
    combined_dataset = ConcatDataset([labeled_dataset, unlabeled_dataset])
    train_loader = DataLoader(combined_dataset, batch_size=batch_size, shuffle=True, num_workers=2)

    val_dataset = datasets.ImageFolder(val_dir, transform=val_transform)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    print(f"Number of images in the labeled dataset: {len(labeled_dataset)}")
    print(f"Number of images in the unlabeled dataset: {len(unlabeled_dataset)}")
    print(f"Number of images in the val dataset: {len(val_dataset)}")

    # --- Training Loop ---
    best_val_acc = 0.0
    for epoch in range(num_epochs):
        train_loss_student, train_acc_student = train_distillation_epoch(
            student_model, teacher_model, train_loader, criterion_ce, kd_loss, optimizer_student, T, device, alpha
        )
        val_loss_student, val_acc_student = validate(student_model, val_loader, criterion_ce)

        print(f'Epoch {epoch+1+start}: '
              f'Train Loss (Student): {train_loss_student:.4f} Acc (Labeled): {train_acc_student:.4f} | '
              f'Val Loss (Student): {val_loss_student:.4f} Acc: {val_acc_student:.4f}')

        if save_path and (epoch+1)%2==0:
            torch.save(student_model.state_dict(), save_path)
            print(f"Saved student model at epoch {epoch+1+start}")

        if save_best and val_acc_student > best_val_acc:
            best_val_acc = val_acc_student
            torch.save(student_model.state_dict(), save_best)
            print(f"Saved best student model with validation accuracy: {best_val_acc:.4f}")

    return student_model

In [3]:
import torch
import torchvision.models as models
import torch.nn as nn


def load_student_model(model_path=None):
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  student = models.mobilenet_v2(weights=None)
  num_classes = 2
  student.classifier[1] = nn.Linear(student.last_channel, num_classes)
  if model_path:
      student_state_dict = torch.load(model_path, map_location='cpu')
      student.load_state_dict(student_state_dict, strict=True)
  student = student.to(device)
  print("MobileNetV2 student model defined with classification head.")
  return student


In [None]:
student = load_student_model('/content/drive/MyDrive/CatDogCNN/mods/only_distilled_student_3000.pth')
teacher_model_path = '/content/drive/MyDrive/CatDogCNN/mods/resnet_finetune_only.pth'
labeled_dir = '/content/drive/MyDrive/CatDogCNN/pets0/finetune_train'
unlabeled_dir = '/content/drive/MyDrive/CatDogCNN/pets0/train3000'
val_dir = '/content/drive/MyDrive/CatDogCNN/pets0/val'

run_distillation(5, student, teacher_model_path, labeled_dir,
                     unlabeled_dir, val_dir, img_size=96, batch_size=32, learning_rate=3e-4,
                     T=5.0, alpha=0.7, start=3, save_path=None, save_best=None, device=None)