# **Knowledge Distillation**

Francesco Marrocco

Knowldege Distillation (KD) è una tecnica utilizzata con lo scopo di rendere Deep Neural Networks più leggere e pratiche da addestrare senza perdere troppa precisione. L'idea di fondo è di permettere a un modello più piccolo di addestrarsi non solo sui dati del training set ma anche su un modello più grande già addestrato.$^1$

$1)$ In realtà esistono vari schemi di distillazione per rete neurale e non tutti richedono la presenza di un modello grande e uno piccolo per l'addestramento. In particolare si possono distinguere tre possibilità:


1.   *Offline Distillation* quella che segue la prima spiegazione fornita di KD, ovvero un modello profondo già addestrato e uno più leggero che si aggiorna per minimizzare la seguente loss
2.   *Online Distillation*
3.   *Self Distillation*



In [1]:
trained_models = True

### **Import dependencies and helper functions**

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.preprocessing import StandardScaler
from torch.utils.data import DataLoader, TensorDataset
import torch.nn.functional as F
import zipfile
import os
from google.colab import files
import glob

In [3]:
def del_all():
    # Get a list of all .pth files in the current directory
    pth_files = glob.glob('*.pth')

    # Loop through the list and delete each file
    for file in pth_files:
        try:
            os.remove(file)
            print(f'Deleted: {file}')
        except FileNotFoundError:
            print(f'File not found: {file}')

### **Teacher: Non-Neural Network**

In [4]:
torch.manual_seed(0)

<torch._C.Generator at 0x7849fc1a8530>

In [5]:
from sklearn import datasets

# 1. Train the Non-Neural Teacher Model
iris = datasets.load_iris()
X_train, X_test, y_train, y_test = train_test_split(iris.data, iris.target, test_size=0.2, random_state=0)

scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

teacher_model1 = DecisionTreeClassifier()
teacher_model1.fit(X_train, y_train)

# Evaluate teacher model accuracy
teacher_accuracy = 100 * teacher_model1.score(X_test, y_test)
print(f'Teacher Model Accuracy: {teacher_accuracy:.2f}%')

Teacher Model Accuracy: 100.00%


In [6]:
# 2. Extract Teacher Predictions
teacher_logits_train = teacher_model1.predict_proba(X_train)
teacher_logits_test = teacher_model1.predict_proba(X_test)

# Convert to PyTorch tensors
X_train_tensor = torch.tensor(X_train, dtype=torch.float32)
y_train_tensor = torch.tensor(y_train, dtype=torch.long)
teacher_logits_train_tensor = torch.tensor(teacher_logits_train, dtype=torch.float32)

train_dataset = TensorDataset(X_train_tensor, y_train_tensor, teacher_logits_train_tensor)
train_loader_iris = DataLoader(train_dataset, batch_size=16, shuffle=True)

In [7]:
# 3. Define the Student Neural Network
class StudentNet(nn.Module):
    def __init__(self):
        super(StudentNet, self).__init__()
        self.fc1 = nn.Linear(4, 10)
        self.fc2 = nn.Linear(10, 3)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

student_model_no_kd_non_neural_0 = StudentNet()
student_model_with_kd_non_neural_0 = StudentNet()

In [8]:
# 4. Distillation Loss Function
def distillation_loss(student_logits, teacher_logits, targets, alpha=0.5, temperature=2.0):
    soft_targets = nn.functional.softmax(teacher_logits / temperature, dim=1)
    soft_student_logits = nn.functional.log_softmax(student_logits / temperature, dim=1)
    distillation_loss = nn.functional.kl_div(soft_student_logits, soft_targets, reduction='batchmean') * (temperature ** 2)
    hard_loss = nn.functional.cross_entropy(student_logits, targets)
    return alpha * distillation_loss + (1 - alpha) * hard_loss

# Helper function to train the student model
def train_student_model1(model, use_kd=False, num_epochs=50):
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for inputs, targets, teacher_logits in train_loader_iris:
            optimizer.zero_grad()
            student_logits = model(inputs)
            if use_kd:
                loss = distillation_loss(student_logits, teacher_logits, targets)
            else:
                loss = nn.functional.cross_entropy(student_logits, targets)
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * inputs.size(0)
        epoch_loss = running_loss / len(train_loader_iris.dataset)

In [9]:
def test1(student_model_no_kd, X_test, y_test):
    with torch.no_grad():
        X_test_tensor = torch.tensor(X_test, dtype=torch.float32)
        student_logits_test_no_kd = student_model_no_kd(X_test_tensor)
        student_predictions_no_kd = torch.argmax(student_logits_test_no_kd, dim=1).numpy()
        student_accuracy_no_kd = 100 * (student_predictions_no_kd == y_test).mean()
    return student_accuracy_no_kd

In [10]:
# 5. Train the Student Model without KD
train_student_model1(student_model_no_kd_non_neural_0, use_kd=False)

student_accuracy_no_kd_0 = test1(student_model_no_kd_non_neural_0, X_test, y_test)

torch.save(student_model_no_kd_non_neural_0.state_dict(), 'student_model_no_kd_non_neural_0.pth')

print(f"Student Model Accuracy without KD: {student_accuracy_no_kd_0:.2f}%")

Student Model Accuracy without KD: 76.67%


In [11]:
# 6. Train the Student Model with KD
train_student_model1(student_model_with_kd_non_neural_0, use_kd=True, num_epochs=50)

student_accuracy_with_kd_0 = test1(student_model_with_kd_non_neural_0, X_test, y_test)

torch.save(student_model_with_kd_non_neural_0.state_dict(), 'student_model_with_kd_non_neural_0.pth')

print(f'Student Model Accuracy with KD: {student_accuracy_with_kd_0:.2f}%')

Student Model Accuracy with KD: 76.67%


In [12]:
student_no_kd = 0
student_with_kd = 0

num_iterations = 101

if not trained_models:

    for _ in range(num_iterations):
        torch.manual_seed(_)
        student_model_no_kd = StudentNet()
        student_model_with_kd = StudentNet()
        X_train, X_test, y_train, y_test = train_test_split(iris.data, iris.target, test_size=0.2, random_state=_)
        X_train = scaler.fit_transform(X_train)
        X_test = scaler.transform(X_test)

        train_student_model1(student_model_no_kd, use_kd=False)
        student_accuracy_no_kd = test1(student_model_no_kd, X_test, y_test)
        student_no_kd += student_accuracy_no_kd
        torch.save(student_model_no_kd.state_dict(), f'student_model_no_kd_non_neural_{_}.pth')

        train_student_model1(student_model_with_kd, use_kd=True)
        student_accuracy_with_kd = test1(student_model_with_kd, X_test, y_test)
        student_with_kd += student_accuracy_with_kd
        torch.save(student_model_with_kd.state_dict(), f'student_model_with_kd_non_neural_{_}.pth')

        with zipfile.ZipFile('first_MLP_models.zip', 'a') as zipf:
            zipf.write(f'student_model_no_kd_non_neural_{_}.pth')
            zipf.write(f'student_model_with_kd_non_neural_{_}.pth')

    with zipfile.ZipFile('non_neural_models.zip', 'w') as zipf:
        for i in range(num_iterations):
            zipf.write(f'student_model_no_kd_non_neural_{i}.pth')
            zipf.write(f'student_model_with_kd_non_neural_{i}.pth')

    # Download the zip file
    files.download('non_neural_models.zip')

else:

    with zipfile.ZipFile('non_neural_models.zip', 'r') as zip_ref:
        zip_ref.extractall()
    for i in range(num_iterations):
        torch.manual_seed(i)
        X_train, X_test, y_train, y_test = train_test_split(iris.data, iris.target, test_size=0.2, random_state=i)
        X_train = scaler.fit_transform(X_train)
        X_test = scaler.transform(X_test)

        student_model_no_kd = StudentNet()
        student_model_no_kd.load_state_dict(torch.load(f'student_model_no_kd_non_neural_{i}.pth'))
        acc1 = test1(student_model_no_kd, X_test, y_test)
        student_no_kd += acc1

        student_model_with_kd = StudentNet()
        student_model_with_kd.load_state_dict(torch.load(f'student_model_with_kd_non_neural_{i}.pth'))
        acc2 = test1(student_model_with_kd, X_test, y_test)
        student_with_kd += acc2

student_no_kd /= num_iterations
student_with_kd /= num_iterations
print(f'Student Model without KD Accuracy: {student_no_kd:.2f}%')
print(f'Student Model with KD Accuracy: {student_with_kd:.2f}%')

Student Model without KD Accuracy: 85.28%
Student Model with KD Accuracy: 86.50%


A questo punto posso eliminare le reti neurali salvate (o anche no se si vuole continuare a giocarci un po' :) )

In [13]:
# Delete the files
for i in range(num_iterations):
    try:
        os.remove(f'student_model_no_kd_non_neural_{i}.pth')
    except FileNotFoundError:
        pass

    try:
        os.remove(f'student_model_with_kd_non_neural_{i}.pth')
    except FileNotFoundError:
        pass

### **Teacher: Small MLP**

In [14]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cpu


In [113]:
from torchvision import datasets, transforms

# Define transformations for the dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Load the FashionMNIST dataset
train_dataset = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=1000, shuffle=False)

In [16]:
class TeacherModel(nn.Module):
    def __init__(self):
        super(TeacherModel, self).__init__()
        self.fc1 = nn.Linear(28*28, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, 128)
        self.fc4 = nn.Linear(128, 64)
        self.fc5 = nn.Linear(64, 10)

    def forward(self, x):
        x = x.view(-1, 28*28)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = F.relu(self.fc4(x))
        x = self.fc5(x)
        return x

