In [1]:
from datapreparation import *
from simsiam import *
from utils import *
from evaluation import *
import torch
from collections import OrderedDict
from torch.utils.data import DataLoader
from torch.utils.data import Dataset, TensorDataset, ConcatDataset
from torchvision import datasets, transforms
import torch.optim as optim
import copy
import torchvision
import numpy as np
from PIL import Image
import torch.nn as nn


### TODO
- add non-iid dataset creation

https://github.com/vaseline555/Federated-Averaging-PyTorch/tree/1afb2be2c1972d8527efca357832f71c815b30b4/src

In [97]:
class MyDataset(Dataset):
    def __init__(self, x, y, is_train=True):#, transform_x=None):
        self.x = x
        self.y = y
      #  self.transform_x = transform_x
        self.is_train = is_train

    def __getitem__(self, idx):
        x = self.x[idx]
        x = Image.fromarray(x.astype(np.uint8))

        y = self.y[idx]

        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                std=[0.229, 0.224, 0.225])
        augmentation = [
            transforms.RandomResizedCrop(224, scale=(0.2, 1.)),
            transforms.RandomApply([
                transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)  # not strengthened
            ], p=0.8),
            transforms.RandomGrayscale(p=0.2),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize
            ]

        if self.is_train:
            transform = transforms.Compose(augmentation)

            x1 = transform(x)
            x2 = transform(x)
            return [x1, x2], y

        else:
            transform=transforms.Compose([transforms.ToTensor(), normalize])

            x = transform(x)
            return x, y
    
    def __len__(self):
        return len(self.x)
    

def create_datasets(num_clients, iid):
    """Split the whole dataset in IID or non-IID manner for distributing to clients."""

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    
    trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True)
    testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, 
                                           transform=transforms.Compose([transforms.ToTensor(), normalize]))

    if iid:
        shuffled_indices = torch.randperm(len(trainset))
        training_x = trainset.data[shuffled_indices]
        training_y = torch.Tensor(trainset.targets)[shuffled_indices]

        split_size = len(trainset) // num_clients
        split_datasets = list(
                            zip(
                                torch.split(torch.Tensor(training_x), split_size),
                                torch.split(torch.Tensor(training_y), split_size)
                            )
                        )
        new_split_datasets = [(dataset[0].numpy(), dataset[1].tolist()) for dataset in split_datasets]
        new_split_datasets = [(dataset[0], list(map(int, dataset[1]))) for dataset in new_split_datasets]

        local_trainset = [MyDataset(local_dataset[0], local_dataset[1], is_train=True) for local_dataset in new_split_datasets]

        local_dataloaders = [DataLoader(dataset=dataset, batch_size=64, shuffle=True, num_workers=2, pin_memory=True) for dataset in local_trainset]
    else: 
        # If non-iid: Sort by label and split to clients
        labels = trainset.targets
        sorted_indices = torch.as_tensor([i[0] for i in sorted(enumerate(labels), key=lambda x:x[1])])
        training_x = trainset.data[sorted_indices]
        training_y = torch.Tensor(trainset.targets)[sorted_indices]

        split_size = len(trainset) // num_clients
        split_datasets = list(
                            zip(
                                torch.split(torch.Tensor(training_x), split_size),
                                torch.split(torch.Tensor(training_y), split_size)
                            )
                        )
        new_split_datasets = [(dataset[0].numpy(), dataset[1].tolist()) for dataset in split_datasets]
        new_split_datasets = [(dataset[0], list(map(int, dataset[1]))) for dataset in new_split_datasets]

        local_trainset = [MyDataset(local_dataset[0], local_dataset[1], is_train=True) for local_dataset in new_split_datasets]

        local_dataloaders = [DataLoader(dataset=dataset, batch_size=64, shuffle=True, num_workers=2, pin_memory=True) for dataset in local_trainset]

    testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                            shuffle=False, num_workers=2, pin_memory=True)
    return local_dataloaders, testloader

In [62]:
class DownstreamEvaluation(nn.Module):
    def __init__(self, simsiam_model):
        super(DownstreamEvaluation, self).__init__()
        self.simsiam = simsiam_model

        # freeze parameters         
        for param in self.simsiam.parameters():
            param.requires_grad = False

        self.classifier = nn.Linear(2048, 10)


    def forward(self, x):
        z, _ = self.simsiam(x)
        x = self.classifier(z)
        return x

