In [1]:
import logging

In [7]:
logging.basicConfig(level = logging.INFO)

In [3]:
log = logging.getLogger(__name__)

In [5]:
log.setLevel(logging.INFO)

In [8]:
log.info("Hi you just set your fleeb to level plumbus")


INFO:__main__:Hi you just set your fleeb to level plumbus


In [74]:
import copy
from collections import OrderedDict

import torch


class CenterServer:
    def __init__(self, model, dataloader, device="cpu"):
        self.model = model
        self.dataloader = dataloader
        self.device = device

    def aggregation(self):
        raise NotImplementedError

    def send_model(self):
        return copy.deepcopy(self.model)

    def validation(self):
        raise NotImplementedError


class FedAvgCenterServer(CenterServer):
    def __init__(self, model, dataloader, device="cpu"):
        super().__init__(model, dataloader, device)

    def aggregation(self, clients, aggregation_weights):
        update_state = OrderedDict()

        for k, client in enumerate(clients):
            local_state = client.model.state_dict()
            for key in self.model.state_dict().keys():
                if k == 0:
                    update_state[
                        key] = local_state[key] * aggregation_weights[k]
                else:
                    update_state[
                        key] += local_state[key] * aggregation_weights[k]

        self.model.load_state_dict(update_state)

    def validation(self, loss_fn):
        self.model.to(self.device)
        self.model.eval()
        test_loss = 0
        correct = 0
        with torch.no_grad():
            for img, target in self.dataloader:
                img = img.to(self.device)
                target = target.type(torch.LongTensor)
                target = target.to(self.device)
                logits = self.model(img)
                test_loss += loss_fn(logits, target).item()
                pred = logits.argmax(dim=1, keepdim=True)
                correct += pred.eq(target.view_as(pred)).sum().item()

        self.model.to("cpu")
        test_loss = test_loss / len(self.dataloader)
        accuracy = 100. * correct / len(self.dataloader.dataset)

        return test_loss, accuracy

In [83]:
class Client:
    def __init__(self, client_id, dataloader, device='cpu'):
        self.client_id = client_id
        self.dataloader = dataloader
        self.device = device
        self.__model = None

    @property
    def model(self):
        return self.__model

    @model.setter
    def model(self, model):
        self.__model = model

    def client_update(self,  local_epoch, loss_fn):
        raise NotImplementedError

    def __len__(self):
        return len(self.dataloader.dataset)


class FedAvgClient(Client):
    def client_update(self,  local_epoch, loss_fn):
        self.model.train()
        self.model.to(self.device)
        optimizer = optim.Adam(self.model.parameters(),lr = 1e-3)
        for i in range(local_epoch):
            for img, target in self.dataloader:
                img = img.to(self.device)
                target = target.type(torch.LongTensor)
                target = target.to(self.device)
                optimizer.zero_grad()
                logits = self.model(img)
                loss = loss_fn(logits, target)

                loss.backward()
                optimizer.step()
        self.model.to("cpu")

In [3]:
import PIL.Image as Image
import torch
from torch.utils.data import Dataset
from torchvision import transforms


class MnistLocalDataset(Dataset):
    def __init__(self, images, labels, client_id):
        self.images = images
        self.labels = labels.astype(int)
        self.client_id = client_id
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, ))
        ])

    def __getitem__(self, index):
        img = Image.fromarray(self.images[index].reshape(28, 28), mode='L')
        img = self.transform(img)
        target = self.labels[index]
        return img, target

    def __len__(self):
        return len(self.images)

In [11]:
import random

In [9]:
import os
import os.path as ops
import urllib.request
import gzip
import numpy as np


def get_mnist_data(datadir):
    dataroot = 'http://yann.lecun.com/exdb/mnist/'
    key_file = {
        'train_img': 'train-images-idx3-ubyte.gz',
        'train_label': 'train-labels-idx1-ubyte.gz',
        'test_img': 't10k-images-idx3-ubyte.gz',
        'test_label': 't10k-labels-idx1-ubyte.gz'
    }
    os.makedirs(datadir, exist_ok=True)

    for key, filename in key_file.items():
        if ops.exists(ops.join(datadir, filename)):
            print(f"already downloaded : {filename}")
        else:
            urllib.request.urlretrieve(ops.join(dataroot, filename),
                                       ops.join(datadir, filename))

    with gzip.open(ops.join(datadir, key_file["train_img"]), "rb") as f:
        train_img = np.frombuffer(f.read(), np.uint8, offset=16)
    train_img = train_img.reshape(-1, 784)

    with gzip.open(ops.join(datadir, key_file["train_label"]), "rb") as f:
        train_label = np.frombuffer(f.read(), np.uint8, offset=8)

    with gzip.open(ops.join(datadir, key_file["test_img"]), "rb") as f:
        test_img = np.frombuffer(f.read(), np.uint8, offset=16)
    test_img = test_img.reshape(-1, 784)

    with gzip.open(ops.join(datadir, key_file["test_label"]), "rb") as f:
        test_label = np.frombuffer(f.read(), np.uint8, offset=8)

    return train_img, train_label, test_img,  test_label