class StudentModel(nn.Module):
    def __init__(self):
        super(StudentModel, self).__init__()
        self.fc1 = nn.Linear(28*28, 32)
        self.fc2 = nn.Linear(32, 10)

    def forward(self, x):
        x = x.view(-1, 28*28)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [17]:
def train(model, epochs):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    model.train()

    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in train_loader:
            # inputs: A collection of batch_size images
            # labels: A vector of dimensionality batch_size with integers denoting class of each image
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)

            # outputs: Output of the network for the collection of images. A tensor of dimensionality batch_size x num_classes
            # labels: The actual labels of the images. Vector of dimensionality batch_size
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")

    return model

In [18]:
def test_model(model, test_loader):
    model.eval()
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    accuracy = 100 * correct / len(test_loader.dataset)
    print(f"Test Accuracy: {accuracy:.2f}%")
    return accuracy

In [19]:
if not trained_models:
    torch.manual_seed(42)
    teacher_model = TeacherModel().to(device)
    teacher_model = train(teacher_model, epochs = 10)
    teacher_accuracy = test_model(teacher_model, test_loader)
    torch.save(teacher_model.state_dict(), f'teacher_model_1_NN_42.pth')
    with zipfile.ZipFile('first_MLP_models.zip', 'a') as zipf:
        zipf.write(f'teacher_model_1_NN_42.pth')
else:
    teacher_model = TeacherModel().to(device)
    # Open the zip file
    with zipfile.ZipFile('first_MLP_models.zip', 'r') as zip_ref:
        # Extract the specified file
        zip_ref.extract('teacher_model_1_NN_42.pth')
    teacher_model.load_state_dict(torch.load(f'teacher_model_1_NN_42.pth'))
    teacher_accuracy = test_model(teacher_model, test_loader)

Test Accuracy: 87.28%


In [20]:
if not trained_models:
    torch.manual_seed(42)
    student_model = StudentModel().to(device)
    student_model = train(student_model, epochs = 10)
    student_accuracy = test_model(student_model, test_loader)
    torch.save(student_model.state_dict(), f'smaller_model_1_NN_42.pth')
    with zipfile.ZipFile('first_MLP_models.zip', 'a') as zipf:
        zipf.write(f'smaller_model_1_NN_42.pth')


else:
    student_model = StudentModel().to(device)
    # Open the zip file
    with zipfile.ZipFile('first_MLP_models.zip', 'r') as zip_ref:
        # Extract the specified file
        zip_ref.extract('smaller_model_1_NN_42.pth')
    student_model.load_state_dict(torch.load(f'smaller_model_1_NN_42.pth'))
    student_accuracy = test_model(student_model, test_loader)

Test Accuracy: 86.00%


In [21]:
def train_student_model(teacher, student, train_loader, epochs = 10, alpha=0.5, T=2.0):
    ce_loss = nn.CrossEntropyLoss()
    optimizer = optim.Adam(student.parameters(), lr=0.001)

    teacher.eval()  # Teacher set to evaluation mode
    student.train()  # Student to train mode

    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()

            # Forward pass with the teacher model - do not save gradients here as we do not change the teacher's weights
            with torch.no_grad():
                teacher_logits = teacher(inputs)

            # Forward pass with the student model
            student_logits = student(inputs)

            # Soften the teacher logits by applying softmax first and log() second
            soft_targets = nn.functional.softmax(teacher_logits / T, dim=-1)
            soft_prob = nn.functional.log_softmax(student_logits / T, dim=-1)

            # Calculate the soft targets loss. Scaled by T**2 as suggested by the authors of the paper "Distilling the knowledge in a neural network"
            soft_targets_loss = nn.KLDivLoss()(soft_prob, soft_targets) * (T**2)

            # Calculate the true label loss
            label_loss = ce_loss(student_logits, labels)

            # Weighted sum of the two losses
            loss = alpha * soft_targets_loss + (1 - alpha) * label_loss

            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f"Epoch {epoch + 1}/{epochs}, Loss: {running_loss / len(train_loader)}")

    return student

https://pytorch.org/tutorials/beginner/knowledge_distillation_tutorial.html

In [22]:
if not trained_models:
    torch.manual_seed(42)
    student_model_distillated = StudentModel().to(device)
    student_model_distillated = train_student_model(teacher_model, student_model_distillated, train_loader)
    student_distillated_accuracy = test_model(student_model_distillated, test_loader)
    torch.save(student_model_distillated.state_dict(), f'student_model_1_NN_42.pth')
    with zipfile.ZipFile('first_MLP_models.zip', 'a') as zipf:
        zipf.write(f'student_model_1_NN_42.pth')
else:
    student_model_distillated = StudentModel().to(device)
    # Open the zip file
    with zipfile.ZipFile('first_MLP_models.zip', 'r') as zip_ref:
        # Extract the specified file
        zip_ref.extract('student_model_1_NN_42.pth')
    student_model_distillated.load_state_dict(torch.load(f'student_model_1_NN_42.pth'))
    student_distillated_accuracy = test_model(student_model_distillated, test_loader)

Test Accuracy: 85.95%


In [23]:
if not trained_models: files.download('first_MLP_models.zip')

In [24]:
total_params_deep = "{:,}".format(sum(p.numel() for p in teacher_model.parameters()))
print(f"Teacher NN: {total_params_deep}")
total_params_light = "{:,}".format(sum(p.numel() for p in student_model.parameters()))
print(f"Smaller NN: {total_params_light}")
total_params_dist = "{:,}".format(sum(p.numel() for p in student_model_distillated.parameters()))
print(f"Smaller NN distillated: {total_params_dist}")

Teacher NN: 575,050
Smaller NN: 25,450
Smaller NN distillated: 25,450


In [25]:
print(f"Teacher NN accuracy: {teacher_accuracy:.2f}%")
print(f"Smaller NN accuracy: {student_accuracy:.2f}%")
print(f"Smaller NN distillated accuracy: {student_distillated_accuracy:.2f}%")

Teacher NN accuracy: 87.28%
Smaller NN accuracy: 86.00%
Smaller NN distillated accuracy: 85.95%


In [26]:
#@title *Finding good $\alpha$ and  $T$*

# Initialize accumulators for total accuracy
total_acc = 0
total_acc_dist = 0
student_precision = []
# Define number of iterations
num_iterations = 5

# Experiment with different alpha and temperature values
alpha_values = [0.25, 0.5, 0.75]
temperature_values = [1.0, 2.0, 3.0]

if not trained_models:
    for alpha in alpha_values:
        for temperature in temperature_values:

            for i in range(num_iterations):
                # Reset the random seed for reproducibility
                torch.manual_seed(i)

                # Initialize and train the student model using distillation
                student_model_distillated = StudentModel().to(device)
                student_model_distillated = train_student_model(teacher_model, student_model_distillated, train_loader, epochs = 10, alpha=alpha, T=temperature)
                torch.save(student_model_distillated.state_dict(), f'model_{alpha}_{temperature}_{i}.pth')
                with zipfile.ZipFile('finding_a_T.zip', 'a') as zipf:
                    zipf.write(f'model_{alpha}_{temperature}_{i}.pth')
                # Test the student model and accumulate accuracy
                student_distillated_accuracy = test_model(student_model_distillated, test_loader)
                student_precision.append([student_distillated_accuracy, ("alpha",alpha), ("temperature", temperature)])
else:
    for alpha in alpha_values:
        for temperature in temperature_values:
            for i in range(num_iterations):
                torch.manual_seed(i)
                student_model_distillated = StudentModel().to(device)
                with zipfile.ZipFile('finding_a_T.zip', 'r') as zip_ref:
                    # Extract the specified file
                    zip_ref.extract(f'model_{alpha}_{temperature}_{i}.pth')
                student_model_distillated.load_state_dict(torch.load(f'model_{alpha}_{temperature}_{i}.pth'))
                student_distillated_accuracy = test_model(student_model_distillated, test_loader)
                student_precision.append([student_distillated_accuracy, ("alpha",alpha), ("temperature", temperature)])

if not trained_models: files.download('finding_a_T.zip')
print(student_precision)

