## installing dependencies

In [None]:
!pip install torchjd

## imports

In [None]:
import zipfile
import os
from torchvision import models as M
from torch.utils.data import random_split, DataLoader
from sklearn.metrics import classification_report
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets
import torchvision.transforms as transforms
from tqdm import tqdm
from torch.nn import MSELoss, KLDivLoss, CrossEntropyLoss, L1Loss
from torch.optim import Adam
from torchjd import mtl_backward
from torchjd.aggregation import UPGrad
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix
import numpy as np

## specify device

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

## transformations

In [None]:
# Transforms
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
    transforms.RandomApply([
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomVerticalFlip(p=0.5),
        transforms.RandomAffine(degrees=0, scale=(0.8, 1.2)),
    ], p=0.3)
])

## path to training and testing dataset, creation of val datasets and data loaders

In [None]:
# Usage with your data loaders
train_data = datasets.ImageFolder(root="path to folder containing only unperturbed images - training dataset for teacher", transform=transform)

# Train-Val Split
train_size = int(0.8 * len(train_data))
val_size = len(train_data) - train_size
train_data, val_data = random_split(train_data, [train_size, val_size])

test_data = datasets.ImageFolder(root="path to folder containing both perturbed and unperturbed images - test dataset for student", transform=transform)

train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=64, shuffle=True)
val_dataloader = DataLoader(val_data, batch_size=32, shuffle=False)

test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=64, shuffle=True)

## common architecture for student and teacher

In [None]:
class Model(nn.Module):
    def __init__(self, encoder):
        super(Model, self).__init__()
        self.encoder = encoder
        # Modify the fully connected layer to output 2 classes
        self.encoder.fc = nn.Linear(self.encoder.fc.in_features, 2)

        # Modify the first convolutional layer to output 16 channels, as expected by BasicBlock
        self.encoder.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
        self.encoder.bn1 = nn.BatchNorm2d(16)  # This should match the output channels of the first conv layer
        self.encoder.maxpool = nn.Identity()  # Avoid max pooling for now

    def forward(self, x):
        # Pre-processing
        x = self.encoder.conv1(x)
        x = self.encoder.bn1(x)  # Match the number of channels after conv1
        x = self.encoder.relu(x)
        x = self.encoder.maxpool(x)

        # Pass through ResNet20 layers
        x1 = self.encoder.layer1(x)  # First block
        x2 = self.encoder.layer2(x1)  # Second block
        x3 = self.encoder.layer3(x2)  # Third block

        # Global Average Pooling and fully connected layers
        x = self.encoder.avgpool(x3)
        x = torch.flatten(x, 1)
        x = self.encoder.fc(x)

        return x, [x1, x2, x3]


In [None]:
class BasicBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

        # Shortcut connection for downsampling
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out += self.shortcut(x)
        out = self.relu(out)
        return out

In [None]:
class ResNet20(nn.Module):
    def __init__(self, num_classes=2):
        super(ResNet20, self).__init__()
        self.in_channels = 16

        # Initial convolutional block
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(16)
        self.relu = nn.ReLU(inplace=True)

        # Three residual layers
        self.layer1 = self._make_layer(16, num_blocks=3, stride=1)  # No downsampling
        self.layer2 = self._make_layer(32, num_blocks=3, stride=2)  # Downsampling
        self.layer3 = self._make_layer(64, num_blocks=3, stride=2)  # Downsampling

        # Global average pooling and fully connected layer
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(64, num_classes)

    def _make_layer(self, out_channels, num_blocks, stride):
        """
        Create a residual block with the specified number of blocks and stride.
        """
        strides = [stride] + [1] * (num_blocks - 1)  # First block may downsample
        layers = []
        for s in strides:
            layers.append(BasicBlock(self.in_channels, out_channels, s))
            self.in_channels = out_channels  # Update in_channels for the next block
        return nn.Sequential(*layers)

    def forward(self, x):
        # Initial convolutional block
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)

        # Residual layers
        x1 = self.layer1(x)  # First residual layer
        x2 = self.layer2(x1)  # Second residual layer
        x3 = self.layer3(x2)  # Third residual layer

        # Global average pooling and fully connected layer
        x = self.avgpool(x3)
        x = torch.flatten(x, 1)
        x = self.fc(x)

        return x, [x1, x2, x3]


