In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.models as models
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import time
import os

In [None]:
print("La version de torch est : ",torch.__version__)
print("Le calcul GPU est disponible ? ", torch.cuda.is_available())
print ('Available devices ', torch.cuda.device_count())
device = torch.device("cuda:0")
print ('Current cuda device ', torch.cuda.current_device())


La version de torch est :  1.13.1+cu116
Le calcul GPU est disponible ?  True
Available devices  1
Current cuda device  0


In [None]:
# Chargement des données CIFAR-10
train_dataset = datasets.CIFAR10(root='data', train=True, transform=transforms.ToTensor(), download=True)
test_dataset = datasets.CIFAR10(root='data', train=False, transform=transforms.ToTensor(), download=True)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=True)


Files already downloaded and verified
Files already downloaded and verified


**Définition de la fonction de calcul de l'accuracy des modèles teacher**

In [None]:
def check_accuracy(loader, model, model_name):
    num_correct = 0
    num_samples = 0
    model.eval()
    
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device=device)
            y = y.to(device=device)
            
            scores = model(x)
            _, predictions = scores.max(1)
            num_correct += (predictions == y).sum()
            num_samples += predictions.size(0)
        
        print(model_name, f' Got {num_correct} / {num_samples} with accuracy {float(num_correct)/float(num_samples)*100:.2f}') 
    
    model.train()
    accuracy = 100*num_correct/num_samples
    return accuracy

Teacher  Got 8287 / 10000 with accuracy 82.87
Baseline student  Got 6939 / 10000 with accuracy 69.39


## **Entrainement des modèles teachers** (Full-KD)

**ResNet18**

In [None]:
## On charge le modèle sur GPU
## A faire avant la déclaration de l'optimiseur, sinon les paramètres optimisés ne seront pas les mêmes!
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

#Entraînement du teacher_model à partir de poids pré-entraînés pour essayer de gagner du temps

teacher_model_18 = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
num_ftrs = teacher_model_18.fc.in_features
teacher_model_18.fc = nn.Linear(num_ftrs, 10)
teacher_model_18 = teacher_model_18.to(device)
criterion = nn.CrossEntropyLoss()
optimizer_teacher = optim.SGD(teacher_model_18.parameters(), lr=0.1, momentum=0.9)
EPOCHS = 101

loss_list_18 = []
test_accuracy_18 = []

for epoch in range(EPOCHS):
    start = time.time()
    cumloss = 0
    for i,(inputs, labels) in enumerate(train_loader):
        ## On charge le batch sur GPU
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = teacher_model_18(inputs)
        loss = criterion(outputs, labels)


        optimizer_teacher.zero_grad()
        loss.backward()
        print(loss)
        optimizer_teacher.step()


        cumloss+= loss.item()
        loss_list_18.append(loss)

    #Perfs en test toutes les 10 epoch
    if epoch%10 == 0 :
      acc = check_accuracy(test_loader, teacher_model_18, "ResNet 18 acc. : ")
      test_accuracy_18.append(acc)
    
    #On enregistre un version ES des teacher
    if epoch == 50 : 
      torch.save(teacher_model_18.state_dict(), os.getcwd() + "ES_teacher_model_18")

    print('Epoch %d, loss: %.3f' %
          (epoch + 1, cumloss / len(train_loader)))
    print(f'epoch time : {time.time() - start}')

torch.save(teacher_model_18.state_dict(), os.getcwd()+ "teacher_model_18")
print("Modèle enregistré à : ", os.getcwd()+ "teacher_model_18")

**ResNet34**

In [None]:
## On charge le modèle sur GPU
## A faire avant la déclaration de l'optimiseur, sinon les paramètres optimisés ne seront pas les mêmes!
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

#Entraînement du teacher_model à partir de poids pré-entraînés pour essayer de gagner du temps