Test Accuracy: 87.18%
Test Accuracy: 86.66%
Test Accuracy: 87.03%
Test Accuracy: 86.54%
Test Accuracy: 85.67%
Test Accuracy: 86.88%
Test Accuracy: 86.74%
Test Accuracy: 86.80%
Test Accuracy: 86.47%
Test Accuracy: 85.13%
Test Accuracy: 87.05%
Test Accuracy: 86.64%
Test Accuracy: 87.05%
Test Accuracy: 86.23%
Test Accuracy: 85.92%
Test Accuracy: 86.88%
Test Accuracy: 86.94%
Test Accuracy: 87.15%
Test Accuracy: 86.49%
Test Accuracy: 85.44%
Test Accuracy: 87.06%
Test Accuracy: 86.55%
Test Accuracy: 86.80%
Test Accuracy: 86.15%
Test Accuracy: 86.09%
Test Accuracy: 86.77%
Test Accuracy: 86.41%
Test Accuracy: 86.49%
Test Accuracy: 86.26%
Test Accuracy: 85.79%
Test Accuracy: 87.28%
Test Accuracy: 87.05%
Test Accuracy: 87.13%
Test Accuracy: 85.69%
Test Accuracy: 85.79%
Test Accuracy: 87.33%
Test Accuracy: 86.84%
Test Accuracy: 86.64%
Test Accuracy: 86.13%
Test Accuracy: 86.50%
Test Accuracy: 86.98%
Test Accuracy: 86.48%
Test Accuracy: 86.73%
Test Accuracy: 86.37%
Test Accuracy: 86.63%
[[87.18, (

In [27]:
# Initialize accumulators for total accuracy
total_acc = 0
# Define number of iterations
num_iterations = 5

if not trained_models:

    for i in range(num_iterations):
        # Reset the random seed for reproducibility
        torch.manual_seed(i)

        # Initialize and train the student model using distillation
        student_model = StudentModel().to(device)
        student_model = train(student_model, epochs = 10)
        torch.save(student_model.state_dict(), f'model_{0.0}_{1.0}_{i}.pth')
        with zipfile.ZipFile('finding_a_T.zip', 'a') as zipf:
            zipf.write(f'model_{0.0}_{1.0}_{i}.pth')
        # Test the student model and accumulate accuracy
        student_accuracy = test_model(student_model, test_loader)
else:

    for i in range(num_iterations):
        torch.manual_seed(i)
        student = StudentModel().to(device)
        with zipfile.ZipFile('finding_a_T.zip', 'r') as zip_ref:
            # Extract the specified file
            zip_ref.extract(f'model_{0.0}_{1.0}_{i}.pth')
        student.load_state_dict(torch.load(f'model_{0.0}_{1.0}_{i}.pth'))
        student_accuracy = test_model(student, test_loader)
        total_acc+=student_accuracy
    print(total_acc)

if not trained_models: files.download('finding_a_T.zip')

Test Accuracy: 87.06%
Test Accuracy: 86.79%
Test Accuracy: 86.78%
Test Accuracy: 86.43%
Test Accuracy: 85.34%
432.4


In [28]:
res = []
for i in range(len(student_precision)):
    if i%5==0:
        res.append(student_precision[i])
    else:
        res[i//5][0] += student_precision[i][0]

print(res)

[[433.08000000000004, ('alpha', 0.25), ('temperature', 1.0)], [432.02, ('alpha', 0.25), ('temperature', 2.0)], [432.89000000000004, ('alpha', 0.25), ('temperature', 3.0)], [432.90000000000003, ('alpha', 0.5), ('temperature', 1.0)], [432.6500000000001, ('alpha', 0.5), ('temperature', 2.0)], [431.72, ('alpha', 0.5), ('temperature', 3.0)], [432.94, ('alpha', 0.75), ('temperature', 1.0)], [433.44, ('alpha', 0.75), ('temperature', 2.0)], [433.19, ('alpha', 0.75), ('temperature', 3.0)]]


In [29]:
res = sorted(res, key=lambda x: x[0], reverse = True)
best_accuracy = f'{res[0][0]/5:.2f}%'
best_alpha = res[0][1][1]
best_temperature = res[0][2][1]
print(f'Best accuracy: {best_accuracy} for alpha = {best_alpha} and T = {best_temperature}')
print(f'Same dimension NN on same dataset accuracy = {total_acc/5:.2f}%')

Best accuracy: 86.69% for alpha = 0.75 and T = 2.0
Same dimension NN on same dataset accuracy = 86.48%


In [30]:
del_all()

Deleted: model_0.25_1.0_3.pth
Deleted: model_0.25_2.0_4.pth
Deleted: model_0.5_1.0_0.pth
Deleted: model_0.75_3.0_0.pth
Deleted: model_0.25_1.0_4.pth
Deleted: model_0.0_1.0_0.pth
Deleted: model_0.75_1.0_0.pth
Deleted: model_0.25_3.0_2.pth
Deleted: model_0.25_3.0_1.pth
Deleted: model_0.5_2.0_4.pth
Deleted: model_0.0_1.0_3.pth
Deleted: model_0.75_3.0_2.pth
Deleted: model_0.5_3.0_3.pth
Deleted: model_0.0_1.0_4.pth
Deleted: teacher_model_1_NN_42.pth
Deleted: model_0.75_3.0_3.pth
Deleted: model_0.75_2.0_0.pth
Deleted: student_model_1_NN_42.pth
Deleted: model_0.5_3.0_0.pth
Deleted: model_0.75_1.0_4.pth
Deleted: model_0.5_2.0_3.pth
Deleted: model_0.25_2.0_3.pth
Deleted: model_0.25_1.0_0.pth
Deleted: model_0.25_1.0_1.pth
Deleted: model_0.25_1.0_2.pth
Deleted: model_0.5_3.0_4.pth
Deleted: model_0.0_1.0_2.pth
Deleted: model_0.5_1.0_2.pth
Deleted: model_0.75_2.0_2.pth
Deleted: model_0.5_3.0_2.pth
Deleted: model_0.5_2.0_0.pth
Deleted: model_0.25_2.0_1.pth
Deleted: model_0.5_1.0_3.pth
Deleted: model

Quindi ho insegnanti molto precisi e informativi tanto che ...
Ma il miglioramento sembra minimo allora sperimento nuove tecniche per cercare di alzare un po' l'accuracy

### ***Multi teacher KD***

In [31]:
teachers = []
num_teachers = 20
if not trained_models:
    for i in range(num_teachers):
        torch.manual_seed(i)
        teacher = TeacherModel().to(device)
        teacher = train(teacher, epochs = 10)
        teachers.append(teachers)
        torch.save(teacher.state_dict(), f'teacher_{i}.pth')
        with zipfile.ZipFile(f'MTKD.zip', 'a') as zipf:
            zipf.write(f'teacher_{i}.pth')
    files.download(f'MTKD.zip')
else:
    for i in range(num_teachers):
        torch.manual_seed(i)
        teacher = TeacherModel().to(device)
        with zipfile.ZipFile('MTKD.zip', 'r') as zip_ref:
            # Extract the specified file
            zip_ref.extract(f'teacher_{i}.pth')
        teacher.load_state_dict(torch.load(f'teacher_{i}.pth'))
        teacher_accuracy = test_model(teacher, test_loader)
        teachers.append(teacher)

Test Accuracy: 88.27%
Test Accuracy: 88.23%
Test Accuracy: 87.95%
Test Accuracy: 88.32%
Test Accuracy: 88.13%
Test Accuracy: 88.42%
Test Accuracy: 88.61%
Test Accuracy: 87.62%
Test Accuracy: 88.47%
Test Accuracy: 88.47%
Test Accuracy: 88.50%
Test Accuracy: 88.70%
Test Accuracy: 88.30%
Test Accuracy: 88.48%
Test Accuracy: 88.28%
Test Accuracy: 87.50%
Test Accuracy: 88.45%
Test Accuracy: 88.58%
Test Accuracy: 88.09%
Test Accuracy: 87.77%


In [32]:
del_all()

Deleted: teacher_7.pth
Deleted: teacher_19.pth
Deleted: teacher_18.pth
Deleted: teacher_0.pth
Deleted: teacher_11.pth
Deleted: teacher_14.pth
Deleted: teacher_12.pth
Deleted: teacher_4.pth
Deleted: teacher_17.pth
Deleted: teacher_3.pth
Deleted: teacher_16.pth
Deleted: teacher_6.pth
Deleted: teacher_1.pth
Deleted: teacher_8.pth
Deleted: teacher_10.pth
Deleted: teacher_15.pth
Deleted: teacher_5.pth
Deleted: teacher_9.pth
Deleted: teacher_13.pth
Deleted: teacher_2.pth


In [77]:
def train_student_model_multi_teacher(teachers, student, loader = train_loader, epochs = 10, alpha=0.5, T=2.0, device='cpu'):
    ce_loss = nn.CrossEntropyLoss()
    optimizer = optim.Adam(student.parameters(), lr=0.001)

    for teacher in teachers:
        teacher.eval()  # Set all teacher models to evaluation mode
    student.train()  # Student in training mode

    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in loader:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()

            # Forward pass with the teacher models - do not save gradients
            teacher_logits = []
            with torch.no_grad():
                for teacher in teachers:
                    teacher_logits.append(teacher(inputs))

            # Average the logits from the teacher models
            avg_teacher_logits = sum(teacher_logits) / len(teacher_logits)

            # Forward pass with the student model
            student_logits = student(inputs)

            # Soften the teacher and student logits by applying softmax first and log() second
            soft_targets = nn.functional.softmax(avg_teacher_logits / T, dim=-1)
            soft_prob = nn.functional.log_softmax(student_logits / T, dim=-1)

            # Calculate the soft targets loss. Scaled by T**2 as suggested by the authors of the paper "Distilling the knowledge in a neural network"
            soft_targets_loss = nn.KLDivLoss()(soft_prob, soft_targets) * (T**2)

            # Calculate the true label loss
            label_loss = ce_loss(student_logits, labels)

            # Weighted sum of the two losses
            loss = alpha * soft_targets_loss + (1 - alpha) * label_loss

            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(loader):.4f}")

    return student


AVG MTKD $\alpha=0.5$

In [34]:
num_teachers = [2, 5, 10, 20]

if not trained_models:
    for num in num_teachers:
        for i in range(num_iterations):
            # Reset the random seed for reproducibility
            torch.manual_seed(i)
            student_model_distillated = StudentModel().to(device)
            student_model_distillated = train_student_model_multi_teacher(teachers[:num], student_model_distillated, epochs = 10, alpha = 0.5, T = 2.0)
            student_distillated_accuracy = test_model(student_model_distillated, test_loader)
            torch.save(student_model_distillated.state_dict(), f'MTKD_student_{num}_{i}.pth')
            with zipfile.ZipFile('MTKD_students.zip', 'a') as zipf:
                zipf.write(f'MTKD_student_{num}_{i}.pth')

else:
    for num in num_teachers:
        for i in range(num_iterations):
            torch.manual_seed(i)
            student_model_distillated = StudentModel().to(device)
            with zipfile.ZipFile('MTKD_students.zip', 'r') as zip_ref:
                # Extract the specified file
                zip_ref.extract(f'MTKD_student_{num}_{i}.pth')
            student_model_distillated.load_state_dict(torch.load(f'MTKD_student_{num}_{i}.pth'))
            student_distillated_accuracy = test_model(student_model_distillated, test_loader)

if not trained_models: files.download('MTKD_students.zip')

Test Accuracy: 87.12%
Test Accuracy: 86.85%
Test Accuracy: 87.22%
Test Accuracy: 86.05%
Test Accuracy: 85.72%
Test Accuracy: 86.96%
Test Accuracy: 87.02%
Test Accuracy: 87.05%
Test Accuracy: 86.16%
Test Accuracy: 85.55%
Test Accuracy: 87.21%
Test Accuracy: 86.46%
Test Accuracy: 87.27%
Test Accuracy: 86.79%
Test Accuracy: 85.73%
Test Accuracy: 86.98%
Test Accuracy: 86.76%
Test Accuracy: 87.08%
Test Accuracy: 86.07%
Test Accuracy: 85.98%


AVG MTKD $\alpha=0.75$

In [35]:
num_teachers = [2, 5, 10, 20]

if not trained_models:
    for num in num_teachers:
        for i in range(num_iterations):
            # Reset the random seed for reproducibility
            torch.manual_seed(i)
            student_model_distillated = StudentModel().to(device)
            student_model_distillated = train_student_model_multi_teacher(teachers[:num], student_model_distillated, epochs = 10, alpha = 0.75, T = 2.0)
            student_distillated_accuracy = test_model(student_model_distillated, test_loader)
            torch.save(student_model_distillated.state_dict(), f'MTKD_student075_{num}_{i}.pth')
            with zipfile.ZipFile('MTKD_students_075.zip', 'a') as zipf:
                zipf.write(f'MTKD_student075_{num}_{i}.pth')

else:
    for num in num_teachers:
        for i in range(num_iterations):
            torch.manual_seed(i)
            student_model_distillated = StudentModel().to(device)
            with zipfile.ZipFile('MTKD_students_075.zip', 'r') as zip_ref:
                # Extract the specified file
                zip_ref.extract(f'MTKD_student075_{num}_{i}.pth')
            student_model_distillated.load_state_dict(torch.load(f'MTKD_student075_{num}_{i}.pth'))
            student_distillated_accuracy = test_model(student_model_distillated, test_loader)

if not trained_models: files.download('MTKD_students_075.zip')

Test Accuracy: 86.79%
Test Accuracy: 86.78%
Test Accuracy: 86.51%
Test Accuracy: 86.79%
Test Accuracy: 86.00%
Test Accuracy: 86.93%
Test Accuracy: 86.71%
Test Accuracy: 86.72%
Test Accuracy: 86.48%
Test Accuracy: 86.22%
Test Accuracy: 86.98%
Test Accuracy: 86.70%
Test Accuracy: 87.00%
Test Accuracy: 86.47%
Test Accuracy: 86.56%
Test Accuracy: 87.09%
Test Accuracy: 86.86%
Test Accuracy: 87.05%
Test Accuracy: 86.86%
Test Accuracy: 86.27%


In [36]:
#Multi Teacher KD versione 2

In [37]:
def train_student(teachers, student, temperature, alpha, epochs):
    dataloader = train_loader
    optimizer = torch.optim.Adam(student.parameters(), lr=0.001)
    criterion_ce = nn.CrossEntropyLoss()
    criterion_kld = nn.KLDivLoss(reduction='batchmean')
    student.train()
    for epoch in range(epochs):
        total_loss = 0.0
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)

            # Get teacher logits and compute KLD loss
            kld_loss = 0
            for teacher in teachers:
                teacher.eval()
                with torch.no_grad():
                    teacher_logits = teacher(images) / temperature
                    teacher_softmax = F.softmax(teacher_logits, dim=1)
                    student_logits = student(images) / temperature
                    student_log_softmax = F.log_softmax(student_logits, dim=1)
                    kld_loss += criterion_kld(student_log_softmax, teacher_softmax)
            kld_loss /= len(teachers)

            # Compute cross-entropy loss with true labels
            student_logits = student(images)
            ce_loss = criterion_ce(student_logits, labels)

            # Total loss
            loss = alpha * kld_loss * (temperature ** 2) + (1 - alpha) * ce_loss

            # Backpropagation and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        avg_loss = total_loss / len(dataloader)
        print(f'Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}')

    return student

In [38]:
#@title ###*2$^{nd}$ version MTKD,  $\alpha = 0.5$*

num_teachers = [2, 5, 10, 20]

if not trained_models:
    for num in num_teachers:
        for i in range(num_iterations):
            torch.manual_seed(i)
            student_model_distillated = StudentModel().to(device)
            student_model_distillated = train_student(teachers[:num], student, 2.0, 0.5, 10)
            test_model(student, test_loader)
            student_distillated_accuracy = test_model(student_model_distillated, test_loader)
            torch.save(student_model_distillated.state_dict(), f'MTKD_2_student05_{num}_{i}.pth')
            with zipfile.ZipFile('MTKD_2_students_05.zip', 'a') as zipf:
                zipf.write(f'MTKD_2_student05_{num}_{i}.pth')

else:
    for num in num_teachers:
        for i in range(num_iterations):
            torch.manual_seed(i)
            student_model_distillated = StudentModel().to(device)
            with zipfile.ZipFile('MTKD_2_students_05.zip', 'r') as zip_ref:
                # Extract the specified file
                zip_ref.extract(f'MTKD_2_student05_{num}_{i}.pth')
            student_model_distillated.load_state_dict(torch.load(f'MTKD_2_student05_{num}_{i}.pth'))
            student_distillated_accuracy = test_model(student_model_distillated, test_loader)

if not trained_models: files.download('MTKD_2_students_05.zip')

Test Accuracy: 86.90%
Test Accuracy: 87.48%
Test Accuracy: 87.13%
Test Accuracy: 87.52%
Test Accuracy: 87.15%
Test Accuracy: 87.31%
Test Accuracy: 87.43%
Test Accuracy: 86.71%
Test Accuracy: 86.44%
Test Accuracy: 86.89%
Test Accuracy: 86.75%
Test Accuracy: 86.48%
Test Accuracy: 86.49%
Test Accuracy: 86.28%
Test Accuracy: 86.44%
Test Accuracy: 86.43%
Test Accuracy: 86.05%
Test Accuracy: 86.17%
Test Accuracy: 85.78%
Test Accuracy: 86.04%


In [39]:
#@title ###*2$^{nd}$ version MTKD,  $\alpha = 0.75$*

num_teachers = [2, 5, 10, 20]

if not trained_models:
    for num in num_teachers:
        for i in range(num_iterations):
            torch.manual_seed(i)
            student_model_distillated = StudentModel().to(device)
            student_model_distillated = train_student(teachers[:num], student, 2.0, 0.75, 10)
            test_model(student, test_loader)
            student_distillated_accuracy = test_model(student_model_distillated, test_loader)
            torch.save(student_model_distillated.state_dict(), f'MTKD_2_student075_{num}_{i}.pth')
            with zipfile.ZipFile('MTKD_2_students_075.zip', 'a') as zipf:
                zipf.write(f'MTKD_2_student075_{num}_{i}.pth')

else:
    for num in num_teachers:
        for i in range(num_iterations):
            torch.manual_seed(i)
            student_model_distillated = StudentModel().to(device)
            with zipfile.ZipFile('MTKD_2_students_075.zip', 'r') as zip_ref:
                # Extract the specified file
                zip_ref.extract(f'MTKD_2_student075_{num}_{i}.pth')
            student_model_distillated.load_state_dict(torch.load(f'MTKD_2_student075_{num}_{i}.pth'))
            student_distillated_accuracy = test_model(student_model_distillated, test_loader)

if not trained_models: files.download('MTKD_2_students_075.zip')

Test Accuracy: 86.12%
Test Accuracy: 85.69%
Test Accuracy: 85.81%
Test Accuracy: 85.50%
Test Accuracy: 85.74%
Test Accuracy: 86.06%
Test Accuracy: 85.67%
Test Accuracy: 85.49%
Test Accuracy: 85.49%
Test Accuracy: 85.64%
Test Accuracy: 85.75%
Test Accuracy: 85.72%
Test Accuracy: 85.50%
Test Accuracy: 85.40%
Test Accuracy: 85.60%
Test Accuracy: 85.60%
Test Accuracy: 85.61%
Test Accuracy: 85.25%
Test Accuracy: 85.36%
Test Accuracy: 85.42%


In [40]:
del_all()

Deleted: MTKD_2_student05_2_1.pth
Deleted: MTKD_2_student05_20_4.pth
Deleted: MTKD_2_student05_5_2.pth
Deleted: MTKD_2_student075_2_1.pth
Deleted: MTKD_2_student075_2_3.pth
Deleted: MTKD_2_student075_2_2.pth
Deleted: MTKD_2_student075_10_0.pth
Deleted: MTKD_student_5_3.pth
Deleted: MTKD_student_5_1.pth
Deleted: MTKD_student075_5_3.pth
Deleted: MTKD_2_student075_20_4.pth
Deleted: MTKD_student_20_2.pth
Deleted: MTKD_student_20_0.pth
Deleted: MTKD_2_student075_5_3.pth
Deleted: MTKD_student075_2_3.pth
Deleted: MTKD_student075_20_1.pth
Deleted: MTKD_student075_5_4.pth
Deleted: MTKD_student_5_2.pth
Deleted: MTKD_student_2_3.pth
Deleted: MTKD_student075_20_3.pth
Deleted: MTKD_2_student075_2_4.pth
Deleted: MTKD_student_10_2.pth
Deleted: MTKD_2_student075_5_2.pth
Deleted: MTKD_student075_10_2.pth
Deleted: MTKD_student075_2_0.pth
Deleted: MTKD_2_student05_2_2.pth
Deleted: MTKD_2_student05_20_2.pth
Deleted: MTKD_2_student05_5_4.pth
Deleted: MTKD_2_student075_10_2.pth
Deleted: MTKD_2_student075_20

### ***Self distillation***

In [41]:
class MLPWithSelfDistillation(nn.Module):
    def __init__(self):
        super(MLPWithSelfDistillation, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, 128)
        self.fc4 = nn.Linear(128, 64)
        self.fc5 = nn.Linear(64, 10)  # Corrected from self.fc4 to self.fc5 to avoid overwriting

        # Classifiers at different depths
        self.classifier1 = nn.Linear(512, 10)
        self.classifier2 = nn.Linear(256, 10)
        self.classifier3 = nn.Linear(128, 10)
        self.classifier4 = nn.Linear(64, 10)

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x1 = F.relu(self.fc1(x))
        x2 = F.relu(self.fc2(x1))
        x3 = F.relu(self.fc3(x2))
        x4 = F.relu(self.fc4(x3))
        out = self.fc5(x4)  # Corrected from self.fc4(x3) to self.fc5(x4)

        c1 = self.classifier1(x1)
        c2 = self.classifier2(x2)
        c3 = self.classifier3(x3)
        c4 = self.classifier4(x4)

        return out, c1, c2, c3, c4

# Define the loss function for knowledge distillation
def loss_fn_kd(outputs, labels, teacher_outputs, alpha=0.5, T=2):
    KD_loss = nn.KLDivLoss()(F.log_softmax(outputs / T, dim=1),
                             F.softmax(teacher_outputs / T, dim=1)) * (alpha * T * T) + \
              F.cross_entropy(outputs, labels) * (1. - alpha)
    return KD_loss

# Training function for a single model
def train_self(model, alpha=0.5, epochs=10):
    model.train()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    for epoch in range(1, epochs + 1):
        model.train()
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()

            output, c1, c2, c3, c4 = model(data)

            # Calculate loss
            loss1 = F.cross_entropy(c1, target)
            loss2 = F.cross_entropy(c2, target)
            loss3 = F.cross_entropy(c3, target)
            loss4 = F.cross_entropy(c4, target)
            loss_final = F.cross_entropy(output, target)

            # Self-distillation losses
            kl_loss1 = loss_fn_kd(c1, target, output, alpha)
            kl_loss2 = loss_fn_kd(c2, target, output, alpha)
            kl_loss3 = loss_fn_kd(c3, target, output, alpha)
            kl_loss4 = loss_fn_kd(c4, target, output, alpha)

            # Total loss
            loss = loss1 + loss2 + loss3 + loss4 + loss_final + kl_loss1 + kl_loss2 + kl_loss3 + kl_loss4

            loss.backward()
            optimizer.step()

        print(f"Epoch {epoch}/{epochs}, Loss: {loss / len(train_loader):.4f}")


    return model

In [42]:
def testSTKD(model, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output, _, _, _, _ = model(data)
            test_loss += F.cross_entropy(output, target, reduction='sum').item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    accuracy = 100. * correct / len(test_loader.dataset)
    print(f'Accuracy: {accuracy:.2f}%')
    return accuracy

In [43]:
teachers_STKD_05 = []
num_teachers = 20
if not trained_models:
    for i in range(num_teachers):
        torch.manual_seed(i)
        teacher = MLPWithSelfDistillation().to(device)
        teacher = train_self(teacher)
        teacher_acc = testSTKD(teacher, test_loader)
        teachers_STKD_05.append(teacher)
        torch.save(teacher.state_dict(), f'teacher_STKD_{i}.pth')
        with zipfile.ZipFile(f'STKD_05.zip', 'a') as zipf:
            zipf.write(f'teacher_STKD_{i}.pth')
    files.download(f'STKD_05.zip')
else:
    for i in range(num_teachers):
        torch.manual_seed(i)
        teacher = MLPWithSelfDistillation().to(device)
        with zipfile.ZipFile('STKD_05.zip', 'r') as zip_ref:
            # Extract the specified file
            zip_ref.extract(f'teacher_STKD_{i}.pth')
        teacher.load_state_dict(torch.load(f'teacher_STKD_{i}.pth'))
        teacher_accuracy = testSTKD(teacher, test_loader)
        teachers_STKD_05.append(teacher)

Accuracy: 88.68%
Accuracy: 88.45%
Accuracy: 88.64%
Accuracy: 88.44%
Accuracy: 88.61%
Accuracy: 87.99%
Accuracy: 88.22%
Accuracy: 88.23%
Accuracy: 88.38%
Accuracy: 88.12%
Accuracy: 88.96%
Accuracy: 88.53%
Accuracy: 88.77%
Accuracy: 87.76%
Accuracy: 88.29%
Accuracy: 88.51%
Accuracy: 88.21%
Accuracy: 88.41%
Accuracy: 88.13%
Accuracy: 88.30%


In [44]:
teachers_STKD_025 = []
num_teachers = 20
if not trained_models:
    for i in range(num_teachers):
        torch.manual_seed(i)
        teacher = MLPWithSelfDistillation().to(device)
        teacher = train_self(teacher)
        teacher_acc = testSTKD(teacher, test_loader)
        teachers_STKD_025.append(teacher)
        torch.save(teacher.state_dict(), f'teacher_STKD_025_{i}.pth')
        with zipfile.ZipFile(f'STKD_025.zip', 'a') as zipf:
            zipf.write(f'teacher_STKD_025_{i}.pth')
    files.download(f'STKD_025.zip')
else:
    for i in range(num_teachers):
        torch.manual_seed(i)
        teacher = MLPWithSelfDistillation().to(device)
        with zipfile.ZipFile('STKD_025.zip', 'r') as zip_ref:
            # Extract the specified file
            zip_ref.extract(f'teacher_STKD_025_{i}.pth')
        teacher.load_state_dict(torch.load(f'teacher_STKD_025_{i}.pth'))
        teacher_accuracy = testSTKD(teacher, test_loader)
        teachers_STKD_025.append(teacher)

Accuracy: 88.68%
Accuracy: 88.45%
Accuracy: 88.64%
Accuracy: 88.44%
Accuracy: 88.61%
Accuracy: 87.99%
Accuracy: 88.22%
Accuracy: 88.23%
Accuracy: 88.38%
Accuracy: 88.12%
Accuracy: 88.96%
Accuracy: 88.53%
Accuracy: 88.77%
Accuracy: 87.76%
Accuracy: 88.29%
Accuracy: 88.51%
Accuracy: 88.21%
Accuracy: 88.41%
Accuracy: 88.13%
Accuracy: 88.30%


In [45]:
del_all()

Deleted: teacher_STKD_025_0.pth
Deleted: teacher_STKD_4.pth
Deleted: teacher_STKD_025_9.pth
Deleted: teacher_STKD_7.pth
Deleted: teacher_STKD_025_17.pth
Deleted: teacher_STKD_025_11.pth
Deleted: teacher_STKD_025_4.pth
Deleted: teacher_STKD_025_12.pth
Deleted: teacher_STKD_15.pth
Deleted: teacher_STKD_18.pth
Deleted: teacher_STKD_025_6.pth
Deleted: teacher_STKD_0.pth
Deleted: teacher_STKD_6.pth
Deleted: teacher_STKD_5.pth
Deleted: teacher_STKD_025_7.pth
Deleted: teacher_STKD_025_18.pth
Deleted: teacher_STKD_025_15.pth
Deleted: teacher_STKD_16.pth
Deleted: teacher_STKD_10.pth
Deleted: teacher_STKD_17.pth
Deleted: teacher_STKD_025_3.pth
Deleted: teacher_STKD_025_14.pth
Deleted: teacher_STKD_11.pth
Deleted: teacher_STKD_8.pth
Deleted: teacher_STKD_12.pth
Deleted: teacher_STKD_025_13.pth
Deleted: teacher_STKD_13.pth
Deleted: teacher_STKD_19.pth
Deleted: teacher_STKD_025_1.pth
Deleted: teacher_STKD_025_8.pth
Deleted: teacher_STKD_025_16.pth
Deleted: teacher_STKD_025_19.pth
Deleted: teacher_S

In [46]:
teachers_STKD_075 = []
num_teachers = 20
if not trained_models:
    for i in range(num_teachers):
        torch.manual_seed(i)
        teacher = MLPWithSelfDistillation().to(device)
        teacher = train_self(teacher, alpha = 0.75)
        teachers_STKD_075.append(teachers)
        torch.save(teacher.state_dict(), f'teacher_{i}.pth')
        with zipfile.ZipFile(f'STKD_075.zip', 'a') as zipf:
            zipf.write(f'teacher_{i}.pth')
    files.download(f'STKD_075.zip')
else:
    for i in range(num_teachers):
        torch.manual_seed(i)
        teacher = MLPWithSelfDistillation().to(device)
        with zipfile.ZipFile('STKD_075.zip', 'r') as zip_ref:
            # Extract the specified file
            zip_ref.extract(f'teacher_{i}.pth')
        teacher.load_state_dict(torch.load(f'teacher_{i}.pth'))
        teacher_accuracy = testSTKD(teacher, test_loader)
        teachers_STKD_075.append(teacher)

Accuracy: 88.17%
Accuracy: 89.12%
Accuracy: 88.40%
Accuracy: 88.84%
Accuracy: 88.42%
Accuracy: 88.04%
Accuracy: 88.64%
Accuracy: 88.30%
Accuracy: 88.43%
Accuracy: 88.75%
Accuracy: 88.87%
Accuracy: 88.54%
Accuracy: 88.66%
Accuracy: 88.97%
Accuracy: 87.88%
Accuracy: 88.34%
Accuracy: 88.35%
Accuracy: 88.60%
Accuracy: 88.45%
Accuracy: 88.89%


In [47]:
del_all()

Deleted: teacher_7.pth
Deleted: teacher_19.pth
Deleted: teacher_18.pth
Deleted: teacher_0.pth
Deleted: teacher_11.pth
Deleted: teacher_14.pth
Deleted: teacher_12.pth
Deleted: teacher_4.pth
Deleted: teacher_17.pth
Deleted: teacher_3.pth
Deleted: teacher_16.pth
Deleted: teacher_6.pth
Deleted: teacher_1.pth
Deleted: teacher_8.pth
Deleted: teacher_10.pth
Deleted: teacher_15.pth
Deleted: teacher_5.pth
Deleted: teacher_9.pth
Deleted: teacher_13.pth
Deleted: teacher_2.pth


### ***Unisco i concetti***

In [48]:
def train_student_model_multi_teacher(teachers, student, epochs = 10, alpha=0.5, T=2.0, device='cpu'):
    ce_loss = nn.CrossEntropyLoss()
    optimizer = optim.Adam(student.parameters(), lr=0.001)

    for teacher in teachers:
        teacher.eval()  # Set all teacher models to evaluation mode
    student.train()  # Student in training mode

    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()

            # Forward pass with the teacher models - do not save gradients
            teacher_logits = []
            with torch.no_grad():
                for teacher in teachers:
                    t_logits,_,_,_,_ = teacher(inputs)
                    teacher_logits.append(t_logits)

            # Average the logits from the teacher models
            avg_teacher_logits = sum(teacher_logits) / len(teacher_logits)

            # Forward pass with the student model
            student_logits = student(inputs)

            # Soften the teacher and student logits by applying softmax first and log() second
            soft_targets = nn.functional.softmax(avg_teacher_logits / T, dim=-1)
            soft_prob = nn.functional.log_softmax(student_logits / T, dim=-1)

            # Calculate the soft targets loss. Scaled by T**2 as suggested by the authors of the paper "Distilling the knowledge in a neural network"
            soft_targets_loss = nn.KLDivLoss()(soft_prob, soft_targets) * (T**2)

            # Calculate the true label loss
            label_loss = ce_loss(student_logits, labels)

            # Weighted sum of the two losses
            loss = alpha * soft_targets_loss + (1 - alpha) * label_loss

            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader):.4f}")

    return student


In [49]:
trained_models = True

num_teachers = [2,5,10,20]

num_iterations = 5

if not trained_models:
    for num in num_teachers:
        for i in range(num_iterations):
            torch.manual_seed(i)
            student_model_distillated = StudentModel().to(device)
            student_model_distillated = train_student_model_multi_teacher(teachers[:num], student_model_distillated)
            student_distillated_accuracy = test_model(student_model_distillated, test_loader)
            torch.save(student_model_distillated.state_dict(), f'Final_2_student05_{num}_{i}.pth')
            with zipfile.ZipFile('Final_2_students_05.zip', 'a') as zipf:
                zipf.write(f'Final_2_student05_{num}_{i}.pth')
    files.download('Final_2_students_05.zip')
else:
    for num in num_teachers:
        for i in range(num_iterations):
            torch.manual_seed(i)
            student_model_distillated = StudentModel().to(device)
            with zipfile.ZipFile('Final_2_students_05.zip', 'r') as zip_ref:
                # Extract the specified file
                zip_ref.extract(f'Final_2_student05_{num}_{i}.pth')
            student_model_distillated.load_state_dict(torch.load(f'Final_2_student05_{num}_{i}.pth'))
            student_distillated_accuracy = test_model(student_model_distillated, test_loader)



Test Accuracy: 87.07%
Test Accuracy: 86.82%
Test Accuracy: 87.17%
Test Accuracy: 86.01%
Test Accuracy: 85.74%
Test Accuracy: 87.09%
Test Accuracy: 86.86%
Test Accuracy: 86.91%
Test Accuracy: 86.09%
Test Accuracy: 85.51%
Test Accuracy: 87.25%
Test Accuracy: 86.76%
Test Accuracy: 86.94%
Test Accuracy: 86.33%
Test Accuracy: 85.75%
Test Accuracy: 87.05%
Test Accuracy: 86.90%
Test Accuracy: 86.80%
Test Accuracy: 86.21%
Test Accuracy: 85.42%


In [50]:
del_all()

Deleted: Final_2_student05_5_0.pth
Deleted: Final_2_student05_10_2.pth
Deleted: Final_2_student05_2_4.pth
Deleted: Final_2_student05_20_0.pth
Deleted: Final_2_student05_10_0.pth
Deleted: Final_2_student05_2_1.pth
Deleted: Final_2_student05_5_4.pth
Deleted: Final_2_student05_10_1.pth
Deleted: Final_2_student05_20_2.pth
Deleted: Final_2_student05_10_3.pth
Deleted: Final_2_student05_5_3.pth
Deleted: Final_2_student05_5_1.pth
Deleted: Final_2_student05_20_4.pth
Deleted: Final_2_student05_20_1.pth
Deleted: Final_2_student05_2_0.pth
Deleted: Final_2_student05_20_3.pth
Deleted: Final_2_student05_2_3.pth
Deleted: Final_2_student05_5_2.pth
Deleted: Final_2_student05_2_2.pth
Deleted: Final_2_student05_10_4.pth


In [51]:
#MTKD versione 2

def train_student_MTKD_2(teachers, student, temperature = 2.0, alpha = 0.5, epochs = 10):
    dataloader = train_loader
    optimizer = torch.optim.Adam(student.parameters(), lr=0.001)
    criterion_ce = nn.CrossEntropyLoss()
    criterion_kld = nn.KLDivLoss(reduction='batchmean')
    student.train()
    for epoch in range(epochs):
        total_loss = 0.0
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)

            # Get teacher logits and compute KLD loss
            kld_loss = 0
            for teacher in teachers:
                teacher.eval()
                with torch.no_grad():
                    T_L, _, _, _, _ = teacher(images)
                    teacher_logits = T_L / temperature
                    teacher_softmax = F.softmax(teacher_logits, dim=1)
                    student_logits = student(images) / temperature
                    student_log_softmax = F.log_softmax(student_logits, dim=1)
                    kld_loss += criterion_kld(student_log_softmax, teacher_softmax)
            kld_loss /= len(teachers)

            # Compute cross-entropy loss with true labels
            student_logits = student(images)
            ce_loss = criterion_ce(student_logits, labels)

            # Total loss
            loss = alpha * kld_loss * (temperature ** 2) + (1 - alpha) * ce_loss

            # Backpropagation and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        avg_loss = total_loss / len(dataloader)
        print(f'Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}')

    return student

