In [101]:
import torch
import torchvision
from torchvision import datasets, transforms

import ssl
ssl._create_default_https_context = ssl._create_unverified_context

In [102]:
dataset_train = datasets.CIFAR10('./data/cifar', train=True, download=True)
dataset_test = datasets.CIFAR10('./data/cifar', train=False, download=True)

Files already downloaded and verified
Files already downloaded and verified


In [108]:
nb_clients = 8


def splitting_dataset(dataset, nb_clients):
    labels_per_client = 5000//nb_clients
    clients = [{i:0 for i in range(10)} for _ in range(nb_clients)]
    clients_dataset = [[] for _ in range(nb_clients)]

    for data in dataset:
        for i, client in enumerate(clients):
            if client[data[1]] < labels_per_client:
                client[data[1]] += 1
                clients_dataset[i].append(data)
                break
    
    return clients_dataset

client_train_dataset = splitting_dataset(dataset_train, nb_clients)
client_test_dataset = splitting_dataset(dataset_test, nb_clients)


In [109]:
print(len(client_train_dataset[0]))
print(len(client_train_dataset[1]))
print(len(client_train_dataset[2]))
print(len(client_train_dataset[3]))
print(len(client_train_dataset[4]))
print(len(client_train_dataset[5]))
print(len(client_train_dataset[6]))
print(len(client_train_dataset[7]))


6250
6250
6250
6250
6250
6250
6250
6250


In [110]:
# I have to transform the dataset PIL images to tensors
dataset_train_0 = [(torchvision.transforms.functional.to_tensor(data[0]), data[1]) for data in client_train_dataset[0]]
dataset_test_0 = [(torchvision.transforms.functional.to_tensor(data[0]), data[1]) for data in client_test_dataset[0]]


In [111]:
import torch.nn as nn
import torch.nn.functional as F
    
class CNNCifar(nn.Module):
    def __init__(self):
        super(CNNCifar, self).__init__()

        self.conv1 = nn.Conv2d(3, 64, 5)
        self.pool = nn.MaxPool2d(3, 2)
        self.conv2 = nn.Conv2d(64, 64, 5)
        self.fc1 = nn.Linear(64 * 4 * 4, 384)
        self.fc2 = nn.Linear(384, 192)
        self.fc3 = nn.Linear(192, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 64 * 4 * 4)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return F.log_softmax(x, dim=1)

In [112]:
model = CNNCifar()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [113]:
from torch.utils.data import DataLoader, TensorDataset

# Convert lists to tensor datasets
train_data = [(data[0], data[1]) for data in dataset_train_0]
test_data = [(data[0], data[1]) for data in dataset_test_0]

train_dataset = TensorDataset(torch.stack([item[0] for item in train_data]), torch.tensor([item[1] for item in train_data]))
test_dataset = TensorDataset(torch.stack([item[0] for item in test_data]), torch.tensor([item[1] for item in test_data]))

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

In [116]:
def train_model(model, train_loader, criterion, optimizer, num_epochs=25):
    model.train()
    for epoch in range(num_epochs):
        running_loss = 0.0
        for images, labels in train_loader:
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            print(labels.shape, outputs.shape)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        print(f'Epoch {epoch+1}, Loss: {running_loss/len(train_loader)}')

train_model(model, train_loader, criterion, optimizer)

torch.Size([64]) torch.Size([64, 10])
torch.Size([64]) torch.Size([64, 10])
torch.Size([64]) torch.Size([64, 10])
torch.Size([64]) torch.Size([64, 10])
torch.Size([64]) torch.Size([64, 10])
torch.Size([64]) torch.Size([64, 10])
torch.Size([64]) torch.Size([64, 10])
torch.Size([64]) torch.Size([64, 10])
torch.Size([64]) torch.Size([64, 10])
torch.Size([64]) torch.Size([64, 10])
torch.Size([64]) torch.Size([64, 10])
torch.Size([64]) torch.Size([64, 10])
torch.Size([64]) torch.Size([64, 10])
torch.Size([64]) torch.Size([64, 10])
torch.Size([64]) torch.Size([64, 10])
torch.Size([64]) torch.Size([64, 10])
torch.Size([64]) torch.Size([64, 10])
torch.Size([64]) torch.Size([64, 10])
torch.Size([64]) torch.Size([64, 10])
torch.Size([64]) torch.Size([64, 10])
torch.Size([64]) torch.Size([64, 10])
torch.Size([64]) torch.Size([64, 10])
torch.Size([64]) torch.Size([64, 10])
torch.Size([64]) torch.Size([64, 10])
torch.Size([64]) torch.Size([64, 10])
torch.Size([64]) torch.Size([64, 10])
torch.Size([

KeyboardInterrupt: 

In [115]:
def evaluate_model(model, test_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    print(f'Accuracy: {100 * correct / total}%')

evaluate_model(model, test_loader)

Accuracy: 56.592%
