In [249]:
import torch
import tqdm
import torchvision
import re
import numpy
import copy

from torch.utils.data import Dataset, DataLoader

In [250]:
class LeNet5(torch.nn.Module):
    def __init__(self, input_channels):
        super(LeNet5, self).__init__()
        self.conv1 = torch.nn.Conv2d(
            in_channels=input_channels, out_channels=6, kernel_size=5, stride=1)
        self.pool1 = torch.nn.AvgPool2d(kernel_size=2, stride=2)
        self.conv2 = torch.nn.Conv2d(
            in_channels=6, out_channels=16, kernel_size=5, stride=1)
        self.pool2 = torch.nn.AvgPool2d(kernel_size=2, stride=2)
        self.fc1 = torch.nn.Linear(in_features=16 * 4 * 4, out_features=120)
        self.fc2 = torch.nn.Linear(in_features=120, out_features=84)
        self.fc3 = torch.nn.Linear(in_features=84, out_features=10)

    def forward(self, x):
        x = self.pool1(torch.relu(self.conv1(x)))
        x = self.pool2(torch.relu(self.conv2(x)))
        x = x.view(-1, 16 * 4 * 4)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x


class Loss(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, input, target):
        return torch.nn.functional.cross_entropy(input, target)


class Server:
    def __init__(self, model, client_params):
        self.model = copy.deepcopy(model)
        self.client_params = client_params
        self.n_client = len(self.client_params)

        self.server_params = self.client_params[0]
        for key in self.server_params:
            self.server_params[key] = self.server_params[key].div(
                self.n_client)

    def fed_avg(self):
        for client in range(self.n_client):
            for key in self.server_params:
                deal_param = self.client_params[client][key].div(self.n_client)
                self.server_params[key] = self.server_params[key].add(deal_param)
        return self.server_params


class DealDataset(Dataset):
    def __init__(self, dataset, idx):
        self.dataset = dataset
        self.idx = idx
        self.len = len(idx)

    def __len__(self):
        return self.len

    def __getitem__(self, index):
        img, target = self.dataset[self.idx[index]]
        return img, target

In [251]:
def idx_split(dataset, mode='iid', n_dataset=1, n_data_each_set=1):
    labels_list = dataset.targets.tolist()
    all_labels = set(labels_list)
    idx_label = dict()
    for label in all_labels:
        idx_label[label] = list()
        for idx, label_in_list in enumerate(labels_list):
            if label_in_list == label:
                idx_label[label] += [idx]

    if mode == 'iid':
        n_each_set = dict()
        for label in all_labels:
            n_each_set[label] = int(
                len(idx_label[label]) / len(labels_list) * n_data_each_set / n_dataset)
            print(label, n_each_set[label], end='|')
        print('\n')
        dataset_splited = dict()
        left_idx_label = idx_label
        for i in range(n_dataset):
            dataset_splited[i] = list()
            for label in all_labels:
                choiced_idx = numpy.random.choice(
                    left_idx_label[label],
                    n_each_set[label],
                    replace=False)
                dataset_splited[i] += list(choiced_idx)
                left_idx_label[i] = list(
                    set(left_idx_label[label]) - set(dataset_splited[i]))
                print(i, label, len(dataset_splited[i]), n_each_set[label], len(left_idx_label[i]))
        return dataset_splited
    elif mode == 'non-iid':
        print('TO DO.')


def train_model(model, dataset, device='cpu', epochs=1):
    trained_model = copy.deepcopy(model).to(device)
    trained_model.train()
    train_dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(trained_model.parameters())
    for epoch in range(epochs):
        for i, (data, label) in enumerate(train_dataloader):
            optimizer.zero_grad()
            output = trained_model(data.to(device))
            loss = criterion(output, label.to(device))
            loss.backward()
            optimizer.step()

        #     if (i+1) % 100 == 0:
        #         print('\r', end='')
        #         print(
        #             f'step [{i+1}/{len(train_dataloader)}], loss: {loss.item():.4f}', end='')
        # print(f'\nepoch {epoch+1}/{epochs} down.')
    return trained_model


def eval_model(model, dataset):
    server_model = copy.deepcopy(model)
    server_model.eval()
    with torch.no_grad():
        correct = 0
        total = 0
        data_loader = DataLoader(dataset, batch_size=32, shuffle=True)
        for images, labels in data_loader:
            outputs = server_model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        # print('Test Accuracy: {:.2f}%'.format(100 * correct / total))

In [252]:
if torch.cuda.is_available():
    device = 'cuda'
else:
    device = 'cpu'
print(device)

train_dataset = torchvision.datasets.MNIST(
    root='./data',
    train=True,
    transform=torchvision.transforms.ToTensor(),
    download=True)
test_dataset = torchvision.datasets.MNIST(
    root='./data',
    train=False,
    transform=torchvision.transforms.ToTensor())

n_total_client = 10
n_data = 10000
communication_round = 1
epochs = 1
n_client = 3

model = LeNet5(input_channels=1)
idx_splited = idx_split(dataset=train_dataset,
                        n_dataset=n_total_client,
                        n_data_each_set=n_data)
dataset_client = dict()
for i in range(n_total_client):
    dataset_client[i] = DealDataset(train_dataset, idx_splited[i])

server_model = model
for i in range(communication_round):
    client = dict()
    choicen_client = numpy.random.choice(
        range(n_total_client), n_client, replace=False)
    for j, k in enumerate(choicen_client):
        client[j] = train_model(
            model=server_model,
            dataset=DealDataset(train_dataset, idx_splited[k]),
            device=device,
            epochs=epochs).state_dict()
    server_model = Server(model=model, client_params=client)

cuda
0 98|1 112|2 99|3 102|4 97|5 90|6 98|7 104|8 97|9 99|

0 0 98 98 5825
0 1 210 112 6630
0 2 309 99 5859
0 3 411 102 6029
0 4 508 97 5745
0 5 598 90 5331
0 6 696 98 5820
0 7 800 104 6161
0 8 897 97 5754
0 9 996 99 5850
1 0 98 98 5752
1 1 210 112 5640
1 2 309 99 5859
1 3 411 102 6029
1 4 508 97 5745
1 5 598 90 5331
1 6 696 98 5820
1 7 800 104 6161
1 8 897 97 5754
1 9 996 99 5642
2 0 98 98 5752
2 1 210 112 5442
2 2 309 99 5343
2 3 411 102 6029
2 4 508 97 5745
2 5 598 90 5331
2 6 696 98 5820
2 7 800 104 6161
2 8 897 97 5754
2 9 996 99 5547
3 0 98 98 5752
3 1 210 112 5437
3 2 309 99 5254
3 3 411 102 5152
3 4 508 97 5745
3 5 598 90 5331
3 6 696 98 5820
3 7 800 104 6161
3 8 897 97 5754
3 9 996 99 5451
4 0 98 98 5752
4 1 210 112 5438
4 2 309 99 5249
4 3 411 102 5071
4 4 508 97 4974
4 5 598 90 5331
4 6 696 98 5820
4 7 800 104 6161
4 8 897 97 5754
4 9 996 99 5358
5 0 98 98 5752
5 1 210 112 5443
5 2 309 99 5257
5 3 411 102 5088
5 4 508 97 4913
5 5 598 90 4823
5 6 696 98 5820
5 7 800 104 6161