In [52]:
num_teachers = [10, 20]

trained_models = False

num_iterations = 5

if not trained_models:
    for num in num_teachers:
        if num == 10: j = 2
        else: j = 0
        for i in range(j, num_iterations):
            torch.manual_seed(i)
            student_model_distillated = StudentModel().to(device)
            student_model_distillated = train_student_MTKD_2(teachers_STKD_05[:num], student_model_distillated)
            student_distillated_accuracy = test_model(student_model_distillated, test_loader)
            torch.save(student_model_distillated.state_dict(), f'Final_3_student05_{num}_{i}.pth')
            with zipfile.ZipFile('Final_3_students_05.zip', 'a') as zipf:
                zipf.write(f'Final_3_student05_{num}_{i}.pth')
    files.download('Final_3_students_05.zip')
else:
    for num in num_teachers:
        for i in range(num_iterations):
            torch.manual_seed(i)
            student_model_distillated = StudentModel().to(device)
            with zipfile.ZipFile('Final_3_students_05.zip', 'r') as zip_ref:
                # Extract the specified file
                zip_ref.extract(f'Final_3_student05_{num}_{i}.pth')
            student_model_distillated.load_state_dict(torch.load(f'Final_3_student05_{num}_{i}.pth'))
            student_distillated_accuracy = test_model(student_model_distillated, test_loader)


