In [1]:
import math
import torch
import torch.nn as nn

__all__ = [
    'vgg11',
    'vgg11_bn',
    'vgg13',
    'vgg13_bn',
    'vgg16',
    'vgg16_bn',
    'vgg19',
    'vgg19_bn',
]


class VGG(nn.Module):
    def __init__(self, features):
        super(VGG, self).__init__()
        self.features = features
        self.classifier = nn.Sequential(
            nn.Linear(512, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(0.4),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(0.4),
            nn.Linear(4096, 10)
        )
        self._initialize_weight()

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x
    
    def _initialize_weight(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.xavier_normal_(m.weight.data)
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.weight.data.normal_(0, 0.01)
                m.bias.data.zero_()


def make_layers(cfg):
    layers = []
    in_channels = 3
    for x in cfg:
        if x == 'M':
            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
        else:
            layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1),
                       nn.BatchNorm2d(x),
                       nn.ReLU(inplace=True)]
            in_channels = x

    layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
    return nn.Sequential(*layers)



cfg = {
    'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}


def vgg11_bn():
    """VGG 11-layer model (configuration "A") with batch normalization"""
    return VGG(make_layers(cfg['A']))


def vgg13_bn():
    """VGG 13-layer model (configuration "B") with batch normalization"""
    return VGG(make_layers(cfg['B']))


def vgg16_bn():
    """VGG 16-layer model (configuration "D") with batch normalization"""
    return VGG(make_layers(cfg['D']))


def vgg19_bn():
    """VGG 19-layer model (configuration 'E') with batch normalization"""
    return VGG(make_layers(cfg['E']))

In [2]:
import numpy as np
from torch.utils.data import Dataset
from torchvision import transforms, datasets


class DatasetSplit(Dataset):
    def __init__(self, dataset, idxs):
        self.dataset = dataset
        self.idxs = idxs

    def __len__(self):
        return len(self.idxs)

    def __getitem__(self, item):
        image, label = self.dataset[self.idxs[item]]
        return image, label


def cifar_iid(dataset, num_users):
    num_items = int(len(dataset) / num_users)  # 分配给每个user的数据量
    dict_users = {}  # user编号->user分配的数据
    all_idx = [i for i in range(len(dataset))]

    for i in range(num_users):
        dict_users[i] = list(np.random.choice(all_idx, num_items, replace=False))  # 从剩余数据中随机选择
        all_idx = list(set(all_idx) - set(dict_users[i]))  # 从剩余数据中删除已选数据
    return dict_users


def get_dataset(dataset_name, device_num):

    if dataset_name == 'cifar10':
        data_dir = '/kaggle/input/cifar10-python'
        
        train_transform = transforms.Compose([
            transforms.Pad(4),
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomGrayscale(),
            transforms.RandomCrop(32, padding=4),
        ])

        test_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
        ])

        train_dataset = datasets.CIFAR10(data_dir, train=True, transform=train_transform, download=False)

        test_dataset = datasets.CIFAR10(data_dir, train=False, transform=test_transform, download=False)

        # 按照独立同分布将数据分成device_num组
        user_groups = cifar_iid(train_dataset, device_num)
        user_groups_test = cifar_iid(test_dataset, device_num)

        return train_dataset, test_dataset, user_groups, user_groups_test


def get_common_base_layers(model_list):
    min_idx = 0
    min_len = 1000000

    # 找到参数最少的一个模型
    for i in range(0, len(model_list)):
        if len(model_list[i]) < min_len:
            min_idx = i
            min_len = len(model_list[i])

    commonList = [s for s in model_list[min_idx].keys()]

    # 找到common base layers
    for i in range(0, len(model_list)):
        weight_name_list = [s for s in model_list[i].keys()]
        for j in range(len(commonList)):
            if commonList[j] == weight_name_list[j]:
                continue
            else:
                del commonList[j:len(commonList) + 1]  # 从哪一层开始不同，删去该层后所有层
                break
    return commonList

In [3]:
import copy
import torch
import torch.nn as nn
from torch.utils.data import DataLoader


def client_local_train(net, dataset, idxs, device, lr=0.01, epochs=10):
    net.to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)  
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=4, gamma=0.4, last_epoch=-1)
    
    train_loader = DataLoader(DatasetSplit(dataset, idxs), batch_size=64, shuffle=True)

    for epoch in range(epochs):
            
        net.train()
        correct, total = 0, 0
        
        for batch_idx, (inputs, targets) in enumerate(train_loader):
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            
    print('Local : |  Acc: %.3f%% (%d/%d)' %(100. * correct / total, correct, total))
        
    return net.state_dict()


def common_basic(model_list):
    commonList = get_common_base_layers(model_list)

    # 聚合local models
    for k in commonList:
        comWeight = copy.deepcopy(model_list[0][k])
        for i in range(1, len(model_list)):
            comWeight += model_list[i][k]
        comWeight = comWeight / len(model_list)

        for i in range(0, len(model_list)):
            model_list[i][k] = comWeight

    return model_list


def common_max(model_list):
    backup = copy.deepcopy(model_list)

    count = [[] for _ in range(len(model_list))]

    for i in range(len(model_list)):
        weight_name_list = [s for s in model_list[i].keys()]
        count[i] = [1 for _ in range(len(weight_name_list))]

    for i in range(0, len(model_list)):

        weight_name_list1 = [s for s in model_list[i].keys()]  # 第i个模型

        for j in range(i + 1, len(model_list)):
            if i == j:
                continue
            weight_name_list2 = [s for s in model_list[j].keys()]
            # 能共享就共享
            for k in range(0, len(weight_name_list1)):
                if weight_name_list2[k] == weight_name_list1[k]:
                    name = weight_name_list1[k]
                    model_list[i][name] += backup[j][name]
                    model_list[j][name] += backup[i][name]
                    count[i][k] += 1
                    count[j][k] += 1
                else:
                    break

    for c in range(0, len(model_list)):
        weight_name_list = [s for s in model_list[c].keys()]
        for k in range(0, len(weight_name_list)):
            model_list[c][weight_name_list[k]] = model_list[c][weight_name_list[k]].cpu() / count[c][k]

    return model_list


