# 基本定义

In [None]:
import matplotlib.pyplot as plt
import random
import time
import pickle

In [None]:
import torch
# from torchvision.datasets import MNIST
import torchvision
from torchvision import transforms
import matplotlib.pyplot as plt

## 数据集/模型

### MLP + MNIST

In [None]:
# MLP
optConfig = {
    'honestSize': 50,
    'byzantineSize': 20,

    'rounds': 15,
    'displayInterval': 1000,

    'weight_decay': 0.00,
    
    'fixSeed': False,
    'SEED': 100,
    
    'batchSize': 5,
    'shuffle': True,
}

# 数据集属性
dataSetConfig = {
    'name': 'mnist',

    'dataSet' : 'mnist',
    'dataSetSize': 60000,
    'maxFeature': 784,

    'honestNodeSize': 50,
    'byzantineNodeSize': 20,

    'rounds': 15,
    'displayInterval': 1000,
}

SGDConfig = optConfig.copy()
SGDConfig['gamma'] = 1e-1

batchConfig = optConfig.copy()
batchConfig['batchSize'] = 50
batchConfig['gamma'] = 5e-1

SVRGConfig = optConfig.copy()
SVRGConfig['snapshotInterval'] = dataSetConfig['dataSetSize']
SVRGConfig['gamma'] = 1e-1

SAGAConfig = optConfig.copy()
SAGAConfig['gamma'] = 1e-1

SARAHConfig = optConfig.copy()
SARAHConfig['gamma'] = 1e-1

# 加载数据集
train_transform = transforms.Compose([
    transforms.ToTensor(), # Convert a PIL Image or numpy.ndarray to tensor.
    # Normalize a tensor image with mean 0.1307 and standard deviation 0.3081
    transforms.Normalize((0.1307,), (0.3081,))
])
test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = torchvision.datasets.MNIST(root='./dataset/', 
                            train=True, 
                            transform=train_transform,
                            download=True)
validate_dataset = torchvision.datasets.MNIST(root='./dataset/', 
                           train=False, 
                           transform=test_transform,
                           download=False)

# 模型
class MLP(torch.nn.Module):
    """
    Inputs                Linear/Function        Output
    [128, 1, 28, 28]   -> Linear(28*28, 100) -> [128, 100]  # first hidden layer
                       -> Tanh               -> [128, 100]  # Tanh activation function, may sigmoid
                       -> Linear(100, 100)   -> [128, 100]  # third hidden layer
                       -> Tanh               -> [128, 100]  # Tanh activation function, may sigmoid
                       -> Linear(100, 10)    -> [128, 10]   # Classification Layer                                                          
   """
    def __init__(self, input_size, hidden_size, output_size, SEED=100):
        super(MLP, self).__init__()
        self.hidden = torch.nn.Linear(input_size, hidden_size)
        self.classification_layer = torch.nn.Linear(hidden_size, output_size)
        
        self.tanh1 = torch.nn.Tanh()
        self.tanh2 = torch.nn.Tanh()
        
        self.softmax = torch.nn.Softmax(dim=1)
        
    def forward(self, x):
        """Defines the computation performed at every call.
           Should be overridden by all subclasses.
        Args:
            x: [batch_size, channel, height, width], input for network
        Returns:
            out: [batch_size, n_classes], output from network
        """
        
        out = x.view(x.size(0), -1) # flatten x in [128, 784]
        out = self.tanh1(out)
        out = self.hidden(out)
        out = self.tanh2(out)
        out = self.classification_layer(out)
        out = self.softmax(out)
        return out
    
# 模型工厂
def modelFactory(SEED=100):
    return MLP(784, 50, 10)

### ResNet + CIFAR10

In [None]:
# # ResNet50 + CIFAR10
# optConfig = {
#     'honestSize': 10,
#     'byzantineSize': 4,

#     'rounds': 15,
#     'displayInterval': 6000,
    
#     'weight_decay': 0.0001,
    
#     'fixSeed': False,
#     'SEED': 100,
    
#     'batchSize': 5,
#     'shuffle': True,
# }

# SGDConfig = optConfig.copy()
# SGDConfig['gamma'] = 1e-1

# batchConfig = optConfig.copy()
# batchConfig['batchSize'] = 50
# batchConfig['gamma'] = 5e-1

# SVRGConfig = optConfig.copy()
# SVRGConfig['snapshotInterval'] = dataSetConfig['dataSetSize']
# SVRGConfig['gamma'] = 1e-1

# SAGAConfig = optConfig.copy()
# SAGAConfig['gamma'] = 1e-1

# SARAHConfig = optConfig.copy()
# SARAHConfig['gamma'] = 1e-1

# # 数据集属性
# dataSetConfig = {
#     'name': 'CIFAR-10',

#     'dataSet' : 'CIFAR-10',
#     'dataSetSize': 60000,
#     'maxFeature': 32*32*3,
# }

# # 加载数据集
# preprocess = transforms.Compose([
#     transforms.Resize(256),
#     transforms.CenterCrop(224),
#     transforms.ToTensor(),
#     transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
# ])
# train_dataset = torchvision.datasets.CIFAR10(root='./dataset/',
#                                              train=True, 
#                                              transform=preprocess,
#                                              download=False)
# validate_dataset = torchvision.datasets.CIFAR10(root='./dataset/',
#                                             train=False, 
#                                             transform=preprocess)

# 模型工厂
# def modelFactory(SEED=100):
#     return torchvision.models.resnet50()

## 运行参数

In [None]:
CACHE_DIR = './cache/' + dataSetConfig['name'] + '_'

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

## 辅助函数

In [None]:
# 报告函数
def log(*k, **kw):
    timeStamp = time.strftime('[%m-%d %H:%M:%S] ', time.localtime())
    print(timeStamp, end='')
    print(*k, **kw)
def debug(*k, **kw):
    timeStamp = time.strftime('[%m-%d %H:%M:%S] (debug)', time.localtime())
    print(timeStamp, end='')
    print(*k, **kw)

## 损失函数

In [None]:
loss_func = torch.nn.CrossEntropyLoss()

In [None]:
def getVarience(w_local, honestSize):
    avg = w_local[:honestSize].mean(dim=0)
    s = 0
    for w in w_local[:honestSize]:
        s += (w - avg).norm()**2
    s /= honestSize
    return s.item()

