In [177]:
import torch
import tqdm
import torchvision
import re
import numpy
import copy

from torch.utils.data import Dataset, DataLoader

In [178]:
class LeNet5(torch.nn.Module):
    def __init__(self):
        super(LeNet5, self).__init__()
        self.conv1 = torch.nn.Conv2d(
            in_channels=1, out_channels=6, kernel_size=5, stride=1)
        self.pool1 = torch.nn.AvgPool2d(kernel_size=2, stride=2)
        self.conv2 = torch.nn.Conv2d(
            in_channels=6, out_channels=16, kernel_size=5, stride=1)
        self.pool2 = torch.nn.AvgPool2d(kernel_size=2, stride=2)
        self.fc1 = torch.nn.Linear(in_features=16 * 4 * 4, out_features=120)
        self.fc2 = torch.nn.Linear(in_features=120, out_features=84)
        self.fc3 = torch.nn.Linear(in_features=84, out_features=10)

    def forward(self, x):
        x = self.pool1(torch.relu(self.conv1(x)))
        x = self.pool2(torch.relu(self.conv2(x)))
        x = x.view(-1, 16 * 4 * 4)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x


class Loss(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, input, target):
        return torch.nn.functional.cross_entropy(input, target)



class Server:
    def __init__(self, model, client_params):
        self.model = copy.deepcopy(model)
        self.client_params = client_params
        self.n_client = len(self.client_params)

        self.clients = list(client_params.keys())
        self.server_params = self.client_params[self.clients[0]]
        for key in self.server_params:
            self.server_params[key] = self.server_params[key].div(self.n_client)

    def fed_avg(self):
        for client in self.clients:
            for key in self.server_params:
                self.server_params[key] = self.server_params[key].add(self.client_params[client][key].div(self.n_client))
        return self.server_params


class DealDataset(Dataset):

    def __init__(self, dataset, idx):
        self.dataset = dataset
        self.idx = idx
        self.len = len(idx)

    def __len__(self):
        return self.len

    def __getitem__(self, index):
        img, target = self.dataset[self.idx[index]]
        return img, target

In [179]:
def idx_split(dataset, mode='iid', n_dataset=1, n_data_each_set=1):
    labels_list = dataset.targets.tolist()
    all_labels = set(labels_list)
    idx_label = dict()
    for label in all_labels:
        idx_label[label] = list()
        for idx, label_in_list in enumerate(labels_list):
            if label_in_list == label:
                idx_label[label] += [idx]

    if mode == 'iid':
        n_each_set = dict()
        for label in all_labels:
            n_each_set[label] = int(len(idx_label[label]) / len(labels_list) * n_data_each_set / n_dataset)
        dataset_splited = dict()
        left_idx_label = idx_label
        for i in range(n_dataset):
            dataset_splited[i] = list()
            for label in all_labels:
                choiced_idx = numpy.random.choice(
                    left_idx_label[label],
                    n_each_set[label],
                    replace=False)
                dataset_splited[i] += list(choiced_idx)
                left_idx_label[i] = list(set(left_idx_label[label]) - set(dataset_splited[i]))
        return dataset_splited
    elif mode == 'non-iid':
        print('TO DO.')


def train_model(model, dataset, device='cpu', epochs=1):
    trained_model = copy.deepcopy(model).to(device)
    trained_model.train()
    train_dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(trained_model.parameters())

    for epoch in range(epochs):
        for i, (data, label) in enumerate(train_dataloader):
            optimizer.zero_grad()
            output = trained_model(data.to(device))
            loss = criterion(output, label.to(device))
            loss.backward()
            optimizer.step()

        #     if (i+1) % 100 == 0:
        #         print('\r', end='')
        #         print(
        #             f'step [{i+1}/{len(train_dataloader)}], loss: {loss.item():.4f}', end='')
        # print(f'\nepoch {epoch+1}/{epochs} down.')

    return trained_model


def eval_model(model, dataset):
    server_model = copy.deepcopy(model)
    server_model.eval()
    with torch.no_grad():
        correct = 0
        total = 0
        data_loader = DataLoader(dataset, batch_size=32, shuffle=True)
        for images, labels in data_loader:
            outputs = server_model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        # print('Test Accuracy: {:.2f}%'.format(100 * correct / total))

In [180]:
if torch.cuda.is_available():
    device = 'cuda'
else:
    device = 'cpu'
print(device)

train_dataset = torchvision.datasets.MNIST(
    root='./data',
    train=True,
    transform=torchvision.transforms.ToTensor(),
    download=True)
test_dataset = torchvision.datasets.MNIST(
    root='./data',
    train=False,
    transform=torchvision.transforms.ToTensor())

model = LeNet5()

n_total_client = 2
n_data = 10000
communication_round = 1
epochs = 1
n_client = 2

idx_splited = idx_split(dataset=train_dataset,
                        n_dataset=n_total_client,
                        n_data_each_set=n_data)
dataset_client = dict()
for i in range(n_total_client):
    dataset_client[i] = DealDataset(train_dataset, idx_splited[i])

server_model = copy.deepcopy(model)
for i in range(communication_round):
    client = dict()
    choicen_client = numpy.random.choice(range(n_total_client), n_client, replace=False)
    for j, k in enumerate(choicen_client):
        client[i] = train_model(
            model=server_model,
            dataset=DealDataset(train_dataset, idx_splited[k]),
            device=device,
            epochs=epochs).state_dict()
    print(model, client)
    server_model = Server(model=model, client_params=client)

# client_model = copy.deepcopy(model)
# client_model.train()
# train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)

# client_model.to(device)
# criterion = torch.nn.CrossEntropyLoss()
# optimizer = torch.optim.Adam(client_model.parameters())

# for epoch in range(1):
#     for i, (data, label) in enumerate(train_dataloader):
#         optimizer.zero_grad()
#         output = client_model(data.to(device))
#         loss = criterion(output, label.to(device))
#         loss.backward()
#         optimizer.step()

#         if (i+1) % 100 == 0:
#             print('\r', end='')
#             print(f'step [{i+1}/{len(train_dataloader)}], loss: {loss.item():.4f}', end='')
# print('\ndown.')

# client_model1 = train_model(model=model, dataset=train_dataset, device=device)

cuda
LeNet5(
  (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
  (pool1): AvgPool2d(kernel_size=2, stride=2, padding=0)
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (pool2): AvgPool2d(kernel_size=2, stride=2, padding=0)
  (fc1): Linear(in_features=256, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
) {0: OrderedDict([('conv1.weight', tensor([[[[ 0.0822, -0.1553, -0.0330,  0.0729,  0.1298],
          [ 0.2404,  0.2046,  0.2235,  0.0482, -0.0937],
          [-0.0404, -0.0650,  0.2495,  0.0410,  0.0213],
          [ 0.0907,  0.1661, -0.0532,  0.2728, -0.0750],
          [-0.0004,  0.1714,  0.1608,  0.1793,  0.0554]]],


        [[[ 0.1275, -0.0707,  0.0973,  0.1435, -0.0083],
          [-0.0044,  0.1348, -0.0043, -0.0638,  0.2768],
          [ 0.0950,  0.2278,  0.1206,  0.2507, -0.0933],
          [ 0.2891,  0.1737,  0.1920, -0.0414,  0.1919],
          [ 0.1422

In [181]:
# server_model = copy.deepcopy(client_model)
# eval_model(server_model.to('cpu'), test_dataset)

# server_model1 = copy.deepcopy(client_model1)
# eval_model(server_model1.to('cpu'), test_dataset)