In [5]:
class FedBase:
    def create_mnist_datasets(self,
                              num_clients=100,
                              shard_size=300,
                              datadir="./data/mnist",
                              iid=False):
        train_img, train_label, test_img, test_label = get_mnist_data(datadir)

        train_sorted_index = np.argsort(train_label)
        train_img = train_img[train_sorted_index]
        train_label = train_label[train_sorted_index]

        if iid:
            random.shuffle(train_sorted_index)
            train_img = train_img[train_sorted_index]
            train_label = train_label[train_sorted_index]

        shard_start_index = [i for i in range(0, len(train_img), shard_size)]
        random.shuffle(shard_start_index)
        print(
            f"divide data into {len(shard_start_index)} shards of size {shard_size}"
        )

        num_shards = len(shard_start_index) // num_clients
        local_datasets = []
        for client_id in range(num_clients):
            _index = num_shards * client_id
            img = np.concatenate([
                train_img[shard_start_index[_index +
                                            i]:shard_start_index[_index + i] +
                          shard_size] for i in range(num_shards)
            ],
                                 axis=0)

            label = np.concatenate([
                train_label[shard_start_index[_index +
                                              i]:shard_start_index[_index +
                                                                   i] +
                            shard_size] for i in range(num_shards)
            ],
                                   axis=0)

            local_datasets.append(MnistLocalDataset(img, label, client_id))

        test_sorted_index = np.argsort(test_label)
        test_img = test_img[test_sorted_index]
        test_label = test_label[test_sorted_index]

        test_dataset = MnistLocalDataset(test_img, test_label, client_id=-1)

        return local_datasets, test_dataset

    def train_step(self):
        raise NotImplementedError

    def validation_step(self):
        raise NotImplementedError

    def fit(self, num_round):
        raise NotImplementedError

In [112]:
train_img, train_label, test_img, test_label = get_mnist_data("./data/mnist")

already downloaded : train-images-idx3-ubyte.gz
already downloaded : train-labels-idx1-ubyte.gz
already downloaded : t10k-images-idx3-ubyte.gz
already downloaded : t10k-labels-idx1-ubyte.gz


In [80]:
class FedAvg(FedBase):
    def __init__(self,
                 model,
                 
                 num_clients=200,
                 batchsize=50,
                 fraction=1,
                 local_epoch=1,
                 iid=False,
                 device="cpu",
                 writer=None):
        

        self.num_clients = num_clients  # K
        self.batchsize = batchsize  # B
        self.fraction = fraction  # C, 0 < C <= 1
        self.local_epoch = local_epoch  # E

        local_datasets, test_dataset = self.create_mnist_datasets(
            num_clients, shard_size=300, iid=iid)
        local_dataloaders = [
            DataLoader(dataset,
                       num_workers=0,
                       batch_size=batchsize,
                       shuffle=True) for dataset in local_datasets
        ]

        self.clients = [
            FedAvgClient(k, local_dataloaders[k], device) for k in range(num_clients)
        ]
        self.total_data_size = sum([len(client) for client in self.clients])
        self.aggregation_weights = [
            len(client) / self.total_data_size for client in self.clients
        ]

        test_dataloader = DataLoader(test_dataset,
                                     num_workers=0,
                                     batch_size=batchsize)
        self.center_server = FedAvgCenterServer(model, test_dataloader, device)

        self.loss_fn = CrossEntropyLoss()

        self.writer = writer

        self._round = 0
        self.result = None

    def fit(self, num_round):
        self._round = 0
        self.result = {'loss': [], 'accuracy': []}
        self.validation_step()
        for t in range(num_round):
            self._round = t + 1
            self.train_step()
            self.validation_step()

    def train_step(self):
        self.send_model()
        n_sample = max(int(self.fraction * self.num_clients), 1)
        sample_set = np.random.randint(0, self.num_clients, n_sample)
        for k in iter(sample_set):
            self.clients[k].client_update(
                                          self.local_epoch, self.loss_fn)
        self.center_server.aggregation(self.clients, self.aggregation_weights)

    def send_model(self):
        for client in self.clients:
            client.model = self.center_server.send_model()

    def validation_step(self):
        test_loss, accuracy = self.center_server.validation(self.loss_fn)
        log.info(
            f"[Round: {self._round: 04}] Test set: Average loss: {test_loss:.4f}, Accuracy: {accuracy:.2f}%"
        )
        if self.writer is not None:
            self.writer.add_scalar("val/loss", test_loss, self._round)
            self.writer.add_scalar("val/accuracy", accuracy, self._round)

        self.result['loss'].append(test_loss)
        self.result['accuracy'].append(accuracy)

