<a href="https://colab.research.google.com/github/anushkaa-ambuj/Design-Credit-Project1/blob/main/Basic_KD_CIFAR10.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision import transforms
from torchvision.models import resnet18, resnet101

In [2]:
# define the device to use
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
# define the batch size
batch_size = 128

# define the temperature parameter
T = 5

# define the weighting factor for the soft targets loss
alpha = 0.5

In [4]:
# define the data transformations
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

In [5]:
# load the training data
train_dataset = CIFAR10(root='./data', train=True, transform=transform, download=True)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting ./data/cifar-10-python.tar.gz to ./data




In [6]:
# load the validation data
val_dataset = CIFAR10(root='./data', train=False, transform=transform, download=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

Files already downloaded and verified


In [7]:
# define the teacher model
teacher_model = resnet101(pretrained=True)
teacher_model.fc = nn.Linear(teacher_model.fc.in_features, 10)
teacher_model.to(device)
teacher_model.eval()

Downloading: "https://download.pytorch.org/models/resnet101-63fe2227.pth" to /root/.cache/torch/hub/checkpoints/resnet101-63fe2227.pth


  0%|          | 0.00/171M [00:00<?, ?B/s]

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [8]:
# define the student model
student_model = resnet18(pretrained=False)
student_model.fc = nn.Linear(student_model.fc.in_features, 10)
student_model.to(device)



ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [13]:
# define the loss function for the teacher
teacher_criterion = nn.CrossEntropyLoss().to(device)

# define the loss function for the student
student_criterion = nn.KLDivLoss(reduction='batchmean').to(device)

# define the optimizer for the student model
optimizer = optim.Adam(student_model.parameters(), lr=0.001)

In [16]:
# train the student model using knowledge distillation
for epoch in range(10):
    student_model.train()
    for inputs, labels in train_dataloader:
        optimizer.zero_grad()
        inputs = inputs.to(device)
        labels = labels.to(device)
        teacher_outputs = teacher_model(inputs).detach()
        teacher_probs = torch.softmax(teacher_outputs / T, dim=1)
        student_outputs = student_model(inputs)
        student_probs = torch.softmax(student_outputs / T, dim=1)
        loss = student_criterion(student_probs, teacher_probs) * T * T + teacher_criterion(student_outputs, labels) * (1. - alpha)
        loss.backward()
        optimizer.step()
    student_model.eval()
    with torch.no_grad():
        val_loss = 0
        val_acc = 0
        for inputs, labels in val_dataloader:
            inputs = inputs.to(device)
            labels = labels.to(device)
            outputs = student_model(inputs)
            val_loss += teacher_criterion(outputs, labels)
            val_acc += (outputs.argmax(dim=1) == labels).float().mean()
        val_loss /= len(val_dataloader)
        print(f'Epoch {epoch+1}: val_loss={val_loss:.4f}, val_acc={val_acc:.4f}')

Epoch 1: val_loss=1.0634, val_acc=50.4375
Epoch 2: val_loss=0.9424, val_acc=53.2656
Epoch 3: val_loss=0.8325, val_acc=56.2031
Epoch 4: val_loss=0.8988, val_acc=55.6562
Epoch 5: val_loss=0.8056, val_acc=57.3359
Epoch 6: val_loss=0.7722, val_acc=58.9609
Epoch 7: val_loss=0.8244, val_acc=58.4531
Epoch 8: val_loss=0.7595, val_acc=59.5156
Epoch 9: val_loss=0.8023, val_acc=58.3906
Epoch 10: val_loss=0.7193, val_acc=59.8906
