In [2]:
#args
import argparse
    
parser = argparse.ArgumentParser()
# federated arguments
parser.add_argument('--epochs', type=int, default=50, help="rounds of training")
parser.add_argument('--num_users', type=int, default=50, help="number of users: K")
parser.add_argument('--ratio', type=float, default=1,  help="portion of iid user: ")
parser.add_argument('--frac', type=float, default=0.2, help="the fraction of clients: C")
parser.add_argument('--local_ep', type=int, default=50, help="the number of local epochs: E")
parser.add_argument('--local_bs', type=int, default=128, help="local batch size: B")
parser.add_argument('--bs', type=int, default=128, help="test batch size")
parser.add_argument('--lr', type=float, default=0.01, help="learning rate")
parser.add_argument('--momentum', type=float, default=0.5, help="SGD momentum (default: 0.5)")
parser.add_argument('--split', type=str, default='user', help="train-test split type, user or sample")

# model arguments
parser.add_argument('--model', type=str, default='resnet', help='model name')
parser.add_argument('--kernel_num', type=int, default=9, help='number of each kind of kernel')
parser.add_argument('--kernel_sizes', type=str, default='3,4,5',
                    help='comma-separated kernel size to use for convolution')
parser.add_argument('--norm', type=str, default='batch_norm', help="batch_norm, layer_norm, or None")
parser.add_argument('--num_filters', type=int, default=32, help="number of filters for conv nets")
parser.add_argument('--max_pool', type=str, default='True',
                    help="Whether use max pooling rather than strided convolutions")

# other arguments
parser.add_argument('--dataset', type=str, default='cifar', help="name of dataset,cifar,mnist")
parser.add_argument("--sample", type=int, default=500, help="number of samples for each node")
parser.add_argument('--pattern', type=str,  default='iid', help='iid,noniid,1-9,iid-q')
parser.add_argument('--num_classes', type=int, default=10, help="number of classes")
parser.add_argument('--num_channels', type=int, default=1, help="number of channels of imges")
parser.add_argument('--gpu', type=int, default=0, help="GPU ID, -1 for CPU")
parser.add_argument('--stopping_rounds', type=int, default=10, help='rounds of early stopping')
parser.add_argument('--verbose', action='store_true', help='verbose print')
parser.add_argument('--seed', type=int, default=1, help='random seed (default: 1)')
parser.add_argument('--all_clients', action='store_true', help='aggregation over all clients')
parser.add_argument('--alg',type=str,default='avg',help='avg or fed')
args = parser.parse_args([])

In [3]:
#sample
import numpy as np
from torchvision import datasets, transforms
import torch
# from torch.utils.data.sampler import SubsetRandomSampler
from torch.utils.data import random_split
def Dataset_config(dataset, num_users, pattern):
    if dataset == 'mnist':
        trans_mnist = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
        dataset_train = datasets.MNIST(root='./data/', train=True, download=True,
                                       transform=trans_mnist)
        dataset_test = datasets.MNIST(root='./data/', train=False, download=True,
                                      transform=trans_mnist)
        X_train, y_train = dataset_train.data, dataset_train.targets
        X_train = X_train.data.numpy()
        y_train = y_train.data.numpy()

    elif dataset == 'cifar':
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])


        dataset_train = datasets.CIFAR10('./data/', train=True, download=True, transform=transform_train)
        dataset_test = datasets.CIFAR10('./data/', train=False, download=True, transform=transform_test)
                
        X_train, y_train = dataset_train.data, dataset_train.targets
        y_train = np.array(y_train)
        
    elif dataset =='fashion':
        transformations = transforms.Compose([transforms.ToTensor(),])
        dataset_train = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transformations)
        dataset_test = datasets.FashionMNIST('./data', download=True, train=False, transform=transformations)
        X_train, y_train = dataset_train.data, dataset_train.targets
                     
        
    else:
        exit('Error: unrecognized dataset')

    if pattern == 'iid':
        dict_users = iid(y_train, num_users)

    elif pattern == 'noniid':
        dict_users = noniid(y_train, num_users)
    elif pattern > "1" and pattern <= "9":
        #todo labelnoniid
        # exit('Error: unfinsh')
        dict_users=label(pattern,num_users,y_train)
    elif pattern == "iid-q":
        idxs = np.random.permutation(y_train.shape[0])
        min_size = 0
        while min_size < 10:
            proportions = np.random.dirichlet(np.repeat(0.5, num_users))
            proportions = proportions/proportions.sum()
            min_size = np.min(proportions*len(idxs))
        proportions = (np.cumsum(proportions)*len(idxs)).astype(int)[:-1]
        batch= np.split(idxs,proportions)
        dict_users = {i:batch[i] for i in range(num_users)}
    else:
        exit('Error: unrecognized pattern')
        
    
    # all_idxs = [i for i in range(y_train.shape[0])]
    # l=len(len(dataset_test))
    # alp=0.8
    # all_idxs_test = [i for i in range(l)]
    # train_sampler = SubsetRandomSampler(all_idxs[:int(l*alp)])
    # val_sampler = SubsetRandomSampler(all_idxs[int(l*alp):l])
    # test_sampler = SubsetRandomSampler(all_idxs_test)

    # print(dict_users)
    train_dict = {}
    val_dict={}
    for i in range(num_users):
        t,v = random_split(dict_users[i],lengths=[0.8,0.2],generator=torch.Generator().manual_seed(42))
        train_dict[i] = t
        val_dict[i] = v
    return dict_users, train_dict, val_dict, dataset_train, dataset_test
    # , dataset_test_part


def iid(dataset, num_users):
    """
    iid partition
    """
    print("iid partion")
    idxs = np.random.permutation(dataset.shape[0])
    batch = np.array_split(idxs,num_users)
    dict_users = {i:batch[i] for i in range(num_users)}
    return dict_users


def noniid(dataset, num_users):
    """
    noniid partition
    dirichlet=0.5
    """
    min_size = 0
    min_require_size = 10
    K = 10
    dirichlet = 0.5
    print("noniid partion",dirichlet)
    N = dataset.shape[0]
    #np.random.seed(2020)
    net_dataidx_map = {}

    while min_size < min_require_size:
        idx_batch = [[] for _ in range(num_users)]
        for k in range(K):
            idx_k = np.where(dataset== k)[0]
            np.random.shuffle(idx_k)
            proportions = np.random.dirichlet(np.repeat(dirichlet, num_users))
            
            ## Balance
            proportions = np.array([p * (len(idx_j) < N / num_users) for p, idx_j in zip(proportions, idx_batch)])
            proportions = proportions / proportions.sum()
            proportions = (np.cumsum(proportions) * len(idx_k)).astype(int)[:-1]
            idx_batch = [idx_j + idx.tolist() for idx_j, idx in zip(idx_batch, np.split(idx_k, proportions))]
            min_size = min([len(idx_j) for idx_j in idx_batch])
    for j in range(num_users):
        np.random.shuffle(idx_batch[j])
        net_dataidx_map[j] = idx_batch[j]
    return net_dataidx_map

def label(partition:str,num_users,y_train):
    num = eval(partition)
    K = 10
    if num == 10:
        net_dataidx_map ={i:np.ndarray(0,dtype=np.int64) for i in range(num_users)}
        for i in range(10):
            idx_k = np.where(y_train==i)[0]
            np.random.shuffle(idx_k)
            split = np.array_split(idx_k,num_users)
            for j in range(num_users):
                net_dataidx_map[j]=np.append(net_dataidx_map[j],split[j])
    else:
        times=[0 for i in range(K)]
        contain=[]
        for i in range(num_users):
            current=[i%K]
            times[i%K]+=1
            j=1
            while (j<num):
                ind=np.random.randint(0,K-1)
                if (ind not in current):
                    j=j+1
                    current.append(ind)
                    times[ind]+=1
            contain.append(current)
        net_dataidx_map ={i:np.ndarray(0,dtype=np.int64) for i in range(num_users)}
        for i in range(K):
            idx_k = np.where(y_train==i)[0]
            np.random.shuffle(idx_k)
            split = np.array_split(idx_k,times[i])
            ids=0
            for j in range(num_users):
                if i in contain[j]:
                    net_dataidx_map[j]=np.append(net_dataidx_map[j],split[ids])
                    ids+=1
    return net_dataidx_map

In [4]:
#net
import torch
from torch import nn
import torch.nn.functional as F


class MLP(nn.Module):
    def __init__(self, dim_in, dim_hidden, dim_out):
        super(MLP, self).__init__()
        self.layer_input = nn.Linear(dim_in, dim_hidden)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout()
        self.layer_hidden = nn.Linear(dim_hidden, dim_out)

    def forward(self, x):
        x = x.view(-1, x.shape[1] * x.shape[-2] * x.shape[-1])
        x = self.layer_input(x)
        x = self.dropout(x)
        x = self.relu(x)
        x = self.layer_hidden(x)
        return x


class CNN_Net(nn.Module):

    def __init__(self):
        super(CNN_Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 64, 3, 1)
        self.conv2 = nn.Conv2d(64, 16, 7, 1)
        self.fc1 = nn.Linear(4 * 4 * 16, 200)
        self.fc2 = nn.Linear(200, 10)

    def forward(self, x):
        x = x.view(-1, 1, 32, 32)
        x = torch.tanh(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = torch.tanh(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4 * 4 * 16)
        x = torch.tanh(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)


class CNNMnist(nn.Module):
    def __init__(self, args):
        super(CNNMnist, self).__init__()
        self.conv1 = nn.Conv2d(args.num_channels, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, args.num_classes)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, x.shape[1] * x.shape[2] * x.shape[3])
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return x


class CNNCifar(nn.Module):
    def __init__(self, args):
        super(CNNCifar, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, args.num_classes)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

class CNNfashion(nn.Module): # extend nn.Module class of nn
    def __init__(self):
        super().__init__() # super class constructor
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=(5,5))
        self.batchN1 = nn.BatchNorm2d(num_features=6)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=12, kernel_size=(5,5))
        self.fc1 = nn.Linear(in_features=12*4*4, out_features=120)
        self.batchN2 = nn.BatchNorm1d(num_features=120)
        self.fc2 = nn.Linear(in_features=120, out_features=60)
        self.out = nn.Linear(in_features=60, out_features=10)
        
        
        
    def forward(self, x): # implements the forward method (flow of tensors)
        
        # hidden conv layer 
        x = self.conv1(x)
        x = F.max_pool2d(input=x, kernel_size=2, stride=2)
        x = F.relu(x)
        x = self.batchN1(x)
        
        # hidden conv layer
        x = self.conv2(x)
        x = F.max_pool2d(input=x, kernel_size=2, stride=2)
        x = F.relu(x)
        
        # flatten
        x = x.reshape(-1, 12*4*4)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.batchN2(x)
        x = self.fc2(x)
        x = F.relu(x)
        
        # output
        x = self.out(x)
        
        return x   

