# The notebook contains
### Code for _Median_ aggregation algorithm
### Evaluation of all of the attacks (Fang, LIE, and our SOTA AGR-tailored and AGR-agnstic) on Median

In [3]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:90% !important; }</style>"))

In [1]:
from __future__ import print_function
import argparse, os, sys, csv, shutil, time, random, operator, pickle, ast, math
import numpy as np
import pandas as pd
from torch.optim import Optimizer
import torch.nn.functional as F
import torch
import pickle
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data as data
import torch.multiprocessing as mp

sys.path.insert(0,'./../utils/')
from logger import *
from eval import *
from misc import *

from cifar10_normal_train import *
from cifar10_util import *
from adam import Adam
from sgd import SGD

## Get CIFAR10 data and split it in IID fashion

In [4]:
import torchvision.transforms as transforms
import torchvision.datasets as datasets
data_loc='/mnt/nfs/work1/amir/vshejwalkar/cifar10_data/'
# load the train dataset

train_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

cifar10_train = datasets.CIFAR10(root=data_loc, train=True, download=True, transform=train_transform)

cifar10_test = datasets.CIFAR10(root=data_loc, train=False, download=True, transform=train_transform)

X=[]
Y=[]
for i in range(len(cifar10_train)):
    X.append(cifar10_train[i][0].numpy())
    Y.append(cifar10_train[i][1])

for i in range(len(cifar10_test)):
    X.append(cifar10_test[i][0].numpy())
    Y.append(cifar10_test[i][1])

X=np.array(X)
Y=np.array(Y)

print('total data len: ',len(X))

if not os.path.isfile('./cifar10_shuffle.pkl'):
    all_indices = np.arange(len(X))
    np.random.shuffle(all_indices)
    pickle.dump(all_indices,open('./cifar10_shuffle.pkl','wb'))
else:
    all_indices=pickle.load(open('./cifar10_shuffle.pkl','rb'))

X=X[all_indices]
Y=Y[all_indices]

Files already downloaded and verified
Files already downloaded and verified


In [6]:
# data loading

nusers=50
user_tr_len=1000

total_tr_len=user_tr_len*nusers
val_len=5000
te_len=5000

print('total data len: ',len(X))

if not os.path.isfile('./cifar10_shuffle.pkl'):
    all_indices = np.arange(len(X))
    np.random.shuffle(all_indices)
    pickle.dump(all_indices,open('./cifar10_shuffle.pkl','wb'))
else:
    all_indices=pickle.load(open('./cifar10_shuffle.pkl','rb'))

total_tr_data=X[:total_tr_len]
total_tr_label=Y[:total_tr_len]

val_data=X[total_tr_len:(total_tr_len+val_len)]
val_label=Y[total_tr_len:(total_tr_len+val_len)]

te_data=X[(total_tr_len+val_len):(total_tr_len+val_len+te_len)]
te_label=Y[(total_tr_len+val_len):(total_tr_len+val_len+te_len)]

total_tr_data_tensor=torch.from_numpy(total_tr_data).type(torch.FloatTensor)
total_tr_label_tensor=torch.from_numpy(total_tr_label).type(torch.LongTensor)

val_data_tensor=torch.from_numpy(val_data).type(torch.FloatTensor)
val_label_tensor=torch.from_numpy(val_label).type(torch.LongTensor)

te_data_tensor=torch.from_numpy(te_data).type(torch.FloatTensor)
te_label_tensor=torch.from_numpy(te_label).type(torch.LongTensor)

print('total tr len %d | val len %d | test len %d'%(len(total_tr_data_tensor),len(val_data_tensor),len(te_data_tensor)))

#==============================================================================================================

user_tr_data_tensors=[]
user_tr_label_tensors=[]

for i in range(nusers):
    
    user_tr_data_tensor=torch.from_numpy(total_tr_data[user_tr_len*i:user_tr_len*(i+1)]).type(torch.FloatTensor)
    user_tr_label_tensor=torch.from_numpy(total_tr_label[user_tr_len*i:user_tr_len*(i+1)]).type(torch.LongTensor)

    user_tr_data_tensors.append(user_tr_data_tensor)
    user_tr_label_tensors.append(user_tr_label_tensor)
    print('user %d tr len %d'%(i,len(user_tr_data_tensor)))

total data len:  60000
total tr len 50000 | val len 5000 | test len 5000
user 0 tr len 1000
user 1 tr len 1000
user 2 tr len 1000
user 3 tr len 1000
user 4 tr len 1000
user 5 tr len 1000
user 6 tr len 1000
user 7 tr len 1000
user 8 tr len 1000
user 9 tr len 1000
user 10 tr len 1000
user 11 tr len 1000
user 12 tr len 1000
user 13 tr len 1000
user 14 tr len 1000
user 15 tr len 1000
user 16 tr len 1000
user 17 tr len 1000
user 18 tr len 1000
user 19 tr len 1000
user 20 tr len 1000
user 21 tr len 1000
user 22 tr len 1000
user 23 tr len 1000
user 24 tr len 1000
user 25 tr len 1000
user 26 tr len 1000
user 27 tr len 1000
user 28 tr len 1000
user 29 tr len 1000
user 30 tr len 1000
user 31 tr len 1000
user 32 tr len 1000
user 33 tr len 1000
user 34 tr len 1000
user 35 tr len 1000
user 36 tr len 1000
user 37 tr len 1000
user 38 tr len 1000
user 39 tr len 1000
user 40 tr len 1000
user 41 tr len 1000
user 42 tr len 1000
user 43 tr len 1000
user 44 tr len 1000
user 45 tr len 1000
user 46 tr len 10

## Full knowledge Fang attack on median aggregation
### Note that Fang attacks on Trimmed-mean and median are the same