## teacher model

In [None]:
encoder = ResNet20(num_classes=2)

teacher = Model(encoder)
teacher = teacher.to(device)
teacher.train()
None

## training function - teacher

In [None]:
# Train the teacher
def train_model_teacher(model, train_dataloader, val_dataloader, epochs, lr, device='cuda'):
    """
    Train the teacher model with early stopping on the validation dataset.
    Args:
        model (nn.Module): The teacher model to train.
        train_dataloader (DataLoader): Dataloader for the training set.
        val_dataloader (DataLoader): Dataloader for the validation set.
        epochs (int): Number of epochs to train.
        lr (float): Learning rate.
    Returns:
        nn.Module: The trained model with the best weights loaded.
    """
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()

    # Early stopping variables
    best_val_loss = float('inf')
    patience = 4
    counter = 0
    delta = 0
    best_model_path = './best_teacher_model.pth'

    for epoch in range(epochs):
        model.train()
        running_loss = 0.0

        # Training loop
        for i, data in enumerate(tqdm(train_dataloader), 0):
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs, _ = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        # Calculate average training loss
        train_avg_loss = running_loss / len(train_dataloader)
        print(f"Epoch {epoch + 1} Train Loss: {train_avg_loss}")

        # Evaluate on Train and Validation datasets
        train_acc, train_report, _ = evaluate(model, train_dataloader, device)
        val_acc, val_report, val_avg_loss = evaluate(model, val_dataloader, device)

        print(f"Epoch {epoch + 1} Train Accuracy: {train_acc:.4f}")
        print(train_report)
        print(f"Epoch {epoch + 1} Validation Loss: {val_avg_loss:.4f} | Validation Accuracy: {val_acc:.4f}")
        print(val_report)

        # Early stopping on validation loss
        if val_avg_loss < best_val_loss - delta:
            best_val_loss = val_avg_loss
            torch.save(model.state_dict(), best_model_path)
            print(f"New best model saved at epoch {epoch + 1}.")
            counter = 0
        else:
            counter += 1
            print(f"No improvement. Early stopping patience counter: {counter}/{patience}")

        if counter >= patience:
            print("Early stopping triggered.")
            break

    # Load the best model weights before returning
    model.load_state_dict(torch.load(best_model_path))
    print("Loaded the best model weights.")
    return model


## validation function

In [None]:
# Evaluation Function
def evaluate(model, dataloader, device='cuda'):
    """
    Evaluate the model on a given dataloader.
    Args:
        model (nn.Module): The model to evaluate.
        dataloader (DataLoader): The dataloader for evaluation.
        device (str): The device to use for computation.
    Returns:
        float: Accuracy of the model on the dataset.
        dict: Classification report as a dictionary.
        float: Average loss on the dataset.
    """
    model.eval()
    correct = 0
    total = 0
    all_labels = []
    all_preds = []
    running_loss = 0.0
    criterion = nn.CrossEntropyLoss()
    with torch.no_grad():
        for data in tqdm(dataloader):
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs, _ = model(images)
            _, predicted = torch.max(outputs.data, 1)
            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(predicted.cpu().numpy())
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            loss = criterion(outputs, labels)
            running_loss += loss.item()
    avg_loss = running_loss / len(dataloader)
    return correct / total, classification_report(all_labels, all_preds, output_dict=True), avg_loss

## training of teacher

In [None]:
train_model_teacher(teacher,train_dataloader,test_dataloader,25,0.0001)

 72%|███████▏  | 969/1345 [03:45<01:20,  4.70it/s]

## student model

In [None]:
encoder = ResNet20(num_classes=2)

