In [2]:
import math
import torch
import torch.nn as nn
# from torchsummary import summary

__all__ = [
    'vgg11',
    'vgg11_bn',
    'vgg13',
    'vgg13_bn',
    'vgg16',
    'vgg16_bn',
    'vgg19',
    'vgg19_bn',
]


class VGG(nn.Module):
    def __init__(self, features):
        super(VGG, self).__init__()
        self.features = features
        self.classifier = nn.Sequential(
            nn.Dropout(),
            nn.Linear(512, 4096),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(True),
            nn.Linear(4096, 10),
        )
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
                m.bias.data.zero_()

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, start_dim=1)
        x = self.classifier(x)
        return x


def make_layers(cfg, batch_norm=False):
    layers = []
    in_channels = 3
    for v in cfg:
        if v == 'M':
            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
        else:
            conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
            if batch_norm:
                layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
            else:
                layers += [conv2d, nn.ReLU(inplace=True)]
            in_channels = v
    return nn.Sequential(*layers)


cfg = {
    'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}


def vgg11():
    """VGG 11-layer model (configuration "A")"""
    return VGG(make_layers(cfg['A']))


def vgg11_bn():
    """VGG 11-layer model (configuration "A") with batch normalization"""
    return VGG(make_layers(cfg['A'], batch_norm=True))


def vgg13():
    """VGG 13-layer model (configuration "B")"""
    return VGG(make_layers(cfg['B']))


def vgg13_bn():
    """VGG 13-layer model (configuration "B") with batch normalization"""
    return VGG(make_layers(cfg['B'], batch_norm=True))


def vgg16():
    """VGG 16-layer model (configuration "D")"""
    return VGG(make_layers(cfg['D']))


def vgg16_bn():
    """VGG 16-layer model (configuration "D") with batch normalization"""
    return VGG(make_layers(cfg['D'], batch_norm=True))


def vgg19():
    """VGG 19-layer model (configuration "E")"""
    return VGG(make_layers(cfg['E']))


def vgg19_bn():
    """VGG 19-layer model (configuration 'E') with batch normalization"""
    return VGG(make_layers(cfg['E'], batch_norm=True))


# def test():
#     # net = VGG_11()
#
#     net = vgg11_bn()
#     # net = vgg11_bn()
#     # names = [s for s in net.state_dict().keys() if s.startswith('classifier')]
#
#     summary(net, (3, 32, 32), 1)
#     for name in net.state_dict():
#         print(name, '\t', net.state_dict()[name].size())
#
#     print(len(net.state_dict()))
#
#
# test()




In [3]:
import numpy as np
from torch.utils.data import Dataset
from torchvision import transforms, datasets


class DatasetSplit(Dataset):
    def __init__(self, dataset, idxs):
        self.dataset = dataset
        self.idxs = idxs

    def __len__(self):
        return len(self.idxs)

    def __getitem__(self, item):
        image, label = self.dataset[self.idxs[item]]
        return image, label


def cifar_iid(dataset, num_users):
    num_items = int(len(dataset) / num_users)  # 分配给每个user的数据量
    dict_users = {}  # user编号->user分配的数据
    all_idx = [i for i in range(len(dataset))]

    for i in range(num_users):
        dict_users[i] = list(np.random.choice(all_idx, num_items, replace=False))  # 从剩余数据中随机选择
        all_idx = list(set(all_idx) - set(dict_users[i]))  # 从剩余数据中删除已选数据
    return dict_users


def get_dataset(dataset_name, device_num):

    if dataset_name == 'cifar10':
        data_dir = '/kaggle/input/cifar10-python'

        train_transform = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
        ])

        test_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
        ])

        train_dataset = datasets.CIFAR10(data_dir, train=True, transform=train_transform, download=False)

        test_dataset = datasets.CIFAR10(data_dir, train=False, transform=test_transform, download=False)

        # 按照独立同分布将数据分成20组
        user_groups = cifar_iid(train_dataset, device_num)
        user_groups_test = cifar_iid(test_dataset, device_num)

        return train_dataset, test_dataset, user_groups, user_groups_test


