In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import datasets, models, transforms
from torch.utils.tensorboard import SummaryWriter
import torch.nn.functional as F
# Define the custom mentee model
class MenteeModel(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16*5*5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
    
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# Load pre-trained ResNet18 model
resnet18 = models.resnet18(pretrained=True)

# Freeze the initial layers of ResNet18
for param in resnet18.parameters():
    param.requires_grad = False

# Modify the last fully connected layer of ResNet18 to match the number of classes
num_ftrs = resnet18.fc.in_features
resnet18.fc = nn.Linear(num_ftrs, 10)  # Assuming 10 classes for example

# Initialize the mentee model
mentee_model = MenteeModel()

# Define transformations for data augmentation and normalization
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5])
])

# Load datasets for training and validation
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
val_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=64, shuffle=False)

# Define loss function and optimizer for mentor model
criterion_mentor = nn.CrossEntropyLoss()
optimizer_mentor = optim.SGD(resnet18.parameters(), lr=0.001, momentum=0.9)

# Define loss function and optimizer for mentee model
criterion_mentee = nn.CrossEntropyLoss()
optimizer_mentee = optim.SGD(mentee_model.parameters(), lr=0.001, momentum=0.9)

# Initialize TensorBoard writer
writer = SummaryWriter()

# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    running_loss_mentor = 0.0
    running_loss_mentee = 0.0
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data
        optimizer_mentor.zero_grad()
        optimizer_mentee.zero_grad()

        # Forward pass for mentor
        outputs_mentor = resnet18(inputs)
        loss_mentor = criterion_mentor(outputs_mentor, labels)

        # Forward pass for mentee
        outputs_mentee = mentee_model(inputs)
        loss_mentee = criterion_mentee(outputs_mentee, labels)

        # Knowledge distillation loss
        temperature = 5
        soft_outputs_mentor = nn.functional.softmax(outputs_mentor / temperature, dim=1)
        soft_outputs_mentee = nn.functional.softmax(outputs_mentee / temperature, dim=1)
        distillation_loss = nn.functional.kl_div(soft_outputs_mentor.log(), soft_outputs_mentee, reduction='batchmean')
        total_loss = loss_mentee + distillation_loss

        # Backward pass for mentor
        total_loss.backward(retain_graph=True)  # Retain graph for backward pass of mentee

        # Backward pass for mentee
        loss_mentee.backward()

        # Optimizer step for mentor
        optimizer_mentor.step()

        # Optimizer step for mentee
        optimizer_mentee.step()

        # Log loss to TensorBoard
       
        running_loss_mentor += loss_mentor.item()
        running_loss_mentee += loss_mentee.item()

        if i % 100 == 99:  # Print every 100 mini-batches
            writer.add_scalar('mentor_training_loss', loss_mentor.item(), epoch * len(train_loader) + i)
            writer.add_scalar('mentee_training_loss', loss_mentee.item(), epoch * len(train_loader) + i)

print('Finished Training')

# Save the trained models
torch.save(resnet18.state_dict(), 'resnet18_mentor.pth')
torch.save(mentee_model.state_dict(), 'mentee_model.pth')


  from .autonotebook import tqdm as notebook_tqdm


Files already downloaded and verified
Files already downloaded and verified
Finished Training