student = Model(encoder)
student = student.to(device)

## adversarial attacks - FGSM and PGD

In [None]:
def fgsm_attack(model, x, y, eps=0.03, targeted=False):
    model.eval()
    x = x.to(device)  # Ensure x is on the same device as the model
    y = y.to(device)  # Ensure y is on the same device as the model
    x_adv = x.clone().detach().requires_grad_(True)

    output,_ = model(x_adv)
    criterion = nn.CrossEntropyLoss()

    model.zero_grad()
    loss = criterion(output, y)
    loss.backward()

    with torch.no_grad():
        if targeted:
            perturb = eps * torch.sign(-x_adv.grad)
        else:
            perturb = eps * torch.sign(x_adv.grad)

        x_adv = torch.clamp(x_adv + perturb, min=0, max=1)

    return x_adv

# PGD attack function
def pgd_attack(model, x, y, eps=0.03, alpha=0.01, steps=2, targeted=False):
    model.eval()
    x = x.to(device)  # Ensure x is on the same device as the model
    y = y.to(device)  # Ensure y is on the same device as the model
    x_adv = x.clone().detach()

    x_adv = x_adv + torch.empty_like(x_adv).uniform_(-eps, eps)
    x_adv = torch.clamp(x_adv, min=0, max=1)

    criterion = nn.CrossEntropyLoss()

    for _ in range(steps):
        x_adv.requires_grad_(True)

        output,_ = model(x_adv)
        loss = criterion(output, y)

        model.zero_grad()
        loss.backward()

        with torch.no_grad():
            if targeted:
                perturb = -alpha * torch.sign(x_adv.grad)
            else:
                perturb = alpha * torch.sign(x_adv.grad)

            x_adv += perturb
            x_adv = torch.max(torch.min(x_adv, x + eps), x - eps)
            x_adv = torch.clamp(x_adv, min=0, max=1)

    return x_adv

import random
def attackit(x,y, model,p_orig=0.2,p_pgd=0.4,p_fgsm=0.4):
    rn = random.uniform(0,1)
    if rn < p_orig:
        return x,y
    elif rn < p_orig + p_pgd:
        return pgd_attack(model,x,y),y
    else:
        return fgsm_attack(model,x,y),y

## training and evaluation function - student