In [7]:
def get_malicious_updates_fang_trmean(all_updates, deviation, n_attackers, epoch_num, compression='none', q_level=2, norm='inf'):
    b = 2
    max_vector = torch.max(all_updates, 0)[0]
    min_vector = torch.min(all_updates, 0)[0]

    max_ = (max_vector > 0).type(torch.FloatTensor).cuda()
    min_ = (min_vector < 0).type(torch.FloatTensor).cuda()

    max_[max_ == 1] = b
    max_[max_ == 0] = 1 / b
    min_[min_ == 1] = b
    min_[min_ == 0] = 1 / b

    max_range = torch.cat((max_vector[:, None], (max_vector * max_)[:, None]), dim=1)
    min_range = torch.cat(((min_vector * min_)[:, None], min_vector[:, None]), dim=1)

    rand = torch.from_numpy(np.random.uniform(0, 1, [len(deviation), n_attackers])).type(torch.FloatTensor).cuda()

    max_rand = torch.stack([max_range[:, 0]] * rand.shape[1]).T + rand * torch.stack([max_range[:, 1] - max_range[:, 0]] * rand.shape[1]).T
    min_rand = torch.stack([min_range[:, 0]] * rand.shape[1]).T + rand * torch.stack([min_range[:, 1] - min_range[:, 0]] * rand.shape[1]).T

    mal_vec = (torch.stack([(deviation > 0).type(torch.FloatTensor)] * max_rand.shape[1]).T.cuda() * max_rand + torch.stack(
        [(deviation > 0).type(torch.FloatTensor)] * min_rand.shape[1]).T.cuda() * min_rand).T

    quant_mal_vec = []
    if compression != 'none':
        if epoch_num == 0: print('compressing malicious update')
        for i in range(mal_vec.shape[0]):
            mal_ = mal_vec[i]
            mal_quant = qsgd(mal_, s=q_level, norm=norm)
            quant_mal_vec = mal_quant[None, :] if not len(quant_mal_vec) else torch.cat((quant_mal_vec, mal_quant[None, :]), 0)
    else:
        quant_mal_vec = mal_vec

    mal_updates = torch.cat((quant_mal_vec, all_updates), 0)

    return mal_updates

In [10]:
batch_size=250
resume=0
nepochs=1200
schedule=[1000]
nbatches = user_tr_len//batch_size

gamma=.5
opt = 'sgd'
fed_lr=0.5
criterion=nn.CrossEntropyLoss()
use_cuda = torch.cuda.is_available()

aggregation='median'
multi_k = False
candidates = []

at_type='fang'
z_values=[0.0]
n_attackers=[10]

arch='alexnet'
chkpt='./'+aggregation

for n_attacker in n_attackers:
    epoch_num = 0
    best_global_acc = 0
    best_global_te_acc = 0

    for z in z_values:
        fed_file='alexnet_checkpoint_%s_%s_%d_%.2f.pth.tar'%(aggregation,at_type,n_attacker,z)
        fed_best_file='alexnet_best_%s_%s_%d_%.2f.pth.tar'%(aggregation,at_type,n_attacker,z)

        if resume:
            fed_checkpoint = chkpt+'/'+fed_file
            assert os.path.isfile(fed_checkpoint), 'Error: no user checkpoint at %s'%(fed_checkpoint)
            checkpoint = torch.load(fed_checkpoint, map_location='cuda:%d'%torch.cuda.current_device())
            fed_model.load_state_dict(checkpoint['state_dict'])
            optimizer_fed.load_state_dict(checkpoint['optimizer'])
            resume = 0
            best_global_acc=checkpoint['best_acc']
            best_global_te_acc=checkpoint['best_te_acc']
            val_loss, val_acc = test(val_data_tensor,val_label_tensor,fed_model,criterion,use_cuda)
            epoch_num += checkpoint['epoch']
            print('resuming from epoch %d | val acc %.4f | best acc %.3f | best te acc %.3f'%(epoch_num, val_acc, best_global_acc, best_global_te_acc))

        torch.cuda.empty_cache()
        r=np.arange(user_tr_len)

        fed_model, _ = return_model(arch, 0.1, 0.9, parallel=False)
        optimizer_fed = SGD(fed_model.parameters(), lr=fed_lr)

        while epoch_num <= nepochs:
            user_grads=[]
            if not epoch_num and epoch_num%nbatches == 0:
                np.random.shuffle(r)
                for i in range(nusers):
                    user_tr_data_tensors[i]=user_tr_data_tensors[i][r]
                    user_tr_label_tensors[i]=user_tr_label_tensors[i][r]

            for i in range(n_attacker, nusers):

                inputs = user_tr_data_tensors[i][(epoch_num%nbatches)*batch_size:((epoch_num%nbatches) + 1) * batch_size]
                targets = user_tr_label_tensors[i][(epoch_num%nbatches)*batch_size:((epoch_num%nbatches) + 1) * batch_size]

                inputs, targets = inputs.cuda(), targets.cuda()
                inputs, targets = torch.autograd.Variable(inputs), torch.autograd.Variable(targets)

                outputs = fed_model(inputs)
                loss = criterion(outputs, targets)
                fed_model.zero_grad()
                loss.backward(retain_graph=True)

                param_grad=[]
                for param in fed_model.parameters():
                    param_grad=param.grad.data.view(-1) if not len(param_grad) else torch.cat((param_grad,param.grad.view(-1)))

                user_grads=param_grad[None, :] if len(user_grads)==0 else torch.cat((user_grads,param_grad[None,:]), 0)

            malicious_grads = user_grads

            if epoch_num in schedule:
                for param_group in optimizer_fed.param_groups:
                    param_group['lr'] *= gamma
                    print('New learnin rate ', param_group['lr'])

            if n_attacker > 0:
                if at_type == 'paf':
                    malicious_grads=get_malicious_predictions_poison_all_far_sign(malicious_grads,nusers,n_attacker)
                elif at_type == 'lie':
                    malicious_grads = get_malicious_updates_lie(malicious_grads, n_attacker, z, epoch_num)
                elif at_type == 'fang':
                    agg_grads = torch.mean(malicious_grads, 0)
                    deviation = torch.sign(agg_grads)
                    malicious_grads = get_malicious_updates_fang_trmean(malicious_grads, deviation, n_attacker, epoch_num)
                elif at_type == 'our':
                    agg_grads = torch.mean(malicious_grads, 0)
                    malicious_grads = our_attack_krum(malicious_grads, agg_grads, n_attacker, compression=compression, q_level=q_level, norm=norm)

            if not epoch_num : 
                print(malicious_grads.shape)
                
            if aggregation=='median':
                agg_grads=torch.median(malicious_grads,dim=0)[0]

            elif aggregation=='average':
                agg_grads=torch.mean(malicious_grads,dim=0)

            elif aggregation=='trmean':
                agg_grads=tr_mean(malicious_grads, n_attacker)

            elif aggregation=='krum' or aggregation=='mkrum':
                multi_k = True if aggregation == 'mkrum' else False
                if epoch_num == 0: print('multi krum is ', multi_k)
                agg_grads, krum_candidate = multi_krum(malicious_grads, n_attacker, multi_k=multi_k)
                
            elif aggregation=='bulyan':
                agg_grads, krum_candidate=bulyan(malicious_grads, n_attacker)

            del user_grads

            start_idx=0

            optimizer_fed.zero_grad()

            model_grads=[]

            for i, param in enumerate(fed_model.parameters()):
                param_=agg_grads[start_idx:start_idx+len(param.data.view(-1))].reshape(param.data.shape)
                start_idx=start_idx+len(param.data.view(-1))
                param_=param_.cuda()
                model_grads.append(param_)

            optimizer_fed.step(model_grads)

            val_loss, val_acc = test(val_data_tensor,val_label_tensor,fed_model,criterion,use_cuda)
            te_loss, te_acc = test(te_data_tensor,te_label_tensor, fed_model, criterion, use_cuda)

            is_best = best_global_acc < val_acc

            best_global_acc = max(best_global_acc, val_acc)

            if is_best:
                best_global_te_acc = te_acc

            if epoch_num%10==0 or epoch_num==nepochs-1:
                print('%s: at %s n_at %d e %d fed_model val loss %.4f val acc %.4f best val_acc %f te_acc %f'%(aggregation, at_type, n_attacker, epoch_num, val_loss, val_acc, best_global_acc,best_global_te_acc))

            if val_loss > 10:
                print('val loss %f too high'%val_loss)
                break
            
            epoch_num+=1