In [None]:
def calculateAccuracy(model, loader, device):
    loss = 0
    accuracy = 0
    total = 0
    
    for material, targets in loader:
        material, targets = material.to(device), targets.to(device)
        outputs = model(material)
        
        l = loss_func(outputs, targets)

        loss += l.item() * len(targets)
        _, predicted = torch.max(outputs.data, dim=1)
        accuracy += (predicted == targets).sum().item()
        total += len(targets)
    
    loss /= total
    accuracy /= total
    
    return loss, accuracy

## 聚合函数

In [None]:
def mean(wList):
    return torch.mean(wList, dim=0)

In [None]:
def gm(wList):
    max_iter = 80
    tol = 1e-5
    guess = torch.mean(wList, dim=0)
    for _ in range(max_iter):
        dist_li = torch.norm(wList-guess, dim=1)
        for i in range(len(dist_li)):
            if dist_li[i] == 0:
                dist_li[i] = 1
        temp1 = torch.sum(torch.stack([w/d for w, d in zip(wList, dist_li)]), dim=0)
        temp2 = torch.sum(1/dist_li)
        guess_next = temp1 / temp2
        guess_movement = torch.norm(guess - guess_next)
        guess = guess_next
        if guess_movement <= tol:
            break
    return guess

In [None]:
def Krum_(nodeSize, byzantineSize):
    honestSize = nodeSize - byzantineSize
    dist = torch.zeros(nodeSize, nodeSize, dtype=torch.float32)
    def Krum(wList):
        for i in range(nodeSize):
            for j in range(i, nodeSize):
                distance = wList[i].data - wList[j].data
                distance = (distance*distance).sum()
                dist[i][j] = distance.data
                dist[j][i] = distance.data
        k = nodeSize - byzantineSize - 2 + 1 # 算上自己和自己的0.00
        topv, _ = dist.topk(k=k, dim=1)
        sumdist = -topv.sum(dim=1)
        resindex = sumdist.topk(1)[1].squeeze()
        return wList[resindex]
    return Krum

In [None]:
def median(wList):
    return wList.median(dim=0)[0]

## torch辅助函数

In [None]:
def flatten_list(message, byzantineSize):
    wList = [torch.cat([p.flatten() for p in parameters]) for parameters in message]
    wList.extend([torch.zeros_like(wList[0]) for _ in range(byzantineSize)])
    wList = torch.stack(wList)
    return wList
def unflatten_vector(vector, model):
    paraGroup = []
    cum = 0
    for p in model.parameters():
        newP = vector[cum:cum+p.numel()]
        paraGroup.append(newP.view_as(p))
        cum += p.numel()
    return paraGroup

In [None]:
def randomSample(dataset, batchSize):
    m, t = zip(*random.sample(dataset, batchSize))
    material, targets = torch.cat(m), torch.tensor(t)
    return material, targets

In [None]:
def getPara(module, useString=True):
    para = sum([x.nelement() for x in module.parameters()])
    if not useString:
        return para
    elif para >= 2**20:
        return '{:.2f}M'.format(para / 2**20)
    elif para >= 2**10:
        return '{:.2f}K'.format(para / 2**10)
    else:
        return str(para)

# 优化算法

报告函数

In [None]:
def report(r, rounds, displayInterval, trainLoss, trainAccuracy, valLoss, valAccuracy, var=None):
    varStr = '' if (var == None) else ' var={:.2e}'.format(var)
    log('[{}/{}](interval: {:.0f}) train: loss={:.4f} acc={:.2f} val: loss={:.4f} acc={:.2f}{}'
        .format(r, rounds, displayInterval, trainLoss, trainAccuracy, valLoss, valAccuracy, varStr)
    )

## CentralSGD

In [None]:
def CentralSGD(model, gamma, aggregate, weight_decay, attack=None, 
          rounds=10, displayInterval=1000, 
          device='cpu', SEED=100, fixSeed=False, 
          batchSize=1,
          **kw):
    if fixSeed:
        random.seed(SEED)

    # 顺序遍历loader
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batchSize, shuffle=False)
    validate_loader = torch.utils.data.DataLoader(dataset=validate_dataset, batch_size=batchSize, shuffle=False)

    # 随机取样器
    randomSampler = lambda dataset: torch.utils.data.sampler.RandomSampler(
        dataset, 
        num_samples=rounds*displayInterval*batchSize, 
        replacement=True
    )
    train_random_loaders_splited = [torch.utils.data.DataLoader(
        dataset=subset,
        batch_size=batchSize, 
        sampler=randomSampler(subset),
    ) for subset in train_dataset_subset]
    randomIters = [iter(loader) for loader in train_random_loaders_splited]
    
    # 求初始误差
    trainLoss, trainAccuracy = calculateAccuracy(model, train_loader, device)
    valLoss, valAccuracy = calculateAccuracy(model, validate_loader, device)

    trainLossPath = [trainLoss]
    trainAccPath = [trainAccuracy]
    valLossPath = [valLoss]
    valAccPath = [valAccuracy]
    
    report(0, rounds, displayInterval, trainLoss, trainAccuracy, valLoss, valAccuracy)

    for r in range(rounds):
        model.train()
        for k in range(displayInterval):
            # 读取数据
            material, targets = next(randomIter)
            material, targets = material.to(device), targets.to(device)

            # 随机梯度
            # --------------------
            # 预测
            outputs = model(material)
            loss = loss_func(outputs, targets)
            # 反向传播
            model.zero_grad()
            loss.backward()

            # 更新
            for para in model.parameters():
                para.data.add_(-gamma, para.grad)
                para.data.add_(-weight_decay, para)
        
        
        model.eval()
        trainLoss, trainAccuracy = calculateAccuracy(model, train_loader, device)
        valLoss, valAccuracy = calculateAccuracy(model, validate_loader, device)

        trainLossPath.append(trainLoss)
        trainAccPath.append(trainAccuracy)
        valLossPath.append(valLoss)
        valAccPath.append(valAccuracy)

        report(r+1, rounds, displayInterval, trainLoss, trainAccuracy, valLoss, valAccuracy)
    return model, trainLossPath, trainAccPath, valLossPath, valAccPath, []

## Central SARAH