def get_common_base_layers(model_list):
    min_idx = 0
    min_len = int(1e8)

    # 找到参数最少的一个模型
    for i in range(len(model_list)):
        if len(model_list[i]) < min_len:
            min_idx = i
            min_len = len(model_list[i])

    commonList = [s for s in model_list[min_idx].keys()]

    # 找到common base layers
    for i in range(len(model_list)):
        weight_name_list = [s for s in model_list[i].keys()]
        for j in range(len(commonList)):
            if commonList[j] == weight_name_list[j]:
                continue
            else:
                del commonList[j:len(commonList) + 1]  # 从哪一层开始不同，删去该层后所有层
                break
    return commonList

In [4]:
import copy
import torch
import torch.nn as nn
from torch.utils.data import DataLoader


def client_local_train(net, dataset, idxs, device, lr=0.01, epochs=10):
    net.to(device)

    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)
    ldr_train = DataLoader(DatasetSplit(dataset, idxs), batch_size=64, shuffle=True)

    net.train()

    for epoch in range(epochs):
        for inputs, labels in ldr_train:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = net(inputs)
            loss = loss_fn(outputs, labels)
            loss.backward()
            optimizer.step()

    return net.state_dict()


def common_basic(model_list):
    commonList = get_common_base_layers(model_list)

    # 聚合local models
    for k in commonList:
        comWeight = copy.deepcopy(model_list[0][k])
        for i in range(1, len(model_list)):
            comWeight += model_list[i][k]
        comWeight = comWeight / len(model_list)

        for i in range(0, len(model_list)):
            model_list[i][k] = comWeight

    return model_list


def common_max(model_list):
    backup = copy.deepcopy(model_list)

    count = [[] for _ in range(len(model_list))]

    for i in range(len(model_list)):
        weight_name_list = [s for s in model_list[i].keys()]
        count[i] = [1 for _ in range(len(weight_name_list))]

    for i in range(0, len(model_list)):

        weight_name_list1 = [s for s in model_list[i].keys()]  # 第i个模型

        for j in range(i + 1, len(model_list)):
            if i == j:
                continue
            weight_name_list2 = [s for s in model_list[j].keys()]
            # 能共享就共享
            for k in range(0, len(weight_name_list1)):
                if weight_name_list2[k] == weight_name_list1[k]:
                    name = weight_name_list1[k]
                    model_list[i][name] += backup[j][name]
                    model_list[j][name] += backup[i][name]
                    count[i][k] += 1
                    count[j][k] += 1
                else:
                    break

    for c in range(0, len(model_list)):
        weight_name_list = [s for s in model_list[c].keys()]
        for k in range(0, len(weight_name_list)):
            model_list[c][weight_name_list[k]] = model_list[c][weight_name_list[k]].cpu() / count[c][k]

    return model_list


In [5]:
import torch
from torch.utils.data import DataLoader


def predict(model, dataset, idxs, device):
    """ Returns the test accuracy.
    """
    model.to(device)

    total, correct = 0.0, 0.0
    ldr_test = DataLoader(DatasetSplit(dataset, idxs), batch_size=64, shuffle=False)

    model.eval()

    for images, targets in ldr_test:
        with torch.no_grad():
            images, targets = images.to(device), targets.to(device)

            outputs = model(images)
            _, predict_labels = torch.max(outputs, 1)

            predict_labels = predict_labels.view(-1)  # 可能无用
            correct += torch.sum(torch.eq(predict_labels, targets)).item()  # 预测正确个数
            total += len(targets)

    acc = correct * 1.0 / total
    return acc


In [None]:
import copy
import time
import numpy as np
import torch
import warnings

warnings.filterwarnings("ignore")

