In [1]:
from utils.dataset import FLMNIST, mnist_data_split
from nodes.client import Client
from nodes.server import Server
from models.LeNet import LeNet

from torchvision import transforms
from torch.utils.data import DataLoader
from typing import List

import torch

In [2]:
dataset0 = FLMNIST(
        root="./datasets/",
        train=True,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))]),
        target_transform=None,
        download=False,
        data_ids=range(10000)
    )

dataset1 = FLMNIST(
        root="./datasets/",
        train=True,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))]),
        target_transform=None,
        download=False,
        data_ids=range(20000, 20300)
    )

In [3]:
test_iter = DataLoader(
    dataset=FLMNIST(
        root="./datasets/",
        train=False,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))]),
        target_transform=None,
        download=False,
        data_ids=range(10000)),
    batch_size=128,
    shuffle=True,
    num_workers=0
)


In [5]:

# initialize clients and server
model = LeNet()
model_structure = [0]
init_param = []
for param in model.parameters():
    model_structure.append(param.data.numel() + model_structure[-1])
    init_param += param.data.cpu().view(-1).tolist()

optimizer = torch.optim.Adam
loss_func = torch.nn.CrossEntropyLoss()

client0 = Client(dataset=dataset0, model_structure=model_structure, device='cuda')
client1 = Client(dataset=dataset1, model_structure=model_structure, device='cuda')

client0.set_weights(init_param)
client1.set_weights(init_param)

server = Server(model_structure=model_structure)

for epoch in range(1, 10):

    weight_list = []
    for cid, client in enumerate([client0, client1]):
        print("\ncid=%d:------"%cid)
        weight_list.append(
            client.local_train_step(model=model, loss_func=loss_func, optimizer=optimizer,
                            batch_size=128, num_workers=0, lr=0.01)
        )
        print("client accuracy=%.3f"%client.evaluate_accuracy(test_iter, model, 'cuda'))

    agg_weights = server.naive_aggregation(weight_list=weight_list)
    print("\nsever accuracy=%.3f"%server.evaluate_accuracy(test_iter, model, 'cuda'))
    
    for cid, client in enumerate([client0, client1]):
        client.set_weights(agg_weights.copy())


cid=0:------
Parameter containing:
tensor([ 0.0969,  0.0674,  0.0212,  0.0572,  0.0750, -0.0012,  0.0515,  0.0088,
        -0.1032,  0.0494], device='cuda:0', requires_grad=True)
Parameter containing:
tensor([ 0.1436, -0.1265,  0.0887,  0.0965,  0.1109, -0.0241,  0.0885,  0.0221,
        -0.0038,  0.0416], device='cuda:0', requires_grad=True)
client accuracy=0.240

cid=1:------
Parameter containing:
tensor([ 0.0969,  0.0674,  0.0212,  0.0572,  0.0750, -0.0012,  0.0515,  0.0088,
        -0.1032,  0.0494], device='cuda:0', requires_grad=True)
Parameter containing:
tensor([ 0.1004,  0.0698,  0.0070,  0.0550,  0.0751, -0.0118,  0.0487,  0.0015,
        -0.1067,  0.0539], device='cuda:0', requires_grad=True)
client accuracy=0.098

sever accuracy=0.129

cid=0:------
Parameter containing:
tensor([ 0.1220, -0.0284,  0.0478,  0.0757,  0.0930, -0.0179,  0.0686,  0.0118,
        -0.0552,  0.0477], device='cuda:0', requires_grad=True)
Parameter containing:
tensor([ 0.6727, -0.3977,  0.1392,  0.16