In [2]:
import torch
import matplotlib.pyplot as plt
import numpy as np
from models_cifar100.resnet import ResNet18
# perform pruning on the model
import torch.nn.utils.prune as prune
import torch.nn.functional as F
import torch.nn as nn
from torchvision.datasets import CIFAR10
import torchvision.transforms as transforms
import torch
from torch.utils.data.dataloader import DataLoader
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
import torch.nn as nn
from models_cifar100.resnet import ResNet18

In [3]:
# device check
if torch.cuda.is_available():
    print("CUDA is available. Using GPU.")
    device = torch.device("cuda")
else:
    print("CUDA is not available. Using CPU.")
    device = torch.device("cpu")

CUDA is available. Using GPU.


In [4]:
# Load the CIFAR-10 dataset
## Normalization adapted for CIFAR10
normalize_scratch = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
# Transforms is a list of transformations applied on the 'raw' dataset before the data is fed to the network. 
# Here, Data augmentation (RandomCrop and Horizontal Flip) are applied to each batch, differently at each epoch, on the training set data only
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    normalize_scratch,
])
transform_test = transforms.Compose([
    transforms.ToTensor(),
    normalize_scratch,
])
### The data from CIFAR10 will be downloaded in the following folder
rootdir = './data/cifar10'
# Load the CIFAR-10 dataset
c10train = CIFAR10(rootdir,train=True,download=True,transform=transform_train)
c10test = CIFAR10(rootdir,train=False,download=True,transform=transform_test)
# Create DataLoaders
trainloader = DataLoader(c10train,batch_size=32,shuffle=True)
testloader = DataLoader(c10test,batch_size=32)

Files already downloaded and verified
Files already downloaded and verified


# disilition from resnet 101

In [7]:
# Load the teacher model
from torchvision.models import resnet101

# Load the ResNet-101 model
teacher_model = resnet101(pretrained=True)

Downloading: "https://download.pytorch.org/models/resnet101-63fe2227.pth" to C:\Users\cecil/.cache\torch\hub\checkpoints\resnet101-63fe2227.pth
100%|██████████| 171M/171M [00:13<00:00, 12.8MB/s] 


# load our prunned and quantizated resnet 

Reload the Pruned Model Correctly
- Reinitialize or recreate the model architecture.
- Reapply the pruning to the model exactly as you did before saving.
- Load the pruned state dictionary.

In [6]:
# Load model with metadata
checkpoint = torch.load('globale_pruned_0.2_retrained_50epochs.pth')
model_pruned = ResNet18()  
for module in model_pruned.modules():
    if isinstance(module, nn.Conv2d):
        prune.l1_unstructured(module, name='weight', amount=0.2)
model_pruned.load_state_dict(checkpoint['net'])
model_pruned.to(device)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (shortcut): Sequential()
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=

define a loss function that includes both the traditional classification loss and the distillation loss. The distillation loss is often computed using the Kullback-Leibler divergence between the softened outputs of the teacher and the student

T is the temperature for scaling the logits, and alpha is a factor that balances the importance of the distillation loss versus the classification loss.

In [None]:
import torch.nn.functional as F

def distillation_loss(y_student, y_teacher, y_true, T, alpha):
    loss_kl = F.kl_div(F.log_softmax(y_student/T, dim=1),
                       F.softmax(y_teacher/T, dim=1),
                       reduction='batchmean') * (T * T * alpha)
    loss_ce = F.cross_entropy(y_student, y_true) * (1 - alpha)
    return loss_kl + loss_ce


In [None]:
import torch.optim as optim
from torch.optim import Adam

learning_rate = 0.001
num_epochs = 10
T = 2.0
alpha = 0.5

# Optimizer setup
optimizer = Adam(model_pruned.parameters(), lr=learning_rate)

# Training loop
for epoch in range(num_epochs):
    model_pruned.train()  # Set the student model to training mode
    teacher_model.eval()  # Set the teacher model to eval mode
    loss = 0.0

    for i, (data, target) in enumerate(trainloader):
        data, target = data.to(device), target.to(device)
        
        # Teacher model inference
        with torch.no_grad():
            teacher_output = teacher_model(data)
        
        # Student model forward pass
        student_output = model_pruned(data)
        
        # Compute distillation loss
        loss = distillation_loss(student_output, teacher_output, target, T, alpha)
        
        # Backward pass and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # Print average loss per epoch
    print(f'Epoch {epoch + 1}, Loss: {loss / len(trainloader)}')