torch.Size([50, 2472266])


	add_(Number alpha, Tensor other)
Consider using one of the following signatures instead:
	add_(Tensor other, *, Number alpha) (Triggered internally at  /opt/conda/conda-bld/pytorch_1603729138878/work/torch/csrc/utils/python_arg_parser.cpp:882.)
  p.data.add_(-group['lr'], d_p)


median: at fang n_at 10 e 0 fed_model val loss 2.3026 val acc 10.5925 best val_acc 10.592532 te_acc 10.146104
median: at fang n_at 10 e 10 fed_model val loss 2.2991 val acc 16.3149 best val_acc 16.375812 te_acc 16.923701
median: at fang n_at 10 e 20 fed_model val loss 2.2894 val acc 19.4602 best val_acc 20.860390 te_acc 21.124188
median: at fang n_at 10 e 30 fed_model val loss 2.3053 val acc 10.6128 best val_acc 20.880682 te_acc 21.245942
median: at fang n_at 10 e 40 fed_model val loss 3.7148 val acc 9.9026 best val_acc 23.072240 te_acc 23.092532
median: at fang n_at 10 e 50 fed_model val loss 2.3023 val acc 10.4708 best val_acc 23.072240 te_acc 23.092532
median: at fang n_at 10 e 60 fed_model val loss 2.2964 val acc 9.2938 best val_acc 23.072240 te_acc 23.092532
median: at fang n_at 10 e 70 fed_model val loss 2.2697 val acc 12.4391 best val_acc 23.072240 te_acc 23.092532
median: at fang n_at 10 e 80 fed_model val loss 2.2708 val acc 14.7119 best val_acc 23.072240 te_acc 23.092532
medi

median: at fang n_at 10 e 740 fed_model val loss 1.4620 val acc 47.1997 best val_acc 48.640422 te_acc 48.417208
median: at fang n_at 10 e 750 fed_model val loss 1.4385 val acc 47.9708 best val_acc 48.640422 te_acc 48.417208
median: at fang n_at 10 e 760 fed_model val loss 1.4782 val acc 46.4692 best val_acc 48.640422 te_acc 48.417208
median: at fang n_at 10 e 770 fed_model val loss 1.4658 val acc 46.6112 best val_acc 48.741883 te_acc 49.127435
median: at fang n_at 10 e 780 fed_model val loss 1.4926 val acc 46.2662 best val_acc 48.741883 te_acc 49.127435
median: at fang n_at 10 e 790 fed_model val loss 1.5236 val acc 44.9067 best val_acc 48.741883 te_acc 49.127435
median: at fang n_at 10 e 800 fed_model val loss 1.4818 val acc 45.5560 best val_acc 48.741883 te_acc 49.127435
median: at fang n_at 10 e 810 fed_model val loss 1.4132 val acc 48.3563 best val_acc 48.741883 te_acc 49.127435
median: at fang n_at 10 e 820 fed_model val loss 1.4821 val acc 45.7386 best val_acc 48.741883 te_acc 49

## Code for LIE attack

In [9]:
def lie_attack(all_updates, z):
    avg = torch.mean(all_updates, dim=0)
    std = torch.std(all_updates, dim=0)
    return avg + z * std

In [11]:
batch_size=250
resume=0
nepochs=1200
schedule=[1000]
nbatches = user_tr_len//batch_size

gamma=.5
opt = 'sgd'
fed_lr=0.5
criterion=nn.CrossEntropyLoss()
use_cuda = torch.cuda.is_available()

aggregation='median'
multi_k = False
candidates = []

at_type='LIE'
z_values={3:0.69847, 5:0.7054, 8:0.71904, 10:0.72575, 12:0.73891}
n_attackers=[10]

arch='alexnet'
chkpt='./'+aggregation

for n_attacker in n_attackers:
    epoch_num = 0
    best_global_acc = 0
    best_global_te_acc = 0

#     fed_file='alexnet_checkpoint_%s_%s_%d_%.2f.pth.tar'%(aggregation,at_type,n_attacker,z)
#     fed_best_file='alexnet_best_%s_%s_%d_%.2f.pth.tar'%(aggregation,at_type,n_attacker,z)

