In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.models as models
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

In [2]:
# Chargement des données CIFAR-10
train_dataset = datasets.CIFAR10(root='data', train=True, transform=transforms.ToTensor(), download=True)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to data\cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting data\cifar-10-python.tar.gz to data


In [None]:

#Lancer le script et voir l'utilisation GPU pour s'assurer qu'il est utilisé
if torch.cuda.is_available() : 
    device = torch.device("cuda") #A mettre partout 
else : 
    device = torch.device("cpu")

#Il faut ensuite mettre l'argument device = device / to_device à la création des tenseurs et pendant leur utilisatio

In [3]:
# Fonction de perte pour l'apprentissage de l'élève avec knowledge distillation
def distillation_loss(student_outputs, teacher_outputs, labels, temperature):
    soft_labels = nn.functional.softmax(teacher_outputs / temperature, dim=1)
    #Proposition de chatGPT, je suis quasiment sûr que la formule correpsond pas à celle de l'article, il faudra la modifier
    return nn.KLDivLoss()(nn.functional.log_softmax(outputs / temperature, dim=1), soft_labels) * (temperature ** 2)

In [8]:
#Entraînement du teacher_model à partir de poids pré-entraînés pour essayer de gagner du temps
teacher_model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
num_ftrs = teacher_model.fc.in_features
teacher_model.fc = nn.Linear(num_ftrs, 10)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(teacher_model.parameters(), lr=0.1, momentum=0.9)

Downloading: "https://download.pytorch.org/models/resnet50-11ad3fa6.pth" to C:\Users\prje/.cache\torch\hub\checkpoints\resnet50-11ad3fa6.pth


  0%|          | 0.00/97.8M [00:00<?, ?B/s]

In [None]:
loss_list = []
running_loss_list = []
for epoch in range(200):  # Entraînement sur 10 époques
    running_loss = 0.0
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data

        optimizer.zero_grad()

        outputs = teacher_model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        loss_list.append(loss)
        running_loss_list.append(running_loss)

    print('Epoch %d, loss: %.3f' %
          (epoch + 1, running_loss / len(trainloader)))

print('Finished Training')

In [None]:
plt.plot([i for i in range(len(loss_list))], loss_list)
plt.xlabel("Epoch")
plt.ylabel("loss")
plt.title("Evolution de la loss")

In [None]:
plt.plot([i for i in range(len(running_loss_list))], running_loss_list)
plt.xlabel("Epoch")
plt.ylabel("loss")
plt.title("Evolution de la running_loss")

In [None]:
#Initialisation du modèle élève from scratch
teacher_model = models.resnet18(weights=None)
num_ftrs = teacher_model.fc.in_features
teacher_model.fc = nn.Linear(num_ftrs, 10)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(teacher_model.parameters(), lr=0.1, momentum=0.9)

#Proposition de chatGPT pour l'optimizer
#optimizer = optim.Adam(student_model.parameters(), lr=lr)

In [None]:
# Paramètres d'apprentissage
num_epochs = 50
lr = 0.1
temperature = 4

In [None]:
# Entraînement du modèle élève avec knowledge distillation
student_loss = []
student_running_loss_list = []
for epoch in range(num_epochs):
    student_model.train()
    running_loss = 0.0
    
    for inputs, labels in train_loader:
        optimizer.zero_grad()
        teacher_outputs = teacher_model(inputs)
        student_outputs = student_model(inputs)
        loss = distillation_loss(student_outputs, teacher_outputs, labels, temperature) + criterion(student_outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item() * inputs.size(0)
        student_loss.append(loss)
        student_running_loss_list.append(running_loss)
    
    epoch_loss = running_loss / len(train_dataset)
    print('Epoch {}/{} loss: {:.4f}'.format(epoch+1, num_epochs, epoch_loss))# Entraînement du modèle élève avec knowledge distillation

In [None]:
plt.plot([i for i in range(len(student_loss))], student_loss)
plt.xlabel("Epoch")
plt.ylabel("loss")
plt.title("Evolution de la loss de l'étudiant")

In [None]:
plt.plot([i for i in range(len(student_running_loss_list))], student_running_loss_list)
plt.xlabel("Epoch")
plt.ylabel("loss")
plt.title("Evolution de la running_loss")