Epoch 1/10, Loss: 1.0359


KeyboardInterrupt: 

In [None]:
'''
Test Accuracy: 86.90%
Test Accuracy: 87.48%
Test Accuracy: 87.13%
Test Accuracy: 87.52%
Test Accuracy: 87.15%
Test Accuracy: 87.31%
Test Accuracy: 87.43%
Test Accuracy: 86.71%
Test Accuracy: 86.44%
Test Accuracy: 86.89%
Test Accuracy: 86.75%
Test Accuracy: 86.48%
Test Accuracy: 86.49%
Test Accuracy: 86.28%
Test Accuracy: 86.44%
Test Accuracy: 86.43%
Test Accuracy: 86.05%
Test Accuracy: 86.17%
Test Accuracy: 85.78%
Test Accuracy: 86.04%

forse con troppi insegnanti non regolarizzo bene e overfitto un po'

'''

In [53]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Subset, DataLoader, random_split

# Define the Teacher Model
class TeacherModel(nn.Module):
    def __init__(self):
        super(TeacherModel, self).__init__()
        self.fc1 = nn.Linear(28*28, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, 128)
        self.fc4 = nn.Linear(128, 64)
        self.fc5 = nn.Linear(64, 10)

    def forward(self, x):
        x = x.view(-1, 28*28)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = F.relu(self.fc4(x))
        x = self.fc5(x)
        return x