import torch.nn as nn

_cfg = {
    'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}


def _make_layers(cfg):
    layers = []
    in_channels = 3
    for layer_cfg in cfg:
        if layer_cfg == 'M':
            layers.append(nn.MaxPool2d(kernel_size=2, stride=2))
        else:
            layers.append(nn.Conv2d(in_channels=in_channels,
                                    out_channels=layer_cfg,
                                    kernel_size=3,
                                    stride=1,
                                    padding=1,
                                    bias=True))
            layers.append(nn.BatchNorm2d(num_features=layer_cfg))
            layers.append(nn.ReLU(inplace=True))
            in_channels = layer_cfg
    return nn.Sequential(*layers)


class _VGG(nn.Module):
    """
    VGG module for 3x32x32 input, 10 classes
    """

    def __init__(self, name):
        super(_VGG, self).__init__()
        cfg = _cfg[name]
        self.layers = _make_layers(cfg)
        flatten_features = 512
        self.fc1 = nn.Linear(flatten_features, 10)
        # self.fc2 = nn.Linear(4096, 4096)
        # self.fc3 = nn.Linear(4096, 10)

    def forward(self, x):
        y = self.layers(x)
        y = y.view(y.size(0), -1)
        y = self.fc1(y)
        # y = self.fc2(y)
        # y = self.fc3(y)
        return y


def VGG11():
    return _VGG('VGG11')


def VGG13():
    return _VGG('VGG13')


def VGG16():
    return _VGG('VGG16')


def VGG19():
    return _VGG('VGG19')

    def __init__(self):
        super(VGG,self).__init__()
        self.conv1 = nn.Conv2d(3,64,3,padding=1)
        self.conv2 = nn.Conv2d(64,64,3,padding=1)
        self.pool1 = nn.MaxPool2d(2, 2)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu1 = nn.ReLU()

        self.conv3 = nn.Conv2d(64,128,3,padding=1)
        self.conv4 = nn.Conv2d(128, 128, 3,padding=1)
        self.pool2 = nn.MaxPool2d(2, 2, padding=1)
        self.bn2 = nn.BatchNorm2d(128)
        self.relu2 = nn.ReLU()

        self.conv5 = nn.Conv2d(128,128, 3,padding=1)
        self.conv6 = nn.Conv2d(128, 128, 3,padding=1)
        self.conv7 = nn.Conv2d(128, 128, 1,padding=1)
        self.pool3 = nn.MaxPool2d(2, 2, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.relu3 = nn.ReLU()

        self.conv8 = nn.Conv2d(128, 256, 3,padding=1)
        self.conv9 = nn.Conv2d(256, 256, 3, padding=1)
        self.conv10 = nn.Conv2d(256, 256, 1, padding=1)
        self.pool4 = nn.MaxPool2d(2, 2, padding=1)
        self.bn4 = nn.BatchNorm2d(256)
        self.relu4 = nn.ReLU()

        self.conv11 = nn.Conv2d(256, 512, 3, padding=1)
        self.conv12 = nn.Conv2d(512, 512, 3, padding=1)
        self.conv13 = nn.Conv2d(512, 512, 1, padding=1)
        self.pool5 = nn.MaxPool2d(2, 2, padding=1)
        self.bn5 = nn.BatchNorm2d(512)
        self.relu5 = nn.ReLU()

        self.fc14 = nn.Linear(512*4*4,1024)
        self.drop1 = nn.Dropout2d()
        self.fc15 = nn.Linear(1024,1024)
        self.drop2 = nn.Dropout2d()
        self.fc16 = nn.Linear(1024,10)


    def forward(self,x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.pool1(x)
        x = self.bn1(x)
        x = self.relu1(x)


        x = self.conv3(x)
        x = self.conv4(x)
        x = self.pool2(x)
        x = self.bn2(x)
        x = self.relu2(x)

        x = self.conv5(x)
        x = self.conv6(x)
        x = self.conv7(x)
        x = self.pool3(x)
        x = self.bn3(x)
        x = self.relu3(x)

        x = self.conv8(x)
        x = self.conv9(x)
        x = self.conv10(x)
        x = self.pool4(x)
        x = self.bn4(x)
        x = self.relu4(x)

        x = self.conv11(x)
        x = self.conv12(x)
        x = self.conv13(x)
        x = self.pool5(x)
        x = self.bn5(x)
        x = self.relu5(x)
        # print(" x shape ",x.size())
        x = x.view(-1,512*4*4)
        x = F.relu(self.fc14(x))
        x = self.drop1(x)
        x = F.relu(self.fc15(x))
        x = self.drop2(x)
        x = self.fc16(x)

        return x

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(
            in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion *
                               planes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion*planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


# class ResNet(nn.Module):
#     def __init__(self, block, num_blocks, num_classes=10):
#         super(ResNet, self).__init__()
#         self.in_planes = 64

#         self.conv1 = nn.Conv2d(3, 64, kernel_size=3,
#                                stride=1, padding=1, bias=False)
#         self.bn1 = nn.BatchNorm2d(64)
#         self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
#         self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
#         self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
#         self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
#         self.linear = nn.Linear(512*block.expansion, num_classes)

#     def _make_layer(self, block, planes, num_blocks, stride):
#         strides = [stride] + [1]*(num_blocks-1)
#         layers = []
#         for stride in strides:
#             layers.append(block(self.in_planes, planes, stride))
#             self.in_planes = planes * block.expansion
#         return nn.Sequential(*layers)

#     def forward(self, x):
#         out = F.relu(self.bn1(self.conv1(x)))
#         out = self.layer1(out)
#         out = self.layer2(out)
#         out = self.layer3(out)
#         out = self.layer4(out)
#         out = F.avg_pool2d(out, 4)
#         out = out.view(out.size(0), -1)
#         out = self.linear(out)
#         return out


# def ResNet18():
#     return ResNet(BasicBlock, [2, 2, 2, 2])


# def ResNet34():
#     return ResNet(BasicBlock, [3, 4, 6, 3])


# def ResNet50():
#     return ResNet(Bottleneck, [3, 4, 6, 3])


# def ResNet101():
#     return ResNet(Bottleneck, [3, 4, 23, 3])


# def ResNet152():
#     return ResNet(Bottleneck, [3, 8, 36, 3])

In [5]:
#test
import torch.nn.functional as F
from torch.utils.data import DataLoader
from collections import Counter


def test_img(net_g, datatest, args,sampler=None):
    net_g.eval()
    # testing
    test_loss = 0
    correct = 0
    if sampler is not None :
        data_loader = DataLoader(datatest, batch_size=args.bs,sampler=sampler)
    else:
        data_loader = DataLoader(datatest, batch_size=args.bs)
    for idx, (data, target) in enumerate(data_loader):
        if args.gpu != -1:
            data, target = data.cuda(), target.cuda()
        log_probs = net_g(data)
        # sum up batch loss
        test_loss += F.cross_entropy(log_probs, target, reduction='sum').item()
        # get the index of the max log-probability
        y_pred = log_probs.data.max(1, keepdim=True)[1]
        correct += y_pred.eq(target.data.view_as(y_pred)).long().cpu().sum()

    test_loss /= len(data_loader.dataset)
    accuracy = correct.item() / 100

    return accuracy, test_loss

def test_dis(net,dataset,args):
    net.eval()
    data_loader = DataLoader(dataset,batch_size=args.bs)
    y=[]
    for data,target in data_loader:
        data, target = data.cuda(), target.cuda()
        predits = net(data)
        y.append(predits.data.max(1)[1].cpu().numpy())
    
        
    return Counter(np.concatenate(y))



In [6]:
#update
import copy
import math

import torch
from torch import nn, autograd
from torch.utils.data import DataLoader, Dataset
import numpy as np
import torch.nn.functional as F
from sklearn import metrics


# class DatasetSplit(Dataset):
#     def __init__(self, dataset, idxs):
#         self.dataset = dataset
#         self.idxs = list(idxs)

#     def __len__(self):
#         return len(self.idxs)

#     def __getitem__(self, item):
#         image, label = self.dataset[self.idxs[item]]
#         return image, label


class LocalUpdate(object):
    def __init__(self, args, dataset=None, train=None):
        self.args = args
        self.loss_func = nn.CrossEntropyLoss()
        self.selected_clients = []
        self.len_train = len(train)
        self.len_val = int(self.len_train*0.2)
        self.ldr_train = DataLoader(dataset,batch_size=self.args.local_bs, drop_last=True,sampler=torch.utils.data.SubsetRandomSampler(train))
        self.ldr_val = DataLoader(dataset,batch_size=self.args.local_bs, drop_last=True,sampler=torch.utils.data.SubsetRandomSampler(train[:self.len_val]))

    def train(self, net):
        net.train()
        # train and update
        optimizer = torch.optim.SGD(net.parameters(), lr=self.args.lr, momentum=self.args.momentum)

        epoch_loss = []
        for e in range(self.args.local_ep):
            batch_loss = []
            for batch_idx, (images, labels) in enumerate(self.ldr_train):
                images, labels = images.to(self.args.device), labels.to(self.args.device)
                net.zero_grad()  
                log_probs = net(images)  
                loss = self.loss_func(log_probs, labels)  
                loss.backward() 
                optimizer.step()  
                if self.args.verbose and batch_idx % 10 == 0:
                    print('Update Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                        e, batch_idx * len(images), self.len_train,
                              100. * batch_idx / self.len_train, loss.item()))
                batch_loss.append(loss.item())
            epoch_loss.append(sum(batch_loss) / len(batch_loss))
        return net, net.state_dict(), sum(epoch_loss) / len(epoch_loss) 
    
    def val(self,net, args):
        net.eval()
        # testing
        test_loss = 0
        correct = 0
        # data_loader = DataLoader(datatest, batch_size=args.bs)
        for idx, (data, target) in enumerate(self.ldr_val):
            if args.gpu != -1:
                data, target = data.cuda(), target.cuda()
            log_probs = net(data)
            
            # sum up batch loss
            test_loss += F.cross_entropy(log_probs, target, reduction='sum').item()
            # get the index of the max log-probability
            y_pred = log_probs.data.max(1, keepdim=True)[1]
            correct += y_pred.eq(target.data.view_as(y_pred)).long().cpu().sum()
        test_loss /= self.len_val
        accuracy = correct / self.len_val

        return accuracy, test_loss


In [7]:
#utility
###flatten function
def flatten(idx_dict):
    return torch.concat([torch.flatten(idx_dict[key]) for key in idx_dict])

def unflatten(flattened, normal_shape):
    w_local = {}
    for k in normal_shape:
        n = len(normal_shape[k].view(-1))
        w_local[k] = (flattened[:n].reshape(normal_shape[k].shape)).clone().detach()
        flattened=flattened[n:]
    return w_local




In [8]:
#defences
def multi_krum_defence(all_updates, n_attackers, multi_k=False):
    candidates = []
    candidate_indices = []
    remaining_updates = all_updates
    all_indices = np.arange(len(all_updates))

    while len(remaining_updates) > 2 * n_attackers + 2:
        torch.cuda.empty_cache()
        distances = []
        for update in remaining_updates:
            distance = []
            for update_ in remaining_updates:
                distance.append(torch.norm((update - update_)) ** 2)
            distance = torch.Tensor(distance).float()
            distances = distance[None, :] if not len(distances) else torch.cat((distances, distance[None, :]), 0)

        distances = torch.sort(distances, dim=1)[0]
        scores = torch.sum(distances[:, :len(remaining_updates) - 2 - n_attackers], dim=1)
        indices = torch.argsort(scores)[:len(remaining_updates) - 2 - n_attackers]

        candidate_indices.append(all_indices[indices[0].cpu().numpy()])
        all_indices = np.delete(all_indices, indices[0].cpu().numpy())
        candidates = remaining_updates[indices[0]][None, :] if not len(candidates) else torch.cat((candidates, remaining_updates[indices[0]][None, :]), 0)
        remaining_updates = torch.cat((remaining_updates[:indices[0]], remaining_updates[indices[0] + 1:]), 0)
        if not multi_k:
            break
    # print(len(remaining_updates))

    aggregate = torch.mean(candidates, dim=0)

    # return aggregate, np.array(candidate_indices)
    return aggregate

def tr_mean(all_updates, n_attackers):
    sorted_updates = torch.sort(all_updates, 0)[0]
    out = torch.mean(sorted_updates[n_attackers:-n_attackers], 0) if n_attackers else torch.mean(sorted_updates,0)
    return out

def bulyan(all_updates, n_attackers):
    nusers = all_updates.shape[0]
    bulyan_cluster = []
    candidate_indices = []
    remaining_updates = all_updates
    all_indices = np.arange(len(all_updates))

    while len(bulyan_cluster) < (nusers - 2 * n_attackers):
        distances = []
        for update in remaining_updates:
            distance = torch.norm((remaining_updates - update), dim=1) ** 2
            distances = distance[None, :] if not len(distances) else torch.cat((distances, distance[None, :]), 0)

        distances = torch.sort(distances, dim=1)[0]

        scores = torch.sum(distances[:, :len(remaining_updates) - 2 - n_attackers], dim=1)
        indices = torch.argsort(scores)[:len(remaining_updates) - 2 - n_attackers]

        candidate_indices.append(all_indices[indices[0].cpu().numpy()])
        all_indices = np.delete(all_indices, indices[0].cpu().numpy())
        bulyan_cluster = remaining_updates[indices[0]][None, :] if not len(bulyan_cluster) else torch.cat((bulyan_cluster, remaining_updates[indices[0]][None, :]), 0)
        remaining_updates = torch.cat((remaining_updates[:indices[0]], remaining_updates[indices[0] + 1:]), 0)

    n, d = bulyan_cluster.shape
    param_med = torch.median(bulyan_cluster, dim=0)[0]
    sort_idx = torch.argsort(torch.abs(bulyan_cluster - param_med), dim=0)
    sorted_params = bulyan_cluster[sort_idx, torch.arange(d)[None, :]]

    # return torch.mean(sorted_params[:n - 2 * n_attackers], dim=0), np.array(candidate_indices)
    return torch.mean(sorted_params[:n - 2 * n_attackers], dim=0)

def dnc(updates,n_attackers):
    d = len(updates[1])
    num_iters = 1
    sub_dim= 1000
    fliter_frac=1.0
    benign_ids = []
    for i in range(num_iters):
        indices = torch.randperm(d)[: sub_dim]
        sub_updates = updates[:, indices]
        mu = sub_updates.mean(dim=0)
        centered_update = sub_updates - mu
        v = torch.linalg.svd(centered_update, full_matrices=False)[2][0, :]
        s = np.array(
            [(torch.dot(update - mu, v) ** 2).item() for update in sub_updates]
        )

        good = s.argsort()[
            : len(updates) - int(fliter_frac * n_attackers)
        ]
        benign_ids.extend(good)
        print(benign_ids)
    
    #
    benign_ids = list(set(benign_ids))
    benign_updates = updates[benign_ids, :].mean(dim=0)
    return benign_ids,benign_updates

#
def TDFL_cos(uw,t):
    g,_=crh(uw,None)
    cs=[]
    for idx in range(len(uw)):
        cs.append(torch.cosine_similarity(uw[idx],g,dim=0))
    cs = torch.stack(cs)
    print(cs)
    return np.where(cs>=t)
    
###our 
##coming soon



In [9]:
##attack 
###fang attack
def attack_median_and_trimmedmean(benign_update,m):
    
    # benign_update = 
    agg_grads = torch.mean(benign_update, 0)
    deviation = torch.sign(agg_grads)
    device = benign_update.device
    b = 2
    max_vector = torch.max(benign_update, 0)[0]
    min_vector = torch.min(benign_update, 0)[0]

    max_ = (max_vector > 0).type(torch.FloatTensor).to(device)
    min_ = (min_vector < 0).type(torch.FloatTensor).to(device)

    max_[max_ == 1] = b
    max_[max_ == 0] = 1 / b
    min_[min_ == 1] = b
    min_[min_ == 0] = 1 / b

    max_range = torch.cat(
        (max_vector[:, None], (max_vector * max_)[:, None]), dim=1
    )
    min_range = torch.cat(
        ((min_vector * min_)[:, None], min_vector[:, None]), dim=1
    )

    rand = (
        torch.from_numpy(
            np.random.uniform(0, 1, [len(deviation), m])
        )
        .type(torch.FloatTensor)
        .to(benign_update.device)
    )

    max_rand = (
        torch.stack([max_range[:, 0]] * rand.shape[1]).T
        + rand * torch.stack([max_range[:, 1] - max_range[:, 0]] * rand.shape[1]).T
    )
    min_rand = (
        torch.stack([min_range[:, 0]] * rand.shape[1]).T
        + rand * torch.stack([min_range[:, 1] - min_range[:, 0]] * rand.shape[1]).T
    )

    mal_vec = (
        torch.stack(
            [(deviation < 0).type(torch.FloatTensor)] * max_rand.shape[1]
        ).T.to(device)
        * max_rand
        + torch.stack(
            [(deviation > 0).type(torch.FloatTensor)] * min_rand.shape[1]
        ).T.to(device)
        * min_rand
    ).T
    return mal_vec
###
def multi_krum(all_updates, n_attackers, multi_k=False):

    candidates = []
    candidate_indices = []
    remaining_updates = all_updates
    all_indices = np.arange(len(all_updates))

    while len(remaining_updates) > 2 * n_attackers + 2:
        torch.cuda.empty_cache()
        distances = []
        for update in remaining_updates:
            distance = []
            for update_ in remaining_updates:
                distance.append(torch.norm((update - update_)) ** 2)
            distance = torch.Tensor(distance).float()
            distances = distance[None, :] if not len(distances) else torch.cat((distances, distance[None, :]), 0)

        distances = torch.sort(distances, dim=1)[0]
        scores = torch.sum(distances[:, :len(remaining_updates) - 2 - n_attackers], dim=1)
        indices = torch.argsort(scores)[:len(remaining_updates) - 2 - n_attackers]

        candidate_indices.append(all_indices[indices[0].cpu().numpy()])
        all_indices = np.delete(all_indices, indices[0].cpu().numpy())
        candidates = remaining_updates[indices[0]][None, :] if not len(candidates) else torch.cat((candidates, remaining_updates[indices[0]][None, :]), 0)
        remaining_updates = torch.cat((remaining_updates[:indices[0]], remaining_updates[indices[0] + 1:]), 0)
        if not multi_k:
            break
    # print(len(remaining_updates))

    aggregate = torch.mean(candidates, dim=0)

    return aggregate, np.array(candidate_indices)

def compute_lambda_fang(all_updates, model_re, n_attackers):

    distances = []
    n_benign, d = all_updates.shape
    for update in all_updates:
        distance = torch.norm((all_updates - update), dim=1)
        distances = distance[None, :] if not len(distances) else torch.cat((distances, distance[None, :]), 0)

    distances[distances == 0] = 10000
    distances = torch.sort(distances, dim=1)[0]
    scores = torch.sum(distances[:, :n_benign - 2 - n_attackers], dim=1)
    min_score = torch.min(scores)
    term_1 = min_score / ((n_benign - n_attackers - 1) * torch.sqrt(torch.Tensor([d]))[0])
    max_wre_dist = torch.max(torch.norm((all_updates - model_re), dim=1)) / (torch.sqrt(torch.Tensor([d]))[0])

    return (term_1 + max_wre_dist)

def get_malicious_updates_fang(all_updates, n_attackers):
    model_re=torch.mean(all_updates,0)
    deviation = torch.sign(model_re)
    lamda = compute_lambda_fang(all_updates, model_re, n_attackers)
    threshold = 1e-5

    mal_updates = []    
    while lamda > threshold:
        mal_update = (- lamda * deviation)

        mal_updates = torch.stack([mal_update] * n_attackers)
        mal_updates = torch.cat((mal_updates, all_updates), 0)

        _, krum_candidate = multi_krum(mal_updates, n_attackers)
        
        if krum_candidate < n_attackers:
            return mal_updates
        
        lamda *= 0.5

    if not len(mal_updates):
        print(lamda, threshold)
        mal_update = (model_re - lamda * deviation)
        
        mal_updates = torch.stack([mal_update] * n_attackers)
        mal_updates = torch.cat((mal_updates, all_updates), 0)

    return mal_updates[:n_attackers]
###
from scipy.stats import norm
def LIE(para_list,m):
    p=copy.deepcopy(para_list[:m])
    n=para_list.shape[0]
    
    mu = torch.mean(para_list,0)
    sigma = torch.std(para_list,0)
    for i in range(m):
        p[i]=mu-norm.ppf((n/2-1)/(n-m))*sigma
    
    return p 
###    
def scaling_attack(para_list,m):
    p=copy.deepcopy(para_list[:m])
    factor= para_list.shape[1]
    for i in range(m):
        p[i]=para_list[i]*factor
    
    return p
###
def mean_attack(para_list,m):

    return torch.stack([-para_list[i] for i in range(m)])

def full_mean_attack(para_list,m):

    p=copy.deepcopy(para_list[:m])

    if m == para_list.shape[1]:
        return mean_attack(para_list,m)
    
    all_sum=torch.sum(para_list,0)
    m_para_sum= torch.sum(para_list[:m],0)
    for i in range(m):
        p[i]=((-all_sum- m_para_sum)/m)
    return p
#### attack for tailored
def AGR_tailored_attack_on_krum(all_updates, n_attackers, dev_type='unit_vec'):
    model_re = torch.mean(all_updates,0)
    if dev_type == 'unit_vec':
        deviation = model_re / torch.norm(model_re)
    elif dev_type == 'sign':
        deviation = torch.sign(model_re)
    elif dev_type == 'std':
        deviation = torch.std(all_updates, 0)

    lamda = torch.Tensor([3.0])

    threshold_diff = 1e-5
    lamda_fail = lamda
    lamda_succ = 0

    while torch.abs(lamda_succ - lamda) > threshold_diff:
        mal_update = (model_re - lamda * deviation)
        mal_updates = torch.stack([mal_update] * n_attackers)
        mal_updates = torch.cat((mal_updates, all_updates), 0)

        agg_grads, krum_candidate = multi_krum(mal_updates, n_attackers, multi_k=True)
        if np.sum(krum_candidate < n_attackers) == n_attackers:
            # print('successful lamda is ', lamda)
            lamda_succ = lamda
            lamda = lamda + lamda_fail / 2
        else:
            lamda = lamda - lamda_fail / 2

        lamda_fail = lamda_fail / 2

    mal_update = (model_re - lamda_succ * deviation)
    mal_updates = torch.stack([mal_update] * n_attackers)
    return mal_updates

def AGR_tailored_attack_on_median(all_updates, n_attackers, dev_type='unit_vec'):
    model_re = torch.mean(all_updates,0)

    if dev_type == 'unit_vec':
        deviation = model_re / torch.norm(model_re)  # unit vector, dir opp to good dir
    elif dev_type == 'sign':
        deviation = torch.sign(model_re)
    elif dev_type == 'std':
        deviation = torch.std(all_updates, 0)

    lamda = torch.Tensor([10.0])#compute_lambda_our(all_updates, model_re, n_attackers)

    threshold_diff = 1e-5
    prev_loss = -1
    lamda_fail = lamda
    lamda_succ = 0
    iters = 0 
    while torch.abs(lamda_succ - lamda) > threshold_diff:
        mal_update = (model_re - lamda * deviation)
        mal_updates = torch.stack([mal_update] * n_attackers)
        mal_updates = torch.cat((mal_updates, all_updates), 0)

        agg_grads = torch.median(mal_updates, 0)[0]
        
        loss = torch.norm(agg_grads - model_re)
        
        if prev_loss < loss:
            lamda_succ = lamda
            lamda = lamda + lamda_fail / 2
        else:
            lamda = lamda - lamda_fail / 2

        lamda_fail = lamda_fail / 2
        prev_loss = loss
        
    mal_update = (model_re - lamda_succ * deviation)
    mal_updates = torch.stack([mal_update] * n_attackers)
    return mal_updates

def AGR_tailored_attack_on_trmean(all_updates, n_attackers, dev_type='unit_vec'):
    model_re = torch.mean(all_updates,0)
    if dev_type == 'unit_vec':
        deviation = model_re / torch.norm(model_re)  # unit vector, dir opp to good dir
    elif dev_type == 'sign':
        deviation = torch.sign(model_re)
    elif dev_type == 'std':
        deviation = torch.std(all_updates, 0)

    lamda = torch.Tensor([10.0]) #compute_lambda_our(all_updates, model_re, n_attackers)
    # print(lamda)
    threshold_diff = 1e-5
    prev_loss = -1
    lamda_fail = lamda
    lamda_succ = 0
    iters = 0 
    while torch.abs(lamda_succ - lamda) > threshold_diff:
        mal_update = (model_re - lamda * deviation)
        mal_updates = torch.stack([mal_update] * n_attackers)
        mal_updates = torch.cat((mal_updates, all_updates), 0)

        agg_grads = tr_mean(mal_updates, n_attackers)
        
        loss = torch.norm(agg_grads - model_re)
        
        if prev_loss < loss:
            # print('successful lamda is ', lamda)
            lamda_succ = lamda
            lamda = lamda + lamda_fail / 2
        else:
            lamda = lamda - lamda_fail / 2

        lamda_fail = lamda_fail / 2
        prev_loss = loss
        
    mal_update = (model_re - lamda_succ * deviation)
    mal_updates = torch.stack([mal_update] * n_attackers)

    return mal_updates
###attack for AGR unknow
def min_max(all_updates,m,dev_type='unit_vec'):
    model_re = torch.mean(all_updates,0)
    if dev_type == 'unit_vec':
        deviation = model_re / torch.norm(model_re)  # unit vector, dir opp to good dir
    elif dev_type == 'sign':
        deviation = torch.sign(model_re)
    elif dev_type == 'std':
        deviation = torch.std(all_updates, 0)

    lamda = torch.Tensor([10]).float()
    # print(lamda)
    threshold_diff = 1e-5
    lamda_fail = lamda
    lamda_succ = 0
    
    distances = []
    for update in all_updates:
        distance = torch.norm((all_updates - update), dim=1) ** 2
        distances = distance[None, :] if not len(distances) else torch.cat((distances, distance[None, :]), 0)
    
    max_distance = torch.max(distances)
    del distances

    while torch.abs(lamda_succ - lamda) > threshold_diff:
        mal_update = (model_re - lamda * deviation)
        distance = torch.norm((all_updates - mal_update), dim=1) ** 2
        max_d = torch.max(distance)
        
        if max_d <= max_distance:
            # print('successful lamda is ', lamda)
            lamda_succ = lamda
            lamda = lamda + lamda_fail / 2
        else:
            lamda = lamda - lamda_fail / 2

        lamda_fail = lamda_fail / 2

    mal_update = (model_re - lamda_succ * deviation)
    mal_updates = torch.stack([mal_update] * m)
    return mal_updates

def min_sum(all_updates,m,dev_type='unit_vec'):
    model_re = torch.mean(all_updates,0)

    if dev_type == 'unit_vec':
        deviation = model_re / torch.norm(model_re)  # unit vector, dir opp to good dir
    elif dev_type == 'sign':
        deviation = torch.sign(model_re)
    elif dev_type == 'std':
        deviation = torch.std(all_updates, 0)
    
    lamda = torch.Tensor([10.0]).float()
    # print(lamda)
    threshold_diff = 1e-5
    lamda_fail = lamda
    lamda_succ = 0
    
    distances = []
    for update in all_updates:
        distance = torch.norm((all_updates - update), dim=1) ** 2
        distances = distance[None, :] if not len(distances) else torch.cat((distances, distance[None, :]), 0)
    
    scores = torch.sum(distances, dim=1)
    min_score = torch.min(scores)
    del distances

    while torch.abs(lamda_succ - lamda) > threshold_diff:
        mal_update = (model_re - lamda * deviation)
        distance = torch.norm((all_updates - mal_update), dim=1) ** 2
        score = torch.sum(distance)
        
        if score <= min_score:
            # print('successful lamda is ', lamda)
            lamda_succ = lamda
            lamda = lamda + lamda_fail / 2
        else:
            lamda = lamda - lamda_fail / 2

        lamda_fail = lamda_fail / 2

    # print(lamda_succ)
    mal_update = (model_re - lamda_succ * deviation)
    mal_updates = torch.stack([mal_update] * m)

    return mal_updates
###     
def KL(P,Q,mask=None):
    eps = 0.0000001
    d = (P+eps).log()-(Q+eps).log()
    d = P*d
    if mask !=None:
        d = d*mask
    return torch.sum(d)
def CE(P,Q,mask=None):
    return KL(P,Q,mask)+KL(1-P,1-Q,mask)

def umap(output, target, data_batch, eps=0.0000001):
    # start_idx = 0
    # for param in test_net.parameters():
    #     length = len(param.data.view(-1))
    #     param.data = output[start_idx: start_idx + length].reshape(param.data.shape).cuda()
    #     start_idx = start_idx + length
    global update
    test_net=copy.deepcopy(target)
    test_net.load_state_dict(unflatten(output,update[0]))
    # print(test_net(data_batch))
    output_net = test_net(data_batch)
    
    target_net = target(data_batch)
    # Normalize each vector by its norm
    (n, d) = output_net.shape
    output_net_norm = torch.sqrt(torch.sum(output_net ** 2, dim=1, keepdim=True))
    output_net = output_net / (output_net_norm + eps)
    output_net[output_net != output_net] = 0

    target_net_norm = torch.sqrt(torch.sum(target_net ** 2, dim=1, keepdim=True))
    target_net = target_net / (target_net_norm + eps)
    target_net[target_net != target_net] = 0

    # Calculate the cosine similarity
    model_similarity = torch.mm(output_net, output_net.transpose(0, 1))
    # model_similarity = model_similarity - torch.min(model_similarity,dim=1)[0].view(-1,1)
    model_distance = 1-model_similarity #[0,2]
    model_distance[range(n), range(n)] = 3
    model_distance = model_distance - torch.min(model_distance, dim=1)[0].view(-1, 1)
    model_distance[range(n), range(n)] = 0

    model_similarity = 1-model_distance

    target_similarity = torch.mm(target_net, target_net.transpose(0, 1))
    target_distance = 1-target_similarity
    target_distance[range(n), range(n)] = 3
    target_distance = target_distance - torch.min(target_distance,dim=1)[0].view(-1,1)
    target_distance[range(n), range(n)] = 0
    target_similarity = 1 - target_distance


    # Scale cosine similarity to 0..1
    model_similarity = (model_similarity + 1.0) / 2.0
    target_similarity = (target_similarity + 1.0) / 2.0

    # Transform them into probabilities
    model_similarity = model_similarity / torch.sum(model_similarity, dim=1, keepdim=True)
    target_similarity = target_similarity / torch.sum(target_similarity, dim=1, keepdim=True)

    # Calculate the KL-divergence
    loss = CE(target_similarity,model_similarity)
    # print(target_similarity,model_similarity)
    # exit()
    return loss

def calculate_distance(w_glob,target):
    distance=torch.norm(w_glob - target, p=2)
    # print("now distance: "+str(distance.item()))
    return distance.item()

def calculate_malicious(target_model,nusers,w_locals,choice,p,dataset_train):
    train_loader = DataLoader(dataset_train, batch_size=args.local_bs,drop_last=True)
    data_batch, _ = next(iter(train_loader))
    data_batch=data_batch.cuda()
    target_model_para=flatten(target_model.state_dict()).cpu()
    min_dis = 1000000
    while True:
        w_locals_c = w_locals
        # Three attack primitives
        if choice == 1:
            w0 = p * torch.randn_like(target_model_para)
        elif choice == 2:
            w0 = p * target_model_para
        elif choice == 3:
            # Note that here is a simulation in the full-knowledge setting
            w0 = p * (target_model_para - torch.sum(w_locals, dim=0) / nusers)

        # print("current p:"+str(p))
        while len(w_locals_c) < nusers:
            w_locals_c = w0[None, :] if len(w_locals_c) == 0 else torch.cat((w_locals_c, w0[None, :]), 0)
        w_glob = torch.mean(w_locals_c,dim=0)
        
        global loss_name
        if loss_name=='umap':
            now_dis = umap(w_glob, target_model, data_batch)
        elif loss_name=='l2':
            now_dis = calculate_distance(w_glob, target_model_para)
        # print(now_dis)
        # print('step 1:'+str(now_dis))
        
        #unflatten+load
        # start_idx=0
        # for param in test_net.parameters():
        #     length=len(param.data.view(-1))
        #     param.data = w_glob[start_idx: start_idx+length].reshape(param.data.shape).cuda()
        #     start_idx = start_idx + length
            
        ##
        decay=1.1
        # print(p,now_dis,min_dis)
        if now_dis<=min_dis:
            min_dis=now_dis
            p/=decay
        else :
            p*=decay
            break

    malicious_model= w0

    return malicious_model



def fang_adap(all_para,n):
    l = 0.01
    max=torch.max(all_para,dim=0)[0]
    min=torch.min(all_para,dim=0)[0]
    x,y = all_para.shape
    m=np.zeros((n,y))
    for i in range(n):
        for y in range(y):
            m[i][y]=np.random.uniform(min[y],max[y])
    m=torch.from_numpy(m)
    para = torch.concat([all_para,m])
    v,_=crh(para)
    V = torch.zeros_like(para[0])
    for e in range(50):
        v_hat ,w = crh(para)
        # print(v_hat,w)
        m_weight = w[x:]
        for idx in range(n):
            for j in range(y):
                m[idx][j] += 2*l*(v_hat[j] - v[j])*m_weight[idx]/sum(w)
        para = torch.concat([all_para,m])
        if(abs(sum(v_hat-V))<1e-7 ):
            print("converge",e)
            V=copy.deepcopy(v_hat)
            break
        V=copy.deepcopy(v_hat)
    return m

In [None]:
#init
from torchvision.models import vgg11,resnet18,ResNet18_Weights
args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu')
dict_users, train_dict, val_dict, dataset_train, dataset_test= Dataset_config(args.dataset, args.num_users, args.pattern)

if args.model == 'cnn' and args.dataset == 'cifar':
    net_glob = CNNCifar(args=args).to(args.device)
elif args.model == 'cnn' and args.dataset == 'mnist':
    net_glob = CNNMnist(args=args).to(args.device)
elif args.model == 'mlp':
    len_in = 1
    img_size = dataset_train[0][0].shape
    for x in img_size:
        len_in *= x
    net_glob = MLP(dim_in=len_in, dim_hidden=200, dim_out=args.num_classes).to(args.device)
elif args.model == 'resnet' and args.dataset == 'cifar':
    net_glob=resnet18(weights=ResNet18_Weights.DEFAULT)
    num_features=net_glob.fc.in_features
    net_glob.fc=nn.Linear(num_features,10)
    model=net_glob.to(args.device)
elif args.model == 'cnn' and args.dataset == 'fashion':
    net_glob = CNNfashion().to(args.device)   
elif args.model == 'vgg' and args.dataset == 'cifar':
    net_glob = VGG11().to(args.device)
else:
    exit('Error: unrecognized model')
print(net_glob)
# writer = SummaryWriter(log_dir='./logs')
# net_glob.requires_grad_()
net_glob.train()

# copy weights
w_glob = net_glob.state_dict()

# print(dataset_train.targets.numpy()[(dict_users[0])])
# print(w_glob)
avg_weights = torch.tensor([1/args.num_users for i in range(args.num_users)])
loss_train = []
val_acc_list, net_list = [], []
#test
test_acc=[]
fedavg_acc=[]
fedavg_loss=[]

In [None]:
##train
local_acc=[]
for e in range(5):
    w_locals = []
    locals_acc=[]
    idxs_users = range(args.num_users)
    pre_para = copy.deepcopy(net_glob.to(args.device).state_dict())
    for idx in idxs_users:
        args.local_ep = 10
        net_glob.load_state_dict(pre_para)
        local = LocalUpdate(args=args, dataset=dataset_train, train=dict_users[idx])
        local_net, w, _ = local.train(net=copy.deepcopy(net_glob).to(args.device))
        w_locals.append(copy.deepcopy(w))
        # net_glob.load_state_dict(w)
        local_test, _ = local.val(local_net,args)
        # print(local_test)
        locals_acc.append(local_test)
    
    net_glob.to(args.device)
    global_para = copy.deepcopy(net_glob.to(args.device).state_dict())
    #avg
    global_para = unflatten(torch.mean(torch.stack([flatten(i).cpu() for i in w_locals]),dim=0),pre_para)

    net_glob.load_state_dict(global_para)


    acc_test, loss_test = test_img(net_glob, dataset_test, args)
    test_acc.append(acc_test)
    print(acc_test)
    local_acc.append(locals_acc)
    print(locals_acc)
np.save('./local_acc.npy',local_acc)


In [None]:
#incentive

e= [i.item() for i in local_acc[0]]
# e= torch.stack([i for i in local_acc[0]])
# np.sort(e)[1]
def inc(l):
    alpha=0.05
    beta=160
    reward=[]
    acc = np.sort(l)[1]
    s=sum(l)
    for i in l:
        if i <=acc:
            reward.append(alpha*(acc-i)+beta*i/s)
        else:
            reward.append(beta*i/s)
    return reward
# print(e)
for i in local_acc:
    print(inc([j.item() for j in i]))
# inc(e)
# for epoch in range(5): 
#     local_acc(epoch)

In [None]:
#only local train (for noniid & label flip)
local_acc=[]
loss_locals = []
w_locals = []
locals_acc=[]
idxs_users = range(args.num_users)
pre_para = copy.deepcopy(net_glob.to(args.device).state_dict())
for idx in idxs_users:
    args.local_ep = 50
    net_glob.load_state_dict(pre_para)
    local = LocalUpdate(args=args, dataset=dataset_train, train=train_dict[idx],val=val_dict[idx])
    local_net,w, loss,size = local.train(net=copy.deepcopy(net_glob).to(args.device))
    w_locals.append(copy.deepcopy(w))
    loss_locals.append(copy.deepcopy(loss))
    net_glob.load_state_dict(w)
    val_acc,val_loss = local.val(local_net,args)
    print(val_acc,val_loss)
    
dir = './mnist_noniid_mlp/0.9/'
np.save(dir+'50e_val_dict',val_dict)
np.save(dir+'50e_train_dict',train_dict)
np.save(dir+'50e_w_locals',w_locals)

In [None]:
#for fmpa
update=np.load('./result/cifar_cnn/noniid_0.5/50e_w_locals.npy',allow_pickle=True)[()]
w_locals = torch.stack([flatten(i).cpu() for i in update])
test_net = copy.deepcopy(net_glob)
# dataset_train
m=20
# b=args.num_users-m
# loss_name ='umap'
# loss_name='l2'
for loss_name in 'umap':
    for choice in [1,2,3]:
        for x in range(5):
            print(loss_name,choice)
            p=calculate_malicious(net_glob,args.num_users,w_locals[m:],choice,10,dataset_train)
            poison=torch.vstack([p.reshape(1,-1)]*m)

            uw = torch.vstack([poison,w_locals[m:]])

            net_glob.load_state_dict(unflatten(torch.mean(w_locals[m:],dim=0),update[0]))
            print("benign-avg",test_img(net_glob,dataset_test,args)[0])

            net_glob.load_state_dict(unflatten(torch.mean(uw,dim=0),update[0]))
            print("poisoned-avg",test_img(net_glob,dataset_test,args)[0])

            #dnc
            benign_ids,final_w = dnc(uw,m)
            print("Dnc",benign_ids)
            net_glob.load_state_dict(unflatten(final_w,update[0]))
            # print(attack.__name__,'dnc',test_img(net_glob,dataset_test,args))
            print("Dnc",test_img(net_glob,dataset_test,args)[0])


            g,distance= svtd(uw)
            #tdfl
            cs=[]
            for idx in range(len(uw)):
                cs.append(torch.cosine_similarity(uw[idx],g,dim=0))
            cs = torch.stack(cs)
            print(cs)
            p= np.where(cs>=0.95)

            print("tdfl",p)
            net_glob.load_state_dict(unflatten(torch.mean(uw[p],dim=0),update[0]))
            # print(attack.__name__,'TDFL_cos',test_img(net_glob,dataset_test,args))
            print("tdfl",test_img(net_glob,dataset_test,args)[0])

            # one_crh
            distance = distance.numpy()
            a=np.where(distance > np.mean(distance)-0.5)
            if len(a[0])>(uw.shape[0]/2):
                p1= a
            else : p1= np.delete(np.arange(len(uw)),a)
            net_glob.load_state_dict(unflatten(torch.mean(uw[p1],dim=0),update[0]))
            print("oneTD",test_img(net_glob,dataset_test,args)[0])

            #kmeanscrh
            re = KMeans(2, random_state=0, n_init="auto").fit(distance.reshape(-1,1)).labels_
            print(re)
            if len(np.where(re==0)[0])>len(np.where(re==1)[0]):
                p2= np.where(re==0)[0]
            else:
                p2= np.where(re==1)[0]
            net_glob.load_state_dict(unflatten(torch.mean(uw[p2],dim=0),update[0]))
            print("kTD",test_img(net_glob,dataset_test,args)[0])

#td
# net_glob.load_state_dict(unflatten(torch.mean(one_crh(uw),dim=0),update[0]))
# print("oneTD",test_img(net_glob,dataset_test,args)[0])
# net_glob.load_state_dict(unflatten(torch.mean(kmeans_crh(uw),dim=0),update[0]))
# print("kTD",test_img(net_glob,dataset_test,args)[0])


In [None]:
import os
import pandas as pd

ddir='./save/mnist_mlp_iid/'
os.makedirs(ddir,exist_ok=True)

update=np.load('/home/k3ats/self-fed/result/mnist_mlp/iid/50e_w_locals.npy',allow_pickle=True)[()]
update_w = torch.stack([flatten(i).cpu() for i in update])
m=20
n_attacker = 20
writer = pd.ExcelWriter(ddir+'/50epoch.xlsx')
print('----------------table1----------------')
table1_attack=[attack_median_and_trimmedmean,
                get_malicious_updates_fang,
                LIE,scaling_attack,
                mean_attack,full_mean_attack,
                min_max,min_sum]
sheet1=[]    
t1_record = []
for attack in table1_attack:            
    poison = attack(update_w,m)
    uw = torch.vstack([poison,update_w[m:]])
    record=[]
    c_record = []
    #mean
    net_glob.load_state_dict(unflatten(torch.mean(uw,dim=0),update[0]))
    # print(attack.__name__,'mean',test_img(net_glob,dataset_test,args))
    record.append(test_img(net_glob,dataset_test,args)[0])
    # median
    net_glob.load_state_dict(unflatten(torch.median(uw,dim=0)[0],update[0]))
    # print(attack.__name__,'median',test_img(net_glob,dataset_test,args))
    record.append(test_img(net_glob,dataset_test,args)[0])
    #trim
    final_w = tr_mean(uw,n_attacker)
    net_glob.load_state_dict(unflatten(final_w,update[0]))
    # print(attack.__name__,'trim',test_img(net_glob,dataset_test,args))
    record.append(test_img(net_glob,dataset_test,args)[0])
    
    #krum
    if m ==24 :
        final_w = multi_krum_defence(uw,20)
    else:
        final_w = multi_krum_defence(uw,n_attacker)
    net_glob.load_state_dict(unflatten(final_w,update[0]))
    # print(attack.__name__,'krum',test_img(net_glob,dataset_test,args))
    record.append(test_img(net_glob,dataset_test,args)[0])
    
    #bulyan
    if(m==24):
        final_w = bulyan(uw,20)
    else:
        final_w = bulyan(uw,n_attacker)
    net_glob.load_state_dict(unflatten(final_w,update[0]))
    # print(attack.__name__,'bulyan',test_img(net_glob,dataset_test,args))
    record.append(test_img(net_glob,dataset_test,args)[0])
    
    #dnc
    benign_ids,final_w = dnc(uw,n_attacker)
    c_record.append(benign_ids)
    net_glob.load_state_dict(unflatten(final_w,update[0]))
    # print(attack.__name__,'dnc',test_img(net_glob,dataset_test,args))
    record.append(test_img(net_glob,dataset_test,args)[0])
    
    
    #TD
    g,distance= crh(uw)
    
    #TDFL-cos
    # p=TDFL_cos(uw,0.95)
    cs=[]
    for idx in range(len(uw)):
        cs.append(torch.cosine_similarity(uw[idx],g,dim=0))
    cs = torch.stack(cs)
    print(cs)
    p= np.where(cs>=0.95)
    
    c_record.append(p)
    net_glob.load_state_dict(unflatten(torch.mean(uw[p],dim=0),update[0]))
    # print(attack.__name__,'TDFL_cos',test_img(net_glob,dataset_test,args))
    record.append(test_img(net_glob,dataset_test,args)[0])
    
    #one_crh
    # p1=one_crh(uw)
    distance = distance.numpy()
    a=np.where(distance > np.mean(distance)-0.5)
    if len(a[0])>(uw.shape[0]/2):
        p1= a
    else : p1= np.delete(np.arange(len(uw)),a)
    
    c_record.append(p1)
    net_glob.load_state_dict(unflatten(torch.mean(uw[p1],dim=0),update[0]))
    # print(attack.__name__,'one_crh',test_img(net_glob,dataset_test,args))
    record.append(test_img(net_glob,dataset_test,args)[0])
    
    #kmeanscrh
    # p2=kmeans_crh(uw)
    
    re = KMeans(2, random_state=0, n_init="auto").fit(distance.reshape(-1,1)).labels_
    print(re)
    if len(np.where(re==0)[0])>len(np.where(re==1)[0]):
        p2= np.where(re==0)[0]
    else:
        p2= np.where(re==1)[0]
        
    c_record.append(p2)
    net_glob.load_state_dict(unflatten(torch.mean(uw[p2],dim=0),update[0]))
    # print(attack.__name__,'kmeans_crh',test_img(net_glob,dataset_test,args))
    record.append(test_img(net_glob,dataset_test,args)[0])
    
    sheet1.append(copy.deepcopy(record))
    t1_record.append(copy.deepcopy(c_record))
np.save(ddir+'/table1.npy',sheet1)
np.save(ddir+'/t1_record.npy',t1_record)
#save to excel
f = pd.DataFrame(sheet1,
                index=['attack_median_and_trimmedmean',
                'get_malicious_updates_fang',
                'LIE','scaling_attack',
                'mean_attack','full_mean_attack','min_max','min_sum']
                ,columns=['mean','median','tr_mean','krum','bulyan','dnc','TDFL_cos','one_crh','kmeans_crh']
                ).astype(np.float64)
print(f)
f.to_excel(writer,sheet_name="table1")

writer.close()

In [90]:
#threshold test
update = np.load('/home/k3ats/self-fed/result/mnist_mlp/noniid_0.5/50e_w_locals.npy',allow_pickle=True)[()]
ddir='./save/mnist_mlp_noniid_0.5/'
os.makedirs(ddir,exist_ok=True)
update_w = torch.stack([flatten(i).cpu() for i in update])
writer = pd.ExcelWriter(ddir+'thres_record.xlsx')
thres_record=[]
for t in [0.01,0.05,0.1,0.3,0.5,0.7,1]:
    record=[]
    print(t)
    for m in [1,2,3,4,5,10,15,20,21,22,23,24]:
        poison = full_mean_attack(update_w,m)
        uw = torch.vstack([poison,update_w[m:]])        
        g,distance= crh(uw)
        
        # print(distance)
        # break
        distance = distance.numpy()
        a=np.where(distance > np.mean(distance)-t)
        if len(a[0])>(uw.shape[0]/2):
            p1= a
        else : p1= np.delete(np.arange(len(uw)),a)
        # c_record.append(p1)
        # print(p1)
        net_glob.load_state_dict(unflatten(torch.mean(uw[p1],dim=0),update[0]))
        # print(attack.__name__,'one_crh',test_img(net_glob,dataset_test,args))
        record.append(test_img(net_glob,dataset_test,args)[0])
    thres_record.append(record)
np.save(ddir+'./threshold_record.npy',thres_record)
f = pd.DataFrame(thres_record,
                 index=['0.01','0.05','0.1','0.3','0.5','0.7','1']
                 ,columns=['1','2','3','4','5','10','15','20','21','22','23','24']
                 ).astype(np.float64)
print(f)
f.to_excel(writer,sheet_name="table1")
writer.close()


0.01
0.05
0.1
0.3
0.5
0.7
1
          1      2      3      4      5     10     15     20     21     22   
0.01  73.36  74.54  76.43  76.52  76.71  76.33  74.99  74.17  74.81  74.69  \
0.05  74.00  75.14  76.43  76.52  76.71  76.33  74.99  74.17  74.81  74.69   
0.1   74.81  75.97  76.43  76.52  76.71  76.33  74.99  74.17  74.81  74.69   
0.3   76.12  75.97  76.43  76.52  76.71  76.33  74.99  74.17  74.81  74.69   
0.5   76.12  75.97  76.43  76.52  76.71  76.33  74.99  74.17  74.81  74.69   
0.7   76.12  75.97  76.43  76.52  76.71  76.33  74.99  74.17  74.81  64.98   
1     76.12  75.97  76.43  76.52  76.71  76.33  74.99  64.76  64.66  64.98   

         23     24  
0.01  73.85  73.74  
0.05  73.85  73.74  
0.1   73.85  73.74  
0.3   73.85  64.58  
0.5   65.21  64.58  
0.7   65.21  64.58  
1     65.21  64.58  


In [101]:
#correct for one-TD
update = np.load('./result/mnist_mlp/noniid_0.5/50e_w_locals.npy',allow_pickle=True)[()]
ddir='./save/mnist_mlp_noniid_0.5/'
os.makedirs(ddir,exist_ok=True)
update_w = torch.stack([flatten(i).cpu() for i in update])
writer = pd.ExcelWriter(ddir+'mean_ratio.xlsx')
r_record=[]
for m in [5,10,15,20,21,22,23,24]:
    record=[]
    poison = mean_attack(update_w,m)
    uw = torch.vstack([poison,update_w[m:]])  

    #without attack
    net_glob.load_state_dict(unflatten(torch.mean(update_w[m:],dim=0),update[0]))
    record.append(test_img(net_glob,dataset_test,args)[0])
    #mean
    net_glob.load_state_dict(unflatten(torch.mean(uw,dim=0),update[0]))
    record.append(test_img(net_glob,dataset_test,args)[0])

    g,distance= crh(uw)
    # print(distance)
    # break
    distance = distance.numpy()
    a=np.where(distance > np.mean(distance)-0.1)
    if len(a[0])>(uw.shape[0]/2):
        p1= a
    else : p1= np.delete(np.arange(len(uw)),a)
    # c_record.append(p1)
    # print(p1)
    net_glob.load_state_dict(unflatten(torch.mean(uw[p1],dim=0),update[0]))
    # print(attack.__name__,'one_crh',test_img(net_glob,dataset_test,args))
    record.append(test_img(net_glob,dataset_test,args)[0])

    re = KMeans(2, random_state=0, n_init="auto").fit(distance.reshape(-1,1)).labels_
    print(re)
    if len(np.where(re==0)[0])>len(np.where(re==1)[0]):
        p2= np.where(re==0)[0]
    else:
        p2= np.where(re==1)[0]
        
    # c_record.append(p2)
    net_glob.load_state_dict(unflatten(torch.mean(uw[p2],dim=0),update[0]))
    # print(attack.__name__,'kmeans_crh',test_img(net_glob,dataset_test,args))
    record.append(test_img(net_glob,dataset_test,args)[0])
    r_record.append(record)
np.save(ddir+'./mean_ratio.npy',r_record)
print(r_record)
f = pd.DataFrame(r_record,
                 index=['5','10','15','20','21','22','23','24']
                 ,columns=['withou attack','fedavg','One-TD','K-TD']
                 ).astype(np.float64)
print(f)
f.to_excel(writer,sheet_name="table1")
writer.close()

[1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0]
[1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0]
[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0]
[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0]
[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0]
[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0]
[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0]
[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0]
[[76.71, 77.37, 76.71, 76.71], [76.33, 76.83, 76.33, 76.33], [74.99, 71.35, 74.99, 74.99], [74.17, 44.56, 74.17, 74.17], [74.81, 41.23, 74.81, 74.81], [74.69, 35.25, 74.69, 74.

In [120]:
#for byzantine attack

#mnist noniid
update = np.load('./50c_resnet_cifar10_noniid0.5.npy',allow_pickle=True)[()]
ddir='./save/cifar_resnet_noniid/'

os.makedirs(ddir,exist_ok=True)
update_w = torch.stack([flatten(i).cpu() for i in update])
writer = pd.ExcelWriter(ddir+'m_attack.xlsx')
r_record=[]
for m in [5,10,15,20,24]:
    n_attacker=m
    record=[]
    
    poison = mean_attack(update_w,m)
    uw = torch.vstack([poison,update_w[m:]])  

    #mean
    net_glob.load_state_dict(unflatten(torch.mean(uw,dim=0),update[0]))
    # print(attack.__name__,'mean',test_img(net_glob,dataset_test,args,test_sampler))
    record.append(test_img(net_glob,dataset_test,args)[0])
    
    # median
    net_glob.load_state_dict(unflatten(torch.median(uw,dim=0)[0],update[0]))
    # print(attack.__name__,'median',test_img(net_glob,dataset_test,args,test_sampler))
    record.append(test_img(net_glob,dataset_test,args)[0])
    #trim
    final_w = tr_mean(uw,n_attacker)
    net_glob.load_state_dict(unflatten(final_w,update[0]))
    # print(attack.__name__,'trim',test_img(net_glob,dataset_test,args,test_sampler))
    record.append(test_img(net_glob,dataset_test,args)[0])
    
    #krum
    if m==24: n_attacker=20
    final_w = multi_krum_defence(uw,n_attacker)
    net_glob.load_state_dict(unflatten(final_w,update[0]))
    # print(attack.__name__,'krum',test_img(net_glob,dataset_test,args,test_sampler))
    record.append(test_img(net_glob,dataset_test,args)[0])
    
    # #bulyan
    if m==24: n_attacker=20
    final_w = bulyan(uw,n_attacker)
    net_glob.load_state_dict(unflatten(final_w,update[0]))
    # print(attack.__name__,'bulyan',test_img(net_glob,dataset_test,args,test_sampler))
    record.append(test_img(net_glob,dataset_test,args)[0])
    
    #dnc
    # benign_ids,final_w = dnc(uw,n_attacker)
    # c_record.append(benign_ids)
    # net_glob.load_state_dict(unflatten(final_w,update[0]))
    # # print(attack.__name__,'dnc',test_img(net_glob,dataset_test,args))
    # record.append(test_img(net_glob,dataset_test,args)[0])
    
    
    #TD
    g,distance= crh(uw)
    
    #TDFL-cos
    # p=TDFL_cos(uw,0.95)
    # cs=[]
    # for idx in range(len(uw)):
    #     cs.append(torch.cosine_similarity(uw[idx],g,dim=0))
    # cs = torch.stack(cs)
    # print(cs)
    # p= np.where(cs>=0.95)
    
    # c_record.append(p)
    # net_glob.load_state_dict(unflatten(torch.mean(uw[p],dim=0),update[0]))
    # # print(attack.__name__,'TDFL_cos',test_img(net_glob,dataset_test,args))
    # record.append(test_img(net_glob,dataset_test,args)[0])
    
    # one_crh
    # p1=one_crh(uw)
    distance = distance.numpy()
    a=np.where(distance > np.mean(distance)-0.1)
    if len(a[0])>(uw.shape[0]/2):
        p1= a
    else : p1= np.delete(np.arange(len(uw)),a)
    
    c_record.append(p1)
    net_glob.load_state_dict(unflatten(torch.mean(uw[p1],dim=0),update[0]))
    # print(attack.__name__,'one_crh',test_img(net_glob,dataset_test,args))
    record.append(test_img(net_glob,dataset_test,args)[0])
    
    #kmeanscrh
    # p2=kmeans_crh(uw)
    
    re = KMeans(2, random_state=0, n_init="auto").fit(distance.reshape(-1,1)).labels_
    print(re)
    if len(np.where(re==0)[0])>len(np.where(re==1)[0]):
        p2= np.where(re==0)[0]
    else:
        p2= np.where(re==1)[0]
        
    c_record.append(p2)
    net_glob.load_state_dict(unflatten(torch.mean(uw[p2],dim=0),update[0]))
    # print(attack.__name__,'kmeans_crh',test_img(net_glob,dataset_test,args))
    record.append(test_img(net_glob,dataset_test,args)[0])
    r_record.append(copy.deepcopy(record))
np.save(ddir+'./m_attack.npy',r_record)
print(r_record)
f = pd.DataFrame(r_record,
                 index=['10','20','30','40','48']
                 ,columns=['mean','median','trim-mean','krum','bulyan','One-TD','K-TD']
                 ).astype(np.float64)
print(f)
f.to_excel(writer,sheet_name="table1")
writer.close()

[1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0]
[1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0]
[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0]
[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0]
[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0]
[[10.0, 52.92, 53.69, 49.66, 53.09, 53.75, 53.75], [10.0, 52.92, 53.26, 49.66, 52.48, 53.73, 53.73], [10.0, 52.23, 52.35, 49.66, 52.1, 53.67, 53.67], [10.0, 50.45, 51.32, 49.66, 10.0, 53.75, 53.75], [10.0, 39.63, 46.26, 48.55, 10.0, 53.7, 53.7]]
    mean  median  trim-mean   krum  bulyan  One-TD   K-TD
10  10.0   52.92      53.69  49.66   53.09   53.75  53.75
20  10.0   52.92      53.26  49.66   52.48   53.73  53.73
30  10.0   52.23      52.35  49.66   52.10   53.67  53.67
40  10.

In [None]:
# exp
import os
import pandas as pd
update = np.load('./result/mnist_mlp/iid/50e_w_locals.npy',allow_pickle=True)[()]
ddir='./save/mnist_mlp_iid/'
os.makedirs(ddir,exist_ok=True)

update_w = torch.stack([flatten(i).cpu() for i in update])
m=24
n_attacker = 24
# dir='./'
writer = pd.ExcelWriter(ddir+'result.xlsx')
print('----------------table1----------------')
table1_attack=[AGR_tailored_attack_on_krum,AGR_tailored_attack_on_trmean,
               attack_median_and_trimmedmean,get_malicious_updates_fang,LIE]
sheet1=[] 
t1_record=[]   
for attack in table1_attack:
    poison = attack(update_w,m)
    uw = torch.vstack([poison,update_w[m:]])
    record=[]
    c_record = []
    #mean
    net_glob.load_state_dict(unflatten(torch.mean(uw,dim=0),update[0]))
    # print(attack.__name__,'mean',test_img(net_glob,dataset_test,args,test_sampler))
    record.append(test_img(net_glob,dataset_test,args)[0])
    # #trim
    # final_w = tr_mean(uw,n_attacker)
    # net_glob.load_state_dict(unflatten(final_w,update[0]))
    # # print(attack.__name__,'trim',test_img(net_glob,dataset_test,args,test_sampler))
    # record.append(test_img(net_glob,dataset_test,args)[0])
    # # median
    # net_glob.load_state_dict(unflatten(torch.median(uw,dim=0)[0],update[0]))
    # # print(attack.__name__,'median',test_img(net_glob,dataset_test,args,test_sampler))
    # record.append(test_img(net_glob,dataset_test,args)[0])
    #krum
    # final_w = multi_krum_defence(uw,n_attacker)
    # net_glob.load_state_dict(unflatten(final_w,update[0]))
    # # print(attack.__name__,'krum',test_img(net_glob,dataset_test,args,test_sampler))
    # record.append(test_img(net_glob,dataset_test,args)[0])
    
    # #bulyan
    # final_w = bulyan(uw,n_attacker)
    # net_glob.load_state_dict(unflatten(final_w,update[0]))
    # # print(attack.__name__,'bulyan',test_img(net_glob,dataset_test,args,test_sampler))
    # record.append(test_img(net_glob,dataset_test,args)[0])
    
    #dnc
    benign_ids,final_w = dnc(uw,n_attacker)
    c_record.append(benign_ids)
    net_glob.load_state_dict(unflatten(final_w,update[0]))
    # print(attack.__name__,'dnc',test_img(net_glob,dataset_test,args))
    record.append(test_img(net_glob,dataset_test,args)[0])
    
    
    #TD
    g,distance= crh(uw)
    
    #TDFL-cos
    # p=TDFL_cos(uw,0.95)
    cs=[]
    for idx in range(len(uw)):
        cs.append(torch.cosine_similarity(uw[idx],g,dim=0))
    cs = torch.stack(cs)
    print(cs)
    p= np.where(cs>=0.95)
    
    c_record.append(p)
    net_glob.load_state_dict(unflatten(torch.mean(uw[p],dim=0),update[0]))
    # print(attack.__name__,'TDFL_cos',test_img(net_glob,dataset_test,args))
    record.append(test_img(net_glob,dataset_test,args)[0])
    
    # one_crh
    # p1=one_crh(uw)
    distance = distance.numpy()
    a=np.where(distance > np.mean(distance)-0.1)
    if len(a[0])>(uw.shape[0]/2):
        p1= a
    else : p1= np.delete(np.arange(len(uw)),a)
    
    c_record.append(p1)
    net_glob.load_state_dict(unflatten(torch.mean(uw[p1],dim=0),update[0]))
    # print(attack.__name__,'one_crh',test_img(net_glob,dataset_test,args))
    record.append(test_img(net_glob,dataset_test,args)[0])
    
    #kmeanscrh
    # p2=kmeans_crh(uw)
    
    re = KMeans(2, random_state=0, n_init="auto").fit(distance.reshape(-1,1)).labels_
    print(re)
    if len(np.where(re==0)[0])>len(np.where(re==1)[0]):
        p2= np.where(re==0)[0]
    else:
        p2= np.where(re==1)[0]
        
    c_record.append(p2)
    net_glob.load_state_dict(unflatten(torch.mean(uw[p2],dim=0),update[0]))
    # print(attack.__name__,'kmeans_crh',test_img(net_glob,dataset_test,args))
    record.append(test_img(net_glob,dataset_test,args)[0])
    
    sheet1.append(copy.deepcopy(record))
    t1_record.append(copy.deepcopy(c_record))
np.save(ddir+'./table1.npy',sheet1)
np.save(ddir+'./t1_record.npy',t1_record)
#save to excel
f = pd.DataFrame(sheet1,
                 index=['AGR_tailored_attack_on_krum','AGR_tailored_attack_on_trmean','attack_median_and_trimmedmean',
                        'get_malicious_updates_fang','LIE']
                 ,columns=['mean','dnc','TDFL_cos','one_TD','K-TD']
                 ).astype(np.float64)
print(f)
f.to_excel(writer,sheet_name="table1")
writer.close()

In [None]:
import pandas as pd

update = np.load('./50epoch_cifar_noniid0.5_50c.npy',allow_pickle=True)[()]


update_w = torch.stack([flatten(i).cpu() for i in update])
m=20
n_attacker = 20
dir='./cifar_noniid/'
writer = pd.ExcelWriter(dir+'50epoch.xlsx')
print('----------------table1----------------')
table1_attack=[attack_median_and_trimmedmean,AGR_tailored_attack_on_trmean,LIE,scaling_attack,mean_attack,full_mean_attack,min_max,min_sum]
sheet1=[]    
for attack in table1_attack:
    poison = attack(update_w,m)
    uw = torch.vstack([poison,update_w[m:]])
    record=[]
    #mean
    net_glob.load_state_dict(unflatten(torch.mean(uw,dim=0),update[0]))
    # print(attack.__name__,'mean',test_img(net_glob,dataset_test,args,test_sampler))
    record.append(test_img(net_glob,dataset_test,args,test_sampler)[0])
    #trim
    final_w = tr_mean(uw,n_attacker)
    net_glob.load_state_dict(unflatten(final_w,update[0]))
    # print(attack.__name__,'trim',test_img(net_glob,dataset_test,args,test_sampler))
    record.append(test_img(net_glob,dataset_test,args,test_sampler)[0])
    #dnc
    final_w = dnc(uw,n_attacker)
    net_glob.load_state_dict(unflatten(final_w,update[0]))
    # print(attack.__name__,'dnc',test_img(net_glob,dataset_test,args,test_sampler))
    record.append(test_img(net_glob,dataset_test,args,test_sampler)[0])
    #TDFL-cos
    p=TDFL_cos(uw,0.95)
    net_glob.load_state_dict(unflatten(torch.mean(uw[p],dim=0),update[0]))
    # print(attack.__name__,'TDFL_cos',test_img(net_glob,dataset_test,args,test_sampler))
    record.append(test_img(net_glob,dataset_test,args,test_sampler)[0])
    
    #one_crh
    net_glob.load_state_dict(unflatten(torch.mean(uw[one_crh(uw)],dim=0),update[0]))
    # print(attack.__name__,'one_crh',test_img(net_glob,dataset_test,args,test_sampler))
    record.append(test_img(net_glob,dataset_test,args,test_sampler)[0])
    
    #kmeanscrh
    net_glob.load_state_dict(unflatten(torch.mean(uw[kmeans_crh(uw)],dim=0),update[0]))
    # print(attack.__name__,'kmeans_crh',test_img(net_glob,dataset_test,args,test_sampler))
    record.append(test_img(net_glob,dataset_test,args,test_sampler)[0])
    sheet1.append(copy.deepcopy(record))
np.save(dir+'table1.npy',sheet1)
#save to excel
f = pd.DataFrame(sheet1,
                 index=['attack_median_and_trimmedmean','AGR_tailored_attack_on_trmean','LIE','scaling_attack','mean_attack','full_mean_attack','min_max','min_sum']
                 ,columns=['mean','tr_mean','dnc','TDFL_cos','one_crh','kmeans_crh']
                 ).astype(np.float64)
print(f)
f.to_excel(writer,sheet_name="table1")

#tabel2
print('----------------table2----------------')
sheet2=[]    
table2_attack=[attack_median_and_trimmedmean,AGR_tailored_attack_on_median,LIE,scaling_attack,mean_attack,full_mean_attack,min_max,min_sum]
for attack in table2_attack:
    record=[]
    poison = attack(update_w,m)
    uw = torch.vstack([poison,update_w[m:]])
    #mean
    net_glob.load_state_dict(unflatten(torch.mean(uw,dim=0),update[0]))
    # print(attack.__name__,'mean',test_img(net_glob,dataset_test,args,test_sampler))
    record.append(test_img(net_glob,dataset_test,args,test_sampler)[0])
    # median
    net_glob.load_state_dict(unflatten(torch.median(uw,dim=0)[0],update[0]))
    # print(attack.__name__,'median',test_img(net_glob,dataset_test,args,test_sampler))
    record.append(test_img(net_glob,dataset_test,args,test_sampler)[0])
    
    #dnc
    final_w = dnc(uw,n_attacker)    
    net_glob.load_state_dict(unflatten(final_w,update[0]))
    # print(attack.__name__,'dnc',test_img(net_glob,dataset_test,args,test_sampler))
    record.append(test_img(net_glob,dataset_test,args,test_sampler)[0])
    
    #TDFL-cos
    p=TDFL_cos(uw,0.95)
    net_glob.load_state_dict(unflatten(torch.mean(uw[p],dim=0),update[0]))
    # print(attack.__name__,'TDFL_cos',test_img(net_glob,dataset_test,args,test_sampler))
    record.append(test_img(net_glob,dataset_test,args,test_sampler)[0])
    
    #one_crh
    net_glob.load_state_dict(unflatten(torch.mean(uw[one_crh(uw)],dim=0),update[0]))
    # print(attack.__name__,'one_crh',test_img(net_glob,dataset_test,args,test_sampler))
    record.append(test_img(net_glob,dataset_test,args,test_sampler)[0])
        
    #kmeanscrh
    net_glob.load_state_dict(unflatten(torch.mean(uw[kmeans_crh(uw)],dim=0),update[0]))
    # print(attack.__name__,'kmeans_crh',test_img(net_glob,dataset_test,args,test_sampler))
    record.append(test_img(net_glob,dataset_test,args,test_sampler)[0])
    sheet2.append(record)
np.save(dir+'table2.npy',sheet2)

f = pd.DataFrame(sheet2,
                 index=['attack_median_and_trimmedmean','AGR_tailored_attack_on_median','LIE','scaling_attack','mean_attack','full_mean_attack','min_max','min_sum']
                 ,columns=['mean','median','dnc','TDFL_cos','one_crh','kmeans_crh']
                 ).astype(np.float64)
print(f)
f.to_excel(writer,sheet_name="table2")

#for table3   
print('----------------table3----------------')
table3_attack=[get_malicious_updates_fang,AGR_tailored_attack_on_krum,LIE,scaling_attack,mean_attack,full_mean_attack,min_max,min_sum]
sheet3=[]     
for attack in table3_attack:
    record=[]
    poison = attack(update_w,m)
    uw = torch.vstack([poison,update_w[m:]])
    #mean
    net_glob.load_state_dict(unflatten(torch.mean(uw,dim=0),update[0]))
    # print(attack.__name__,'mean',test_img(net_glob,dataset_test,args,test_sampler))
    record.append(test_img(net_glob,dataset_test,args,test_sampler)[0])
    
    #krum
    final_w = multi_krum_defence(uw,n_attacker)
    net_glob.load_state_dict(unflatten(final_w,update[0]))
    # print(attack.__name__,'krum',test_img(net_glob,dataset_test,args,test_sampler))
    record.append(test_img(net_glob,dataset_test,args,test_sampler)[0])
    
    #bulyan
    final_w = bulyan(uw,n_attacker)
    net_glob.load_state_dict(unflatten(final_w,update[0]))
    # print(attack.__name__,'bulyan',test_img(net_glob,dataset_test,args,test_sampler))
    record.append(test_img(net_glob,dataset_test,args,test_sampler)[0])
    
    #dnc
    final_w = dnc(uw,n_attacker)
    net_glob.load_state_dict(unflatten(final_w,update[0]))
    # print(attack.__name__,'dnc',test_img(net_glob,dataset_test,args,test_sampler))
    record.append(test_img(net_glob,dataset_test,args,test_sampler)[0])
    
    #TDFL-cos
    p=TDFL_cos(uw,0.95)
    net_glob.load_state_dict(unflatten(torch.mean(uw[p],dim=0),update[0]))
    # print(attack.__name__,'TDFL_cos',test_img(net_glob,dataset_test,args,test_sampler))
    record.append(test_img(net_glob,dataset_test,args,test_sampler)[0])
    
    #one_crh
    net_glob.load_state_dict(unflatten(torch.mean(uw[one_crh(uw)],dim=0),update[0]))
    # print(attack.__name__,'one_crh',test_img(net_glob,dataset_test,args,test_sampler))
    record.append(test_img(net_glob,dataset_test,args,test_sampler)[0])
    
    #kmeanscrh
    net_glob.load_state_dict(unflatten(torch.mean(uw[kmeans_crh(uw)],dim=0),update[0]))
    # print(attack.__name__,'kmeans_crh',test_img(net_glob,dataset_test,args,test_sampler))
    record.append(test_img(net_glob,dataset_test,args,test_sampler)[0])
    sheet3.append(record)

np.save(dir+'table3.npy',sheet3)
f = pd.DataFrame(sheet3,
                 index=['get_malicious_updates_fang','AGR_tailored_attack_on_krum','LIE','scaling_attack','mean_attack','full_mean_attack','min_max','min_sum']
                 ,columns=['mean','krum','bulyan','dnc','TDFL_cos','one_crh','kmeans_crh']
                 ).astype(np.float64)
print(f)
f.to_excel(writer,sheet_name="table3")

writer.close()
    