In [None]:
def CentralSARAH(model, gamma, aggregate, weight_decay, 
          snapshotInterval=len(train_dataset),
          rounds=10, displayInterval=1000, 
          device='cpu', SEED=100, fixSeed=False, 
          batchSize=5,
          **kw):
    
    if fixSeed:
        random.seed(SEED)
    
    # 初始化模型
    lastModel = modelFactory(SEED=SEED)
    lastModel = lastModel.to(device)

    # 随机的停止期限
    randomStop = 1
    
    # 顺序遍历loader
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batchSize, shuffle=False)
    validate_loader = torch.utils.data.DataLoader(dataset=validate_dataset, batch_size=batchSize, shuffle=False)
    
    # 随机取样器
    randomSampler = torch.utils.data.sampler.RandomSampler(
        train_dataset, 
        num_samples=rounds*displayInterval*batchSize, 
        replacement=True
    )
    randomLoader = torch.utils.data.DataLoader(
        dataset=train_dataset,
        batch_size=batchSize, 
        sampler=randomSampler,
    )
    randomIter = iter(randomLoader)
    
    # 求初始误差
    trainLoss, trainAccuracy = calculateAccuracy(model, train_loader, device)
    valLoss, valAccuracy = calculateAccuracy(model, validate_loader, device)
    
    trainLossPath = [trainLoss]
    trainAccPath = [trainAccuracy]
    valLossPath = [valLoss]
    valAccPath = [valAccuracy]
    
    log('[SARAH]初始 train: loss={:.6f} accuracy={:.2f} validation: loss={:.6f} accuracy={:.2f}'
        .format(trainLossPath[0], trainAccPath[0], valLossPath[0], valAccPath[0])
    )

    gradients = [torch.zeros_like(para, requires_grad=False) for para in model.parameters()]
    
    for r in range(rounds):
        for k in range(displayInterval):
            # snapshot
            if (r*displayInterval + k) % randomStop == 0:
                # 清空旧梯度
                for grad in gradients:
                    grad.zero_()
                for material, targets in train_loader:
                    material, targets = material.to(device), targets.to(device)
                    # 预测
                    outputs = model(material)
                    loss = loss_func(outputs, targets)
                    # 反向传播
                    model.zero_grad()
                    loss.backward()

                    for grad, para in zip(gradients, model.parameters()):
                        grad.data.add_(1/len(train_loader), para.grad.data)
                for grad, para in zip(gradients, model.parameters()):
                    grad.data.add_(weight_decay, para.data)
                
                # 保存旧结果
                for oldPara, newPara in zip(lastModel.parameters(), model.parameters()):
                    oldPara.data.copy_(newPara)
                # 更新
                for para, grad in zip(model.parameters(), gradients):
                    para.data.add_(-gamma, grad)
                # 指定下一次停止时间
                randomStop = random.randint(1, snapshotInterval-1)
                
            # 更新
            # 读取数据
            material, targets = next(randomIter)
            material, targets = material.to(device), targets.to(device)

            # 随机梯度
            # --------------------
            # 预测
            outputs = model(material)
            loss = loss_func(outputs, targets)
            # 反向传播
            model.zero_grad()
            loss.backward()

            # 修正梯度
            # --------------------
            # 预测
            outputs = lastModel(material)
            loss = loss_func(outputs, targets)
            # 反向传播
            lastModel.zero_grad()
            loss.backward()

            # 更新梯度表
            for pi, para in enumerate(model.parameters()):
                gradients[pi].data.add_(1, para.grad.data)
                gradients[pi].data.add_(weight_decay, para)
            for pi, para in enumerate(lastModel.parameters()):
                gradients[pi].data.sub_(1, para.grad.data)
                gradients[pi].data.sub_(weight_decay, para)

            # 保存旧结果
            for oldPara, newPara in zip(lastModel.parameters(), model.parameters()):
                oldPara.data.copy_(newPara)
            # 更新
            for para, grad in zip(model.parameters(), gradients):
                para.data.add_(-gamma, grad)
  
        trainLoss, trainAccuracy = calculateAccuracy(model, train_loader, device)
        valLoss, valAccuracy = calculateAccuracy(model, validate_loader, device)

        trainLossPath.append(trainLoss)
        trainAccPath.append(trainAccuracy)
        valLossPath.append(valLoss)
        valAccPath.append(valAccuracy)
        
        report(r+1, rounds, displayInterval, trainLoss, trainAccuracy, valLoss, valAccuracy)
    return model, trainLossPath, trainAccPath, valLossPath, valAccPath, []

## SGD