if __name__ == '__main__':

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # 设备数量
    device_num = 40
    

    # load dataset and user groups
    train_dataset, test_dataset, user_groups, user_groups_test = get_dataset("cifar10", device_num)

    # Training
    epochs = 502
    train_loss, train_accuracy = [], []
    val_acc_list, net_list = [], []
    cv_loss, cv_acc = [], []
    print_every = 2
    val_loss_pre, counter = 0, 0

    uid = [_ for _ in range(device_num)]
    modelAccept = {_: None for _ in range(device_num)}
    local_acc = [[] for _ in range(device_num)]

    # 不同的设备采用不同架构的网络
    for idx in uid:
        if idx < 2:
            modelAccept[idx] = vgg11_bn()

        elif 2 <= idx < 4:
            modelAccept[idx] = vgg13_bn()

        elif 4 <= idx < 6:
            modelAccept[idx] = vgg16_bn()

        else:
            modelAccept[idx] = vgg19_bn()

    # batch = user产生的数据量 / 10
    localData_length = len(user_groups[0]) / 10
    start = 0

    for epoch in range(epochs):

        print(f'\n | Global Training Round : {epoch + 1} |\n')
        end = start + localData_length

        for idx in uid:

            idx_train_all = list(user_groups[idx])
            idx_train_batch = list(idx_train_all[int(start):int(end)])

            if epoch == 0:
                model = modelAccept[idx]

            if epoch > 0:
                if idx < 2:
                    model = vgg11_bn()

                elif 2 <= idx < 4:
                    model = vgg13_bn()

                elif 4 <= idx < 6:
                    model = vgg16_bn()

                else:
                    model = vgg19_bn()

                model.load_state_dict(modelAccept[idx])  # 加载对应的模型

            # 预测
            acc = predict(model, test_dataset, user_groups_test[idx], device)
            local_acc[idx].append(round(acc, 2))
            if epoch % 10 == 0:
                print(local_acc[idx])

            Model = copy.deepcopy(model)
            localModel = client_local_train(Model, train_dataset, idx_train_batch, device)
            modelAccept[idx] = copy.deepcopy(localModel)

        start = end % (50000 / device_num)

        # modelAccept = train.common_max(modelAccept)
        modelAccept = common_basic(modelAccept)




 | Global Training Round : 1 |

[0.08]
[0.1]
[0.08]
[0.12]
[0.11]
[0.06]
[0.14]
[0.1]
[0.08]
[0.1]
[0.12]
[0.11]
[0.12]
[0.1]
[0.08]
[0.1]
[0.11]
[0.06]
[0.09]
[0.13]
[0.1]
[0.08]
[0.1]
[0.1]
[0.12]
[0.1]
[0.11]
[0.11]
[0.08]
[0.12]
[0.1]
[0.06]
[0.07]
[0.08]
[0.08]
[0.11]
[0.09]
[0.09]
[0.15]
[0.15]

 | Global Training Round : 2 |


 | Global Training Round : 3 |


 | Global Training Round : 4 |


 | Global Training Round : 5 |


 | Global Training Round : 6 |


 | Global Training Round : 7 |


 | Global Training Round : 8 |


 | Global Training Round : 9 |


 | Global Training Round : 10 |


 | Global Training Round : 11 |

[0.08, 0.14, 0.1, 0.1, 0.16, 0.23, 0.23, 0.26, 0.25, 0.24, 0.28]
[0.1, 0.13, 0.08, 0.13, 0.14, 0.19, 0.21, 0.24, 0.28, 0.3, 0.26]
[0.08, 0.08, 0.09, 0.08, 0.21, 0.22, 0.18, 0.24, 0.3, 0.3, 0.3]
[0.12, 0.12, 0.14, 0.15, 0.19, 0.18, 0.24, 0.28, 0.25, 0.34, 0.28]
[0.11, 0.14, 0.08, 0.11, 0.18, 0.24, 0.24, 0.26, 0.29, 0.3, 0.24]
[0.06, 0.09, 0.09, 0.09, 0.14, 0.13, 0