In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
import matplotlib.pyplot as plt

# Hyperparameters
batch_size = 128  # Updated batch size
learning_rate = 0.001  # Increased learning rate
epochs = 20  # Reduced epochs
lambda_attention = 0.5  # Weight for attention loss in distillation

# Data loading and preprocessing
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

full_train_data = datasets.CIFAR10(root='./data', train=True, transform=transform, download=True)
train_size = int(0.8 * len(full_train_data))
val_size = len(full_train_data) - train_size
train_data, val_data = random_split(full_train_data, [train_size, val_size])

train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=False)
test_data = datasets.CIFAR10(root='./data', train=False, transform=transform, download=True)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False)

# Teacher Model
class TeacherNet(nn.Module):
    def __init__(self):
        super(TeacherNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self._to_linear = None
        self.convs = nn.Sequential(self.conv1, nn.ReLU(), self.conv2, nn.ReLU(), self.conv3, nn.ReLU(), nn.MaxPool2d(2))
        self._get_flattened_size()
        self.fc1 = nn.Linear(self._to_linear, 512)
        self.fc2 = nn.Linear(512, 10)

    def _get_flattened_size(self):
        with torch.no_grad():
            x = torch.rand(1, 3, 32, 32)
            x = self.convs(x)
            self._to_linear = x.numel()

    def forward(self, x):
        x = self.convs(x)
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        return self.fc2(x)

# Student Model
class StudentNet(nn.Module):
    def __init__(self):
        super(StudentNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self._to_linear = None
        self.convs = nn.Sequential(self.conv1, nn.ReLU(), self.conv2, nn.ReLU(), nn.MaxPool2d(2))
        self._get_flattened_size()
        self.fc1 = nn.Linear(self._to_linear, 128)
        self.fc2 = nn.Linear(128, 10)

    def _get_flattened_size(self):
        with torch.no_grad():
            x = torch.rand(1, 3, 32, 32)
            x = self.convs(x)
            self._to_linear = x.numel()

    def forward(self, x):
        x = self.convs(x)
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        return self.fc2(x)

# Attention Loss Function
def attention_loss(teacher_feature, student_feature):
    teacher_attention = F.normalize(teacher_feature.pow(2).mean(1).view(teacher_feature.size(0), -1))
    student_attention = F.normalize(student_feature.pow(2).mean(1).view(student_feature.size(0), -1))
    return F.mse_loss(student_attention, teacher_attention)

# Train function
def train(model, loader, optimizer, criterion):
    model.train()
    total_loss, correct = 0, 0
    for images, labels in loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        correct += (outputs.argmax(dim=1) == labels).sum().item()
    return total_loss / len(loader), correct / len(loader.dataset)

# Evaluate function
def evaluate(model, loader, criterion):
    model.eval()
    total_loss, correct = 0, 0
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            total_loss += criterion(outputs, labels).item()
            correct += (outputs.argmax(dim=1) == labels).sum().item()
    return total_loss / len(loader), correct / len(loader.dataset)

# Count Parameters
def calculate_params(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

# CUDA setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize models
teacher = TeacherNet().to(device)
student = StudentNet().to(device)
teacher_optimizer = optim.Adam(teacher.parameters(), lr=learning_rate)
student_optimizer = optim.Adam(student.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()

# Track Metrics
teacher_val_acc_list, teacher_test_acc_list, teacher_train_acc_list = [], [], []
student_val_acc_list, student_test_acc_list, student_train_acc_list = [], [], []

# Train Teacher Model
print("Training Teacher Model...")
for epoch in range(epochs):
    train_loss, train_acc = train(teacher, train_loader, teacher_optimizer, criterion)
    val_loss, val_acc = evaluate(teacher, val_loader, criterion)
    test_loss, test_acc = evaluate(teacher, test_loader, criterion)

    teacher_train_acc_list.append(train_acc)
    teacher_val_acc_list.append(val_acc)
    teacher_test_acc_list.append(test_acc)

    print(f"Epoch {epoch+1}/{epochs}, Teacher Train Acc: {train_acc:.4f}, Val Acc: {val_acc:.4f}, Test Acc: {test_acc:.4f}")

# Save Teacher Model
torch.save(teacher.state_dict(), 'teacher_model.pth')

# Train Student Model with Attention Distillation
print("\nTraining Student Model...")
teacher.load_state_dict(torch.load('teacher_model.pth'))
teacher.eval()
best_val_acc = 0.0

for epoch in range(epochs):
    student.train()
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        with torch.no_grad():
            teacher_features = teacher.convs(images)
        student_features = student.convs(images)
        student_output = student(images)
        classification_loss = criterion(student_output, labels)
        att_loss = attention_loss(teacher_features, student_features)
        loss = classification_loss + lambda_attention * att_loss
        student_optimizer.zero_grad()
        loss.backward()
        student_optimizer.step()

    val_loss, val_acc = evaluate(student, val_loader, criterion)
    test_loss, test_acc = evaluate(student, test_loader, criterion)
    train_loss, train_acc = evaluate(student, train_loader, criterion)

    student_train_acc_list.append(train_acc)
    student_val_acc_list.append(val_acc)
    student_test_acc_list.append(test_acc)

    print(f"Epoch {epoch+1}/{epochs}, Student Train Acc: {train_acc:.4f}, Val Acc: {val_acc:.4f}, Test Acc: {test_acc:.4f}")

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(student.state_dict(), 'student_model.pth')

# Final Model Evaluations
teacher_params = calculate_params(teacher)
student_params = calculate_params(student)
compression_rate = student_params / teacher_params

print(f"\nTeacher Model Parameters: {teacher_params}")
print(f"Student Model Parameters: {student_params}")
print(f"Compression Rate: {compression_rate:.4f}")

# Plot Accuracy Graphs
def plot_accuracy(title, train_acc, val_acc, test_acc):
    plt.figure(figsize=(10, 6))
    plt.plot(range(1, epochs + 1), train_acc, label=f"{title} Train Accuracy")
    plt.plot(range(1, epochs + 1), val_acc, label=f"{title} Validation Accuracy")
    plt.plot(range(1, epochs + 1), test_acc, label=f"{title} Test Accuracy")
    plt.xlabel("Epochs")
    plt.ylabel("Accuracy")
    plt.title(f"{title} Model Accuracy Over Epochs")
    plt.legend()
    plt.grid(True)
    plt.show()

# Plot teacher and student accuracies
plot_accuracy("Teacher", teacher_train_acc_list, teacher_val_acc_list, teacher_test_acc_list)
plot_accuracy("Student", student_train_acc_list, student_val_acc_list, student_test_acc_list)

# Plot Graph Comparing Validation and Test Accuracy
def plot_validation_test_accuracy(title, val_acc, test_acc):
    plt.figure(figsize=(10, 6))
    plt.plot(range(1, epochs + 1), val_acc, label=f"{title} Validation Accuracy", linestyle='dashed')
    plt.plot(range(1, epochs + 1), test_acc, label=f"{title} Test Accuracy", linestyle='solid')
    plt.xlabel("Epochs")
    plt.ylabel("Accuracy")
    plt.title(f"{title}: Validation vs Test Accuracy Over Epochs")
    plt.legend()
    plt.grid(True)
    plt.show()

# Plot Teacher and Student Graphs
plot_validation_test_accuracy("Teacher Model", teacher_val_acc_list, teacher_test_acc_list, )
plot_validation_test_accuracy("Student Model", student_val_acc_list, student_test_acc_list)


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170M/170M [00:03<00:00, 51.4MB/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified
Training Teacher Model...
Epoch 1/20, Teacher Train Acc: 0.3855, Val Acc: 0.4825, Test Acc: 0.4796
Epoch 2/20, Teacher Train Acc: 0.5216, Val Acc: 0.5674, Test Acc: 0.5657
Epoch 3/20, Teacher Train Acc: 0.5946, Val Acc: 0.6221, Test Acc: 0.6250
Epoch 4/20, Teacher Train Acc: 0.6459, Val Acc: 0.6578, Test Acc: 0.6510
Epoch 5/20, Teacher Train Acc: 0.6746, Val Acc: 0.6877, Test Acc: 0.6832
Epoch 6/20, Teacher Train Acc: 0.7035, Val Acc: 0.7018, Test Acc: 0.7009
Epoch 7/20, Teacher Train Acc: 0.7217, Val Acc: 0.7087, Test Acc: 0.7073
Epoch 8/20, Teacher Train Acc: 0.7332, Val Acc: 0.7131, Test Acc: 0.7145
Epoch 9/20, Teacher Train Acc: 0.7456, Val Acc: 0.7333, Test Acc: 0.7179
Epoch 10/20, Teacher Train Acc: 0.7582, Val Acc: 0.7408, Test Acc: 0.7361
Epoch 11/20, Teacher Train Acc: 0.7662, Val Acc: 0.7436, Test Acc: 0.7332
Epoch 12/20, Teacher Train Acc: 0.7754, Val Acc: 0.7555, Test Acc: 0.7455
Epoch 

  teacher.load_state_dict(torch.load('teacher_model.pth'))


Epoch 1/20, Student Train Acc: 0.4978, Val Acc: 0.4993, Test Acc: 0.5003
Epoch 2/20, Student Train Acc: 0.5913, Val Acc: 0.5808, Test Acc: 0.5865
Epoch 3/20, Student Train Acc: 0.6265, Val Acc: 0.6213, Test Acc: 0.6178
Epoch 4/20, Student Train Acc: 0.6532, Val Acc: 0.6477, Test Acc: 0.6434
Epoch 5/20, Student Train Acc: 0.6710, Val Acc: 0.6610, Test Acc: 0.6590
Epoch 6/20, Student Train Acc: 0.6834, Val Acc: 0.6772, Test Acc: 0.6731
Epoch 7/20, Student Train Acc: 0.6994, Val Acc: 0.6851, Test Acc: 0.6870
Epoch 8/20, Student Train Acc: 0.7097, Val Acc: 0.6994, Test Acc: 0.6928
Epoch 9/20, Student Train Acc: 0.7136, Val Acc: 0.7014, Test Acc: 0.6927
Epoch 10/20, Student Train Acc: 0.7208, Val Acc: 0.7009, Test Acc: 0.7004
Epoch 11/20, Student Train Acc: 0.7310, Val Acc: 0.7140, Test Acc: 0.7037


In [None]:
# Calculate Parameter and Accuracy Reduction
teacher_accuracy = max(teacher_test_acc_list) * 100  # Convert to percentage
student_accuracy = max(student_test_acc_list) * 100  # Convert to percentage
param_reduction = ((teacher_params - student_params) / teacher_params) * 100
accuracy_reduction = teacher_accuracy - student_accuracy

# Plot Parameter and Accuracy Reduction
plt.figure(figsize=(8, 6))
plt.bar(['Parameter Reduction (%)', 'Accuracy Reduction (%)'], [param_reduction, accuracy_reduction], color=['skyblue', 'salmon'])
plt.ylabel('Percentage Reduction')
plt.title('Reduction in Parameters and Accuracy from Teacher to Student Model')
plt.show()

# Normalize parameters to millions for better visualization
teacher_params_normalized = teacher_params / 1e6
student_params_normalized = student_params / 1e6

# Prepare categories and values for comparison plot
categories = ['Teacher Parameters (M)', 'Teacher Accuracy (%)', 'Student Parameters (M)', 'Student Accuracy (%)']
values = [teacher_params_normalized, teacher_accuracy, student_params_normalized, student_accuracy]

# Plot Comparison of Parameters and Accuracy
plt.figure(figsize=(10, 6))
plt.bar(categories, values, color=['skyblue', 'blue', 'salmon', 'red'])
plt.ylabel('Values (Parameters in Millions, Accuracy in %)')
plt.title('Comparison of Parameters and Accuracy between Teacher and Student Models (Normalized)')
plt.show()
