In [2]:
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
import math
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
from federated_learning import *

In [3]:
# Configuration
num_clients = 5
num_classes = 10
input_shape = (3, 32, 32) 
batch_size = 32
rounds = 5
epochs = 10
model = 2
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"Using device: {device}")

Using device: cpu


In [None]:
# Download and transform MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)) # normalization params for CIFAR10
])

cifar_train = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
cifar_test = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

# Split dataset indices for clients
dataset_size = len(cifar_train)
indices = list(range(dataset_size))
split_size = dataset_size // num_clients

clients_dataloaders = []
for i in range(num_clients):
    start = i * split_size
    end = start + split_size if i < num_clients - 1 else dataset_size
    client_indices = indices[start:end]
    client_subset = Subset(cifar_train, client_indices)
    dataloader = DataLoader(client_subset, batch_size=batch_size, shuffle=True)
    clients_dataloaders.append(dataloader)

print(f"Prepared {num_clients} client dataloaders")

test_loader = DataLoader(dataset=cifar_test, batch_size=64, shuffle=False)

tensor([3])
tensor([8])
tensor([8])
tensor([0])
tensor([6])
tensor([6])
tensor([1])
tensor([6])
tensor([3])
tensor([1])
tensor([0])
tensor([9])
tensor([5])
tensor([7])
tensor([9])
tensor([8])
tensor([5])
tensor([7])
tensor([8])
tensor([6])
tensor([7])
tensor([0])
tensor([4])
tensor([9])
tensor([5])
tensor([2])
tensor([4])
tensor([0])
tensor([9])
tensor([6])
tensor([6])
tensor([5])
tensor([4])
tensor([5])
tensor([9])
tensor([2])
tensor([4])
tensor([1])
tensor([9])
tensor([5])
tensor([4])
tensor([6])
tensor([5])
tensor([6])
tensor([0])
tensor([9])
tensor([3])
tensor([9])
tensor([7])
tensor([6])
tensor([9])
tensor([8])
tensor([0])
tensor([3])
tensor([8])
tensor([8])
tensor([7])
tensor([7])
tensor([4])
tensor([6])
tensor([7])
tensor([3])
tensor([6])
tensor([3])
tensor([6])
tensor([2])
tensor([1])
tensor([2])
tensor([3])
tensor([7])
tensor([2])
tensor([6])
tensor([8])
tensor([8])
tensor([0])
tensor([2])
tensor([9])
tensor([3])
tensor([3])
tensor([8])
tensor([8])
tensor([1])
tensor([1])
tens

In [5]:
global_model, _ = federated_learning(
    clients_dataloaders=clients_dataloaders,
    input_shape=input_shape,
    num_classes=num_classes,
    rounds=rounds,
    epochs=epochs,
    device=device,
    model=model
)

print("Federated learning simulation complete!")

Round 1


KeyboardInterrupt: 

In [None]:
global_model.eval()
correct = 0
total = 0

with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = global_model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = 100 * correct / total
print(f"Test Accuracy: {accuracy:.2f}%")

Test Accuracy: 71.22%


In [None]:
# save trained model to file
torch.save(global_model.state_dict(), 'trained_model_CIFAR10.pth')