In [None]:
def train_model_student_jd(student, teacher, train_dataloader, test_dataloader, aug_fn, epochs, lr, device='cuda'):
    """
    Train a student model using knowledge distillation with Jacobian Descent for optimization.

    Args:
        student (nn.Module): The student model to be trained.
        teacher (nn.Module): The teacher model providing supervision.
        train_dataloader (DataLoader): Dataloader for the training set.
        test_dataloader (DataLoader): Dataloader for the testing set.
        aug_fn (function): Data augmentation function.
        epochs (int): Number of epochs to train.
        lr (float): Learning rate.
        device (str): Device to run the training on ('cuda' or 'cpu').

    Returns:
        nn.Module: The trained student model.
    """
    eps = 1e-4
    optimizer = Adam(student.parameters(), lr=lr)
    criterion = CrossEntropyLoss()  # Supervised loss for classification
    # latent_loss_fn = MSELoss()  # Loss for intermediate feature matching
    latent_loss_fn = L1Loss()  # Loss for intermediate feature matching
    kl_div_fn = KLDivLoss(reduction="batchmean")  # KL divergence for logits matching
    aggregation_strategy = UPGrad()  # Strategy for aggregating gradients

    teacher.eval()  # Teacher in evaluation mode
    student.to(device)
    teacher.to(device)

    # Variables to track best loss and save best model
    best_loss = float('inf')
    best_model_path = './best_student_model.pth'
    import os
    # Create the parent directory if it doesn't exist
    os.makedirs(os.path.dirname(best_model_path), exist_ok=True)

    for epoch in range(epochs):
        student.train()
        running_loss = 0.0

        for i, data in enumerate(tqdm(train_dataloader), 0):
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)

            # Generate teacher outputs
            with torch.no_grad():
                t_logits, [t1, t2, t3] = teacher(inputs)

            # Apply data augmentation (basically - adversarial attack)
            inputs, labels = aug_fn(inputs, labels)

            # Forward pass for the student
            s_logits, [s1, s2, s3] = student(inputs)

            # Compute individual losses
            classification_loss = criterion(s_logits, labels)
            intermediate_loss = (
                latent_loss_fn(s1, t1) +
                latent_loss_fn(s2, t2) +
                latent_loss_fn(s3, t3)
                # latent_loss_fn(s4, t4)
            )
            temperature = 3.0
            kl_loss = kl_div_fn(
                torch.log_softmax(s_logits / temperature, dim=1),
                torch.softmax(t_logits / temperature, dim=1)
            ) * (temperature ** 2)

            # Define the losses as separate tasks
            losses = [classification_loss, intermediate_loss, kl_loss]

            # Zero gradients
            optimizer.zero_grad()

            # Apply multi-task learning backward
            mtl_backward(
                losses=losses,
                features=s_logits,  # Feature space is shared logits
                tasks_params=[[]]*len(losses),  # Student parameters per task
                shared_params=student.parameters(),  # Shared parameters
                A=aggregation_strategy
            )

            # Optimizer step
            optimizer.step()

            # Accumulate loss for logging
            running_loss += sum(loss.item() for loss in losses)

        # Logging after each epoch
        avg_loss = running_loss / len(train_dataloader)
        print(f"Epoch {epoch + 1}/{epochs} | Loss: {avg_loss:.4f} | Classification Loss: {classification_loss:.4f} | Intermediate Loss: {intermediate_loss:.4f} | KL Loss: {kl_loss:.4f}")

        # Save the best model if current loss is lower
        if avg_loss < best_loss:
            best_loss = avg_loss
            torch.save(student.state_dict(), best_model_path)
            print(f"Best model saved with loss {best_loss:.4f} at epoch {epoch + 1}.")

        # Evaluate model
        acc, report = evaluate(student, test_dataloader)
        print(f"Epoch {epoch + 1} | Accuracy: {acc:.2f}")
        print(report)

    return student


## training of student

In [None]:
teacher.eval()
aug_fn = lambda x,y: attackit(x,y, teacher,0.2,0.4,0.4)
train_model_student_jd(student,teacher,train_dataloader,test_dataloader,aug_fn,3,0.0001)

# Inference

In [None]:
model = Model(encoder)
model.load_state_dict(torch.load('./best_student_model.pth'))

In [None]:
all_preds = []
all_labels = []
model.to(device)
for data in test_dataloader:
    inputs, labels = data
    inputs = inputs.to(device)
    labels = labels.to(device)
    outputs, _ = model(inputs)
    probs = F.softmax(outputs, dim=1)
    all_labels.extend(labels.cpu().numpy())
    preds = torch.argmax(probs, dim=1)
    all_preds.extend(preds.cpu().numpy())

In [None]:
# Initialize lists to store all predictions and labels
all_preds = []
all_labels = []

# Assuming your model is already on the device
model.to(device)

# Loop through the test dataloader
for data in test_dataloader:
    inputs, labels = data
    inputs = inputs.to(device)
    labels = labels.to(device)

    # Get model outputs
    outputs, _ = model(inputs)

    # Calculate probabilities
    probs = F.softmax(outputs, dim=1)

    # Convert labels and predictions to numpy arrays for evaluation
    all_labels.extend(labels.cpu().numpy())
    preds = torch.argmax(probs, dim=1)
    all_preds.extend(preds.cpu().numpy())

# Calculate metrics
accuracy = accuracy_score(all_labels, all_preds)
precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='weighted')
conf_matrix = confusion_matrix(all_labels, all_preds)

# Print the results
print(f"Accuracy: {accuracy:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1 Score: {f1:.4f}")
print("Confusion Matrix:")
print(conf_matrix)