#     if resume:
#         fed_checkpoint = chkpt+'/'+fed_file
#         assert os.path.isfile(fed_checkpoint), 'Error: no user checkpoint at %s'%(fed_checkpoint)
#         checkpoint = torch.load(fed_checkpoint, map_location='cuda:%d'%torch.cuda.current_device())
#         fed_model.load_state_dict(checkpoint['state_dict'])
#         optimizer_fed.load_state_dict(checkpoint['optimizer'])
#         resume = 0
#         best_global_acc=checkpoint['best_acc']
#         best_global_te_acc=checkpoint['best_te_acc']
#         val_loss, val_acc = test(val_data_tensor,val_label_tensor,fed_model,criterion,use_cuda)
#         epoch_num += checkpoint['epoch']
#         print('resuming from epoch %d | val acc %.4f | best acc %.3f | best te acc %.3f'%(epoch_num, val_acc, best_global_acc, best_global_te_acc))

    torch.cuda.empty_cache()
    r=np.arange(user_tr_len)

    fed_model, _ = return_model(arch, 0.1, 0.9, parallel=False)
    optimizer_fed = SGD(fed_model.parameters(), lr=fed_lr)

    while epoch_num <= nepochs:
        user_grads=[]
        if not epoch_num and epoch_num%nbatches == 0:
            np.random.shuffle(r)
            for i in range(nusers):
                user_tr_data_tensors[i]=user_tr_data_tensors[i][r]
                user_tr_label_tensors[i]=user_tr_label_tensors[i][r]

        for i in range(n_attacker, nusers):

            inputs = user_tr_data_tensors[i][(epoch_num%nbatches)*batch_size:((epoch_num%nbatches) + 1) * batch_size]
            targets = user_tr_label_tensors[i][(epoch_num%nbatches)*batch_size:((epoch_num%nbatches) + 1) * batch_size]

            inputs, targets = inputs.cuda(), targets.cuda()
            inputs, targets = torch.autograd.Variable(inputs), torch.autograd.Variable(targets)

            outputs = fed_model(inputs)
            loss = criterion(outputs, targets)
            fed_model.zero_grad()
            loss.backward(retain_graph=True)

            param_grad=[]
            for param in fed_model.parameters():
                param_grad=param.grad.data.view(-1) if not len(param_grad) else torch.cat((param_grad,param.grad.view(-1)))

            user_grads=param_grad[None, :] if len(user_grads)==0 else torch.cat((user_grads,param_grad[None,:]), 0)

        malicious_grads = user_grads

        if epoch_num in schedule:
            for param_group in optimizer_fed.param_groups:
                param_group['lr'] *= gamma
                print('New learnin rate ', param_group['lr'])

        if n_attacker > 0:
            if at_type == 'lie':
                mal_update = lie_attack(malicious_grads, z_values[n_attacker])
                malicious_grads = torch.cat((torch.stack([mal_update]*n_attacker), malicious_grads))
            elif at_type == 'fang':
                agg_grads = torch.mean(malicious_grads, 0)
                deviation = torch.sign(agg_grads)
                malicious_grads = get_malicious_updates_fang_trmean(malicious_grads, deviation, n_attacker, epoch_num)
            elif at_type == 'our-agr':
                agg_grads = torch.mean(malicious_grads, 0)
                malicious_grads = our_attack_krum(malicious_grads, agg_grads, n_attacker, compression=compression, q_level=q_level, norm=norm)

        if not epoch_num : 
            print(malicious_grads.shape)

        if aggregation=='median':
            agg_grads=torch.median(malicious_grads,dim=0)[0]

        elif aggregation=='average':
            agg_grads=torch.mean(malicious_grads,dim=0)

        elif aggregation=='trmean':
            agg_grads=tr_mean(malicious_grads, n_attacker)

        elif aggregation=='krum' or aggregation=='mkrum':
            multi_k = True if aggregation == 'mkrum' else False
            if epoch_num == 0: print('multi krum is ', multi_k)
            agg_grads, krum_candidate = multi_krum(malicious_grads, n_attacker, multi_k=multi_k)

        elif aggregation=='bulyan':
            agg_grads, krum_candidate=bulyan(malicious_grads, n_attacker)

        del user_grads

        start_idx=0

        optimizer_fed.zero_grad()

        model_grads=[]

        for i, param in enumerate(fed_model.parameters()):
            param_=agg_grads[start_idx:start_idx+len(param.data.view(-1))].reshape(param.data.shape)
            start_idx=start_idx+len(param.data.view(-1))
            param_=param_.cuda()
            model_grads.append(param_)

        optimizer_fed.step(model_grads)

        val_loss, val_acc = test(val_data_tensor,val_label_tensor,fed_model,criterion,use_cuda)
        te_loss, te_acc = test(te_data_tensor,te_label_tensor, fed_model, criterion, use_cuda)

        is_best = best_global_acc < val_acc

        best_global_acc = max(best_global_acc, val_acc)

        if is_best:
            best_global_te_acc = te_acc

        if epoch_num%10==0 or epoch_num==nepochs-1:
            print('%s: at %s n_at %d e %d fed_model val loss %.4f val acc %.4f best val_acc %f te_acc %f'%(aggregation, at_type, n_attacker, epoch_num, val_loss, val_acc, best_global_acc,best_global_te_acc))

        if val_loss > 10:
            print('val loss %f too high'%val_loss)
            break

        epoch_num+=1

torch.Size([40, 2472266])


	add_(Number alpha, Tensor other)
Consider using one of the following signatures instead:
	add_(Tensor other, *, Number alpha) (Triggered internally at  /opt/conda/conda-bld/pytorch_1603729138878/work/torch/csrc/utils/python_arg_parser.cpp:882.)
  p.data.add_(-group['lr'], d_p)


median: at LIE n_at 10 e 0 fed_model val loss 2.3035 val acc 9.5982 best val_acc 9.598214 te_acc 9.577922
median: at LIE n_at 10 e 10 fed_model val loss 2.2962 val acc 13.7378 best val_acc 13.879870 te_acc 13.940747
median: at LIE n_at 10 e 20 fed_model val loss 2.2283 val acc 18.3036 best val_acc 20.698052 te_acc 20.515422
median: at LIE n_at 10 e 30 fed_model val loss 2.2397 val acc 12.5203 best val_acc 20.698052 te_acc 20.515422
median: at LIE n_at 10 e 40 fed_model val loss 2.6411 val acc 9.5576 best val_acc 23.478084 te_acc 23.498377
median: at LIE n_at 10 e 50 fed_model val loss 2.3050 val acc 9.9026 best val_acc 23.478084 te_acc 23.498377
median: at LIE n_at 10 e 60 fed_model val loss 2.3034 val acc 9.9026 best val_acc 23.478084 te_acc 23.498377
median: at LIE n_at 10 e 70 fed_model val loss 2.3024 val acc 9.9026 best val_acc 23.478084 te_acc 23.498377
median: at LIE n_at 10 e 80 fed_model val loss 2.3003 val acc 11.9927 best val_acc 23.478084 te_acc 23.498377
median: at LIE n_a

median: at LIE n_at 10 e 740 fed_model val loss 1.3391 val acc 51.2581 best val_acc 55.864448 te_acc 55.823864
median: at LIE n_at 10 e 750 fed_model val loss 1.4453 val acc 48.7013 best val_acc 55.965909 te_acc 56.493506
median: at LIE n_at 10 e 760 fed_model val loss 1.2891 val acc 54.5657 best val_acc 55.965909 te_acc 56.493506
median: at LIE n_at 10 e 770 fed_model val loss 1.4783 val acc 47.4026 best val_acc 55.965909 te_acc 56.493506
median: at LIE n_at 10 e 780 fed_model val loss 1.3229 val acc 52.1307 best val_acc 55.965909 te_acc 56.493506
median: at LIE n_at 10 e 790 fed_model val loss 1.2288 val acc 55.7630 best val_acc 55.965909 te_acc 56.493506
median: at LIE n_at 10 e 800 fed_model val loss 1.4126 val acc 50.1826 best val_acc 56.818182 te_acc 57.183442
median: at LIE n_at 10 e 810 fed_model val loss 1.2332 val acc 55.9862 best val_acc 56.818182 te_acc 57.183442
median: at LIE n_at 10 e 820 fed_model val loss 1.2230 val acc 56.2500 best val_acc 56.818182 te_acc 57.183442
m

## Code for our AGR-tailored attack on Median