# Function to train the teacher models
def train_teacher(model, train_loader, criterion, optimizer, epochs=5):
    model.train()
    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in train_loader:
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        print(f"Epoch {epoch+1}, Loss: {running_loss/len(train_loader)}")
'''
# Load the FashionMNIST dataset
transform = transforms.Compose([transforms.ToTensor()])

train_dataset = torchvision.datasets.FashionMNIST(
    root='./data', train=True, download=True, transform=transform)

# Number of teachers
num_teachers = 20

# Function to create random splits
def create_random_splits(dataset, num_teachers):
    split_size = len(dataset) // 5  # Half of the dataset for each teacher
    subsets = []

    for _ in range(num_teachers):
        subset1, _ = random_split(dataset, [split_size, len(dataset) - split_size])
        subsets.append(subset1)

    return subsets

# Create random half splits
subsets = create_random_splits(train_dataset, num_teachers)

# Initialize and train each teacher on their respective random half subset
teacher_models = []
for i in range(num_teachers):
    teacher_model = TeacherModel()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(teacher_model.parameters(), lr=0.001)

    train_loader = DataLoader(subsets[i], batch_size=64, shuffle=True)

    print(f"Training Teacher {i+1}/{num_teachers}")
    train_teacher(teacher_model, train_loader, criterion, optimizer, epochs=10)
    test_model(teacher_model, test_loader)
    teacher_models.append(teacher_model)

print("Training of all teacher models is complete.")
'''


