In [104]:
from datapreparation import *
from simsiam import *
from utils import *
from evaluation import *
import torch
from collections import OrderedDict
from torch.utils.data import DataLoader
from torch.utils.data import Dataset, TensorDataset, ConcatDataset
from torchvision import datasets, transforms
import torch.optim as optim
import copy

https://github.com/vaseline555/Federated-Averaging-PyTorch/tree/1afb2be2c1972d8527efca357832f71c815b30b4/src

In [126]:
class TwoCropsTransform:
    """Take two random crops of one image as the query and key."""

    def __init__(self, base_transform):
        self.base_transform = base_transform

    def __call__(self, x):
        q = self.base_transform(x)
        k = self.base_transform(x)
        return [q, k]
    

def create_datasets(num_clients, iid):
    """Split the whole dataset in IID or non-IID manner for distributing to clients."""
    # get train and test dataset from cifar-10
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    # MoCo v2's aug: similar to SimCLR https://arxiv.org/abs/2002.05709
    augmentation = [
        transforms.RandomResizedCrop(224, scale=(0.2, 1.)),
        transforms.RandomApply([
            transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)  # not strengthened
        ], p=0.8),
        transforms.RandomGrayscale(p=0.2),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize
    ]

    trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=TwoCropsTransform(transforms.Compose(augmentation)))
    testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                        download=True, transform=transforms.Compose([transforms.ToTensor(), normalize]))

    if iid:
        shuffled_indices = torch.randperm(len(trainset))

        training_inputs = trainset.data[shuffled_indices]
        training_labels = torch.Tensor(trainset.targets)[shuffled_indices]
        split_size = len(trainset) // num_clients
        split_datasets = list(
                    zip(
                        torch.split(torch.Tensor(training_inputs), split_size),
                        torch.split(torch.Tensor(training_labels), split_size)
                    )
                )
        local_trainloaders = [
                    torch.utils.data.DataLoader(local_dataset, batch_size=4, 
                                                    shuffle=True, num_workers=2, pin_memory=True)
                    for local_dataset in split_datasets
    ]
    else: 
        pass

    testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                            shuffle=False, num_workers=2, pin_memory=True)
    return local_trainloaders, testloader

In [128]:
trainloaders, testloader = create_datasets(3, True)
trainloaders

Files already downloaded and verified
Files already downloaded and verified


[<torch.utils.data.dataloader.DataLoader at 0x7fa329d0c040>,
 <torch.utils.data.dataloader.DataLoader at 0x7fa3280c8b20>,
 <torch.utils.data.dataloader.DataLoader at 0x7fa32807c580>,
 <torch.utils.data.dataloader.DataLoader at 0x7fa32807c1c0>]

Traceback (most recent call last):
  File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/home/jonas/.local/lib/python3.10/site-packages/ipykernel_launcher.py", line 17, in <module>
    app.launch_new_instance()
  File "/home/jonas/.local/lib/python3.10/site-packages/traitlets/config/application.py", line 1043, in launch_instance
    app.start()
  File "/home/jonas/.local/lib/python3.10/site-packages/ipykernel/kernelapp.py", line 725, in start
    self.io_loop.start()
  File "/home/jonas/.local/lib/python3.10/site-packages/tornado/platform/asyncio.py", line 215, in start
    self.asyncio_loop.run_forever()
  File "/usr/lib/python3.10/asyncio/base_events.py", line 600, in run_forever
    self._run_once()
  File "/usr/lib/python3.10/asyncio/base_events.py", line 1860, in _run_once
    event_list = self._selector.select(timeout)


In [106]:
class Client:
    def __init__(self, client_id, model, dataloader, local_epochs, device):
        self.client_id = client_id
        self.dataloader = dataloader
        self.model = model
        self.optimizer = optim.SGD(self.model.parameters(), lr=0.03, momentum=0.9, weight_decay=0.0005)
        self.local_epochs = local_epochs
        self.device = device


    def client_update(self):
        self.model.train()
        self.model.to(self.device)
        optimizer = self.optimizer

        for epoch in range(self.local_epochs):  # loop over the dataset multiple times
            epoch_loss = 0.0
            running_loss = 0.0
            for i, data in enumerate(self.dataloader):            
                # get the inputs; data is a list of [inputs, labels]
                # inputs, labels = data
                images, _ = data[0], data[1].to(self.device)
                # zero the parameter gradients
                optimizer.zero_grad()

                # get the two views (with random augmentations):
                x1 = images[0].to(self.device)
                x2 = images[1].to(self.device)
                
                # forward + backward + optimize
                z1, p1 = self.model(x1)
                z2, p2 = self.model(x2)
                #loss = criterion(outputs, labels)
                loss = D(p1, z2)/2 + D(p2, z1)/2
                loss.backward()
                optimizer.step()

                # print statistics
                running_loss += loss.item()
                epoch_loss += loss.item()
                if i % 100 == 99:    # print every 2000 mini-batches
                    print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 100:.3f}')
                    running_loss = 0.0
            print("epoch loss = ", epoch_loss/len(self.dataloader))
        print('Finished Training')

    def client_evaluate(self):
        """evaluates model on local dataset TODO: Should this be done in self-supervised learning and if so, how?"""
        # insert evaluate() method of SimSiam
        pass