In [18]:
def our_attack_median(all_updates, model_re, n_attackers, dev_type='unit_vec'):

    if dev_type == 'unit_vec':
        deviation = model_re / torch.norm(model_re)  # unit vector, dir opp to good dir
    elif dev_type == 'sign':
        deviation = torch.sign(model_re)
    elif dev_type == 'std':
        deviation = torch.std(all_updates, 0)

    lamda = torch.Tensor([10.0]).cuda() #compute_lambda_our(all_updates, model_re, n_attackers)

    threshold_diff = 1e-5
    prev_loss = -1
    lamda_fail = lamda
    lamda_succ = 0
    iters = 0 
    while torch.abs(lamda_succ - lamda) > threshold_diff:
        mal_update = (model_re - lamda * deviation)
        mal_updates = torch.stack([mal_update] * n_attackers)
        mal_updates = torch.cat((mal_updates, all_updates), 0)

        agg_grads = torch.median(mal_updates, 0)[0]
        
        loss = torch.norm(agg_grads - model_re)
        
        if prev_loss < loss:
            lamda_succ = lamda
            lamda = lamda + lamda_fail / 2
        else:
            lamda = lamda - lamda_fail / 2

        lamda_fail = lamda_fail / 2
        prev_loss = loss
        
    mal_update = (model_re - lamda_succ * deviation)
    mal_updates = torch.stack([mal_update] * n_attackers)
    mal_updates = torch.cat((mal_updates, all_updates), 0)

    return mal_updates

In [19]:
batch_size=250
resume=0
nepochs=1200
schedule=[1000]
nbatches = user_tr_len//batch_size

gamma=.5
opt = 'sgd'
fed_lr=0.5
criterion=nn.CrossEntropyLoss()
use_cuda = torch.cuda.is_available()

aggregation='median'
multi_k = False
candidates = []

at_type='our-agr'
dev_type = 'std'
z_values={3:0.69847, 5:0.7054, 8:0.71904, 10:0.72575, 12:0.73891}
n_attackers=[10]

arch='alexnet'
chkpt='./'+aggregation

for n_attacker in n_attackers:
    epoch_num = 0
    best_global_acc = 0
    best_global_te_acc = 0

    torch.cuda.empty_cache()
    r=np.arange(user_tr_len)

    fed_model, _ = return_model(arch, 0.1, 0.9, parallel=False)
    optimizer_fed = SGD(fed_model.parameters(), lr=fed_lr)

    while epoch_num <= nepochs:
        user_grads=[]
        if not epoch_num and epoch_num%nbatches == 0:
            np.random.shuffle(r)
            for i in range(nusers):
                user_tr_data_tensors[i]=user_tr_data_tensors[i][r]
                user_tr_label_tensors[i]=user_tr_label_tensors[i][r]

        for i in range(n_attacker, nusers):

            inputs = user_tr_data_tensors[i][(epoch_num%nbatches)*batch_size:((epoch_num%nbatches) + 1) * batch_size]
            targets = user_tr_label_tensors[i][(epoch_num%nbatches)*batch_size:((epoch_num%nbatches) + 1) * batch_size]

            inputs, targets = inputs.cuda(), targets.cuda()
            inputs, targets = torch.autograd.Variable(inputs), torch.autograd.Variable(targets)

            outputs = fed_model(inputs)
            loss = criterion(outputs, targets)
            fed_model.zero_grad()
            loss.backward(retain_graph=True)

            param_grad=[]
            for param in fed_model.parameters():
                param_grad=param.grad.data.view(-1) if not len(param_grad) else torch.cat((param_grad,param.grad.view(-1)))

            user_grads=param_grad[None, :] if len(user_grads)==0 else torch.cat((user_grads,param_grad[None,:]), 0)

        malicious_grads = user_grads

        if epoch_num in schedule:
            for param_group in optimizer_fed.param_groups:
                param_group['lr'] *= gamma
                print('New learnin rate ', param_group['lr'])

        if n_attacker > 0:
            if at_type == 'lie':
                mal_update = lie_attack(malicious_grads, z_values[n_attacker])
                malicious_grads = torch.cat((torch.stack([mal_update]*n_attacker), malicious_grads))
            elif at_type == 'fang':
                agg_grads = torch.mean(malicious_grads, 0)
                deviation = torch.sign(agg_grads)
                malicious_grads = get_malicious_updates_fang_trmean(malicious_grads, deviation, n_attacker, epoch_num)
            elif at_type == 'our-agr':
                agg_grads = torch.mean(malicious_grads, 0)
                malicious_grads = our_attack_median(malicious_grads, agg_grads, n_attacker, dev_type=dev_type)

        if not epoch_num : 
            print(malicious_grads.shape)

        if aggregation=='median':
            agg_grads=torch.median(malicious_grads,dim=0)[0]

        elif aggregation=='average':
            agg_grads=torch.mean(malicious_grads,dim=0)

        elif aggregation=='trmean':
            agg_grads=tr_mean(malicious_grads, n_attacker)

        elif aggregation=='krum' or aggregation=='mkrum':
            multi_k = True if aggregation == 'mkrum' else False
            if epoch_num == 0: print('multi krum is ', multi_k)
            agg_grads, krum_candidate = multi_krum(malicious_grads, n_attacker, multi_k=multi_k)

        elif aggregation=='bulyan':
            agg_grads, krum_candidate=bulyan(malicious_grads, n_attacker)

        del user_grads

        start_idx=0

        optimizer_fed.zero_grad()

        model_grads=[]

        for i, param in enumerate(fed_model.parameters()):
            param_=agg_grads[start_idx:start_idx+len(param.data.view(-1))].reshape(param.data.shape)
            start_idx=start_idx+len(param.data.view(-1))
            param_=param_.cuda()
            model_grads.append(param_)

        optimizer_fed.step(model_grads)

        val_loss, val_acc = test(val_data_tensor,val_label_tensor,fed_model,criterion,use_cuda)
        te_loss, te_acc = test(te_data_tensor,te_label_tensor, fed_model, criterion, use_cuda)

        is_best = best_global_acc < val_acc

        best_global_acc = max(best_global_acc, val_acc)

        if is_best:
            best_global_te_acc = te_acc

        if epoch_num%10==0 or epoch_num==nepochs-1:
            print('%s: at %s n_at %d e %d fed_model val loss %.4f val acc %.4f best val_acc %f te_acc %f'%(aggregation, at_type, n_attacker, epoch_num, val_loss, val_acc, best_global_acc,best_global_te_acc))

        if val_loss > 10:
            print('val loss %f too high'%val_loss)
            break

        epoch_num+=1