In [14]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import _LRScheduler
import torch.utils.data as data
import torchvision
from torchvision import datasets, models, transforms
from sklearn import decomposition
from sklearn import manifold
from sklearn.metrics import confusion_matrix
from sklearn.metrics import ConfusionMatrixDisplay
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import copy
import random
import time
from PIL import Image

In [23]:
from torch.nn import CrossEntropyLoss

from torch.utils.data import DataLoader

In [109]:
A=FedAvg(model,num_clients=50,
                 batchsize=50,
                 fraction=0.2,
                 local_epoch=5,device=device)

already downloaded : train-images-idx3-ubyte.gz
already downloaded : train-labels-idx1-ubyte.gz
already downloaded : t10k-images-idx3-ubyte.gz
already downloaded : t10k-labels-idx1-ubyte.gz
divide data into 200 shards of size 300


In [67]:
A.device

AttributeError: 'FedAvg' object has no attribute 'device'

In [110]:
A.fit(50)

INFO:__main__:[Round:  000] Test set: Average loss: 2.2990, Accuracy: 16.12%
INFO:__main__:[Round:  001] Test set: Average loss: 2.1980, Accuracy: 29.20%
INFO:__main__:[Round:  002] Test set: Average loss: 2.0396, Accuracy: 48.08%
INFO:__main__:[Round:  003] Test set: Average loss: 1.8784, Accuracy: 50.80%
INFO:__main__:[Round:  004] Test set: Average loss: 1.6689, Accuracy: 53.40%
INFO:__main__:[Round:  005] Test set: Average loss: 1.4759, Accuracy: 61.97%
INFO:__main__:[Round:  006] Test set: Average loss: 1.2403, Accuracy: 69.47%
INFO:__main__:[Round:  007] Test set: Average loss: 0.9801, Accuracy: 78.32%
INFO:__main__:[Round:  008] Test set: Average loss: 0.7381, Accuracy: 91.71%
INFO:__main__:[Round:  009] Test set: Average loss: 0.6387, Accuracy: 90.10%
INFO:__main__:[Round:  010] Test set: Average loss: 0.5049, Accuracy: 92.55%
INFO:__main__:[Round:  011] Test set: Average loss: 0.4487, Accuracy: 90.81%
INFO:__main__:[Round:  012] Test set: Average loss: 0.3548, Accuracy: 93.36%

In [37]:
import torch
import torch.nn as nn


class CNN(nn.Module):
    def __init__(self, in_features=1, num_classes=10):
        super().__init__()
        self.conv1 = nn.Conv2d(in_features,
                               32,
                               kernel_size=5,
                               padding=0,
                               stride=1,
                               bias=True)
        self.conv2 = nn.Conv2d(32,
                               64,
                               kernel_size=5,
                               padding=0,
                               stride=1,
                               bias=True)
        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, num_classes)

        self.act = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=(2, 2))

    def forward(self, x):
        x = self.act(self.conv1(x))
        x = self.maxpool(x)
        x = self.act(self.conv2(x))
        x = self.maxpool(x)
        x = torch.flatten(x, 1)
        x = self.act(self.fc1(x))
        x = self.fc2(x)
        return x

In [108]:
model = CNN()

In [39]:
model

CNN(
  (conv1): Conv2d(1, 32, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=1024, out_features=512, bias=True)
  (fc2): Linear(in_features=512, out_features=10, bias=True)
  (act): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
)

In [65]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [63]:
loss_fn = CrossEntropyLoss()