In [114]:
class Server:
    def __init__(self, num_clients, iid, num_rounds):
        self.num_clients = num_clients
        self.iid = iid
        self.num_rounds = num_rounds # number of rounds that models should be trained on clients

    def setup(self):
        self.model = SimSiam()
        local_trainloaders, test_loader = create_datasets(self.num_clients, self.iid)
        self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
        self.clients = self.create_clients(local_trainloaders)
        self.testloader = test_loader
        self.send_model()
        
    def create_clients(self, local_trainloaders):
        clients = []
        for i, dataloader in enumerate(local_trainloaders):
            client = Client(client_id=i, model=SimSiam().to(self.device), dataloader=dataloader, batchsize=4, local_epochs=5, device=self.device)
            clients.append(client)
        return clients

    def send_model(self):
        """Send the updated global model to selected/all clients."""
        for client in self.clients:
            client.model = copy.deepcopy(self.model)

    def average_model(self, coefficients):
        """Average the updated and transmitted parameters from each selected client."""
        averaged_weights = OrderedDict()

        for i, client in enumerate(self.clients):
            local_weights = client.model.state_dict()

            for key in self.model.state_dict().keys():
                if i == 0:
                    averaged_weights[key] = coefficients[it] * local_weights[key]
                else:
                    averaged_weights[key] += coefficients[it] * local_weights[key]
        self.model.load_state_dict(averaged_weights)


    def train_federated_model(self):
        # send current model
        self.send_model()
        
        # TODO: Sample only subset of clients

        # update clients (train client models)
        for client in self.clients:
            client.client_update()
        
        # average models
        total_size = sum([len(client.dataloader.dataset[1]) for client in self.clients])
        mixing_coefficients = [len(client.dataloader.dataset[1]) / total_size for client in self.clients]
        self.average_model(mixing_coefficients)
    
    def evaluate_global_model(self):
        # insert evaluation function here
        pass

    def main(self):
        for i in range(self.num_rounds):
            self.train_federated_model()
            # test_loss, test_accuracy = self.evaluate_global_model() # TODO
        self.send_model()

In [125]:
server = Server(2, True, 2)
server.setup()
server.send_model()

for client in server.clients:
    client.client_update()



Files already downloaded and verified
Files already downloaded and verified


RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/jonas/.local/lib/python3.10/site-packages/torch/utils/data/_utils/worker.py", line 302, in _worker_loop
    data = fetcher.fetch(index)
  File "/home/jonas/.local/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 61, in fetch
    return self.collate_fn(data)
  File "/home/jonas/.local/lib/python3.10/site-packages/torch/utils/data/_utils/collate.py", line 265, in default_collate
    return collate(batch, collate_fn_map=default_collate_fn_map)
  File "/home/jonas/.local/lib/python3.10/site-packages/torch/utils/data/_utils/collate.py", line 120, in collate
    return collate_fn_map[elem_type](batch, collate_fn_map=collate_fn_map)
  File "/home/jonas/.local/lib/python3.10/site-packages/torch/utils/data/_utils/collate.py", line 162, in collate_tensor_fn
    out = elem.new(storage).resize_(len(batch), *list(elem.size()))
RuntimeError: Trying to resize storage that is not resizable


In [110]:
averaged_weights = OrderedDict()
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
local_trainloaders, testloaders = create_datasets(2, True)
clients = [Client(client_id=1, model=SimSiam().to(device), dataloader=local_trainloaders[0], batchsize=4, local_epochs=5, device=device), 
           Client(client_id=2, model=SimSiam().to(device), dataloader=local_trainloaders[1], batchsize=4, local_epochs=5, device=device)]
model = SimSiam()

total_size = sum([len(client.dataloader.dataset[1]) for client in clients])
mixing_coefficients = [len(client.dataloader.dataset[1]) / total_size for client in clients]

for i, client in enumerate(clients):
    local_weights = client.model.state_dict()
    print(local_weights)
    for key in model.state_dict().keys():
        print(key)

    for key in model.state_dict().keys():
        if i == 0:
            averaged_weights[key] = mixing_coefficients[i] * local_weights[key]
        else:
            averaged_weights[key] += mixing_coefficients[i] * local_weights[key]
averaged_weights
# self.model.load_state_dict(averaged_weights)

Files already downloaded and verified
Files already downloaded and verified




[0.5, 0.5]


In [85]:
len(local_trainloaders[0].dataset[1])

25000

In [72]:
model = SimSiam()
model.state_dict().keys()



odict_keys(['model.0.weight', 'model.1.weight', 'model.1.bias', 'model.1.running_mean', 'model.1.running_var', 'model.1.num_batches_tracked', 'model.4.0.conv1.weight', 'model.4.0.bn1.weight', 'model.4.0.bn1.bias', 'model.4.0.bn1.running_mean', 'model.4.0.bn1.running_var', 'model.4.0.bn1.num_batches_tracked', 'model.4.0.conv2.weight', 'model.4.0.bn2.weight', 'model.4.0.bn2.bias', 'model.4.0.bn2.running_mean', 'model.4.0.bn2.running_var', 'model.4.0.bn2.num_batches_tracked', 'model.4.1.conv1.weight', 'model.4.1.bn1.weight', 'model.4.1.bn1.bias', 'model.4.1.bn1.running_mean', 'model.4.1.bn1.running_var', 'model.4.1.bn1.num_batches_tracked', 'model.4.1.conv2.weight', 'model.4.1.bn2.weight', 'model.4.1.bn2.bias', 'model.4.1.bn2.running_mean', 'model.4.1.bn2.running_var', 'model.4.1.bn2.num_batches_tracked', 'model.5.0.conv1.weight', 'model.5.0.bn1.weight', 'model.5.0.bn1.bias', 'model.5.0.bn1.running_mean', 'model.5.0.bn1.running_var', 'model.5.0.bn1.num_batches_tracked', 'model.5.0.conv2.w

In [63]:
server = Server(num_clients=5, iid=True, num_rounds=5)
server.setup()

Files already downloaded and verified
Files already downloaded and verified




AttributeError: 'Server' object has no attribute 'model'

In [8]:
# PATH = "models/simsiam.pth"

# # load trained model
# model = SimSiam()
# device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
# model.load_state_dict(torch.load(PATH, map_location=device))