In [6]:
import sys
 
# setting path
sys.path.append('../')

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets

from non_iid_generator.customDataset import CustomDataset
import pickle
from tqdm import tqdm


# Check if GPU is available, and if not, use the CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [69]:
def train_knowledge_distillation(teacher, student, train_loader, epochs, learning_rate, T, soft_target_loss_weight, ce_loss_weight, device):
    ce_loss = nn.CrossEntropyLoss()
    optimizer = optim.Adam(student.parameters(), lr=learning_rate)

    teacher.eval()  # Teacher set to evaluation mode
    student.train() # Student to train mode

    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            # labels = torch.nn.functional.one_hot(labels, 10).to(device)
            labels = labels.squeeze(1).to(device)

            optimizer.zero_grad()

            # Forward pass with the teacher model - do not save gradients here as we do not change the teacher's weights
            with torch.no_grad():
                teacher_logits = teacher(inputs)

            # Forward pass with the student model
            student_logits = student(inputs)



            #Soften the student logits by applying softmax first and log() second
            soft_targets = nn.functional.softmax(teacher_logits / T, dim=-1)
            soft_prob = nn.functional.log_softmax(student_logits / T, dim=-1)

            # Calculate the soft targets loss. Scaled by T**2 as suggested by the authors of the paper "Distilling the knowledge in a neural network"
            soft_targets_loss = -torch.sum(soft_targets * soft_prob) / soft_prob.size()[0] * (T**2)

            # Calculate the true label loss
            # print(labels)
            # print(len(labels))
            label_loss = ce_loss(student_logits, labels)

            # Weighted sum of the two losses
            loss = soft_target_loss_weight * soft_targets_loss + ce_loss_weight * label_loss

            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")

def train(model, train_loader, epochs, learning_rate, device):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    model.train()

    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in train_loader:
            # inputs: A collection of batch_size images
            # labels: A vector of dimensionality batch_size with integers denoting class of each image
            inputs, labels = inputs.to(device), labels.to(device)
            labels = labels.squeeze(1).to(device)

            optimizer.zero_grad()
            outputs = model(inputs)

            # outputs: Output of the network for the collection of images. A tensor of dimensionality batch_size x num_classes
            # labels: The actual labels of the images. Vector of dimensionality batch_size
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")

def test(model, test_loader, device):
    model.to(device)
    model.eval()

    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            labels = labels.squeeze(1).to(device)

            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)

            total += labels.size(0)
            correct += (predicted.data == labels).sum().item()
            # print(predicted.data)
            # correct += (torch.max(predicted.data, 1)[1] == labels.squeeze_(1)).sum().item()

    accuracy = 100 * correct / total
    print(f"Test Accuracy: {accuracy:.2f}%")
    return accuracy

In [55]:
def load_data(train_dataset_path, test_dataset_path):
    """Load CIFAR-10 (training and test set)."""
    # trf = Compose([ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    # trainset = CIFAR10("./data", train=True, download=True, transform=trf)
    # testset = CIFAR10("./data", train=False, download=True, transform=trf)
    # return DataLoader(trainset, batch_size=32, shuffle=True), DataLoader(testset)

    batch_size = 128
    momentum = 0.9
    weight_decay = 1e-4
    finetune_lr = 0.001

    train_data = pickle.load(open(train_dataset_path, "rb"))
    train_loader = torch.utils.data.DataLoader(
        train_data,
        batch_size=batch_size,
        shuffle=True)

    test_data = pickle.load(open(test_dataset_path, "rb"))
    test_loader = torch.utils.data.DataLoader(
        test_data,
        batch_size=batch_size,
        shuffle=True)
    
    return train_loader, test_loader



In [68]:
client_id = 3
train_dataset_path = f"../data/32_Cifar10_NIID_56c_a03/train/{client_id}.pkl"
test_dataset_path = f"../data/32_Cifar10_NIID_56c_a03/test/{client_id}.pkl"
train_loader, test_loader = load_data(train_dataset_path, test_dataset_path)

nn_deep = torch.load("../projects/define_pretrained_fed_sim_NIID_alpha03/last_model.pth.tar")
nn_light = torch.load("../projects/test/test-4/master/iter_5_best_model.pth.tar")

test_accuracy_deep = test(nn_deep, test_loader, device)

train(nn_light, train_loader, epochs=50, learning_rate=0.001, device=device)
test_accuracy_light_ce = test(nn_light, test_loader, device)

print(f"Teacher accuracy: {test_accuracy_deep:.2f}%")
print(f"Student accuracy: {test_accuracy_light_ce:.2f}%")

###########################
###########################


new_nn_light = torch.load("../projects/test/test-4/master/iter_5_best_model.pth.tar")
# Apply ``train_knowledge_distillation`` with a temperature of 2. Arbitrarily set the weights to 0.75 for CE and 0.25 for distillation loss.
train_knowledge_distillation(teacher=nn_deep, student=new_nn_light, train_loader=train_loader, epochs=50, learning_rate=0.001, T=2, soft_target_loss_weight=0.25, ce_loss_weight=0.75, device=device)
test_accuracy_light_ce_and_kd = test(new_nn_light, test_loader, device)

# Compare the student test accuracy with and without the teacher, after distillation
print(f"Teacher accuracy: {test_accuracy_deep:.2f}%")
print(f"Student accuracy without teacher: {test_accuracy_light_ce:.2f}%")
print(f"Student accuracy with CE + KD: {test_accuracy_light_ce_and_kd:.2f}%")

Test Accuracy: 53.85%
Epoch 1/50, Loss: 22.153010654449464
Epoch 2/50, Loss: 1.1728242337703705
Epoch 3/50, Loss: 1.1214742064476013
Epoch 4/50, Loss: 1.0863099217414856
Epoch 5/50, Loss: 1.0780147552490233
Epoch 6/50, Loss: 1.0762548267841339
Epoch 7/50, Loss: 1.0806386351585389
Epoch 8/50, Loss: 1.088597047328949
Epoch 9/50, Loss: 1.067915803194046
Epoch 10/50, Loss: 1.049210751056671
Epoch 11/50, Loss: 1.0614818394184113
Epoch 12/50, Loss: 1.0126095354557036
Epoch 13/50, Loss: 0.9615123510360718
Epoch 14/50, Loss: 0.9356087446212769
Epoch 15/50, Loss: 0.853547728061676
Epoch 16/50, Loss: 0.807304835319519
Epoch 17/50, Loss: 0.8125032424926758
Epoch 18/50, Loss: 0.7153206050395966
Epoch 19/50, Loss: 0.7078683733940124
Epoch 20/50, Loss: 0.6943287342786789
Epoch 21/50, Loss: 0.684327346086502
Epoch 22/50, Loss: 0.6577076375484466
Epoch 23/50, Loss: 0.6471943259239197
Epoch 24/50, Loss: 0.6479382574558258
Epoch 25/50, Loss: 0.6264076679944992
Epoch 26/50, Loss: 0.6335215121507645
Epoch