torch.Size([50, 2472266])
median: at our-agr n_at 10 e 0 fed_model val loss 2.3026 val acc 9.6794 best val_acc 9.679383 te_acc 10.064935
median: at our-agr n_at 10 e 10 fed_model val loss 2.2910 val acc 14.0219 best val_acc 14.021916 te_acc 14.184253
median: at our-agr n_at 10 e 20 fed_model val loss 2.3668 val acc 11.8101 best val_acc 20.819805 te_acc 20.616883
median: at our-agr n_at 10 e 30 fed_model val loss 2.2703 val acc 16.6599 best val_acc 20.819805 te_acc 20.616883
median: at our-agr n_at 10 e 40 fed_model val loss 2.3029 val acc 10.5114 best val_acc 22.666396 te_acc 23.011364
median: at our-agr n_at 10 e 50 fed_model val loss 2.2392 val acc 10.8766 best val_acc 22.666396 te_acc 23.011364
median: at our-agr n_at 10 e 60 fed_model val loss 2.1888 val acc 21.3271 best val_acc 22.666396 te_acc 23.011364
median: at our-agr n_at 10 e 70 fed_model val loss 2.2922 val acc 14.2451 best val_acc 22.666396 te_acc 23.011364
median: at our-agr n_at 10 e 80 fed_model val loss 2.2073 val acc

## Code for our first AGR-agnostic attack - Min-max

In [20]:
'''
MIN-MAX attack
'''
def our_attack_dist(all_updates, model_re, n_attackers, dev_type='unit_vec'):

    if dev_type == 'unit_vec':
        deviation = model_re / torch.norm(model_re)  # unit vector, dir opp to good dir
    elif dev_type == 'sign':
        deviation = torch.sign(model_re)
    elif dev_type == 'std':
        deviation = torch.std(all_updates, 0)

    lamda = torch.Tensor([10.0]).float().cuda()
    # print(lamda)
    threshold_diff = 1e-5
    lamda_fail = lamda
    lamda_succ = 0
    
    distances = []
    for update in all_updates:
        distance = torch.norm((all_updates - update), dim=1) ** 2
        distances = distance[None, :] if not len(distances) else torch.cat((distances, distance[None, :]), 0)
    
    max_distance = torch.max(distances)
    del distances

    while torch.abs(lamda_succ - lamda) > threshold_diff:
        mal_update = (model_re - lamda * deviation)
        distance = torch.norm((all_updates - mal_update), dim=1) ** 2
        max_d = torch.max(distance)
        
        if max_d <= max_distance:
            # print('successful lamda is ', lamda)
            lamda_succ = lamda
            lamda = lamda + lamda_fail / 2
        else:
            lamda = lamda - lamda_fail / 2

        lamda_fail = lamda_fail / 2

    mal_update = (model_re - lamda_succ * deviation)
    
    return mal_update

In [21]:
batch_size=250
resume=0
nepochs=1200
schedule=[1000]
nbatches = user_tr_len//batch_size

gamma=.5
opt = 'sgd'
fed_lr=0.5
criterion=nn.CrossEntropyLoss()
use_cuda = torch.cuda.is_available()

aggregation='median'
multi_k = False
candidates = []

at_type='min-max'
dev_type ='std'
z=0
n_attackers=[10]

arch='alexnet'
chkpt='./'+aggregation

for n_attacker in n_attackers:
    candidates = []

    epoch_num = 0
    best_global_acc = 0
    best_global_te_acc = 0

    fed_model, _ = return_model(arch, 0.1, 0.9, parallel=False)
    optimizer_fed = SGD(fed_model.parameters(), lr=fed_lr)

    torch.cuda.empty_cache()
    r=np.arange(user_tr_len)

    while epoch_num <= nepochs:
        user_grads=[]
        if not epoch_num and epoch_num%nbatches == 0:
            np.random.shuffle(r)
            for i in range(nusers):
                user_tr_data_tensors[i]=user_tr_data_tensors[i][r]
                user_tr_label_tensors[i]=user_tr_label_tensors[i][r]

        for i in range(n_attacker, nusers):

            inputs = user_tr_data_tensors[i][(epoch_num%nbatches)*batch_size:((epoch_num%nbatches) + 1) * batch_size]
            targets = user_tr_label_tensors[i][(epoch_num%nbatches)*batch_size:((epoch_num%nbatches) + 1) * batch_size]

            inputs, targets = inputs.cuda(), targets.cuda()
            inputs, targets = torch.autograd.Variable(inputs), torch.autograd.Variable(targets)

            outputs = fed_model(inputs)
            loss = criterion(outputs, targets)
            fed_model.zero_grad()
            loss.backward(retain_graph=True)

            param_grad=[]
            for param in fed_model.parameters():
                param_grad=param.grad.data.view(-1) if not len(param_grad) else torch.cat((param_grad,param.grad.view(-1)))

            user_grads=param_grad[None, :] if len(user_grads)==0 else torch.cat((user_grads,param_grad[None,:]), 0)

        malicious_grads = user_grads

        if epoch_num in schedule:
            for param_group in optimizer_fed.param_groups:
                param_group['lr'] *= gamma
                print('New learnin rate ', param_group['lr'])

        if n_attacker > 0:
            if at_type == 'lie':
                malicious_grads = get_malicious_updates_lie(malicious_grads, n_attacker, z, epoch_num)
            elif at_type == 'fang':
                agg_grads = torch.mean(malicious_grads, 0)
                deviation = torch.sign(agg_grads)
                malicious_grads = get_malicious_updates_fang(malicious_grads, agg_grads, deviation, n_attacker)
            elif at_type == 'our-agr':
                agg_grads = torch.mean(malicious_grads, 0)
                mal_update = our_attack_median(malicious_grads, agg_grads, n_attacker, dev_type)
            elif at_type == 'min-max':
                agg_grads = torch.mean(malicious_grads, 0)
                mal_update = our_attack_dist(malicious_grads, agg_grads, n_attacker, dev_type)
            elif at_type == 'min-sum':
                agg_grads = torch.mean(malicious_grads, 0)
                mal_update = our_attack_score(malicious_grads, agg_grads, n_attacker, dev_type)

            mal_updates = torch.stack([mal_update] * n_attacker)
            malicious_grads = torch.cat((mal_updates, user_grads), 0)

        if epoch_num==0: print('malicious_grads shape ', malicious_grads.shape)

        if aggregation=='median':
            agg_grads=torch.median(malicious_grads,dim=0)[0]

        elif aggregation=='average':
            agg_grads=torch.mean(malicious_grads,dim=0)

        elif aggregation=='krum' or aggregation=='mkrum':
            multi_k = True if aggregation == 'mkrum' else False
            if epoch_num == 0: print('multi krum is ', multi_k)
            agg_grads, krum_candidate = multi_krum(malicious_grads, n_attacker, multi_k=multi_k)

        elif aggregation=='bulyan':
            agg_grads,bulyan_candidate=bulyan(malicious_grads, n_attacker)

        del user_grads

        start_idx=0

        optimizer_fed.zero_grad()

        model_grads=[]

        for i, param in enumerate(fed_model.parameters()):
            param_=agg_grads[start_idx:start_idx+len(param.data.view(-1))].reshape(param.data.shape)
            start_idx=start_idx+len(param.data.view(-1))
            param_=param_.cuda()
            model_grads.append(param_)

        optimizer_fed.step(model_grads)

        val_loss, val_acc = test(val_data_tensor,val_label_tensor,fed_model,criterion,use_cuda)
        te_loss, te_acc = test(te_data_tensor,te_label_tensor, fed_model, criterion, use_cuda)

        is_best = best_global_acc < val_acc

        best_global_acc = max(best_global_acc, val_acc)

        if is_best:
            best_global_te_acc = te_acc

        if epoch_num%10==0 or epoch_num==nepochs-1:
            print('%s: at %s n_at %d | e %d fed_model val loss %.4f val acc %.4f best val_acc %f te_acc %f'%(aggregation, at_type, n_attacker, epoch_num, val_loss, val_acc, best_global_acc,best_global_te_acc))

        if val_loss > 1000:
            print('val loss %f too high'%val_loss)
            break
            
        epoch_num+=1