In [4]:
import torch
from torch.utils.data import DataLoader


def predict(model, dataset, idxs, device):
    
    model.to(device)

    total, correct = 0.0, 0.0
    ldr_test = DataLoader(DatasetSplit(dataset, idxs), batch_size=64, shuffle=False)

    model.eval()

    for images, targets in ldr_test:
        with torch.no_grad():
            images, targets = images.to(device), targets.to(device)

            outputs = model(images)
            _, predict = torch.max(outputs, 1)

            predict = predict.view(-1)  # 可能无用
            correct += predict.eq(targets).sum().item()  # 预测正确个数
            total += len(targets)

    acc = correct * 1.0 / total
    acc = round(acc, 2)
    return acc

In [None]:
import copy
import time
import numpy as np
import torch
import warnings

warnings.filterwarnings("ignore")

if __name__ == '__main__':

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # 设备数量
    device_num = 40
    
    # 加载数据，并对用户产生的数据进行分组
    train_dataset, test_dataset, user_groups, user_groups_test = get_dataset("cifar10", device_num)

    # Training
    epochs = 502
    lr = 0.01

    uid = [_ for _ in range(device_num)]
    modelAccept = {i: None for i in range(device_num)}
    local_acc = [[] for _ in range(device_num)]

    # 不同的设备采用不同架构的网络
    for idx in uid:
        
        if idx < 10:
            modelAccept[idx] = vgg11_bn()

        elif 10 <= idx < 20:
            modelAccept[idx] = vgg13_bn()

        elif 20 <= idx < 30:
            modelAccept[idx] = vgg16_bn()

        else:
            modelAccept[idx] = vgg19_bn()


    for epoch in range(epochs):

        print(f'\n | Global Training Round : {epoch + 1} |\n')

        for idx in uid:

            train_all = list(user_groups[idx])

            if epoch == 0:
                model = modelAccept[idx]

            if epoch > 0:
                if idx < 10:
                    model = vgg11_bn()
                    
                elif 10 <= idx < 20:
                    model = vgg13_bn()

                elif 20 <= idx < 30:
                    model = vgg16_bn()

                else:
                    model = vgg19_bn()
                    
                model.load_state_dict(modelAccept[idx])  # 加载对应的模型 
            
            model.to(device)

            criterion = nn.CrossEntropyLoss()
            optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)  
            scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=4, gamma=0.4, last_epoch=-1)

            train_loader = DataLoader(DatasetSplit(train_dataset, train_all), batch_size=64, shuffle=True)

            for i in range(10):

                model.train()
                correct, total = 0, 0

                for batch_idx, (inputs, targets) in enumerate(train_loader):
                    inputs, targets = inputs.to(device), targets.to(device)
                    optimizer.zero_grad()
                    outputs = model(inputs)
                    loss = criterion(outputs, targets)
                    loss.backward()
                    optimizer.step()

                    _, predicted = outputs.max(1)
                    total += targets.size(0)
                    correct += predicted.eq(targets).sum().item()

            print('Local : |  Acc: %.3f%% (%d/%d)' %(100. * correct / total, correct, total))
            
            
            modelAccept[idx] = copy.deepcopy(model.state_dict())
            
            # 预测
            acc = predict(model, test_dataset, user_groups_test[idx], device)
            local_acc[idx].append(round(acc, 2))
            print(local_acc[idx])
            

#         modelAccept = common_max(modelAccept)
        modelAccept = common_basic(modelAccept)
#         modelAccept = common_cluster(modelAccept)



 | Global Training Round : 1 |

Local : |  Acc: 34.400% (430/1250)
[0.28]
Local : |  Acc: 32.480% (406/1250)
[0.26]
Local : |  Acc: 33.280% (416/1250)
[0.32]
Local : |  Acc: 36.880% (461/1250)
[0.36]
Local : |  Acc: 34.160% (427/1250)
[0.31]
Local : |  Acc: 35.040% (438/1250)
[0.28]
Local : |  Acc: 37.040% (463/1250)
[0.26]
Local : |  Acc: 33.040% (413/1250)
[0.32]
Local : |  Acc: 32.960% (412/1250)
[0.34]
Local : |  Acc: 34.000% (425/1250)
[0.34]
Local : |  Acc: 36.000% (450/1250)
[0.37]
Local : |  Acc: 34.720% (434/1250)
[0.35]
Local : |  Acc: 36.000% (450/1250)
[0.36]
Local : |  Acc: 37.760% (472/1250)
[0.33]
Local : |  Acc: 36.560% (457/1250)
[0.33]
Local : |  Acc: 34.240% (428/1250)
[0.34]
Local : |  Acc: 38.160% (477/1250)
[0.24]
Local : |  Acc: 34.160% (427/1250)
[0.4]
Local : |  Acc: 34.480% (431/1250)
[0.35]
Local : |  Acc: 34.160% (427/1250)
[0.34]
Local : |  Acc: 30.480% (381/1250)
[0.28]
Local : |  Acc: 38.000% (475/1250)
[0.34]
Local : |  Acc: 34.880% (436/1250)
[0.36]
Lo