teacher_model_34 = models.resnet34(weights=models.ResNet34_Weights.DEFAULT)
num_ftrs = teacher_model_34.fc.in_features
teacher_model_34.fc = nn.Linear(num_ftrs, 10)
teacher_model_34 = teacher_model_34.to(device)
criterion = nn.CrossEntropyLoss()
optimizer_teacher = optim.SGD(teacher_model_34.parameters(), lr=0.1, momentum=0.9)
EPOCHS = 101

loss_list_34 = []
test_accuracy_34 = []

for epoch in range(EPOCHS):
    start = time.time()
    cumloss = 0
    for i,(inputs, labels) in enumerate(train_loader):
        ## On charge le batch sur GPU
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = teacher_model_34(inputs)
        loss = criterion(outputs, labels)


        optimizer_teacher.zero_grad()
        loss.backward()
        print(loss)
        optimizer_teacher.step()


        cumloss+= loss.item()
        loss_list_34.append(loss)
    #Perfs en test toutes les 10 epoch
    if epoch%10 == 0 :
      acc = check_accuracy(test_loader, teacher_model_34, "ResNet 34 acc. : ")
      test_accuracy_34.append(acc)
    
    #On enregistre un version ES des teacher
    if epoch == 50 : 
      torch.save(teacher_model_34.state_dict(), os.getcwd() + "ES_teacher_model_34")

    print('Epoch %d, loss: %.3f' %
          (epoch + 1, cumloss / len(train_loader)))
    print(f'epoch time : {time.time() - start}')

torch.save(teacher_model_34.state_dict(), os.getcwd()+ "teacher_model_34")
print("Modèle enregistré à : ", os.getcwd()+ "teacher_model_34")

**ResNet50**

In [None]:
## On charge le modèle sur GPU
## A faire avant la déclaration de l'optimiseur, sinon les paramètres optimisés ne seront pas les mêmes!
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

#Entraînement du teacher_model à partir de poids pré-entraînés pour essayer de gagner du temps

teacher_model_50 = models.resnet34(weights=models.ResNet50_Weights.DEFAULT)
num_ftrs = teacher_model_50.fc.in_features
teacher_model_50.fc = nn.Linear(num_ftrs, 10)
teacher_model_50 = teacher_model_50.to(device)
criterion = nn.CrossEntropyLoss()
optimizer_teacher = optim.SGD(teacher_model_50.parameters(), lr=0.1, momentum=0.9)
EPOCHS = 101

loss_list_50 = []
test_accuracy_50 = []

for epoch in range(EPOCHS):
    start = time.time()
    cumloss = 0
    for i,(inputs, labels) in enumerate(train_loader):
        ## On charge le batch sur GPU
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = teacher_model_50(inputs)
        loss = criterion(outputs, labels)


        optimizer_teacher.zero_grad()
        loss.backward()
        print(loss)
        optimizer_teacher.step()


        cumloss+= loss.item()
        loss_list_50.append(loss)
    #Perfs en test toutes les 10 epoch
    if epoch%10 == 0 :
      acc = check_accuracy(test_loader, teacher_model_50, "ResNet 50 acc. : ")
      test_accuracy_50.append(acc)
    
    #On enregistre un version ES des teacher
    if epoch == 50 : 
      torch.save(teacher_model_50.state_dict(), os.getcwd() + "ES_teacher_model_50")

    print('Epoch %d, loss: %.3f' %
          (epoch + 1, cumloss / len(train_loader)))
    print(f'epoch time : {time.time() - start}')

torch.save(teacher_model_50.state_dict(), os.getcwd() + "teacher_model_50")
print("Modèle enregistré à : ", os.getcwd()+ "teacher_model_50")

**ResNet152**

In [None]:
## On charge le modèle sur GPU
## A faire avant la déclaration de l'optimiseur, sinon les paramètres optimisés ne seront pas les mêmes!
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

#Entraînement du teacher_model à partir de poids pré-entraînés pour essayer de gagner du temps