malicious_grads shape  torch.Size([50, 2472266])
median: at min-max n_at 10 | e 0 fed_model val loss 2.3030 val acc 9.9432 best val_acc 9.943182 te_acc 9.253247
median: at min-max n_at 10 | e 10 fed_model val loss 2.2905 val acc 12.6420 best val_acc 14.894481 te_acc 16.193182
median: at min-max n_at 10 | e 20 fed_model val loss 2.2952 val acc 11.8912 best val_acc 19.805195 te_acc 19.947240
median: at min-max n_at 10 | e 30 fed_model val loss 2.3011 val acc 12.7232 best val_acc 21.530032 te_acc 21.875000
median: at min-max n_at 10 | e 40 fed_model val loss 2.1932 val acc 16.7208 best val_acc 21.530032 te_acc 21.875000
median: at min-max n_at 10 | e 50 fed_model val loss 2.3039 val acc 10.8969 best val_acc 21.530032 te_acc 21.875000
median: at min-max n_at 10 | e 60 fed_model val loss 2.2325 val acc 14.8133 best val_acc 21.530032 te_acc 21.875000
median: at min-max n_at 10 | e 70 fed_model val loss 2.1430 val acc 18.7703 best val_acc 21.530032 te_acc 21.875000
median: at min-max n_at 10 

median: at min-max n_at 10 | e 700 fed_model val loss 2.2635 val acc 15.7265 best val_acc 28.226461 te_acc 27.232143
median: at min-max n_at 10 | e 710 fed_model val loss 2.2825 val acc 13.0885 best val_acc 28.226461 te_acc 27.232143
median: at min-max n_at 10 | e 720 fed_model val loss 2.1243 val acc 20.2719 best val_acc 28.226461 te_acc 27.232143
median: at min-max n_at 10 | e 730 fed_model val loss 2.1638 val acc 19.2776 best val_acc 28.226461 te_acc 27.232143
median: at min-max n_at 10 | e 740 fed_model val loss 2.1576 val acc 17.9586 best val_acc 28.226461 te_acc 27.232143
median: at min-max n_at 10 | e 750 fed_model val loss 2.0607 val acc 21.6112 best val_acc 28.226461 te_acc 27.232143
median: at min-max n_at 10 | e 760 fed_model val loss 2.0455 val acc 22.6258 best val_acc 28.226461 te_acc 27.232143
median: at min-max n_at 10 | e 770 fed_model val loss 2.1343 val acc 18.4659 best val_acc 28.226461 te_acc 27.232143
median: at min-max n_at 10 | e 780 fed_model val loss 2.1502 val

## Code for our second AGR-agnostic attack - Min-sum

In [22]:
'''
MIN-SUM attack
'''

def our_attack_score(all_updates, model_re, n_attackers, dev_type='unit_vec'):

    if dev_type == 'unit_vec':
        deviation = model_re / torch.norm(model_re)  # unit vector, dir opp to good dir
    elif dev_type == 'sign':
        deviation = torch.sign(model_re)
    elif dev_type == 'std':
        deviation = torch.std(all_updates, 0)
    
    lamda = torch.Tensor([10.0]).float().cuda()
    # print(lamda)
    threshold_diff = 1e-5
    lamda_fail = lamda
    lamda_succ = 0
    
    distances = []
    for update in all_updates:
        distance = torch.norm((all_updates - update), dim=1) ** 2
        distances = distance[None, :] if not len(distances) else torch.cat((distances, distance[None, :]), 0)
    
    scores = torch.sum(distances, dim=1)
    min_score = torch.min(scores)
    del distances

    while torch.abs(lamda_succ - lamda) > threshold_diff:
        mal_update = (model_re - lamda * deviation)
        distance = torch.norm((all_updates - mal_update), dim=1) ** 2
        score = torch.sum(distance)
        
        if score <= min_score:
            # print('successful lamda is ', lamda)
            lamda_succ = lamda
            lamda = lamda + lamda_fail / 2
        else:
            lamda = lamda - lamda_fail / 2

        lamda_fail = lamda_fail / 2

    # print(lamda_succ)
    mal_update = (model_re - lamda_succ * deviation)
    
    return mal_update
    

In [23]:
batch_size=250
resume=0
nepochs=1200
schedule=[1000]
nbatches = user_tr_len//batch_size

gamma=.5
opt = 'sgd'
fed_lr=0.5
criterion=nn.CrossEntropyLoss()
use_cuda = torch.cuda.is_available()

aggregation='median'

at_type='min-sum'
dev_type ='std'
z=0
n_attackers=[10]

arch='alexnet'
chkpt='./'+aggregation

