In [None]:
import torch
from torch.utils.data import DataLoader, Subset
from torchvision.datasets import CIFAR10
from torchvision.transforms import transforms
from torchvision.models import resnet18
import pickle

accuracy_dict = dict()


# Define dataset transforms
transform = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Load CIFAR-10 dataset and create subsets for clients
train_dataset = CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = CIFAR10(root='./data', train=False, download=True, transform=transform)

test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# Define number of clients and split the dataset indices
num_clients = 5
dataset_size = len(train_dataset)
indices = list(range(dataset_size))
split_size = dataset_size // num_clients
client_indices = [indices[i*split_size:(i+1)*split_size] for i in range(num_clients)]

# Create subsets and data loaders for each client
clients_datasets = []
clients_loaders = []

for i in range(num_clients):
    subset = Subset(train_dataset, client_indices[i])
    loader = DataLoader(subset, batch_size=32, shuffle=True)
    clients_datasets.append(subset)
    clients_loaders.append(loader)

# Define the ResNet model
model = resnet18(pretrained=False)
num_classes = len(train_dataset.classes)
model.fc = torch.nn.Linear(model.fc.in_features, num_classes)

# Train the model using federated learning
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

for epoch in range(10):
    # Train on each client
    model.train()
    for client_loader in clients_loaders:
        for images, labels in client_loader:
            images = images.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()

            outputs = model(images)
            loss = torch.nn.CrossEntropyLoss()(outputs, labels)
            loss.backward()
            optimizer.step()

    # Evaluate on the test dataset
    model.eval()
    with torch.no_grad():
        test_loss = 0
        test_acc = 0
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)

            outputs = model(images)
            loss = torch.nn.CrossEntropyLoss()(outputs, labels)

            test_loss += loss.item() * images.size(0)
            test_acc += (outputs.argmax(dim=1) == labels).sum().item()

        test_loss /= len(test_dataset)
        test_acc /= len(test_dataset)

    print(f'Epoch {epoch+1}, Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}')
    accuracy_dict[epoch+1] = test_acc
    
with open('Res_CIFAR10_10_C5.pickle', 'wb') as handle:
    pickle.dump(accuracy_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)