teacher_model_152 = models.resnet152(weights=models.ResNet152_Weights.DEFAULT)
num_ftrs = teacher_model_152.fc.in_features
teacher_model_152.fc = nn.Linear(num_ftrs, 10)
teacher_model_152 = teacher_model_152.to(device)
criterion = nn.CrossEntropyLoss()
optimizer_teacher = optim.SGD(teacher_model_152.parameters(), lr=0.1, momentum=0.9)
EPOCHS = 101

loss_list_152 = []
test_accuracy_152 = []

for epoch in range(EPOCHS):
    start = time.time()
    cumloss = 0
    for i,(inputs, labels) in enumerate(train_loader):
        ## On charge le batch sur GPU
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = teacher_model_152(inputs)
        loss = criterion(outputs, labels)


        optimizer_teacher.zero_grad()
        loss.backward()
        print(loss)
        optimizer_teacher.step()

        cumloss+= loss.item()
        loss_list.append(loss)

    #Perfs en test toutes les 10 epoch
    if epoch%10 == 0 :
      acc = check_accuracy(test_loader, teacher_model_152, "ResNet 152 acc. : ")
      test_accuracy_152.append(acc)
    
    #On enregistre un version ES des teacher
    if epoch == 50 : 
      torch.save(teacher_model_152.state_dict(), os.getcwd() + "ES_teacher_model_152")


    print('Epoch %d, loss: %.3f' %
          (epoch + 1, cumloss / len(train_loader)))
    print(f'epoch time : {time.time() - start}')

torch.save(teacher_model_152.state_dict(), os.getcwd() + "teacher_model_152")
print("Modèle enregistré à : ", os.getcwd() + "teacher_model_152")

**Comparaison des performances en test pour les versions full trained et early stopped des modèles teacher**

In [5]:
ES_teacher_model_18 = models.resnet18()
ES_teacher_model_18.load_state_dict(torch.load('teacher_model_18'))
ES_teacher_model_18.eval()

ES_teacher_model_34 = models.resnet34()
ES_teacher_model_34.load_state_dict(torch.load('teacher_model_34'))
ES_teacher_model_34.eval()

ES_teacher_model_50 = models.resnet50()
ES_teacher_model_50.load_state_dict(torch.load('teacher_model_50'))
ES_teacher_model_50.eval()

ES_teacher_model_152 = models.resnet152()
ES_teacher_model_152.load_state_dict(torch.load('teacher_model_152'))
ES_teacher_model_152.eval()

teacher_model_18.eval()
teacher_model_34.eval()
teacher_model_50.eval()
teacher_model_152.eval()

check_accuracy(test_loader, teacher_model_18, "ResNet 18 acc. : ")
check_accuracy(test_loader, ES_teacher_model_18, "ES ResNet 18 acc. : ")
check_accuracy(test_loader, teacher_model_34, "ResNet 34 acc. : ")
check_accuracy(test_loader, ES_teacher_model_34, "ES ResNet 34 acc. : ")
check_accuracy(test_loader, teacher_model_50, "ResNet 50 acc. : ")
check_accuracy(test_loader, ES_teacher_model_50, "ES ResNet 50 acc. : ")
check_accuracy(test_loader, teacher_model_152, "ResNet 152 acc. : ")
check_accuracy(test_loader, ES_teacher_model_152, "ES ResNet 152 acc. : ")

FileNotFoundError: ignored

**Définition de la fonction de loss pour la distillation**

In [29]:
# Fonction de perte pour l'apprentissage de l'élève avec knowledge distillation
def distillation_loss(student_outputs, teacher_outputs, labels, temperature):
    soft_teacher_labels = nn.functional.softmax(teacher_outputs / temperature, dim=1)
    log_soft_student_labels = nn.functional.log_softmax(student_outputs / temperature, dim=1)
    
    loss = -(temperature**2)*torch.sum(torch.mul(soft_teacher_labels,log_soft_student_labels), dim=-1)

    return loss 

**Entrainement des students à partir des différents teachers**