Training Teacher 1/20
Epoch 1, Loss: 0.8648814990165385


KeyboardInterrupt: 

qui sotto posso riscrivere il codice che dava il 40 di accuracy con teacher MLP allenati su parti diverse del dataset

In [54]:
import torch
import torch.nn.functional as F


def combine_teacher_predictions(teacher_models, test_loader, temperature=2.0):
    """
    Combine predictions from multiple teacher models on the entire test set and calculate accuracy.

    Args:
        teacher_models (list of nn.Module): List of teacher models.
        test_loader (DataLoader): DataLoader for the test set.
        method (str): Method to combine predictions ('avg', 'max', 'vote').
        temperature (float): Temperature for scaling logits before combining predictions.

    Returns:
        accuracy (float): Accuracy of the combined predictions on the test set.
    """
    correct_predictions = 0
    total_samples = 0

    for inputs, labels in test_loader:
        inputs, labels = inputs.to(device), labels.to(device)

        # Collect logits from all teacher models
        logits_list = []
        for teacher in teacher_models:
            teacher.eval()
            with torch.no_grad():
                #logits, _, _, _, _ = teacher(inputs)
                logits = teacher(inputs)

                # Apply temperature scaling
                logits = logits / temperature
                logits_list.append(logits)

        # Stack logits from all teachers: shape [num_teachers, batch_size, num_classes]
        stacked_logits = torch.stack(logits_list)

        if method == 'avg':
            # Average the logits across all teachers
            combined_logits = torch.mean(stacked_logits, dim=0)

            raise ValueError("Method must be 'avg', 'max', or 'vote'")

        # Final prediction based on the combined logits
        final_predictions = torch.argmax(combined_logits, dim=1)

        # Update correct predictions count and total samples
        correct_predictions += (final_predictions == labels).sum().item()
        total_samples += labels.size(0)

    # Calculate accuracy
    accuracy = correct_predictions / total_samples
    return accuracy
'''
num_teachers = [2,5,10,20]
for t in num_teachers:
    print(f"Test Accuracy: {combine_teacher_predictions(teachers_STKD_05[:t], test_loader, method='avg')}")
    print(f"Test Accuracy: {combine_teacher_predictions(teachers_STKD_025[:t], test_loader, method='avg')}")
    print(f"Test Accuracy: {combine_teacher_predictions(teachers_STKD_075[:t], test_loader, method='avg')}")
'''
# Example usage:
# Assuming you have 20 trained teacher models in a list called `teacher_models`
# and some input data `input_data` (e.g., a batch of images)

# teacher_models = [teacher1, teacher2, ..., teacher20]  # Your 20 teacher models
# input_data = ...  # Your input data

# Combine predictions and get final predictions
# combined_logits, final_predictions = combine_teacher_predictions(teacher_models, input_data, method='avg')


'\nnum_teachers = [2,5,10,20]\nfor t in num_teachers:\n    print(f"Test Accuracy: {combine_teacher_predictions(teachers_STKD_05[:t], test_loader, method=\'avg\')}")\n    print(f"Test Accuracy: {combine_teacher_predictions(teachers_STKD_025[:t], test_loader, method=\'avg\')}")\n    print(f"Test Accuracy: {combine_teacher_predictions(teachers_STKD_075[:t], test_loader, method=\'avg\')}")\n'

In [55]:
'''

2. Over-regularization
Smoothened Outputs: With more teachers, especially when using a high temperature in the softmax function, the output distributions become more smooth and less distinct. This smoothening can lead to over-regularization, where the student model's learning is overly constrained, resulting in a loss of sharpness and specificity in predictions.
Reduced Model Capacity Utilization: The student model may not fully utilize its capacity when learning from overly smooth or averaged teacher outputs, leading to underfitting.

'''

"\n\n2. Over-regularization\nSmoothened Outputs: With more teachers, especially when using a high temperature in the softmax function, the output distributions become more smooth and less distinct. This smoothening can lead to over-regularization, where the student model's learning is overly constrained, resulting in a loss of sharpness and specificity in predictions.\nReduced Model Capacity Utilization: The student model may not fully utilize its capacity when learning from overly smooth or averaged teacher outputs, leading to underfitting.\n\n"

In [56]:
del_all()

In [57]:
# Train function
def train(model, epochs, learning_rate=0.001):
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    criterion = nn.CrossEntropyLoss()

    model.train()
    for epoch in range(epochs):
        running_loss = 0.0
        for images, labels in train_loader:
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        print(f'Epoch {epoch+1}, Loss: {running_loss/len(train_loader)}')

### ***Insegnanti diversi***

In [58]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [59]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Define CNN Teacher Model
class CNNTeacherModel(nn.Module):
    def __init__(self):
        super(CNNTeacherModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 64 * 7 * 7)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

