In [45]:
from datapreparation import *
from simsiam import *
from utils import *
from evaluation import *
import torch
from collections import OrderedDict
from torch.utils.data import DataLoader
from torch.utils.data import Dataset, TensorDataset, ConcatDataset
from torchvision import datasets, transforms
import torch.optim as optim
import copy
from PIL import Image


https://github.com/vaseline555/Federated-Averaging-PyTorch/tree/1afb2be2c1972d8527efca357832f71c815b30b4/src

In [58]:
class MyDataset(Dataset):
    def __init__(self, x, y, is_train=True):#, transform_x=None):
        self.x = x
        self.y = y
      #  self.transform_x = transform_x
        self.is_train = is_train

    def __getitem__(self, idx):
        x = self.x[idx]
        x = Image.fromarray(x.astype(np.uint8))

        y = self.y[idx]

        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                std=[0.229, 0.224, 0.225])
        augmentation = [
            transforms.RandomResizedCrop(224, scale=(0.2, 1.)),
            transforms.RandomApply([
                transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)  # not strengthened
            ], p=0.8),
            transforms.RandomGrayscale(p=0.2),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize
            ]

        if self.is_train:
            transform = transforms.Compose(augmentation)

            x1 = transform(x)
            x2 = transform(x)
            return [x1, x2], y

        else:
            transform=transforms.Compose([transforms.ToTensor(), normalize])

            x = transform(x)
            return x, y
    
    def __len__(self):
        return len(self.x)
    

def create_datasets(num_clients, iid):
    """Split the whole dataset in IID or non-IID manner for distributing to clients."""

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    
    trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True)
    testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, 
                                           transform=transforms.Compose([transforms.ToTensor(), normalize]))

    if iid:
        shuffled_indices = torch.randperm(len(trainset))
        training_x = trainset.data[shuffled_indices]
        training_y = torch.Tensor(trainset.targets)[shuffled_indices]

        split_size = len(trainset) // num_clients
        split_datasets = list(
                            zip(
                                torch.split(torch.Tensor(training_x), split_size),
                                torch.split(torch.Tensor(training_y), split_size)
                            )
                        )
        new_split_datasets = [(dataset[0].numpy(), dataset[1].tolist()) for dataset in split_datasets]
        new_split_datasets = [(dataset[0], list(map(int, dataset[1]))) for dataset in new_split_datasets]

        local_trainset = [MyDataset(local_dataset[0], local_dataset[1], is_train=True) for local_dataset in new_split_datasets]

        local_dataloaders = [DataLoader(dataset=dataset, batch_size=4, shuffle=True, num_workers=2, pin_memory=True) for dataset in local_trainset]
    else: 
        pass

    testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                            shuffle=False, num_workers=2, pin_memory=True)
    return local_dataloaders, testloader

In [60]:
class Client:
    def __init__(self, client_id, model, dataloader, local_epochs, device):
        self.client_id = client_id
        self.dataloader = dataloader
        self.model = model
        self.optimizer = optim.SGD(self.model.parameters(), lr=0.03, momentum=0.9, weight_decay=0.0005)
        self.local_epochs = local_epochs
        self.device = device


    def client_update(self):
        self.model.train()
        self.model.to(self.device)
        optimizer = self.optimizer

        for epoch in range(self.local_epochs):  # loop over the dataset multiple times
            epoch_loss = 0.0
            running_loss = 0.0
            for i, data in enumerate(self.dataloader):            
                # get the inputs; data is a list of [inputs, labels]
                # inputs, labels = data
                images, _ = data[0], data[1].to(self.device)
                # zero the parameter gradients
                optimizer.zero_grad()

                # get the two views (with random augmentations):
                x1 = images[0].to(self.device)
                x2 = images[1].to(self.device)
                
                # forward + backward + optimize
                z1, p1 = self.model(x1)
                z2, p2 = self.model(x2)
                #loss = criterion(outputs, labels)
                loss = D(p1, z2)/2 + D(p2, z1)/2
                loss.backward()
                optimizer.step()

                # print statistics
                running_loss += loss.item()
                epoch_loss += loss.item()
                if i % 100 == 99:    # print every 2000 mini-batches
                    print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 100:.3f}')
                    running_loss = 0.0
            print("epoch loss = ", epoch_loss/len(self.dataloader))
        print('Finished Training')

    def client_evaluate(self):
        """evaluates model on local dataset TODO: Should this be done in self-supervised learning and if so, how?"""
        # insert evaluate() method of SimSiam
        pass