In [None]:
teachers = [teacher_model_18, ES_teacher_model_18, teacher_model_34, ES_teacher_model_34, teacher_model_50, ES_teacher_model_50, teacher_model_152, ES_teacher_model_152]
teachers_name = ["teacher_model_18", "ES_teacher_model_18", "teacher_model_34", "ES_teacher_model_34", "teacher_model_50", "ES_teacher_model_50", "teacher_model_152", "ES_teacher_model_152"]

In [31]:
list_res = []
for ES_KD in [0,1] :
  for teacher_model, teacher_name in zip(teachers, teachers_name) :

      # Paramètres d'apprentissage
      num_epochs = 101
      alpha = 0.9
      temperature = 4 
      learning_rate = 0.0001

      #Initialisation du modèle élève from scratch
      student_model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
      num_ftrs = student_model.fc.in_features
      student_model = student_model.to(device)
      student_model.fc = nn.Linear(num_ftrs, 10)
      criterion = nn.CrossEntropyLoss()
      optimizer_student = optim.SGD(student_model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=0.2)

      #Proposition de chatGPT pour l'optimizer
      train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)

      #Apparemmant obligé de remettre .to(device) aux modèles à chaque box
      student_model = student_model.to(device)
      teacher_model = teacher_model.to(device)

      loss_list = []
      test_accuracy = []
      for epoch in range(num_epochs):
          cumloss = 0
          #Early stop de la KD 
          if ES_KD == 1 : 
            if epoch >= 35 : 
              alpha = 1

          for i,(inputs, labels) in enumerate(train_loader):
              ## On charge le batch sur GPU
              inputs, labels = inputs.to(device), labels.to(device)
              teacher_outputs = teacher_model(inputs)
              student_outputs = student_model(inputs)

              loss = (1-alpha)*distillation_loss(student_outputs, teacher_outputs, labels, temperature) + alpha*criterion(student_outputs, labels) 

              loss_total = loss.mean()

              optimizer_student.zero_grad()
              loss_total.backward()
              optimizer_student.step()

              cumloss+= loss_total.item()
              loss_list.append(loss_total)

          #Perfs en test toutes les 10 epoch
          if epoch%10 == 0 :
            acc = check_accuracy(test_loader, student_model, "{0}/{1} acc. at epoch {2}: ".format(student_model, teacher_model, epoch))
            test_accuracy.append(acc)

          print('Epoch %d, loss: %.3f' %
                (epoch + 1, cumloss / len(train_loader)))
          print(f'epoch time : {time.time() - start}')
      
      res_tuple = (teacher_name, KD, loss_list, test_accuracy)
      list_res.append(res_tuple)
      print(res_tuple)
      print(len(list_res))

Device :  cuda
Epoch 1, loss: 4.155
epoch time : 41.449867248535156
Epoch 2, loss: 4.145
epoch time : 40.37718915939331
Epoch 3, loss: 4.146
epoch time : 39.928427934646606
Epoch 4, loss: 4.145
epoch time : 40.146037101745605
Epoch 5, loss: 4.145
epoch time : 40.56942367553711
Epoch 6, loss: 4.145
epoch time : 40.25503921508789
Epoch 7, loss: 4.146
epoch time : 40.22589635848999
Epoch 8, loss: 4.145
epoch time : 40.182926416397095
Epoch 9, loss: 4.145
epoch time : 40.207844734191895
Epoch 10, loss: 4.146
epoch time : 40.42669987678528
{3: {0.1: [4.355504989624023, 9.168440818786621, 13.844656944274902, 18.371274948120117, 22.561168670654297, 27.206316471099854, 32.28139638900757, 36.602384090423584, 41.11565351486206, 45.466890811920166, 49.861536502838135, 54.013904094696045, 58.193625926971436, 62.34687280654907, 66.49356698989868, 70.70395803451538, 75.00603771209717, 79.39619493484497, 83.5490870475769, 87.7432451248169, 91.84557247161865, 95.9303092956543, 99.95423316955566, 104.2