In [None]:
def SGD(model, gamma, aggregate, weight_decay, 
          honestSize=0, byzantineSize=0, attack=None, 
          rounds=10, displayInterval=1000, 
          device='cpu', SEED=100, fixSeed=False, 
          batchSize=5,
          **kw):
    assert byzantineSize == 0 or attack != None
    assert honestSize != 0
    
    if fixSeed:
        random.seed(SEED)

    nodeSize = honestSize + byzantineSize

    # 数据分片
    pieces = [(i*len(train_dataset)) // honestSize for i in range(honestSize+1)]
    dataPerNode = [pieces[i+1] - pieces[i] for i in range(honestSize)]

    # 回复的消息
    message = [
        [torch.zeros_like(para, requires_grad=False) for para in model.parameters()]
        for _ in range(nodeSize)]  
    
    # 顺序遍历loader
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batchSize, shuffle=False)
    validate_loader = torch.utils.data.DataLoader(dataset=validate_dataset, batch_size=batchSize, shuffle=False)
    
    train_dataset_subset = [torch.utils.data.Subset(train_dataset, range(pieces[i], pieces[i+1])) for i in range(honestSize)]
    train_loaders_splited = [
        torch.utils.data.DataLoader(dataset=subset, batch_size=batchSize, shuffle=False)
        for subset in train_dataset_subset
    ]
    
    # 随机取样器
    randomSampler = lambda dataset: torch.utils.data.sampler.RandomSampler( 
        dataset, 
        num_samples=rounds*displayInterval*batchSize, 
        replacement=True #有放回取样
    )
    train_random_loaders_splited = [torch.utils.data.DataLoader(
        dataset=subset,
        batch_size=batchSize, 
        sampler=randomSampler(subset),
    ) for subset in train_dataset_subset]
    
    randomIters = [iter(loader) for loader in train_random_loaders_splited]
    
    # 求初始误差
    trainLoss, trainAccuracy = calculateAccuracy(model, train_loader, device)
    valLoss, valAccuracy = calculateAccuracy(model, validate_loader, device)
    
    trainLossPath = [trainLoss]
    trainAccPath = [trainAccuracy]
    valLossPath = [valLoss]
    valAccPath = [valAccuracy]
    variencePath = []
    
    report(0, rounds, displayInterval, trainLoss, trainAccuracy, valLoss, valAccuracy)

    for r in range(rounds):
        for k in range(displayInterval):
            # 诚实节点更新
            for node in range(honestSize):
                # 读取数据
                material, targets = next(randomIters[node]) 
                
                # 随机梯度
                # --------------------
                # 预测
                outputs = model(material)
                loss = loss_func(outputs, targets)
                # 反向传播
                model.zero_grad()
                loss.backward()

                # 更新梯度表
                for pi, para in enumerate(model.parameters()):
                    message[node][pi].data.zero_()
                    message[node][pi].data.add_(1, para.grad.data)
                    message[node][pi].data.add_(weight_decay, para)

            # 同步, Byzantine攻击
            message_f = flatten_list(message, byzantineSize) 
            if attack != None:
                attack(message_f, byzantineSize)
            # 聚合
            g_vector = aggregate(message_f)
            # 展开
            g = unflatten_vector(g_vector, model) 
            # 更新
            for para, grad in zip(model.parameters(), g):
                para.data.add_(-gamma, grad)
  
        var = getVarience(message_f, honestSize)
        variencePath.append(var)
        
        trainLoss, trainAccuracy = calculateAccuracy(model, train_loader, device)
        valLoss, valAccuracy = calculateAccuracy(model, validate_loader, device)

        trainLossPath.append(trainLoss)
        trainAccPath.append(trainAccuracy)
        valLossPath.append(valLoss)
        valAccPath.append(valAccuracy)
        
        report(r+1, rounds, displayInterval, trainLoss, trainAccuracy, valLoss, valAccuracy)
    return model, trainLossPath, trainAccPath, valLossPath, valAccPath, variencePath

## SAGA

In [None]:
# 初始化本地模型
def initModel(local_models, honestSize):
    stateDict = local_models[0].state_dict()
    for model in local_models[1:honestSize]:
        model.load_state_dict(stateDict)

# 广播
def broadcastPara(newPara, local_models):
    cum = 0
    for p in local_models[0].parameters():
        newP = newPara[cum:cum+p.numel()]
        p.data.copy_(newP.view_as(p))
        cum += p.numel()
    stateDict = local_models[0].state_dict()
    for model in local_models[1:]:
        model.load_state_dict(stateDict)

In [None]:
def SAGA(model, gamma, aggregate, weight_decay, 
          honestSize=0, byzantineSize=0, attack=None, 
          rounds=10, displayInterval=1000, 
          device='cpu', SEED=100, fixSeed=False, 
          batchSize=1,
          **kw):
    assert byzantineSize == 0 or attack != None
    assert honestSize != 0
    
    if fixSeed:
        random.seed(SEED)

    nodeSize = honestSize + byzantineSize
    
    # 数据分片
    pieces = [(i*len(train_dataset)) // honestSize for i in range(honestSize+1)]
    dataPerNode = [pieces[i+1] - pieces[i] for i in range(honestSize)]
    
    #创建变量
    store = []
    
    # 顺序遍历loader
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batchSize, shuffle=False)
    validate_loader = torch.utils.data.DataLoader(dataset=validate_dataset, batch_size=batchSize, shuffle=False)
    
    train_dataset_subset = [torch.utils.data.Subset(train_dataset, range(pieces[i], pieces[i+1])) for i in range(honestSize)]
    train_loaders_splited = [
        torch.utils.data.DataLoader(dataset=subset, batch_size=batchSize, shuffle=False)
        for subset in train_dataset_subset
    ]
    
    # 随机取样器
    randomSampler = lambda dataset: torch.utils.data.sampler.RandomSampler( 
        dataset, 
        num_samples=rounds*displayInterval*batchSize, #取样规模：10*1500*batchSize 
        replacement=True #有放回取样
    )
    train_random_loaders_splited = [torch.utils.data.DataLoader(
        dataset=subset,
        batch_size=batchSize, 
        sampler=randomSampler(subset),
    ) for subset in train_dataset_subset]
    
    randomIters = [iter(loader) for loader in train_random_loaders_splited]
    
    # 求初始误差
    trainLoss, trainAccuracy = calculateAccuracy(model, train_loader, device)
    valLoss, valAccuracy = calculateAccuracy(model, validate_loader, device)
    
    trainLossPath = [trainLoss]
    trainAccPath = [trainAccuracy]
    valLossPath = [valLoss]
    valAccPath = [valAccuracy]
    variencePath = []
    
    report(0, rounds, displayInterval, trainLoss, trainAccuracy, valLoss, valAccuracy)
    
    #对所有样本的权重梯度进行初始化
    for index, (material, targets) in enumerate(train_dataset):
        # 计算Loss
        outputs = model(material)
        targets = torch.tensor([targets])        
        loss = loss_func(outputs, targets)
        
        # 反向传播
        model.zero_grad()
        loss.backward()
        
        store.append([p.grad.clone().detach() for p in model.parameters()])
    
    # G_avg每一行是单个节点上存储的均值
    G_avg = []
    for i in range(honestSize):
        # storeInThisNode：该节点上梯度缓存的集合
        storeInThisNode = store[pieces[i]: pieces[i+1]]
        # para每一个元素是在对应节点上的一组参数
        (*paras,) = zip(*storeInThisNode)
        # 对所有单一节点上所有数据求平均
        G_avg.append([sum(para)/(pieces[i+1]-pieces[i]) for para in paras])
    
    # 回复的消息
    message = [
        [torch.zeros_like(para, requires_grad=False) for para in model.parameters()]
        for _ in range(nodeSize)
    ]
    
    for r in range(rounds):
        for k in range(displayInterval):
            # 诚实节点更新
            for node in range(honestSize):
                # 读取数据
                index = random.randint(pieces[node], pieces[node+1]-1)
                # 预测
                material, targets = train_dataset[index]
                # 计算Loss
                outputs = model(material)
                targets = torch.tensor([targets])
                loss = loss_func(outputs, targets)
                
                # 反向传播
                model.zero_grad()                
                loss.backward()

                # 更新梯度表
                for pi, para in enumerate(model.parameters()):
                    old_G = store[index][pi]
                    new_G = para.grad.data.clone()
                    new_G.add_(weight_decay, para.data)

                    gradient = new_G.data - old_G.data + G_avg[node][pi].data
                    
                    message[node][pi] =gradient

                    G_avg[node][pi].add_(1 / dataPerNode[node],new_G.data - old_G.data)
                    
                    store[index][pi] = new_G.data
                
                #攻击
                message_f = flatten_list(message, byzantineSize) #将原本parameters的tensor形式压缩成torch.Size([90, 39760])
                if attack != None:
                    attack(message_f, byzantineSize)
                # 聚合
                g_vector = aggregate(message_f)
                # 展开
                g = unflatten_vector(g_vector, model) #展开成原本parameters的tensor形式
                # 更新
                for para, grad in zip(model.parameters(), g):
                    para.data.add_(-gamma, grad)
        
        var = getVarience(message_f, honestSize)
        variencePath.append(var)
        
        trainLoss, trainAccuracy = calculateAccuracy(model, train_loader, device)
        valLoss, valAccuracy = calculateAccuracy(model, validate_loader, device)

        trainLossPath.append(trainLoss)
        trainAccPath.append(trainAccuracy)
        valLossPath.append(valLoss)
        valAccPath.append(valAccuracy)
        
        report(r+1, rounds, displayInterval, trainLoss, trainAccuracy, valLoss, valAccuracy)
    return model, trainLossPath, trainAccPath, valLossPath, valAccPath, variencePath              

