In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os

In [2]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import Subset, DataLoader

# Define transformation for MNIST (e.g., convert to tensor)
transform = transforms.Compose([transforms.ToTensor()
                               ,transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

# Load the full MNIST training dataset
mnist_train = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
mnist_test = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)


# Create a DataLoader for the subset
test_loader = DataLoader(mnist_test, batch_size=64, shuffle=True)


# Select the first 600 images for training
subset_indices_1 = list(range(640))
mnist_subset_1 = Subset(mnist_train, subset_indices_1)

# Create a DataLoader for the subset
train_loader_1 = DataLoader(mnist_subset_1, batch_size=64, shuffle=True)

# Select the first 600 images for training
subset_indices_2 = list(range(640,1280))
mnist_subset_2 = Subset(mnist_train, subset_indices_2)

# Create a DataLoader for the subset
train_loader_2 = DataLoader(mnist_subset_2, batch_size=64, shuffle=True)



Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:04<00:00, 42482153.60it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [21]:
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F

class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 10, kernel_size=5)  # 3 input channels (for CIFAR-10), 10 output channels
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)  # 10 input channels, 20 output channels
        self.fc1 = nn.Linear(500, 50)  # We'll compute 500 dynamically based on input size
        self.fc2 = nn.Linear(50, 10)  # 10 output features (CIFAR-10 has 10 classes)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))  # RELU activation, max pooling
        x = F.relu(F.max_pool2d(self.conv2(x), 2))  # RELU activation, max pooling
        x = x.view(x.size(0), -1)  # Dynamically flatten the output
        x = F.relu(self.fc1(x))  # RELU activation
        x = self.fc2(x)
        return x


In [27]:
import torch.optim as optim

def evaluate(model,loader) -> float:
    correct, total = 0, 0
    model.eval()
    with torch.no_grad():
        for data in loader:
            inputs, labels = data
            inputs = inputs.to(device)
            labels = labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return correct / total

def train(model,train_loader,batch_size=64,EPOCHS=5):
    model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
    model.train()
    for epoch in range(EPOCHS):
        running_loss = 0.0
        running_correct = 0.0
        model = model.to(device)
        for i, data in enumerate(train_loader, 0):
            output = None
            inputs, labels = data
            inputs = inputs.to(device)
            labels = labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            running_correct += (labels == predicted).sum().item()
        epoch_loss = running_loss / len(train_loader)
        epoch_accuracy = running_correct / (len(train_loader) * 64)
        test_accuracy = evaluate(model,test_loader)
        model.train()
        print(f"Epoch [{epoch+1}/{EPOCHS}]: Loss: {epoch_loss}, Train accuracy: {epoch_accuracy}, Test accuracy: {test_accuracy}")
    return  model

In [28]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')


In [29]:
import torch
import copy

# Assume model_list is a list of models
def average_model_weights(model_list):
    avg_model_state_dict = copy.deepcopy(model_list[0])
    
    for key in avg_model_state_dict:
        avg_model_state_dict[key] = torch.zeros_like(avg_model_state_dict[key])

    num_models = len(model_list)
    for model_state_dict in model_list:
        for key in avg_model_state_dict:
            avg_model_state_dict[key] += model_state_dict[key]
    
    for key in avg_model_state_dict:
        avg_model_state_dict[key] /= num_models

    return avg_model_state_dict
    


In [31]:
train_loaders = [train_loader_1, train_loader_2]

def server_execute(t,E,K):
    w0 = copy.deepcopy((SimpleCNN()).state_dict())
    for i in range(t):
        weights = []
        for k in range(K):
            weights.append(ClientUpdate(w0,E,k))
        w0 = average_model_weights(weights)
        avg_model = SimpleCNN()
        avg_model.load_state_dict(w0)
        print(f"round {i} accuracy: {evaluate(avg_model,test_loader)}")
    
def ClientUpdate(w,E,k):
    model = SimpleCNN()
    model.load_state_dict(state_dict=w)
    model = train(model,train_loaders[k],EPOCHS=E)
    print(f"k={k}: {evaluate(model,test_loader)}")
    return model.state_dict()

server_execute(5,10,2)

Epoch [1/10]: Loss: 2.3039036273956297, Train accuracy: 0.096875, Test accuracy: 0.1056
Epoch [2/10]: Loss: 2.2890339374542235, Train accuracy: 0.125, Test accuracy: 0.1002
Epoch [3/10]: Loss: 2.2713544845581053, Train accuracy: 0.1234375, Test accuracy: 0.1
Epoch [4/10]: Loss: 2.2482853174209594, Train accuracy: 0.1234375, Test accuracy: 0.103
Epoch [5/10]: Loss: 2.212786626815796, Train accuracy: 0.15, Test accuracy: 0.1593
Epoch [6/10]: Loss: 2.1436119556427, Train accuracy: 0.24375, Test accuracy: 0.2176
Epoch [7/10]: Loss: 2.0526759147644045, Train accuracy: 0.2890625, Test accuracy: 0.247
Epoch [8/10]: Loss: 2.0005127429962157, Train accuracy: 0.2828125, Test accuracy: 0.249
Epoch [9/10]: Loss: 1.9372020721435548, Train accuracy: 0.290625, Test accuracy: 0.2738
Epoch [10/10]: Loss: 1.8589028596878052, Train accuracy: 0.365625, Test accuracy: 0.2962
k=0: 0.2962
Epoch [1/10]: Loss: 2.3030240535736084, Train accuracy: 0.096875, Test accuracy: 0.1198
Epoch [2/10]: Loss: 2.29218645095