In [15]:
import torch
import tqdm
import torchvision
import re
import numpy
import copy

from torch.utils.data import Dataset, DataLoader

In [16]:
class LeNet5(torch.nn.Module):
    def __init__(self):
        super(LeNet5, self).__init__()
        self.conv1 = torch.nn.Conv2d(
            in_channels=1, out_channels=6, kernel_size=5, stride=1)
        self.pool1 = torch.nn.AvgPool2d(kernel_size=2, stride=2)
        self.conv2 = torch.nn.Conv2d(
            in_channels=6, out_channels=16, kernel_size=5, stride=1)
        self.pool2 = torch.nn.AvgPool2d(kernel_size=2, stride=2)
        self.fc1 = torch.nn.Linear(in_features=16 * 4 * 4, out_features=120)
        self.fc2 = torch.nn.Linear(in_features=120, out_features=84)
        self.fc3 = torch.nn.Linear(in_features=84, out_features=10)

    def forward(self, x):
        x = self.pool1(torch.relu(self.conv1(x)))
        x = self.pool2(torch.relu(self.conv2(x)))
        x = x.view(-1, 16 * 4 * 4)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

class Loss(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, input, target):
        return torch.nn.functional.cross_entropy(input, target)


class Client:
    def __init__(self, model, dataloader, optimizer='adam', device='cpu', epochs=1, loss=Loss()):
        self.model = copy.deepcopy(model)
        self.dataloader = dataloader
        self.device = device
        self.epochs = epochs
        # self.loss = loss
        self.loss = torch.nn.CrossEntropyLoss()
        if optimizer == 'adam':
            self.optimizer = torch.optim.Adam(self.model.parameters())
        else:
            self.optimizer = torch.optim.SGD(self.model.parameters())

    def train(self):
        self.model.train()
        self.model.to(self.device)
        for epoch in range(self.epochs):
            for i, (data, label) in enumerate(self.dataloader):
                self.optimizer.zero_grad()
                output = self.model(data.to(self.device))
                loss = self.loss(output, label.to(self.device))
                loss.backward()
                self.optimizer.step()

                if (i+1) % 100 == 0:
                    print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(epoch+1, self.epochs, i+1, len(self.dataloader), loss.item()))
        return self.model.state_dict()


class Server:
    def __init__(self, model, client_params):
        self.model = copy.deepcopy(model)
        self.client_params = client_params
        self.n_client = len(self.client_params)

        self.clients = list(client_params.keys())
        self.server_params = self.client_params[self.clients[0]]
        for key in self.server_params:
            self.server_params[key] = self.server_params[key].div(self.n_client)

    def fed_avg(self):
        for client in self.clients:
            for key in self.server_params:
                self.server_params[key] = self.server_params[key].add(self.client_params[client][key].div(self.n_client))
        return self.server_params


class DealDataset(Dataset):

    def __init__(self, dataset, idx):
        self.dataset = dataset
        self.idx = idx
        self.len = len(idx)

    def __len__(self):
        return self.len

    def __getitem__(self, index):
        img, target = self.dataset[self.idx[index]]
        return img, target

In [17]:
def DateSplit(dataset, mode='iid', n_dataset=1, n_data_each_set=1):
    if mode == 'iid':
        labels_list = dataset.targets.tolist()
        all_labels = set(labels_list)
        idx_label = dict()
        for label in all_labels:
            idx_label[label] = list([
                idx for idx, _ in enumerate(labels_list) if labels_list[idx] == label])
        dataset_splited = dict()
        for i in range(n_dataset):
            dataset_splited[i] = list()
            for label in all_labels:
                choiced_idx = numpy.random.choice(idx_label[label], n_data_each_set, replace=False)
                dataset_splited[i] += list(choiced_idx)
        return dataset_splited
    elif mode == 'non-iid':
        print('TO DO.')

In [18]:
if torch.cuda.is_available():
    device = 'cuda'
else:
    device = 'cpu'

train_dataset = torchvision.datasets.MNIST(
    root='./data',
    train=True,
    transform=torchvision.transforms.ToTensor(),
    download=True)
test_dataset = torchvision.datasets.MNIST(
    root='./data',
    train=False,
    transform=torchvision.transforms.ToTensor())

test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

model = LeNet5()
criterion = Loss()
optimizer = torch.optim.Adam(model.parameters())

n_client = 10
n_data = 3200
idx_splited = DateSplit(dataset=train_dataset,
                        n_dataset=n_client,
                        n_data_each_set=n_data)

choice_client = 10
conmunication_rounds = 2

server_params = copy.deepcopy(model).state_dict()
for i in range(conmunication_rounds):
    client_params = dict()
    for client in list(numpy.random.choice(range(n_client), choice_client, replace=False)):
        client_model = copy.deepcopy(model)
        client_model.load_state_dict(server_params)
        client_params[client] = Client(model=client_model,
                            dataloader = DataLoader(
                                DealDataset(train_dataset,
                                            idx_splited[client]),
                                            batch_size=32,
                                            shuffle=False
                                            ),
                                optimizer='adam',
                                device=device).train()
    server_params = Server(model=model, client_params=client_params).fed_avg()


Epoch [1/1], Step [100/1000], Loss: 0.0000
Epoch [1/1], Step [200/1000], Loss: 0.0000
Epoch [1/1], Step [300/1000], Loss: 0.0000
Epoch [1/1], Step [400/1000], Loss: 0.0000
Epoch [1/1], Step [500/1000], Loss: 0.0353
Epoch [1/1], Step [600/1000], Loss: 0.2582
Epoch [1/1], Step [700/1000], Loss: 0.4983
Epoch [1/1], Step [800/1000], Loss: 0.6016
Epoch [1/1], Step [900/1000], Loss: 1.3556
Epoch [1/1], Step [1000/1000], Loss: 2.1414
Epoch [1/1], Step [100/1000], Loss: 0.0000
Epoch [1/1], Step [200/1000], Loss: 0.0000
Epoch [1/1], Step [300/1000], Loss: 0.0000
Epoch [1/1], Step [400/1000], Loss: 0.0000
Epoch [1/1], Step [500/1000], Loss: 1.0144
Epoch [1/1], Step [600/1000], Loss: 0.4892
Epoch [1/1], Step [700/1000], Loss: 1.3386
Epoch [1/1], Step [800/1000], Loss: 1.1008
Epoch [1/1], Step [900/1000], Loss: 1.1848
Epoch [1/1], Step [1000/1000], Loss: 1.8044
Epoch [1/1], Step [100/1000], Loss: 0.0000
Epoch [1/1], Step [200/1000], Loss: 0.0000
Epoch [1/1], Step [300/1000], Loss: 0.0000
Epoch [1/

In [19]:
server_model = copy.deepcopy(model)
server_model.load_state_dict(server_params)
server_model.eval()
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        outputs = server_model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    print('Test Accuracy: {:.2f}%'.format(100 * correct / total))

Test Accuracy: 10.09%