## SVRG

In [None]:
def SVRG(w0, gamma, aggregate, weight_decay, honestSize=0, byzantineSize=0, attack=None, 
            snapshotInterval=6000, rounds=10, displayInterval=1000, SEED=100, fixSeed=False, **kw):
    assert byzantineSize == 0 or attack != None
    assert honestSize != 0
    
    if fixSeed:
        random.seed(SEED)

    nodeSize = honestSize + byzantineSize
    
    # 初始化
    w = w0.clone().detach()

    # 数据分片
    pieces = [(i*len(dataset)) // honestSize for i in range(honestSize+1)]
    dataPerNode = [pieces[i+1] - pieces[i] for i in range(honestSize)]

    snapshot_g = torch.zeros(honestSize, len(w0), dtype=torch.float64)
    snapshot_w = torch.zeros(len(w0), dtype=torch.float64)

    path = [F(w, dataset, weight_decay)]
    variencePath = []
    log('[SVRG]初始 loss={:.6f}, accuracy={:.2f} gamma={:}'.format(path[0], accuracy(w, dataset), gamma))
    
    # 中间变量分配空间
    message = torch.zeros(nodeSize, len(w0), dtype=torch.float64)

    log('开始迭代')
    for r in range(rounds):
        for k in range(displayInterval):
            # snapshot
            if (r*displayInterval + k) % snapshotInterval == 0:
                snapshot_g.zero_()
                for node in range(honestSize):
                    for index in range(pieces[node], pieces[node+1]):
                        x, y = dataset[index]
                        # 更新梯度表
                        predict = LogisticRegression(w, x)

                        err = (predict-y).data
                        snapshot_g[node][:-1].add_(1/dataPerNode[node], err*x)
                        snapshot_g[node][-1].add_(1/dataPerNode[node], err)
                    snapshot_g[node].add_(weight_decay, w)
                snapshot_w.copy_(w)
            
            # 诚实节点更新
            message.zero_()
            for node in range(honestSize):
                index = random.randint(pieces[node], pieces[node+1]-1)

                x, y = dataset[index]
                # 随机梯度
                predict = LogisticRegression(w, x)
                err = (predict-y).data
                message[node][:-1].add_(err, x)
                message[node][-1].add_(err, 1)
                message[node].add_(weight_decay, w)
                
                # 修正梯度
                predict = LogisticRegression(snapshot_w, x)
                err = (predict-y).data
                message[node][:-1].add_(-err, x)
                message[node][-1].add_(-err, 1)
                message[node].add_(-weight_decay, snapshot_w)
                
                message[node].add_(1, snapshot_g[node])
                
            # 同步
            # Byzantine攻击
            if attack != None:
                attack(message, byzantineSize)
            g = aggregate(message)
            w.add_(-gamma, g)
            
        loss = F(w, dataset, weight_decay)
        acc = accuracy(w, dataset)
        path.append(loss)
        var = getVarience(message, honestSize)
        variencePath.append(var)
        log('[SVRG]已迭代 {}/{} rounds (interval: {:.0f}), loss={:.9f}, accuracy={:.2f}, var={:.9f}'.format(
            r+1, rounds, displayInterval, loss, acc, var
        ))
    return w, path, variencePath

## SARAH

In [None]:
def SARAH(model, gamma, aggregate, weight_decay, 
          snapshotInterval=len(train_dataset),
          honestSize=0, byzantineSize=0, attack=None, 
          rounds=10, displayInterval=1000, 
          device='cpu', SEED=100, fixSeed=False, 
          batchSize=5,
          **kw):
    assert byzantineSize == 0 or attack != None
    assert honestSize != 0
    
    if fixSeed:
        random.seed(SEED)

    nodeSize = honestSize + byzantineSize
    
    # 初始化模型
    lastModel = modelFactory(SEED=SEED)

    if device == 'cpu':
        torch.manual_seed(SEED)#为CPU设置随机种子
    else:
        torch.cuda.manual_seed(seed)#为当前GPU设置随机种子
        torch.cuda.manual_seed_all(seed)#为所有GPU设置随机种子
    
    # 数据分片
    pieces = [(i*len(train_dataset)) // honestSize for i in range(honestSize+1)]
    dataPerNode = [pieces[i+1] - pieces[i] for i in range(honestSize)]

    # 随机的停止期限
    randomStop = 1
    # 回复的消息
    message = [
        [torch.zeros_like(para, requires_grad=False) for para in model.parameters()]
        for _ in range(nodeSize)
    ]
    
    # 顺序遍历loader
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batchSize, shuffle=False)
    validate_loader = torch.utils.data.DataLoader(dataset=validate_dataset, batch_size=batchSize, shuffle=False)
    
    train_dataset_subset = [torch.utils.data.Subset(train_dataset, range(pieces[i], pieces[i+1])) for i in range(honestSize)]
    train_loaders_splited = [
        torch.utils.data.DataLoader(dataset=subset, batch_size=batchSize, shuffle=False)
        for subset in train_dataset_subset
    ]
    
    # 随机取样器
    randomSampler = lambda dataset: torch.utils.data.sampler.RandomSampler(
        dataset, 
        num_samples=rounds*displayInterval*batchSize, 
        replacement=True
    )
    train_random_loaders_splited = [torch.utils.data.DataLoader(
        dataset=subset,
        batch_size=batchSize, 
        sampler=randomSampler(subset),
    ) for subset in train_dataset_subset]
    randomIters = [iter(loader) for loader in train_random_loaders_splited]
    
    # 求初始误差
    trainLoss, trainAccuracy = calculateAccuracy(model, train_loader)
    valLoss, valAccuracy = calculateAccuracy(model, validate_loader)
    
    trainLossPath = [trainLoss]
    trainAccPath = [trainAccuracy]
    valLossPath = [valLoss]
    valAccPath = [valAccuracy]
    variencePath = []
    
    log('[SARAH]初始 train: loss={:.6f} accuracy={:.2f} validation: loss={:.6f} accuracy={:.2f}'
        .format(trainLossPath[0], trainAccPath[0], valLossPath[0], valAccPath[0])
    )

    for r in range(rounds):
        for k in range(displayInterval):
            # snapshot
            if (r*displayInterval + k) % randomStop == 0:
                for node in range(honestSize):
                    # 清空旧梯度
                    for grad in message[node]:
                        grad.zero_()
                    loader = train_loaders_splited[node]
                    for material, targets in loader:
                        # 预测
                        outputs = model(material)
                        loss = loss_func(outputs, targets)
                        # 反向传播
                        model.zero_grad()
                        loss.backward()
                        
                        for grad, para in zip(message[node], model.parameters()):
                            grad.data.add_(1/len(loader), para.grad.data)
                    for grad, para in zip(message[node], model.parameters()):
                        grad.data.add_(weight_decay, para.data)
                
                # 保存旧结果
                for oldPara, newPara in zip(lastModel.parameters(), model.parameters()):
                    oldPara.data.copy_(newPara)
                # 同步, Byzantine攻击
                message_f = flatten_list(message, byzantineSize)
                if attack != None:
                    attack(message_f, byzantineSize)
                # 聚合
                g_vector = aggregate(message_f)
                # 展开
                g = unflatten_vector(g_vector, model)
                # 更新
                for para, grad in zip(model.parameters(), g):
                    para.data.add_(-gamma, grad)
                # 指定下一次停止时间
                randomStop = random.randint(1, snapshotInterval-1)
                
            # 诚实节点更新
            for node in range(honestSize):
                # 读取数据
                material, targets = next(randomIters[node])
                
                # 随机梯度
                # --------------------
                # 预测
                outputs = model(material)
                loss = loss_func(outputs, targets)
                # 反向传播
                model.zero_grad()
                loss.backward()
                
                # 修正梯度
                # --------------------
                # 预测
                outputs = lastModel(material)
                loss = loss_func(outputs, targets)
                # 反向传播
                lastModel.zero_grad()
                loss.backward()

                # 更新梯度表
                for pi, para in enumerate(model.parameters()):
                    message[node][pi].data.add_(1, para.grad.data)
                    message[node][pi].data.add_(weight_decay, para)
                for pi, para in enumerate(lastModel.parameters()):
                    message[node][pi].data.sub_(1, para.grad.data)
                    message[node][pi].data.sub_(weight_decay, para)

            # 同步, Byzantine攻击
            message_f = flatten_list(message, byzantineSize)
            if attack != None:
                attack(message_f, byzantineSize)
            # 聚合
            g_vector = aggregate(message_f)
            # 展开
            g = unflatten_vector(g_vector, model)
            # 保存旧结果
            for oldPara, newPara in zip(lastModel.parameters(), model.parameters()):
                oldPara.data.copy_(newPara)
            # 更新
            for para, grad in zip(model.parameters(), g):
                para.data.add_(-gamma, grad)
  
        var = getVarience(message_f, honestSize)
        variencePath.append(var)
        
        trainLoss, trainAccuracy = calculateAccuracy(model, train_loader)
        valLoss, valAccuracy = calculateAccuracy(model, validate_loader)

        trainLossPath.append(trainLoss)
        trainAccPath.append(trainAccuracy)
        valLossPath.append(valLoss)
        valAccPath.append(valAccuracy)
        
        report(r+1, rounds, displayInterval, trainLoss, trainAccuracy, valLoss, valAccuracy)
    return model, trainLossPath, trainAccPath, valLossPath, valAccPath, variencePath

# 恶意攻击

In [None]:
def white(messages, byzantinesize):
    # 均值相同，方差为30
    mu = torch.mean(messages[0:-byzantinesize], dim=0)
    messages[-byzantinesize:].copy_(mu)
    noise = torch.randn((byzantinesize, messages.size(1)), dtype=torch.float64)
    messages[-byzantinesize:].add_(30, noise)
    
def maxValue(messages, byzantinesize):
    mu = torch.mean(messages[0:-byzantinesize], dim=0)
    meliciousMessage = -10*mu
    messages[-byzantinesize:].copy_(meliciousMessage)
    
def zeroGradient(messages, byzantinesize):
    s = torch.sum(messages[0:-byzantinesize], dim=0)
    messages[-byzantinesize:].copy_(-s / byzantinesize)

# 训练函数

In [None]:
def train(model, loss_func, optimizer, trainloader, device, weight_decay):
    """
    train model using loss_fn and optimizer in an epoch.
    model: CNN networks
    train_loader: a Dataloader object with training data
    loss_func: loss function
    device: train on cpu or gpu device
    """
    model.train()
    
    trainAccuracy = 0
    trainLoss = 0
    total = 0
    
    for i, (*material, targets) in enumerate(trainloader):
        if isinstance(material, torch.Tensor):
            material = material.to(device)
        else:
            material = [m.to(device) for m in material]
        
        targets = targets.to(device)

        # forward
        outputs = model(*material)
        
        loss = loss_func(outputs, targets)
        trainLoss += loss.item()

        # backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # AdamW - https://zhuanlan.zhihu.com/p/38945390
        for group in optimizer.param_groups:
            for param in group['params']:
                param.data = param.data.add(-weight_decay * group['lr'], param.data)

        # return the maximum value of each row of the input tensor in the 
        # given dimension dim, the second return vale is the index location
        # of each maxium value found(argmax)
        _, predicted = torch.max(outputs.data, dim=1)
        trainAccuracy += (predicted == targets).sum().item()
        
        total += len(targets)
    trainAccuracy /= total
    trainLoss /= total
    return trainLoss, trainAccuracy

In [None]:
def validate(model, loss_func, validateloader, device):
    # evaluate the model
    model.eval()
    # context-manager that disabled gradient computation
    with torch.no_grad():
        # =============================================================
        valAccuracy = 0
        valLoss = 0
        total = 0
        
        for i, (*material, targets) in enumerate(trainloader):
            if isinstance(material, torch.Tensor):
                material = material.to(device)
            else:
                material = [m.to(device) for m in material]

            targets = targets.to(device)
            
            outputs = model(*material)
            
            loss = loss_func(outputs, targets)
            valLoss += loss.item()
            
            # return the maximum value of each row of the input tensor in the 
            # given dimension dim, the second return vale is the index location
            # of each maxium value found(argmax)
            _, predicted = torch.max(outputs.data, dim=1)
            valAccuracy += (predicted == targets).sum().item()
            
            total += len(targets)
        valAccuracy /= total
        valLoss /= total
    return valLoss, valAccuracy

In [None]:
def test(model, testloader, classname=None, name='default'):
    # evaluate the model
    model.eval()
    # context-manager that disabled gradient computation
    with torch.no_grad():
        result = []
        test_cnt = 0
        for i, (*material, targets) in enumerate(testloader):
            if isinstance(material, torch.Tensor):
                material = material.to(device)
            else:
                material = [m.to(device) for m in material]

            targets = targets.to(device)

            outputs = model(*material)

            _, predicted = torch.max(outputs.data, dim=1)

            result.extend(predicted)
            test_cnt += len(targets)

    if classname != None:
        result = [classname[i] for i in result]

    log('共预测{}个数据'.format(test_cnt))
    df_predict = pd.DataFrame({'id': list(range(1, len(result)+1)), 'polarity': result})
    df_predict.to_csv('{}.csv'.format(name), index=False)
    log('预测完成')
    

In [None]:
def showCurve(list_trainLoss, list_trainAccuracy, list_valLoss, list_valAccuracy):
    xAxis = list(range(len(list_trainLoss)))
    fig, axs = plt.subplots(1, 2)

    axs[0].plot(xAxis, list_trainLoss, label='train')
    axs[0].plot(xAxis, list_valLoss, label='validation')
    axs[0].set_title('Loss')

    axs[1].plot(xAxis, list_trainAccuracy, label='train')
    axs[1].plot(xAxis, list_valAccuracy, label='validation')
    axs[1].set_title('Accuracy')

    for ax in axs:
        ax.axis()
        ax.set_xlabel('epoch')
        ax.set_ylabel('{}'.format(ax.get_title()))
        ax.legend()
    fig.set_size_inches((8, 4))
    plt.subplots_adjust(wspace=0.3)
    plt.show()

# 运行函数

In [None]:
def run(optimizer, aggregate, attack, config, device='cpu'):
    # 初始化参数
    _config = config.copy()
    _config['aggregate'] = aggregate
    _config['attack'] = attack
    if attack == None:
        _config['byzantineSize'] = 0
        
    model = modelFactory(SEED=_config['SEED'])
    model = model.to(device)

    # 记录参数
    attackName = 'baseline' if attack == None else attack.__name__
    # e.g. Resnet50_SARAH(5)_baseline_mean
    title = '{}_{}({})_{}_{}'.format(
        model.__class__.__name__, 
        optimizer.__name__, 
        _config['batchSize'],
        attackName, 
        aggregate.__name__
    )
    
    # 打印运行信息
    print('[提交任务] ' + title)
    print('[运行信息]')
    print('[网络属性]   name={} parameters number={}'.format(model.__class__.__name__, getPara(model)))
    print('[优化方法]   name={} aggregation={} attack={}'.format(optimizer.__name__, aggregate.__name__, attackName))
    print('[数据集属性] name={} trainSize={} validationSize={}'.format(dataSetConfig['name'], len(train_dataset), len(validate_dataset)))
    print('[优化器设置] gamma={} weight_decay={} batchSize={}'.format(_config['gamma'], _config['weight_decay'], _config['batchSize']))
    print('[节点个数]   honestSize={}, byzantineSize={}'.format(_config['honestSize'], _config['byzantineSize']))
    print('[运行次数]   rounds={}, displayInterval={}'.format(_config['rounds'], _config['displayInterval']))
    print('[torch设置]  device={}, SEED={}, fixSeed={}'.format(device, _config['SEED'], _config['fixSeed']))
    print('-------------------------------------------')
    
    # 开始运行
    log('优化开始')
    res = optimizer(model, device=device, **_config)
    [*model, trainLossPath, trainAccPath, valLossPath, valAccPath, variencePath] = res

    record = {
        **dataSetConfig,
        **{key:(_config[key].__name__ if hasattr(_config[key], '__call__') else _config[key]) for key in _config},
        'trainLossPath': trainLossPath, 
        'trainAccPath': trainAccPath, 
        'valLossPath': valLossPath, 
        'valAccPath': valAccPath, 
        'variencePath': variencePath,
    }

    with open(CACHE_DIR + title, 'wb') as f:
        pickle.dump(record, f)
    
    _, axis = plt.subplots(1, 2)
    axis[0].plot(list(range(len(trainLossPath))), trainLossPath, label='train loss')
    axis[0].plot(list(range(len(valLossPath))), valLossPath, label='validation loss')
    axis[1].plot(list(range(len(trainAccPath))), trainAccPath, label='train accuracy')
    axis[1].plot(list(range(len(valAccPath))), valAccPath, label='validation accuracy')
    for ax in axis:
        ax.legend()
    plt.show()

# 测试

## 中心式SGD调参

In [None]:
_config = SGDConfig.copy()
_config['gamma'] = 5e-1
_config['rounds'] = 50
_config['batchSize'] = 20
run(optimizer = CentralSGD, aggregate = mean, attack = None, config = _config, device=device)

## 中心式SARAH调参

In [None]:
_config = SARAHConfig.copy()
_config['batchSize'] = 20
_config['gamma'] = 1e-4
_config['displayInterval'] = 100000
_config['rounds'] = 30
run(optimizer = CentralSARAH, aggregate = mean, attack = None, config = _config, device=device)

## SGD

### SGD - mean

In [None]:
run(optimizer = SGD, aggregate = mean, attack = None, config = SGDConfig)

white

In [None]:
run(optimizer = SGD, aggregate = mean, attack = white, config = SGDConfig)

max

In [None]:
run(optimizer = SGD, aggregate = mean, attack = maxValue, config = SGDConfig)

zero Gradient

In [None]:
run(optimizer = SGD, aggregate = mean, attack = zeroGradient, config = SGDConfig)

### SGD - geomtric median

In [None]:
run(optimizer = SGD, aggregate = gm, attack = zeroGradient, config = SGDConfig)

white

In [None]:
run(optimizer = SGD, aggregate = gm, attack = white, config = SGDConfig)

max

In [None]:
run(optimizer = SGD, aggregate = gm, attack = maxValue, config = SGDConfig)

zero Gradient

In [None]:
run(optimizer = SGD, aggregate = gm, attack = zeroGradient, config = SGDConfig)

### SGD - Krum

In [None]:
Krum = Krum_(nodeSize=dataSetConfig['honestNodeSize'], byzantineSize=0)
run(optimizer = SGD, aggregate = Krum, attack = None, config = SGDConfig)

white

In [None]:
Krum = Krum_(nodeSize=dataSetConfig['honestNodeSize'], byzantineSize=0)
run(optimizer = SGD, aggregate = Krum, attack = white, config = SGDConfig)

max

In [None]:
Krum = Krum_(nodeSize=dataSetConfig['honestNodeSize'], byzantineSize=0)
run(optimizer = SGD, aggregate = Krum, attack = maxValue, config = SGDConfig)

zero Gradient

In [None]:
Krum = Krum_(nodeSize=dataSetConfig['honestNodeSize'], byzantineSize=0)
run(optimizer = SGD, aggregate = Krum, attack = zeroGradient, config = SGDConfig)

### SGD - Median

In [None]:
run(optimizer = SGD, aggregate = median, attack = None, config = SGDConfig)

white

In [None]:
run(optimizer = SGD, aggregate = median, attack = white, config = SGDConfig)

max

In [None]:
run(optimizer = SGD, aggregate = median, attack = maxValue, config = SGDConfig)

zero Gradient

In [None]:
run(optimizer = SGD, aggregate = median, attack = zeroGradient, config = SGDConfig)

## BatchSGD

### BatchSGD - mean

In [None]:
run(optimizer = BatchSGD, aggregate = mean, attack = None, config = batchConfig)

white

In [None]:
run(optimizer = BatchSGD, aggregate = mean, attack = white, config = batchConfig)

max

In [None]:
run(optimizer = BatchSGD, aggregate = mean, attack = maxValue, config = batchConfig)

zero Gradient

In [None]:
run(optimizer = BatchSGD, aggregate = mean, attack = zeroGradient, config = batchConfig)

### BatchSGD - geomtric median

In [None]:
run(optimizer = BatchSGD, aggregate = gm, attack = None, config = batchConfig)

white

In [None]:
run(optimizer = BatchSGD, aggregate = gm, attack = white, config = batchConfig)

max

In [None]:
run(optimizer = BatchSGD, aggregate = gm, attack = maxValue, config = batchConfig)

zero Gradient

In [None]:
run(optimizer = BatchSGD, aggregate = gm, attack = zeroGradient, config = batchConfig)

## SAGA

### SAGA - mean

In [None]:
run(optimizer = SAGA, aggregate = mean, attack = None, config = SAGAConfig)

white

In [None]:
run(optimizer = SAGA, aggregate = mean, attack = white, config = SAGAConfig)

max

In [None]:
run(optimizer = SAGA, aggregate = mean, attack = maxValue, config = SAGAConfig)

zero Gradient

In [None]:
run(optimizer = SAGA, aggregate = mean, attack = zeroGradient, config = SAGAConfig)

### SAGA - geomtric median

In [None]:
run(optimizer = SAGA, aggregate = gm, attack = None, config = SAGAConfig)

white

In [None]:
run(optimizer = SAGA, aggregate = gm, attack = white, config = SAGAConfig)

max

In [None]:
run(optimizer = SAGA, aggregate = gm, attack = maxValue, config = SAGAConfig)

zero Gradient

In [None]:
run(optimizer = SAGA, aggregate = gm, attack = zeroGradient, config = SAGAConfig)

### SAGA - Krum

In [None]:
Krum = Krum_(nodeSize=dataSetConfig['honestNodeSize'], byzantineSize=0)
run(optimizer = SAGA, aggregate = Krum, attack = None, config = SAGAConfig)

white

In [None]:
Krum = Krum_(nodeSize=dataSetConfig['honestNodeSize'], byzantineSize=0)
run(optimizer = SAGA, aggregate = Krum, attack = white, config = SAGAConfig)

max

In [None]:
Krum = Krum_(nodeSize=dataSetConfig['honestNodeSize'], byzantineSize=0)
run(optimizer = SAGA, aggregate = Krum, attack = maxValue, config = SAGAConfig)

zero Gradient

In [None]:
Krum = Krum_(nodeSize=dataSetConfig['honestNodeSize'], byzantineSize=0)
run(optimizer = SAGA, aggregate = Krum, attack = zeroGradient, config = SAGAConfig)

### SAGA - Median

In [None]:
run(optimizer = SAGA, aggregate = median, attack = None, config = SAGAConfig)

white

In [None]:
run(optimizer = SAGA, aggregate = median, attack = white, config = SAGAConfig)

max

In [None]:
run(optimizer = SAGA, aggregate = median, attack = maxValue, config = SAGAConfig)

zero Gradient

In [None]:
run(optimizer = SAGA, aggregate = median, attack = zeroGradient, config = SAGAConfig)