class ShallowDenseNet(nn.Module):
    def __init__(self):
        super(ShallowDenseNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 8, kernel_size=3, padding=1)
        self.denseblock = nn.Sequential(
            nn.BatchNorm2d(8),
            nn.ReLU(inplace=True),
            nn.Conv2d(8, 16, kernel_size=3, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True),
            nn.Conv2d(16, 8, kernel_size=3, padding=1)
        )
        self.fc = nn.Linear(8 * 28 * 28, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.denseblock(x) + x  # Skip connection
        x = x.view(-1, 8 * 28 * 28)
        x = self.fc(x)
        return x


class FireModule(nn.Module):
    def __init__(self, in_channels, squeeze_channels, expand_channels):
        super(FireModule, self).__init__()
        self.squeeze = nn.Conv2d(in_channels, squeeze_channels, kernel_size=1)
        self.expand1x1 = nn.Conv2d(squeeze_channels, expand_channels, kernel_size=1)
        self.expand3x3 = nn.Conv2d(squeeze_channels, expand_channels, kernel_size=3, padding=1)

    def forward(self, x):
        x = torch.relu(self.squeeze(x))
        return torch.cat([
            torch.relu(self.expand1x1(x)),
            torch.relu(self.expand3x3(x))
        ], 1)

class SimpleSqueezeNet(nn.Module):
    def __init__(self):
        super(SimpleSqueezeNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=2)
        self.fire1 = FireModule(16, 16, 32)
        self.fire2 = FireModule(64, 16, 32)
        self.conv2 = nn.Conv2d(64, 10, kernel_size=1)
        self.pool = nn.AdaptiveAvgPool2d(1)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = self.fire1(x)
        x = self.fire2(x)
        x = torch.relu(self.conv2(x))
        x = self.pool(x)
        x = x.view(-1, 10)
        return x

In [73]:
trained_models = True

num_teachers = [2,5,10,20]
final_teachers = []

num_iterations = 5

if not trained_models:
    for i in range(num_iterations):
        torch.manual_seed(i)
        cnn = CNNTeacherModel()
        train(cnn, 10)
        torch.save(cnn.state_dict(), f'CNN_teacher_{i}.pth')
        with zipfile.ZipFile('Last_Teachers.zip', 'a') as zipf: zipf.write(f'CNN_teacher_{i}.pth')
        torch.manual_seed(i)
        shallowdensenet = ShallowDenseNet()
        train(shallowdensenet, 10)
        torch.save(shallowdensenet.state_dict(), f'shallowdensenet_teacher_{i}.pth')
        with zipfile.ZipFile('Last_Teachers.zip', 'a') as zipf: zipf.write(f'shallowdensenet_teacher_{i}.pth')
        torch.manual_seed(i)
        simple_squeeze_net = SimpleSqueezeNet()
        train(simple_squeeze_net, 10)
        torch.save(simple_squeeze_net.state_dict(), f'simple_squeeze_net_teacher_{i}.pth')
        with zipfile.ZipFile('Last_Teachers.zip', 'a') as zipf: zipf.write(f'simple_squeeze_net_teacher_{i}.pth')

    files.download('Last_Teachers.zip')
else:
    for i in range(num_iterations):
        torch.manual_seed(i)
        with zipfile.ZipFile('Last_Teachers.zip', 'r') as zip_ref:
            zip_ref.extract(f'CNN_teacher_{i}.pth')
        cnn = CNNTeacherModel().to(device)
        cnn.load_state_dict(torch.load(f'CNN_teacher_{i}.pth'))
        cnn_acc = test_model(cnn, test_loader)

        with zipfile.ZipFile('Last_Teachers.zip', 'r') as zip_ref:
            zip_ref.extract(f'shallowdensenet_teacher_{i}.pth')
        sdnt = ShallowDenseNet().to(device)
        sdnt.load_state_dict(torch.load(f'shallowdensenet_teacher_{i}.pth'))
        sdnt_acc = test_model(sdnt, test_loader)

        '''
        with zipfile.ZipFile('Last_Teachers.zip', 'r') as zip_ref:
            zip_ref.extract(f'simple_squeeze_net_teacher_{i}.pth')
        SSN = SimpleSqueezeNet().to(device)
        SSN.load_state_dict(torch.load(f'simple_squeeze_net_teacher_{i}.pth'))
        SSN_acc = test_model(SSN, test_loader)
        '''
        with zipfile.ZipFile('MTKD.zip', 'r') as zip_ref:
            zip_ref.extract(f'teacher_{i}.pth')
        SSN = TeacherModel().to(device)
        SSN.load_state_dict(torch.load(f'teacher_{i}.pth'))
        SSN_acc = test_model(SSN, test_loader)

        final_teachers.append(cnn)
        final_teachers.append(sdnt)
        final_teachers.append(SSN)

Test Accuracy: 91.63%
Test Accuracy: 89.44%
Test Accuracy: 88.27%
Test Accuracy: 92.02%
Test Accuracy: 89.14%
Test Accuracy: 88.23%
Test Accuracy: 92.02%
Test Accuracy: 88.89%
Test Accuracy: 87.95%
Test Accuracy: 91.93%
Test Accuracy: 89.51%
Test Accuracy: 88.32%
Test Accuracy: 91.87%
Test Accuracy: 89.32%
Test Accuracy: 88.13%


In [74]:
combine_teacher_predictions(final_teachers, test_loader, method='avg', temperature=2.0)

0.9303

In [83]:
final_teachers

[CNNTeacherModel(
   (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
   (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
   (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
   (fc1): Linear(in_features=3136, out_features=128, bias=True)
   (fc2): Linear(in_features=128, out_features=10, bias=True)
 ),
 ShallowDenseNet(
   (conv1): Conv2d(1, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
   (denseblock): Sequential(
     (0): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     (1): ReLU(inplace=True)
     (2): Conv2d(8, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
     (3): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     (4): ReLU(inplace=True)
     (5): Conv2d(16, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
   )
   (fc): Linear(in_features=6272, out_features=10, bias=True)
 ),
 TeacherModel(
   (fc1): Linea

In [None]:
#@title ###*2$^{nd}$ version MTKD,  $\alpha = 0.5$*

num_teachers = [2, 5, 10, 20]

trained_models = False

if not trained_models:
    for i in range(num_iterations):
        torch.manual_seed(i)
        student_model_distillated = StudentModel().to(device)
        student_model_distillated = train_student_model_multi_teacher(final_teachers, student_model_distillated, train_loader)
        test_model(student_model_distillated, test_loader)
        student_distillated_accuracy = test_model(student_model_distillated, test_loader)
        torch.save(student_model_distillated.state_dict(), f'MTKD_finale_student05_{i}.pth')
        with zipfile.ZipFile('MTKD_finale.zip', 'a') as zipf:
            zipf.write(f'MTKD_finale_student05_{i}.pth')

else:
    for i in range(num_iterations):
        torch.manual_seed(i)
        student_model_distillated = StudentModel().to(device)
        with zipfile.ZipFile('MTKD_finale.zip', 'r') as zip_ref:
            # Extract the specified file
            zip_ref.extract(f'MTKD_finale_student05_{i}.pth')
        student_model_distillated.load_state_dict(torch.load(f'MTKD_finale_student05_{i}.pth'))
        student_distillated_accuracy = test_model(student_model_distillated, test_loader)

if not trained_models: files.download('MTKD_finale.zip')



Epoch 1/10, Loss: 0.3645
Epoch 2/10, Loss: 0.2633
Epoch 3/10, Loss: 0.2413
Epoch 4/10, Loss: 0.2250
Epoch 5/10, Loss: 0.2139
Epoch 6/10, Loss: 0.2047


In [96]:
combine_teacher_predictions(final_teachers, test_loader, temperature=1.0).shape

torch.Size([10000, 10])

In [99]:
for input, label in test_loader:
  print(student(input).shape)
  break

torch.Size([1000, 10])


In [88]:
'''
def combine_teacher_predictions(teacher_models, test_loader, temperature=2.0):
    """
    Combine predictions from multiple teacher models on the entire test set and calculate accuracy.

    Args:
        teacher_models (list of nn.Module): List of teacher models.
        test_loader (DataLoader): DataLoader for the test set.
        method (str): Method to combine predictions ('avg', 'max', 'vote').
        temperature (float): Temperature for scaling logits before combining predictions.

    Returns:
        accuracy (float): Accuracy of the combined predictions on the test set.
    """
    correct_predictions = 0
    total_samples = 0

    for inputs, labels in test_loader:
        inputs, labels = inputs.to(device), labels.to(device)

        # Collect logits from all teacher models
        logits_list = []
        for teacher in teacher_models:
            teacher.eval()
            with torch.no_grad():
                #logits, _, _, _, _ = teacher(inputs)
                logits = teacher(inputs)

                # Apply temperature scaling
                logits = logits / temperature
                logits_list.append(logits)

        # Stack logits from all teachers: shape [num_teachers, batch_size, num_classes]
        stacked_logits = torch.stack(logits_list)

        combined_logits = torch.mean(stacked_logits, dim=0)

    return combined_logits

def train_student_model(teachers, student, train_loader, epochs = 10, alpha=0.5, T=2.0):
    ce_loss = nn.CrossEntropyLoss()
    optimizer = optim.Adam(student.parameters(), lr=0.001)

    for teacher in teachers:
        teacher.eval()  # Teacher set to evaluation mode
    student.train()  # Student to train mode

    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()

            # Forward pass with the teacher model - do not save gradients here as we do not change the teacher's weights
            with torch.no_grad():
                teacher_logits = combine_teacher_predictions(teachers, test_loader, temperature = 1.0)
            # Forward pass with the student model
            student_logits = student(inputs)

            # Soften the teacher logits by applying softmax first and log() second
            soft_targets = nn.functional.softmax(teacher_logits / T, dim=-1)
            soft_prob = nn.functional.log_softmax(student_logits / T, dim=-1)

            # Calculate the soft targets loss. Scaled by T**2 as suggested by the authors of the paper "Distilling the knowledge in a neural network"
            soft_targets_loss = nn.KLDivLoss()(soft_prob, soft_targets) * (T**2)

            # Calculate the true label loss
            label_loss = ce_loss(student_logits, labels)

            # Weighted sum of the two losses
            loss = alpha * soft_targets_loss + (1 - alpha) * label_loss

            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f"Epoch {epoch + 1}/{epochs}, Loss: {running_loss / len(train_loader)}")

    return student
'''

In [106]:
'''
import torch
import torch.nn.functional as F

def multi_teacher_distillation(teacher_models,student_model,temperature=2.0, alpha=0.5):
    """
    Perform multi-teacher knowledge distillation on a student model.

    Args:
        student_model (nn.Module): The student model to be trained.
        teacher_models (list of nn.Module): List of teacher models.
        train_loader (DataLoader): DataLoader for the training set.
        optimizer (Optimizer): Optimizer for training the student model.
        criterion (Loss): Loss function (e.g., CrossEntropyLoss).
        temperature (float): Temperature for scaling logits.
        alpha (float): Weighting factor between the distillation loss and the original loss.
        device (str): Device to use for training ('cuda' or 'cpu').

    Returns:
        float: The average training loss over the entire dataset.
    """
    student_model.train()

    total_loss = 0.0
    total_samples = 0
    criterion = torch.nn.CrossEntropyLoss()
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)

        # Forward pass through the student model
        student_logits = student_model(inputs)

        # Collect and combine logits from all teacher models
        with torch.no_grad():
            teacher_logits_list = []
            for teacher in teacher_models:
                teacher.eval()
                teacher_logits = teacher(inputs)
                teacher_logits = teacher_logits / temperature  # Apply temperature scaling
                teacher_logits_list.append(teacher_logits)

            # Average the logits from all teachers
            avg_teacher_logits = torch.mean(torch.stack(teacher_logits_list), dim=0)

        # Distillation loss: KL Divergence between teacher logits and student logits
        distillation_loss = F.kl_div(
            F.log_softmax(student_logits / temperature, dim=1),
            F.softmax(avg_teacher_logits / temperature, dim=1),
            reduction='batchmean'
        ) * (temperature ** 2)

        # Original loss: CrossEntropy between student logits and ground truth labels
        original_loss = criterion(student_logits, labels)

        # Total loss: Weighted combination of original loss and distillation loss
        loss = alpha * original_loss + (1.0 - alpha) * distillation_loss

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * inputs.size(0)
        total_samples += inputs.size(0)

    avg_loss = total_loss / total_samples
    return student_model
'''

In [109]:
'''
import torch
import torch.nn as nn
import torch.optim as optim

def get_combined_teacher_logits(teachers, inputs, device='cpu'):
    teacher_logits = []
    with torch.no_grad():
        for teacher in teachers:
            teacher_logits.append(teacher(inputs.to(device)))

    # Average the logits from the teacher models
    combined_logits = sum(teacher_logits) / len(teacher_logits)
    return combined_logits

def train_student_with_combined_logits(teachers, student, loader, epochs=10, alpha=0.5, T=2.0, device='cpu'):
    ce_loss = nn.CrossEntropyLoss()
    optimizer = optim.Adam(student.parameters(), lr=0.001)

    for teacher in teachers:
        teacher.eval()  # Set all teacher models to evaluation mode
    student.train()  # Student in training mode

    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in loader:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()

            # Get combined logits from the teacher models
            combined_logits = get_combined_teacher_logits(teachers, inputs, device)

            # Forward pass with the student model
            student_logits = student(inputs)

            # Soften the teacher and student logits by applying softmax first and log() second
            soft_targets = nn.functional.softmax(combined_logits / T, dim=-1)
            soft_prob = nn.functional.log_softmax(student_logits / T, dim=-1)

            # Calculate the soft targets loss. Scaled by T**2 as suggested by the authors of the paper "Distilling the knowledge in a neural network"
            soft_targets_loss = nn.KLDivLoss()(soft_prob, soft_targets) * (T**2)

            # Calculate the true label loss
            label_loss = ce_loss(student_logits, labels)

            # Weighted sum of the two losses
            loss = alpha * soft_targets_loss + (1 - alpha) * label_loss

            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(loader):.4f}")

    return student
'''