In [61]:
class Server:
    def __init__(self, num_clients, iid, num_rounds):
        self.num_clients = num_clients
        self.iid = iid
        self.num_rounds = num_rounds # number of rounds that models should be trained on clients

    def setup(self):
        self.model = SimSiam()
        local_trainloaders, test_loader = create_datasets(self.num_clients, self.iid)
        self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
        self.clients = self.create_clients(local_trainloaders)
        self.testloader = test_loader
        self.send_model()
        
    def create_clients(self, local_trainloaders):
        clients = []
        for i, dataloader in enumerate(local_trainloaders):
            client = Client(client_id=i, model=SimSiam().to(self.device), dataloader=dataloader, batchsize=4, local_epochs=5, device=self.device)
            clients.append(client)
        return clients

    def send_model(self):
        """Send the updated global model to selected/all clients."""
        for client in self.clients:
            client.model = copy.deepcopy(self.model)

    def average_model(self, coefficients):
        """Average the updated and transmitted parameters from each selected client."""
        averaged_weights = OrderedDict()

        for i, client in enumerate(self.clients):
            local_weights = client.model.state_dict()

            for key in self.model.state_dict().keys():
                if i == 0:
                    averaged_weights[key] = coefficients[i] * local_weights[key]
                else:
                    averaged_weights[key] += coefficients[i] * local_weights[key]
        self.model.load_state_dict(averaged_weights)


    def train_federated_model(self):
        # send current model
        self.send_model()
        
        # TODO: Sample only subset of clients

        # update clients (train client models)
        for client in self.clients:
            client.client_update()
        
        # average models
        total_size = sum([len(client.dataloader.dataset[1]) for client in self.clients])
        mixing_coefficients = [len(client.dataloader.dataset[1]) / total_size for client in self.clients]
        self.average_model(mixing_coefficients)
    
    def evaluate_global_model(self):
        # insert evaluation function here
        pass

    def main(self):
        for i in range(self.num_rounds):
            self.train_federated_model()
            # test_loss, test_accuracy = self.evaluate_global_model() # TODO
        self.send_model()

In [39]:
server = Server(2, True, 2)
server.setup()
server.send_model()

for client in server.clients:
    client.client_update()



Files already downloaded and verified
Files already downloaded and verified


Traceback (most recent call last):
  File "/usr/lib/python3.10/multiprocessing/queues.py", line 250, in _feed
    send_bytes(obj)
  File "/usr/lib/python3.10/multiprocessing/connection.py", line 200, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/usr/lib/python3.10/multiprocessing/connection.py", line 411, in _send_bytes
    self._send(header + buf)
  File "/usr/lib/python3.10/multiprocessing/connection.py", line 368, in _send
    n = write(self._handle, buf)
OSError: [Errno 9] Bad file descriptor
Exception in thread QueueFeederThread:
Traceback (most recent call last):
  File "/usr/lib/python3.10/multiprocessing/queues.py", line 239, in _feed
    reader_close()
  File "/usr/lib/python3.10/multiprocessing/connection.py", line 177, in close
    self._close()
  File "/usr/lib/python3.10/multiprocessing/connection.py", line 361, in _close
    _close(self._handle)
OSError: [Errno 9] Bad file descriptor

During handling of the above exception, another exception occurre

TypeError: Client.__init__() got an unexpected keyword argument 'batchsize'