for n_attacker in n_attackers:
    candidates = []

    epoch_num = 0
    best_global_acc = 0
    best_global_te_acc = 0

    fed_model, _ = return_model(arch, 0.1, 0.9, parallel=False)
    optimizer_fed = SGD(fed_model.parameters(), lr=fed_lr)

    torch.cuda.empty_cache()
    r=np.arange(user_tr_len)

    while epoch_num <= nepochs:
        user_grads=[]
        if not epoch_num and epoch_num%nbatches == 0:
            np.random.shuffle(r)
            for i in range(nusers):
                user_tr_data_tensors[i]=user_tr_data_tensors[i][r]
                user_tr_label_tensors[i]=user_tr_label_tensors[i][r]

        for i in range(n_attacker, nusers):

            inputs = user_tr_data_tensors[i][(epoch_num%nbatches)*batch_size:((epoch_num%nbatches) + 1) * batch_size]
            targets = user_tr_label_tensors[i][(epoch_num%nbatches)*batch_size:((epoch_num%nbatches) + 1) * batch_size]

            inputs, targets = inputs.cuda(), targets.cuda()
            inputs, targets = torch.autograd.Variable(inputs), torch.autograd.Variable(targets)

            outputs = fed_model(inputs)
            loss = criterion(outputs, targets)
            fed_model.zero_grad()
            loss.backward(retain_graph=True)

            param_grad=[]
            for param in fed_model.parameters():
                param_grad=param.grad.data.view(-1) if not len(param_grad) else torch.cat((param_grad,param.grad.view(-1)))

            user_grads=param_grad[None, :] if len(user_grads)==0 else torch.cat((user_grads,param_grad[None,:]), 0)

        malicious_grads = user_grads

        if epoch_num in schedule:
            for param_group in optimizer_fed.param_groups:
                param_group['lr'] *= gamma
                print('New learnin rate ', param_group['lr'])

        if n_attacker > 0:
            if at_type == 'lie':
                malicious_grads = get_malicious_updates_lie(malicious_grads, n_attacker, z, epoch_num)
            elif at_type == 'fang':
                agg_grads = torch.mean(malicious_grads, 0)
                deviation = torch.sign(agg_grads)
                malicious_grads = get_malicious_updates_fang(malicious_grads, agg_grads, deviation, n_attacker)
            elif at_type == 'our-agr':
                agg_grads = torch.mean(malicious_grads, 0)
                mal_update = our_attack_median(malicious_grads, agg_grads, n_attacker, dev_type)
            elif at_type == 'min-max':
                agg_grads = torch.mean(malicious_grads, 0)
                mal_update = our_attack_dist(malicious_grads, agg_grads, n_attacker, dev_type)
            elif at_type == 'min-sum':
                agg_grads = torch.mean(malicious_grads, 0)
                mal_update = our_attack_score(malicious_grads, agg_grads, n_attacker, dev_type)

            mal_updates = torch.stack([mal_update] * n_attacker)
            malicious_grads = torch.cat((mal_updates, user_grads), 0)

        if epoch_num==0: print('malicious_grads shape ', malicious_grads.shape)

        if aggregation=='median':
            agg_grads=torch.median(malicious_grads,dim=0)[0]

        elif aggregation=='average':
            agg_grads=torch.mean(malicious_grads,dim=0)

        elif aggregation=='krum' or aggregation=='mkrum':
            multi_k = True if aggregation == 'mkrum' else False
            if epoch_num == 0: print('multi krum is ', multi_k)
            agg_grads, krum_candidate = multi_krum(malicious_grads, n_attacker, multi_k=multi_k)

        elif aggregation=='bulyan':
            agg_grads,bulyan_candidate=bulyan(malicious_grads, n_attacker)

        del user_grads

        start_idx=0

        optimizer_fed.zero_grad()

        model_grads=[]

        for i, param in enumerate(fed_model.parameters()):
            param_=agg_grads[start_idx:start_idx+len(param.data.view(-1))].reshape(param.data.shape)
            start_idx=start_idx+len(param.data.view(-1))
            param_=param_.cuda()
            model_grads.append(param_)

        optimizer_fed.step(model_grads)

        val_loss, val_acc = test(val_data_tensor,val_label_tensor,fed_model,criterion,use_cuda)
        te_loss, te_acc = test(te_data_tensor,te_label_tensor, fed_model, criterion, use_cuda)

        is_best = best_global_acc < val_acc

        best_global_acc = max(best_global_acc, val_acc)

        if is_best:
            best_global_te_acc = te_acc

        if epoch_num%10==0 or epoch_num==nepochs-1:
            print('%s: at %s n_at %d | e %d fed_model val loss %.4f val acc %.4f best val_acc %f te_acc %f'%(aggregation, at_type, n_attacker, epoch_num, val_loss, val_acc, best_global_acc,best_global_te_acc))

        if val_loss > 1000:
            print('val loss %f too high'%val_loss)
            break
            
        epoch_num+=1

malicious_grads shape  torch.Size([50, 2472266])
median: at min-sum n_at 10 | e 0 fed_model val loss 2.3024 val acc 9.8011 best val_acc 9.801136 te_acc 10.491071
median: at min-sum n_at 10 | e 10 fed_model val loss 2.2846 val acc 16.3758 best val_acc 18.790584 te_acc 19.419643
median: at min-sum n_at 10 | e 20 fed_model val loss 2.3058 val acc 9.9432 best val_acc 22.321429 te_acc 21.672078
median: at min-sum n_at 10 | e 30 fed_model val loss 2.2668 val acc 9.9432 best val_acc 22.321429 te_acc 21.672078
median: at min-sum n_at 10 | e 40 fed_model val loss 2.2428 val acc 15.5438 best val_acc 22.321429 te_acc 21.672078
median: at min-sum n_at 10 | e 50 fed_model val loss 2.2125 val acc 16.8222 best val_acc 22.321429 te_acc 21.672078
median: at min-sum n_at 10 | e 60 fed_model val loss 2.1811 val acc 18.6485 best val_acc 22.321429 te_acc 21.672078
median: at min-sum n_at 10 | e 70 fed_model val loss 2.2008 val acc 16.2946 best val_acc 22.321429 te_acc 21.672078
median: at min-sum n_at 10 |

median: at min-sum n_at 10 | e 700 fed_model val loss 2.3026 val acc 9.9432 best val_acc 28.733766 te_acc 27.658279
median: at min-sum n_at 10 | e 710 fed_model val loss 2.3026 val acc 10.3693 best val_acc 28.733766 te_acc 27.658279
median: at min-sum n_at 10 | e 720 fed_model val loss 2.3026 val acc 9.9432 best val_acc 28.733766 te_acc 27.658279
median: at min-sum n_at 10 | e 730 fed_model val loss 2.3026 val acc 10.3693 best val_acc 28.733766 te_acc 27.658279
median: at min-sum n_at 10 | e 740 fed_model val loss 2.3026 val acc 9.9432 best val_acc 28.733766 te_acc 27.658279
median: at min-sum n_at 10 | e 750 fed_model val loss 2.3026 val acc 10.3693 best val_acc 28.733766 te_acc 27.658279
median: at min-sum n_at 10 | e 760 fed_model val loss 2.3026 val acc 9.9432 best val_acc 28.733766 te_acc 27.658279
median: at min-sum n_at 10 | e 770 fed_model val loss 2.3026 val acc 10.3693 best val_acc 28.733766 te_acc 27.658279
median: at min-sum n_at 10 | e 780 fed_model val loss 2.3026 val acc