In [60]:
class Client:
    def __init__(self, client_id, model, dataloader, local_epochs):
        self.client_id = client_id
        self.dataloader = dataloader
        self.model = model
        self.local_epochs = local_epochs
        self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')


    def client_update(self):
        self.model.train()
        self.model.to(self.device)
        optimizer = optim.SGD(self.model.parameters(), lr=0.03, momentum=0.9, weight_decay=0.0005)

        for epoch in range(self.local_epochs):  # loop over the dataset multiple times
            epoch_loss = 0.0
            running_loss = 0.0
            for i, data in enumerate(self.dataloader):            
                # get the inputs; data is a list of [inputs, labels]
                # inputs, labels = data
                images, labels = data[0], data[1].to(self.device)
                # zero the parameter gradients
                optimizer.zero_grad()

                # get the two views (with random augmentations):
                x1 = images[0].to(self.device)
                x2 = images[1].to(self.device)
                
                # forward + backward + optimize
                z1, p1 = self.model(x1)
                z2, p2 = self.model(x2)
                #loss = criterion(outputs, labels)
                loss = D(p1, z2)/2 + D(p2, z1)/2
                loss.backward()
                optimizer.step()

                # print statistics
                running_loss += loss.item()
                epoch_loss += loss.item()
                if i % 100 == 99:    # print every 2000 mini-batches
                    print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 100:.3f}')
                    running_loss = 0.0
            print("epoch loss = ", epoch_loss/len(self.dataloader))
        print('Finished Training')

    def client_evaluate(self):
        """evaluates model on local dataset TODO: Should this be done in self-supervised learning and if so, how?"""
        # insert evaluate() method of SimSiam
        pass

In [63]:
class Server:
    def __init__(self, num_clients, iid, num_rounds, local_epochs):
        self.num_clients = num_clients
        self.iid = iid
        self.num_rounds = num_rounds # number of rounds that models should be trained on clients
        self.local_epochs = local_epochs # number of epochs each client is trained per round

    def setup(self):
        self.model = SimSiam()
        local_trainloaders, test_loader = create_datasets(self.num_clients, self.iid)
        self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
        self.clients = self.create_clients(local_trainloaders)
        self.testloader = test_loader
        self.send_model()
        
    def create_clients(self, local_trainloaders):
        clients = []
        for i, dataloader in enumerate(local_trainloaders):
            client = Client(client_id=i, model=SimSiam().to(self.device), dataloader=dataloader, local_epochs=self.local_epochs)
            clients.append(client)
        return clients

    def send_model(self):
        """Send the updated global model to selected/all clients."""
        for client in self.clients:
            client.model = copy.deepcopy(self.model)

    def average_model(self, coefficients):
        """Average the updated and transmitted parameters from each selected client."""
        averaged_weights = OrderedDict()

        for i, client in enumerate(self.clients):
            local_weights = client.model.state_dict()

            for key in self.model.state_dict().keys():
                if i == 0:
                    averaged_weights[key] = coefficients[i] * local_weights[key]
                else:
                    averaged_weights[key] += coefficients[i] * local_weights[key]
        self.model.load_state_dict(averaged_weights)


    def train_federated_model(self):
        # send current model
        self.send_model()
        
        # TODO: Sample only subset of clients

        # update clients (train client models)
        for client in self.clients:
            client.client_update()
        
        # average models
        total_size = sum([len(client.dataloader.dataset[1]) for client in self.clients])
        mixing_coefficients = [len(client.dataloader.dataset[1]) / total_size for client in self.clients]
        self.average_model(mixing_coefficients)
    
    def evaluate_global_model(self, num_epochs):
        """Linear evaluation on 1% of CIFAR-10 Training data"""
        trainloader, testloader = get_downstream_data(percentage_of_data=0.01, batch_size=64)

        model = DownstreamEvaluation(self.model)
        model = model.to(self.device)

        model.simsiam.eval()
        model.classifier.train()

        # Train SimSiam on downstream task
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

        train_downstream(num_epochs, model, trainloader, criterion, optimizer, device=self.device)

        # Evaluate SimSiam on downstream task
        evaluate_simsiam_downstream(model, testloader, self.device)
    
    
    def learn_federated_simsiam(self):
        self.setup()
        for i in range(self.num_rounds):
            self.train_federated_model()
            downstream_accuracy = self.evaluate_global_model() 
            print(downstream_accuracy)
        # save final averaged model
        PATH = "simsiam_fedavg.pth"
        torch.save(self.model.state_dict(), PATH)
        self.send_model()

In [3]:
# server = Server(2, True, 1)
# server.learn_federated_simsiam()

In [8]:
# Evaluation of FEDAVG Model

In [25]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [26]:
PATH = "models/simsiam_fedavg.pth"
model = SimSiamDownstream(trained_model_path=PATH, device=device, linearevaluation=True)
model = model.to(device)

In [27]:
trainloader, testloader = get_downstream_data(percentage_of_data=0.1, batch_size=32)

Files already downloaded and verified
Files already downloaded and verified


In [28]:
optimizer = optim.SGD(model.parameters(), lr=0.03, momentum=0.9, weight_decay=0.0005)
criterion = nn.CrossEntropyLoss()

In [29]:
train_simsiam_downstream(5, model, trainloader, criterion, optimizer, device)

Finished Training


In [30]:
evaluate_simsiam_downstream(model, testloader, device)

Accuracy of the network on the 10000 test images: 46 %


46

In [31]:
# Supervised model

In [34]:
model = SupervisedModel(pretrained=True, linearevaluation=True)
model = model.to(device)

In [35]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

In [36]:
train_downstream(5, model, trainloader, criterion, optimizer, device)

Finished Training


In [37]:
evaluate_downstream(model, testloader, device)

Accuracy of the network on the 10000 test images: 40 %
