In [1]:
from src.dataset import CIFARDataset, iid_dataloader
from src.lenet import LeNet, weights_init
from src.train import federated_learning_experiment, train_client
from torchvision import datasets, transforms
from torch.autograd import grad
import os
import numpy as np
from pprint import pprint
from PIL import Image
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F

##### Train LeNet on CIFAR-10 using Federated Averaging

In [2]:
device = device = 'cuda' if torch.cuda.is_available() else 'cpu'
criterion = nn.CrossEntropyLoss().to(device)

batch_size = 50
num_clients = 5
cifar_data = CIFARDataset(batch_size=batch_size, num_clients=num_clients, top_5_classes_indices=None)
train_dataset, validation_dataset, user_groups = cifar_data.get_dataset()

alpha = 1

train_subset_size = int(alpha * len(train_dataset))
validation_subset_size = int(alpha * len(validation_dataset))

train_dataset = torch.utils.data.Subset(train_dataset, indices=range(train_subset_size))
validation_dataset = torch.utils.data.Subset(validation_dataset, indices=range(validation_subset_size))

iid_client_train_loader = iid_dataloader(train_dataset, batch_size=batch_size, num_clients=num_clients)
validation_loader = torch.utils.data.DataLoader(validation_dataset, batch_size=batch_size, shuffle=False)

global_model = LeNet().to(device)

results = federated_learning_experiment(global_model, num_clients_per_round=2, num_local_epochs=10, lr = 5e-3, client_train_loader=iid_client_train_loader, max_rounds=5, device=device, criterion=criterion, test_dataloader=validation_loader)

Files already downloaded and verified
Files already downloaded and verified
Round 0 is starting
Clients for round 0 are: [2 0]
round 0, starting client 1/2, id: 2
round 0, starting client 2/2, id: 0
Round 0, validation accuracy: 13.719999999999999 %
Round 1 is starting
Clients for round 1 are: [3 4]
round 1, starting client 1/2, id: 3
round 1, starting client 2/2, id: 4
Round 1, validation accuracy: 43.480000000000004 %
Round 2 is starting
Clients for round 2 are: [0 2]
round 2, starting client 1/2, id: 0
round 2, starting client 2/2, id: 2
Round 2, validation accuracy: 51.67 %
Round 3 is starting
Clients for round 3 are: [3 0]
round 3, starting client 1/2, id: 3
round 3, starting client 2/2, id: 0
Round 3, validation accuracy: 52.33 %
Round 4 is starting
Clients for round 4 are: [4 2]
round 4, starting client 1/2, id: 4
round 4, starting client 2/2, id: 2
Round 4, validation accuracy: 54.0 %


##### Construct confusion matrix

In [4]:
from sklearn.metrics import confusion_matrix
import numpy as np

global_model.eval()
all_predictions = []
all_true_labels = []

with torch.no_grad():
    for imgs, labels in validation_loader:
        imgs = imgs.to(device)
        labels = labels.to(device)
        outputs = global_model(imgs)
        _, predicted = torch.max(outputs, 1)
        all_predictions.extend(predicted.cpu().numpy())
        all_true_labels.extend(labels.cpu().numpy())

conf_matrix = confusion_matrix(all_true_labels, all_predictions)

##### Construct CIFAR-5 dataset for better classification performance

In [6]:
class_accuracies = np.diag(conf_matrix) / np.sum(conf_matrix, axis=1)
top_5_classes_indices = np.argsort(class_accuracies)[-5:]
print(f"Top 5 classes with the highest classification accuracy: {top_5_classes_indices}")

Top 5 classes with the highest classification accuracy: [0 9 6 1 8]


##### Train LeNet on CIFAR-5 using Federated Averaging

In [7]:
device = device = 'cuda' if torch.cuda.is_available() else 'cpu'
criterion = nn.CrossEntropyLoss().to(device)

batch_size = 50
num_clients = 5
cifar_data = CIFARDataset(batch_size=batch_size, num_clients=num_clients, top_5_classes_indices=top_5_classes_indices)
train_dataset, validation_dataset, user_groups = cifar_data.get_dataset()

alpha = 1

train_subset_size = int(alpha * len(train_dataset))
validation_subset_size = int(alpha * len(validation_dataset))

train_dataset = torch.utils.data.Subset(train_dataset, indices=range(train_subset_size))
validation_dataset = torch.utils.data.Subset(validation_dataset, indices=range(validation_subset_size))

iid_client_train_loader = iid_dataloader(train_dataset, batch_size=batch_size, num_clients=num_clients)
validation_loader = torch.utils.data.DataLoader(validation_dataset, batch_size=batch_size, shuffle=False)

global_model = LeNet().to(device)

results = federated_learning_experiment(global_model, num_clients_per_round=2, num_local_epochs=10, lr = 5e-3, client_train_loader=iid_client_train_loader, max_rounds=5, device=device, criterion=criterion, test_dataloader=validation_loader)

Files already downloaded and verified
Files already downloaded and verified
Round 0 is starting
Clients for round 0 are: [4 0]
round 0, starting client 1/2, id: 4
round 0, starting client 2/2, id: 0
Round 0, validation accuracy: 61.18 %
Round 1 is starting
Clients for round 1 are: [3 1]
round 1, starting client 1/2, id: 3
round 1, starting client 2/2, id: 1
Round 1, validation accuracy: 69.39999999999999 %
Round 2 is starting
Clients for round 2 are: [2 1]
round 2, starting client 1/2, id: 2
round 2, starting client 2/2, id: 1
Round 2, validation accuracy: 70.02000000000001 %
Round 3 is starting
Clients for round 3 are: [4 1]
round 3, starting client 1/2, id: 4
round 3, starting client 2/2, id: 1
Round 3, validation accuracy: 71.12 %
Round 4 is starting
Clients for round 4 are: [0 1]
round 4, starting client 1/2, id: 0
round 4, starting client 2/2, id: 1
Round 4, validation accuracy: 71.64 %
