# The notebook contains
### Code for _Trimmed-mean_ aggregation algorithm
### Evaluation of all of the attacks (Fang, LIE, and our SOTA AGR-tailored and AGR-agnstic) on Trimmed-mean

In [3]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:90% !important; }</style>"))

In [1]:
from __future__ import print_function
import argparse, os, sys, csv, shutil, time, random, operator, pickle, ast, math
import numpy as np
import pandas as pd
from torch.optim import Optimizer
import torch.nn.functional as F
import torch
import pickle
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data as data
import torch.multiprocessing as mp

sys.path.insert(0,'./../utils/')
from logger import *
from eval import *
from misc import *

from cifar10_normal_train import *
from cifar10_util import *
from adam import Adam
from sgd import SGD

## Get CIFAR10 data and split it in IID fashion

In [4]:
import torchvision.transforms as transforms
import torchvision.datasets as datasets
data_loc='/mnt/nfs/work1/amir/vshejwalkar/cifar10_data/'
# load the train dataset

train_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

cifar10_train = datasets.CIFAR10(root=data_loc, train=True, download=True, transform=train_transform)

cifar10_test = datasets.CIFAR10(root=data_loc, train=False, download=True, transform=train_transform)

X=[]
Y=[]
for i in range(len(cifar10_train)):
    X.append(cifar10_train[i][0].numpy())
    Y.append(cifar10_train[i][1])

for i in range(len(cifar10_test)):
    X.append(cifar10_test[i][0].numpy())
    Y.append(cifar10_test[i][1])

X=np.array(X)
Y=np.array(Y)

print('total data len: ',len(X))

if not os.path.isfile('./cifar10_shuffle.pkl'):
    all_indices = np.arange(len(X))
    np.random.shuffle(all_indices)
    pickle.dump(all_indices,open('./cifar10_shuffle.pkl','wb'))
else:
    all_indices=pickle.load(open('./cifar10_shuffle.pkl','rb'))

X=X[all_indices]
Y=Y[all_indices]

Files already downloaded and verified
Files already downloaded and verified


In [6]:
# data loading

nusers=50
user_tr_len=1000

total_tr_len=user_tr_len*nusers
val_len=5000
te_len=5000

print('total data len: ',len(X))

if not os.path.isfile('./cifar10_shuffle.pkl'):
    all_indices = np.arange(len(X))
    np.random.shuffle(all_indices)
    pickle.dump(all_indices,open('./cifar10_shuffle.pkl','wb'))
else:
    all_indices=pickle.load(open('./cifar10_shuffle.pkl','rb'))

total_tr_data=X[:total_tr_len]
total_tr_label=Y[:total_tr_len]

val_data=X[total_tr_len:(total_tr_len+val_len)]
val_label=Y[total_tr_len:(total_tr_len+val_len)]

te_data=X[(total_tr_len+val_len):(total_tr_len+val_len+te_len)]
te_label=Y[(total_tr_len+val_len):(total_tr_len+val_len+te_len)]

total_tr_data_tensor=torch.from_numpy(total_tr_data).type(torch.FloatTensor)
total_tr_label_tensor=torch.from_numpy(total_tr_label).type(torch.LongTensor)

val_data_tensor=torch.from_numpy(val_data).type(torch.FloatTensor)
val_label_tensor=torch.from_numpy(val_label).type(torch.LongTensor)

te_data_tensor=torch.from_numpy(te_data).type(torch.FloatTensor)
te_label_tensor=torch.from_numpy(te_label).type(torch.LongTensor)

print('total tr len %d | val len %d | test len %d'%(len(total_tr_data_tensor),len(val_data_tensor),len(te_data_tensor)))

#==============================================================================================================

user_tr_data_tensors=[]
user_tr_label_tensors=[]

for i in range(nusers):
    
    user_tr_data_tensor=torch.from_numpy(total_tr_data[user_tr_len*i:user_tr_len*(i+1)]).type(torch.FloatTensor)
    user_tr_label_tensor=torch.from_numpy(total_tr_label[user_tr_len*i:user_tr_len*(i+1)]).type(torch.LongTensor)

    user_tr_data_tensors.append(user_tr_data_tensor)
    user_tr_label_tensors.append(user_tr_label_tensor)
    print('user %d tr len %d'%(i,len(user_tr_data_tensor)))

total data len:  60000
total tr len 50000 | val len 5000 | test len 5000
user 0 tr len 1000
user 1 tr len 1000
user 2 tr len 1000
user 3 tr len 1000
user 4 tr len 1000
user 5 tr len 1000
user 6 tr len 1000
user 7 tr len 1000
user 8 tr len 1000
user 9 tr len 1000
user 10 tr len 1000
user 11 tr len 1000
user 12 tr len 1000
user 13 tr len 1000
user 14 tr len 1000
user 15 tr len 1000
user 16 tr len 1000
user 17 tr len 1000
user 18 tr len 1000
user 19 tr len 1000
user 20 tr len 1000
user 21 tr len 1000
user 22 tr len 1000
user 23 tr len 1000
user 24 tr len 1000
user 25 tr len 1000
user 26 tr len 1000
user 27 tr len 1000
user 28 tr len 1000
user 29 tr len 1000
user 30 tr len 1000
user 31 tr len 1000
user 32 tr len 1000
user 33 tr len 1000
user 34 tr len 1000
user 35 tr len 1000
user 36 tr len 1000
user 37 tr len 1000
user 38 tr len 1000
user 39 tr len 1000
user 40 tr len 1000
user 41 tr len 1000
user 42 tr len 1000
user 43 tr len 1000
user 44 tr len 1000
user 45 tr len 1000
user 46 tr len 10

## Code for Trimmed-mean aggregation algorithm

In [8]:
def tr_mean(all_updates, n_attackers):
    sorted_updates = torch.sort(all_updates, 0)[0]
    out = torch.mean(sorted_updates[n_attackers:-n_attackers], 0) if n_attackers else torch.mean(sorted_updates,0)
    return out

## Full knowledge Fang attack on Trimmed mean aggregation
### Note that Fang attacks on Trimmed-mean and median are the same

In [7]:
def get_malicious_updates_fang_trmean(all_updates, deviation, n_attackers, epoch_num, compression='none', q_level=2, norm='inf'):
    b = 2
    max_vector = torch.max(all_updates, 0)[0]
    min_vector = torch.min(all_updates, 0)[0]

    max_ = (max_vector > 0).type(torch.FloatTensor).cuda()
    min_ = (min_vector < 0).type(torch.FloatTensor).cuda()

    max_[max_ == 1] = b
    max_[max_ == 0] = 1 / b
    min_[min_ == 1] = b
    min_[min_ == 0] = 1 / b

    max_range = torch.cat((max_vector[:, None], (max_vector * max_)[:, None]), dim=1)
    min_range = torch.cat(((min_vector * min_)[:, None], min_vector[:, None]), dim=1)

    rand = torch.from_numpy(np.random.uniform(0, 1, [len(deviation), n_attackers])).type(torch.FloatTensor).cuda()

    max_rand = torch.stack([max_range[:, 0]] * rand.shape[1]).T + rand * torch.stack([max_range[:, 1] - max_range[:, 0]] * rand.shape[1]).T
    min_rand = torch.stack([min_range[:, 0]] * rand.shape[1]).T + rand * torch.stack([min_range[:, 1] - min_range[:, 0]] * rand.shape[1]).T

    mal_vec = (torch.stack([(deviation > 0).type(torch.FloatTensor)] * max_rand.shape[1]).T.cuda() * max_rand + torch.stack(
        [(deviation > 0).type(torch.FloatTensor)] * min_rand.shape[1]).T.cuda() * min_rand).T

    quant_mal_vec = []
    if compression != 'none':
        if epoch_num == 0: print('compressing malicious update')
        for i in range(mal_vec.shape[0]):
            mal_ = mal_vec[i]
            mal_quant = qsgd(mal_, s=q_level, norm=norm)
            quant_mal_vec = mal_quant[None, :] if not len(quant_mal_vec) else torch.cat((quant_mal_vec, mal_quant[None, :]), 0)
    else:
        quant_mal_vec = mal_vec

    mal_updates = torch.cat((quant_mal_vec, all_updates), 0)

    return mal_updates

In [12]:
batch_size=250
resume=0
nepochs=1200
schedule=[1000]
nbatches = user_tr_len//batch_size

gamma=.5
opt = 'sgd'
fed_lr=0.5
criterion=nn.CrossEntropyLoss()
use_cuda = torch.cuda.is_available()

aggregation='trmean'
multi_k = False
candidates = []

at_type='fang'
z_values=[0.0]
n_attackers=[10]

arch='alexnet'
chkpt='./'+aggregation

for n_attacker in n_attackers:
    epoch_num = 0
    best_global_acc = 0
    best_global_te_acc = 0

    for z in z_values:
        fed_file='alexnet_checkpoint_%s_%s_%d_%.2f.pth.tar'%(aggregation,at_type,n_attacker,z)
        fed_best_file='alexnet_best_%s_%s_%d_%.2f.pth.tar'%(aggregation,at_type,n_attacker,z)

        if resume:
            fed_checkpoint = chkpt+'/'+fed_file
            assert os.path.isfile(fed_checkpoint), 'Error: no user checkpoint at %s'%(fed_checkpoint)
            checkpoint = torch.load(fed_checkpoint, map_location='cuda:%d'%torch.cuda.current_device())
            fed_model.load_state_dict(checkpoint['state_dict'])
            optimizer_fed.load_state_dict(checkpoint['optimizer'])
            resume = 0
            best_global_acc=checkpoint['best_acc']
            best_global_te_acc=checkpoint['best_te_acc']
            val_loss, val_acc = test(val_data_tensor,val_label_tensor,fed_model,criterion,use_cuda)
            epoch_num += checkpoint['epoch']
            print('resuming from epoch %d | val acc %.4f | best acc %.3f | best te acc %.3f'%(epoch_num, val_acc, best_global_acc, best_global_te_acc))

        torch.cuda.empty_cache()
        r=np.arange(user_tr_len)

        fed_model, _ = return_model(arch, 0.1, 0.9, parallel=False)
        optimizer_fed = SGD(fed_model.parameters(), lr=fed_lr)

        while epoch_num <= nepochs:
            user_grads=[]
            if not epoch_num and epoch_num%nbatches == 0:
                np.random.shuffle(r)
                for i in range(nusers):
                    user_tr_data_tensors[i]=user_tr_data_tensors[i][r]
                    user_tr_label_tensors[i]=user_tr_label_tensors[i][r]

            for i in range(n_attacker, nusers):

                inputs = user_tr_data_tensors[i][(epoch_num%nbatches)*batch_size:((epoch_num%nbatches) + 1) * batch_size]
                targets = user_tr_label_tensors[i][(epoch_num%nbatches)*batch_size:((epoch_num%nbatches) + 1) * batch_size]

                inputs, targets = inputs.cuda(), targets.cuda()
                inputs, targets = torch.autograd.Variable(inputs), torch.autograd.Variable(targets)

                outputs = fed_model(inputs)
                loss = criterion(outputs, targets)
                fed_model.zero_grad()
                loss.backward(retain_graph=True)

                param_grad=[]
                for param in fed_model.parameters():
                    param_grad=param.grad.data.view(-1) if not len(param_grad) else torch.cat((param_grad,param.grad.view(-1)))

                user_grads=param_grad[None, :] if len(user_grads)==0 else torch.cat((user_grads,param_grad[None,:]), 0)

            malicious_grads = user_grads

            if epoch_num in schedule:
                for param_group in optimizer_fed.param_groups:
                    param_group['lr'] *= gamma
                    print('New learnin rate ', param_group['lr'])

            if n_attacker > 0:
                if at_type == 'paf':
                    malicious_grads=get_malicious_predictions_poison_all_far_sign(malicious_grads,nusers,n_attacker)
                elif at_type == 'lie':
                    malicious_grads = get_malicious_updates_lie(malicious_grads, n_attacker, z, epoch_num)
                elif at_type == 'fang':
                    agg_grads = torch.mean(malicious_grads, 0)
                    deviation = torch.sign(agg_grads)
                    malicious_grads = get_malicious_updates_fang_trmean(malicious_grads, deviation, n_attacker, epoch_num)
                elif at_type == 'our':
                    agg_grads = torch.mean(malicious_grads, 0)
                    malicious_grads = our_attack_krum(malicious_grads, agg_grads, n_attacker, compression=compression, q_level=q_level, norm=norm)

            if not epoch_num : 
                print(malicious_grads.shape)
                
            if aggregation=='median':
                agg_grads=torch.median(malicious_grads,dim=0)[0]

            elif aggregation=='average':
                agg_grads=torch.mean(malicious_grads,dim=0)

            elif aggregation=='trmean':
                agg_grads=tr_mean(malicious_grads, n_attacker)

            elif aggregation=='krum' or aggregation=='mkrum':
                multi_k = True if aggregation == 'mkrum' else False
                if epoch_num == 0: print('multi krum is ', multi_k)
                agg_grads, krum_candidate = multi_krum(malicious_grads, n_attacker, multi_k=multi_k)
                
            elif aggregation=='bulyan':
                agg_grads, krum_candidate=bulyan(malicious_grads, n_attacker)

            del user_grads

            start_idx=0

            optimizer_fed.zero_grad()

            model_grads=[]

            for i, param in enumerate(fed_model.parameters()):
                param_=agg_grads[start_idx:start_idx+len(param.data.view(-1))].reshape(param.data.shape)
                start_idx=start_idx+len(param.data.view(-1))
                param_=param_.cuda()
                model_grads.append(param_)

            optimizer_fed.step(model_grads)

            val_loss, val_acc = test(val_data_tensor,val_label_tensor,fed_model,criterion,use_cuda)
            te_loss, te_acc = test(te_data_tensor,te_label_tensor, fed_model, criterion, use_cuda)

            is_best = best_global_acc < val_acc

            best_global_acc = max(best_global_acc, val_acc)

            if is_best:
                best_global_te_acc = te_acc

            if epoch_num%10==0 or epoch_num==nepochs-1:
                print('%s: at %s n_at %d e %d fed_model val loss %.4f val acc %.4f best val_acc %f te_acc %f'%(aggregation, at_type, n_attacker, epoch_num, val_loss, val_acc, best_global_acc,best_global_te_acc))

            if val_loss > 10:
                print('val loss %f too high'%val_loss)
                break
            
            epoch_num+=1

torch.Size([50, 2472266])
trmean: at fang n_at 10 e 0 fed_model val loss 2.3023 val acc 9.8011 best val_acc 9.801136 te_acc 10.491071
trmean: at fang n_at 10 e 10 fed_model val loss 2.2965 val acc 10.4911 best val_acc 10.693994 te_acc 11.708604
trmean: at fang n_at 10 e 20 fed_model val loss 2.2621 val acc 18.3036 best val_acc 19.602273 te_acc 19.622565
trmean: at fang n_at 10 e 30 fed_model val loss 2.1581 val acc 22.0982 best val_acc 22.098214 te_acc 21.773539
trmean: at fang n_at 10 e 40 fed_model val loss 2.1621 val acc 20.6778 best val_acc 22.098214 te_acc 21.773539
trmean: at fang n_at 10 e 50 fed_model val loss 2.1552 val acc 19.4805 best val_acc 23.011364 te_acc 22.199675
trmean: at fang n_at 10 e 60 fed_model val loss 2.0702 val acc 20.8401 best val_acc 23.620130 te_acc 23.823052
trmean: at fang n_at 10 e 70 fed_model val loss 2.0510 val acc 21.8344 best val_acc 23.620130 te_acc 23.823052
trmean: at fang n_at 10 e 80 fed_model val loss 1.9517 val acc 27.2524 best val_acc 27.25

trmean: at fang n_at 10 e 740 fed_model val loss 1.4170 val acc 49.4927 best val_acc 50.771104 te_acc 51.237825
trmean: at fang n_at 10 e 750 fed_model val loss 1.3967 val acc 49.5739 best val_acc 50.771104 te_acc 51.237825
trmean: at fang n_at 10 e 760 fed_model val loss 1.4202 val acc 48.9042 best val_acc 50.933442 te_acc 50.669643
trmean: at fang n_at 10 e 770 fed_model val loss 1.3929 val acc 50.1218 best val_acc 50.933442 te_acc 50.669643
trmean: at fang n_at 10 e 780 fed_model val loss 1.3900 val acc 50.3044 best val_acc 50.933442 te_acc 50.669643
trmean: at fang n_at 10 e 790 fed_model val loss 1.4255 val acc 48.2955 best val_acc 50.933442 te_acc 50.669643
trmean: at fang n_at 10 e 800 fed_model val loss 1.4077 val acc 49.9391 best val_acc 50.933442 te_acc 50.669643
trmean: at fang n_at 10 e 810 fed_model val loss 1.4020 val acc 49.1883 best val_acc 50.933442 te_acc 50.669643
trmean: at fang n_at 10 e 820 fed_model val loss 1.4265 val acc 49.6347 best val_acc 50.933442 te_acc 50

## Code for LIE attack 

In [9]:
def lie_attack(all_updates, z):
    avg = torch.mean(all_updates, dim=0)
    std = torch.std(all_updates, dim=0)
    return avg + z * std

In [14]:
batch_size=250
resume=0
nepochs=1200
schedule=[1000]
nbatches = user_tr_len//batch_size

gamma=.5
opt = 'sgd'
fed_lr=0.5
criterion=nn.CrossEntropyLoss()
use_cuda = torch.cuda.is_available()

aggregation='trmean'
multi_k = False
candidates = []

at_type='LIE'
z_values={3:0.69847, 5:0.7054, 8:0.71904, 10:0.72575, 12:0.73891}
n_attackers=[10]

arch='alexnet'
chkpt='./'+aggregation

for n_attacker in n_attackers:
    epoch_num = 0
    best_global_acc = 0
    best_global_te_acc = 0

    torch.cuda.empty_cache()
    r=np.arange(user_tr_len)

    fed_model, _ = return_model(arch, 0.1, 0.9, parallel=False)
    optimizer_fed = SGD(fed_model.parameters(), lr=fed_lr)

    while epoch_num <= nepochs:
        user_grads=[]
        if not epoch_num and epoch_num%nbatches == 0:
            np.random.shuffle(r)
            for i in range(nusers):
                user_tr_data_tensors[i]=user_tr_data_tensors[i][r]
                user_tr_label_tensors[i]=user_tr_label_tensors[i][r]

        for i in range(n_attacker, nusers):

            inputs = user_tr_data_tensors[i][(epoch_num%nbatches)*batch_size:((epoch_num%nbatches) + 1) * batch_size]
            targets = user_tr_label_tensors[i][(epoch_num%nbatches)*batch_size:((epoch_num%nbatches) + 1) * batch_size]

            inputs, targets = inputs.cuda(), targets.cuda()
            inputs, targets = torch.autograd.Variable(inputs), torch.autograd.Variable(targets)

            outputs = fed_model(inputs)
            loss = criterion(outputs, targets)
            fed_model.zero_grad()
            loss.backward(retain_graph=True)

            param_grad=[]
            for param in fed_model.parameters():
                param_grad=param.grad.data.view(-1) if not len(param_grad) else torch.cat((param_grad,param.grad.view(-1)))

            user_grads=param_grad[None, :] if len(user_grads)==0 else torch.cat((user_grads,param_grad[None,:]), 0)

        malicious_grads = user_grads

        if epoch_num in schedule:
            for param_group in optimizer_fed.param_groups:
                param_group['lr'] *= gamma
                print('New learnin rate ', param_group['lr'])

        if n_attacker > 0:
            if at_type == 'lie':
                mal_update = lie_attack(malicious_grads, z_values[n_attacker])
                malicious_grads = torch.cat((torch.stack([mal_update]*n_attacker), malicious_grads))
            elif at_type == 'fang':
                agg_grads = torch.mean(malicious_grads, 0)
                deviation = torch.sign(agg_grads)
                malicious_grads = get_malicious_updates_fang_trmean(malicious_grads, deviation, n_attacker, epoch_num)
            elif at_type == 'our-agr':
                agg_grads = torch.mean(malicious_grads, 0)
                malicious_grads = our_attack_krum(malicious_grads, agg_grads, n_attacker, compression=compression, q_level=q_level, norm=norm)

        if not epoch_num : 
            print(malicious_grads.shape)

        if aggregation=='median':
            agg_grads=torch.median(malicious_grads,dim=0)[0]

        elif aggregation=='average':
            agg_grads=torch.mean(malicious_grads,dim=0)

        elif aggregation=='trmean':
            agg_grads=tr_mean(malicious_grads, n_attacker)

        elif aggregation=='krum' or aggregation=='mkrum':
            multi_k = True if aggregation == 'mkrum' else False
            if epoch_num == 0: print('multi krum is ', multi_k)
            agg_grads, krum_candidate = multi_krum(malicious_grads, n_attacker, multi_k=multi_k)

        elif aggregation=='bulyan':
            agg_grads, krum_candidate=bulyan(malicious_grads, n_attacker)

        del user_grads

        start_idx=0

        optimizer_fed.zero_grad()

        model_grads=[]

        for i, param in enumerate(fed_model.parameters()):
            param_=agg_grads[start_idx:start_idx+len(param.data.view(-1))].reshape(param.data.shape)
            start_idx=start_idx+len(param.data.view(-1))
            param_=param_.cuda()
            model_grads.append(param_)

        optimizer_fed.step(model_grads)

        val_loss, val_acc = test(val_data_tensor,val_label_tensor,fed_model,criterion,use_cuda)
        te_loss, te_acc = test(te_data_tensor,te_label_tensor, fed_model, criterion, use_cuda)

        is_best = best_global_acc < val_acc

        best_global_acc = max(best_global_acc, val_acc)

        if is_best:
            best_global_te_acc = te_acc

        if epoch_num%10==0 or epoch_num==nepochs-1:
            print('%s: at %s n_at %d e %d fed_model val loss %.4f val acc %.4f best val_acc %f te_acc %f'%(aggregation, at_type, n_attacker, epoch_num, val_loss, val_acc, best_global_acc,best_global_te_acc))

        if val_loss > 10:
            print('val loss %f too high'%val_loss)
            break

        epoch_num+=1

torch.Size([40, 2472266])
trmean: at LIE n_at 10 e 0 fed_model val loss 2.3031 val acc 9.6794 best val_acc 9.679383 te_acc 10.064935
trmean: at LIE n_at 10 e 10 fed_model val loss 2.2939 val acc 12.7029 best val_acc 13.879870 te_acc 14.103084
trmean: at LIE n_at 10 e 20 fed_model val loss 2.2019 val acc 19.9269 best val_acc 22.321429 te_acc 22.828734
trmean: at LIE n_at 10 e 30 fed_model val loss 2.2290 val acc 15.0365 best val_acc 22.321429 te_acc 22.828734
trmean: at LIE n_at 10 e 40 fed_model val loss 2.1578 val acc 20.0284 best val_acc 22.706981 te_acc 22.646104
trmean: at LIE n_at 10 e 50 fed_model val loss 2.1634 val acc 18.8312 best val_acc 24.249188 te_acc 23.457792
trmean: at LIE n_at 10 e 60 fed_model val loss 2.0806 val acc 19.2573 best val_acc 24.249188 te_acc 23.457792
trmean: at LIE n_at 10 e 70 fed_model val loss 2.2866 val acc 13.5552 best val_acc 24.249188 te_acc 23.457792
trmean: at LIE n_at 10 e 80 fed_model val loss 2.2637 val acc 14.8539 best val_acc 24.249188 te_a

trmean: at LIE n_at 10 e 740 fed_model val loss 1.4977 val acc 48.9651 best val_acc 59.638799 te_acc 59.780844
trmean: at LIE n_at 10 e 750 fed_model val loss 1.2149 val acc 56.7370 best val_acc 59.638799 te_acc 59.780844
trmean: at LIE n_at 10 e 760 fed_model val loss 1.2588 val acc 55.3774 best val_acc 59.638799 te_acc 59.780844
trmean: at LIE n_at 10 e 770 fed_model val loss 1.1938 val acc 57.2646 best val_acc 59.638799 te_acc 59.780844
trmean: at LIE n_at 10 e 780 fed_model val loss 1.3489 val acc 52.6989 best val_acc 59.638799 te_acc 59.780844
trmean: at LIE n_at 10 e 790 fed_model val loss 1.2654 val acc 57.1834 best val_acc 59.638799 te_acc 59.780844
trmean: at LIE n_at 10 e 800 fed_model val loss 1.2386 val acc 57.0617 best val_acc 59.638799 te_acc 59.780844
trmean: at LIE n_at 10 e 810 fed_model val loss 1.2901 val acc 55.4180 best val_acc 59.882305 te_acc 60.186688
trmean: at LIE n_at 10 e 820 fed_model val loss 1.4314 val acc 51.3190 best val_acc 59.882305 te_acc 60.186688
t

## Code for our AGR-tailored attack on Trimmed-mean

In [10]:
def our_attack_trmean(all_updates, model_re, n_attackers, dev_type='unit_vec'):

    if dev_type == 'unit_vec':
        deviation = model_re / torch.norm(model_re)  # unit vector, dir opp to good dir
    elif dev_type == 'sign':
        deviation = torch.sign(model_re)
    elif dev_type == 'std':
        deviation = torch.std(all_updates, 0)

    lamda = torch.Tensor([10.0]).cuda() #compute_lambda_our(all_updates, model_re, n_attackers)
    # print(lamda)
    threshold_diff = 1e-5
    prev_loss = -1
    lamda_fail = lamda
    lamda_succ = 0
    iters = 0 
    while torch.abs(lamda_succ - lamda) > threshold_diff:
        mal_update = (model_re - lamda * deviation)
        mal_updates = torch.stack([mal_update] * n_attackers)
        mal_updates = torch.cat((mal_updates, all_updates), 0)

        agg_grads = tr_mean(mal_updates, n_attackers)
        
        loss = torch.norm(agg_grads - model_re)
        
        if prev_loss < loss:
            # print('successful lamda is ', lamda)
            lamda_succ = lamda
            lamda = lamda + lamda_fail / 2
        else:
            lamda = lamda - lamda_fail / 2

        lamda_fail = lamda_fail / 2
        prev_loss = loss
        
    mal_update = (model_re - lamda_succ * deviation)
    mal_updates = torch.stack([mal_update] * n_attackers)
    mal_updates = torch.cat((mal_updates, all_updates), 0)

    return mal_updates

In [18]:
batch_size=250
resume=0
nepochs=1200
schedule=[1000]
nbatches = user_tr_len//batch_size

gamma=.5
opt = 'sgd'
fed_lr=0.5
criterion=nn.CrossEntropyLoss()
use_cuda = torch.cuda.is_available()

aggregation='trmean'
multi_k = False
candidates = []

at_type='our-agr'
dev_type = 'std'
z_values={3:0.69847, 5:0.7054, 8:0.71904, 10:0.72575, 12:0.73891}
n_attackers=[10]

arch='alexnet'
chkpt='./'+aggregation

for n_attacker in n_attackers:
    epoch_num = 0
    best_global_acc = 0
    best_global_te_acc = 0

    torch.cuda.empty_cache()
    r=np.arange(user_tr_len)

    fed_model, _ = return_model(arch, 0.1, 0.9, parallel=False)
    optimizer_fed = SGD(fed_model.parameters(), lr=fed_lr)

    while epoch_num <= nepochs:
        user_grads=[]
        if not epoch_num and epoch_num%nbatches == 0:
            np.random.shuffle(r)
            for i in range(nusers):
                user_tr_data_tensors[i]=user_tr_data_tensors[i][r]
                user_tr_label_tensors[i]=user_tr_label_tensors[i][r]

        for i in range(n_attacker, nusers):

            inputs = user_tr_data_tensors[i][(epoch_num%nbatches)*batch_size:((epoch_num%nbatches) + 1) * batch_size]
            targets = user_tr_label_tensors[i][(epoch_num%nbatches)*batch_size:((epoch_num%nbatches) + 1) * batch_size]

            inputs, targets = inputs.cuda(), targets.cuda()
            inputs, targets = torch.autograd.Variable(inputs), torch.autograd.Variable(targets)

            outputs = fed_model(inputs)
            loss = criterion(outputs, targets)
            fed_model.zero_grad()
            loss.backward(retain_graph=True)

            param_grad=[]
            for param in fed_model.parameters():
                param_grad=param.grad.data.view(-1) if not len(param_grad) else torch.cat((param_grad,param.grad.view(-1)))

            user_grads=param_grad[None, :] if len(user_grads)==0 else torch.cat((user_grads,param_grad[None,:]), 0)

        malicious_grads = user_grads

        if epoch_num in schedule:
            for param_group in optimizer_fed.param_groups:
                param_group['lr'] *= gamma
                print('New learnin rate ', param_group['lr'])

        if n_attacker > 0:
            if at_type == 'lie':
                mal_update = lie_attack(malicious_grads, z_values[n_attacker])
                malicious_grads = torch.cat((torch.stack([mal_update]*n_attacker), malicious_grads))
            elif at_type == 'fang':
                agg_grads = torch.mean(malicious_grads, 0)
                deviation = torch.sign(agg_grads)
                malicious_grads = get_malicious_updates_fang_trmean(malicious_grads, deviation, n_attacker, epoch_num)
            elif at_type == 'our-agr':
                agg_grads = torch.mean(malicious_grads, 0)
                malicious_grads = our_attack_trmean(malicious_grads, agg_grads, n_attacker, dev_type=dev_type)

        if not epoch_num : 
            print(malicious_grads.shape)

        if aggregation=='median':
            agg_grads=torch.median(malicious_grads,dim=0)[0]

        elif aggregation=='average':
            agg_grads=torch.mean(malicious_grads,dim=0)

        elif aggregation=='trmean':
            agg_grads=tr_mean(malicious_grads, n_attacker)

        elif aggregation=='krum' or aggregation=='mkrum':
            multi_k = True if aggregation == 'mkrum' else False
            if epoch_num == 0: print('multi krum is ', multi_k)
            agg_grads, krum_candidate = multi_krum(malicious_grads, n_attacker, multi_k=multi_k)

        elif aggregation=='bulyan':
            agg_grads, krum_candidate=bulyan(malicious_grads, n_attacker)

        del user_grads

        start_idx=0

        optimizer_fed.zero_grad()

        model_grads=[]

        for i, param in enumerate(fed_model.parameters()):
            param_=agg_grads[start_idx:start_idx+len(param.data.view(-1))].reshape(param.data.shape)
            start_idx=start_idx+len(param.data.view(-1))
            param_=param_.cuda()
            model_grads.append(param_)

        optimizer_fed.step(model_grads)

        val_loss, val_acc = test(val_data_tensor,val_label_tensor,fed_model,criterion,use_cuda)
        te_loss, te_acc = test(te_data_tensor,te_label_tensor, fed_model, criterion, use_cuda)

        is_best = best_global_acc < val_acc

        best_global_acc = max(best_global_acc, val_acc)

        if is_best:
            best_global_te_acc = te_acc

        if epoch_num%10==0 or epoch_num==nepochs-1:
            print('%s: at %s n_at %d e %d fed_model val loss %.4f val acc %.4f best val_acc %f te_acc %f'%(aggregation, at_type, n_attacker, epoch_num, val_loss, val_acc, best_global_acc,best_global_te_acc))

        if val_loss > 1000:
            print('val loss %f too high'%val_loss)
            break

        epoch_num+=1

torch.Size([50, 2472266])
trmean: at our-agr n_at 10 e 0 fed_model val loss 2.3025 val acc 9.6388 best val_acc 9.638799 te_acc 9.659091
trmean: at our-agr n_at 10 e 10 fed_model val loss 2.2920 val acc 12.6623 best val_acc 12.662338 te_acc 13.413149
trmean: at our-agr n_at 10 e 20 fed_model val loss 2.5542 val acc 12.9058 best val_acc 19.500812 te_acc 20.758929
trmean: at our-agr n_at 10 e 30 fed_model val loss 2.2986 val acc 12.2362 best val_acc 19.500812 te_acc 20.758929
trmean: at our-agr n_at 10 e 40 fed_model val loss 2.2566 val acc 17.8369 best val_acc 21.509740 te_acc 21.022727
trmean: at our-agr n_at 10 e 50 fed_model val loss 2.2911 val acc 11.7695 best val_acc 21.509740 te_acc 21.022727
trmean: at our-agr n_at 10 e 60 fed_model val loss 2.2098 val acc 16.8831 best val_acc 21.509740 te_acc 21.022727
trmean: at our-agr n_at 10 e 70 fed_model val loss 2.3257 val acc 11.1810 best val_acc 21.509740 te_acc 21.022727
trmean: at our-agr n_at 10 e 80 fed_model val loss 2.1418 val acc 

## Code for our first AGR-agnostic attack - Min-max

In [11]:
'''
MIN-MAX attack
'''
def our_attack_dist(all_updates, model_re, n_attackers, dev_type='unit_vec'):

    if dev_type == 'unit_vec':
        deviation = model_re / torch.norm(model_re)  # unit vector, dir opp to good dir
    elif dev_type == 'sign':
        deviation = torch.sign(model_re)
    elif dev_type == 'std':
        deviation = torch.std(all_updates, 0)

    lamda = torch.Tensor([10.0]).float().cuda()
    # print(lamda)
    threshold_diff = 1e-5
    lamda_fail = lamda
    lamda_succ = 0
    
    distances = []
    for update in all_updates:
        distance = torch.norm((all_updates - update), dim=1) ** 2
        distances = distance[None, :] if not len(distances) else torch.cat((distances, distance[None, :]), 0)
    
    max_distance = torch.max(distances)
    del distances

    while torch.abs(lamda_succ - lamda) > threshold_diff:
        mal_update = (model_re - lamda * deviation)
        distance = torch.norm((all_updates - mal_update), dim=1) ** 2
        max_d = torch.max(distance)
        
        if max_d <= max_distance:
            # print('successful lamda is ', lamda)
            lamda_succ = lamda
            lamda = lamda + lamda_fail / 2
        else:
            lamda = lamda - lamda_fail / 2

        lamda_fail = lamda_fail / 2

    mal_update = (model_re - lamda_succ * deviation)
    
    return mal_update

In [12]:
batch_size=250
resume=0
nepochs=1200
schedule=[1000]
nbatches = user_tr_len//batch_size

gamma=.5
opt = 'sgd'
fed_lr=0.5
criterion=nn.CrossEntropyLoss()
use_cuda = torch.cuda.is_available()

aggregation='trmean'
multi_k = False
candidates = []

at_type='min-max'
dev_type ='std'
z=0
n_attackers=[10]

arch='alexnet'
chkpt='./'+aggregation

for n_attacker in n_attackers:
    candidates = []

    epoch_num = 0
    best_global_acc = 0
    best_global_te_acc = 0

    fed_model, _ = return_model(arch, 0.1, 0.9, parallel=False)
    optimizer_fed = SGD(fed_model.parameters(), lr=fed_lr)

    torch.cuda.empty_cache()
    r=np.arange(user_tr_len)

    while epoch_num <= nepochs:
        user_grads=[]
        if not epoch_num and epoch_num%nbatches == 0:
            np.random.shuffle(r)
            for i in range(nusers):
                user_tr_data_tensors[i]=user_tr_data_tensors[i][r]
                user_tr_label_tensors[i]=user_tr_label_tensors[i][r]

        for i in range(n_attacker, nusers):

            inputs = user_tr_data_tensors[i][(epoch_num%nbatches)*batch_size:((epoch_num%nbatches) + 1) * batch_size]
            targets = user_tr_label_tensors[i][(epoch_num%nbatches)*batch_size:((epoch_num%nbatches) + 1) * batch_size]

            inputs, targets = inputs.cuda(), targets.cuda()
            inputs, targets = torch.autograd.Variable(inputs), torch.autograd.Variable(targets)

            outputs = fed_model(inputs)
            loss = criterion(outputs, targets)
            fed_model.zero_grad()
            loss.backward(retain_graph=True)

            param_grad=[]
            for param in fed_model.parameters():
                param_grad=param.grad.data.view(-1) if not len(param_grad) else torch.cat((param_grad,param.grad.view(-1)))

            user_grads=param_grad[None, :] if len(user_grads)==0 else torch.cat((user_grads,param_grad[None,:]), 0)

        malicious_grads = user_grads

        if epoch_num in schedule:
            for param_group in optimizer_fed.param_groups:
                param_group['lr'] *= gamma
                print('New learnin rate ', param_group['lr'])

        if n_attacker > 0:
            if at_type == 'lie':
                malicious_grads = get_malicious_updates_lie(malicious_grads, n_attacker, z, epoch_num)
            elif at_type == 'fang':
                agg_grads = torch.mean(malicious_grads, 0)
                deviation = torch.sign(agg_grads)
                malicious_grads = get_malicious_updates_fang(malicious_grads, agg_grads, deviation, n_attacker)
            elif at_type == 'our-agr':
                agg_grads = torch.mean(malicious_grads, 0)
                mal_update = our_attack_median(malicious_grads, agg_grads, n_attacker, dev_type)
            elif at_type == 'min-max':
                agg_grads = torch.mean(malicious_grads, 0)
                mal_update = our_attack_dist(malicious_grads, agg_grads, n_attacker, dev_type)
            elif at_type == 'min-sum':
                agg_grads = torch.mean(malicious_grads, 0)
                mal_update = our_attack_score(malicious_grads, agg_grads, n_attacker, dev_type)

            mal_updates = torch.stack([mal_update] * n_attacker)
            malicious_grads = torch.cat((mal_updates, user_grads), 0)

        if epoch_num==0: print('malicious_grads shape ', malicious_grads.shape)

        if aggregation=='median':
            agg_grads=torch.median(malicious_grads,dim=0)[0]

        elif aggregation=='trmean':
            agg_grads=tr_mean(malicious_grads, n_attacker)
            
        elif aggregation=='average':
            agg_grads=torch.mean(malicious_grads,dim=0)

        elif aggregation=='krum' or aggregation=='mkrum':
            multi_k = True if aggregation == 'mkrum' else False
            if epoch_num == 0: print('multi krum is ', multi_k)
            agg_grads, krum_candidate = multi_krum(malicious_grads, n_attacker, multi_k=multi_k)

        elif aggregation=='bulyan':
            agg_grads,bulyan_candidate=bulyan(malicious_grads, n_attacker)

        del user_grads

        start_idx=0

        optimizer_fed.zero_grad()

        model_grads=[]

        for i, param in enumerate(fed_model.parameters()):
            param_=agg_grads[start_idx:start_idx+len(param.data.view(-1))].reshape(param.data.shape)
            start_idx=start_idx+len(param.data.view(-1))
            param_=param_.cuda()
            model_grads.append(param_)

        optimizer_fed.step(model_grads)

        val_loss, val_acc = test(val_data_tensor,val_label_tensor,fed_model,criterion,use_cuda)
        te_loss, te_acc = test(te_data_tensor,te_label_tensor, fed_model, criterion, use_cuda)

        is_best = best_global_acc < val_acc

        best_global_acc = max(best_global_acc, val_acc)

        if is_best:
            best_global_te_acc = te_acc

        if epoch_num%10==0 or epoch_num==nepochs-1:
            print('%s: at %s n_at %d | e %d fed_model val loss %.4f val acc %.4f best val_acc %f te_acc %f'%(aggregation, at_type, n_attacker, epoch_num, val_loss, val_acc, best_global_acc,best_global_te_acc))

        if val_loss > 1000:
            print('val loss %f too high'%val_loss)
            break
            
        epoch_num+=1

malicious_grads shape  torch.Size([50, 2472266])


	add_(Number alpha, Tensor other)
Consider using one of the following signatures instead:
	add_(Tensor other, *, Number alpha) (Triggered internally at  /opt/conda/conda-bld/pytorch_1603729138878/work/torch/csrc/utils/python_arg_parser.cpp:882.)
  p.data.add_(-group['lr'], d_p)


trmean: at min-max n_at 10 | e 0 fed_model val loss 2.3025 val acc 9.9229 best val_acc 9.922890 te_acc 9.435877
trmean: at min-max n_at 10 | e 10 fed_model val loss 2.2891 val acc 10.8969 best val_acc 16.152597 te_acc 15.909091
trmean: at min-max n_at 10 | e 20 fed_model val loss 2.3295 val acc 9.9432 best val_acc 20.474838 te_acc 20.474838
trmean: at min-max n_at 10 | e 30 fed_model val loss 2.2610 val acc 10.1461 best val_acc 20.474838 te_acc 20.474838
trmean: at min-max n_at 10 | e 40 fed_model val loss 2.2000 val acc 18.0804 best val_acc 20.474838 te_acc 20.474838
trmean: at min-max n_at 10 | e 50 fed_model val loss 2.2149 val acc 17.0455 best val_acc 20.474838 te_acc 20.474838
trmean: at min-max n_at 10 | e 60 fed_model val loss 2.1762 val acc 19.9472 best val_acc 21.509740 te_acc 21.205357
trmean: at min-max n_at 10 | e 70 fed_model val loss 2.2078 val acc 16.0714 best val_acc 21.509740 te_acc 21.205357
trmean: at min-max n_at 10 | e 80 fed_model val loss 2.3448 val acc 9.3953 be

trmean: at min-max n_at 10 | e 710 fed_model val loss 1.9609 val acc 28.6526 best val_acc 30.641234 te_acc 28.774351
trmean: at min-max n_at 10 | e 720 fed_model val loss 1.9986 val acc 23.2955 best val_acc 30.641234 te_acc 28.774351
trmean: at min-max n_at 10 | e 730 fed_model val loss 2.1788 val acc 21.4083 best val_acc 30.641234 te_acc 28.774351
trmean: at min-max n_at 10 | e 740 fed_model val loss 2.1673 val acc 18.7906 best val_acc 30.641234 te_acc 28.774351
trmean: at min-max n_at 10 | e 750 fed_model val loss 2.1418 val acc 23.2752 best val_acc 32.021104 te_acc 31.554383
trmean: at min-max n_at 10 | e 760 fed_model val loss 2.5925 val acc 18.1006 best val_acc 32.609578 te_acc 32.974838
trmean: at min-max n_at 10 | e 770 fed_model val loss 1.8563 val acc 30.3977 best val_acc 32.609578 te_acc 32.974838
trmean: at min-max n_at 10 | e 780 fed_model val loss 2.1515 val acc 21.1039 best val_acc 32.609578 te_acc 32.974838
trmean: at min-max n_at 10 | e 790 fed_model val loss 1.9447 val

## Code for our second AGR-agnostic attack - Min-sum

In [13]:
'''
MIN-SUM attack
'''

def our_attack_score(all_updates, model_re, n_attackers, dev_type='unit_vec'):

    if dev_type == 'unit_vec':
        deviation = model_re / torch.norm(model_re)  # unit vector, dir opp to good dir
    elif dev_type == 'sign':
        deviation = torch.sign(model_re)
    elif dev_type == 'std':
        deviation = torch.std(all_updates, 0)
    
    lamda = torch.Tensor([10.0]).float().cuda()
    # print(lamda)
    threshold_diff = 1e-5
    lamda_fail = lamda
    lamda_succ = 0
    
    distances = []
    for update in all_updates:
        distance = torch.norm((all_updates - update), dim=1) ** 2
        distances = distance[None, :] if not len(distances) else torch.cat((distances, distance[None, :]), 0)
    
    scores = torch.sum(distances, dim=1)
    min_score = torch.min(scores)
    del distances

    while torch.abs(lamda_succ - lamda) > threshold_diff:
        mal_update = (model_re - lamda * deviation)
        distance = torch.norm((all_updates - mal_update), dim=1) ** 2
        score = torch.sum(distance)
        
        if score <= min_score:
            # print('successful lamda is ', lamda)
            lamda_succ = lamda
            lamda = lamda + lamda_fail / 2
        else:
            lamda = lamda - lamda_fail / 2

        lamda_fail = lamda_fail / 2

    # print(lamda_succ)
    mal_update = (model_re - lamda_succ * deviation)
    
    return mal_update
    

In [14]:
batch_size=250
resume=0
nepochs=1200
schedule=[1000]
nbatches = user_tr_len//batch_size

gamma=.5
opt = 'sgd'
fed_lr=0.5
criterion=nn.CrossEntropyLoss()
use_cuda = torch.cuda.is_available()

aggregation='trmean'
multi_k = False
candidates = []

at_type='min-sum'
dev_type ='std'
z=0
n_attackers=[10]

arch='alexnet'
chkpt='./'+aggregation

for n_attacker in n_attackers:
    candidates = []

    epoch_num = 0
    best_global_acc = 0
    best_global_te_acc = 0

    fed_model, _ = return_model(arch, 0.1, 0.9, parallel=False)
    optimizer_fed = SGD(fed_model.parameters(), lr=fed_lr)

    torch.cuda.empty_cache()
    r=np.arange(user_tr_len)

    while epoch_num <= nepochs:
        user_grads=[]
        if not epoch_num and epoch_num%nbatches == 0:
            np.random.shuffle(r)
            for i in range(nusers):
                user_tr_data_tensors[i]=user_tr_data_tensors[i][r]
                user_tr_label_tensors[i]=user_tr_label_tensors[i][r]

        for i in range(n_attacker, nusers):

            inputs = user_tr_data_tensors[i][(epoch_num%nbatches)*batch_size:((epoch_num%nbatches) + 1) * batch_size]
            targets = user_tr_label_tensors[i][(epoch_num%nbatches)*batch_size:((epoch_num%nbatches) + 1) * batch_size]

            inputs, targets = inputs.cuda(), targets.cuda()
            inputs, targets = torch.autograd.Variable(inputs), torch.autograd.Variable(targets)

            outputs = fed_model(inputs)
            loss = criterion(outputs, targets)
            fed_model.zero_grad()
            loss.backward(retain_graph=True)

            param_grad=[]
            for param in fed_model.parameters():
                param_grad=param.grad.data.view(-1) if not len(param_grad) else torch.cat((param_grad,param.grad.view(-1)))

            user_grads=param_grad[None, :] if len(user_grads)==0 else torch.cat((user_grads,param_grad[None,:]), 0)

        malicious_grads = user_grads

        if epoch_num in schedule:
            for param_group in optimizer_fed.param_groups:
                param_group['lr'] *= gamma
                print('New learnin rate ', param_group['lr'])

        if n_attacker > 0:
            if at_type == 'lie':
                malicious_grads = get_malicious_updates_lie(malicious_grads, n_attacker, z, epoch_num)
            elif at_type == 'fang':
                agg_grads = torch.mean(malicious_grads, 0)
                deviation = torch.sign(agg_grads)
                malicious_grads = get_malicious_updates_fang(malicious_grads, agg_grads, deviation, n_attacker)
            elif at_type == 'our-agr':
                agg_grads = torch.mean(malicious_grads, 0)
                mal_update = our_attack_median(malicious_grads, agg_grads, n_attacker, dev_type)
            elif at_type == 'min-max':
                agg_grads = torch.mean(malicious_grads, 0)
                mal_update = our_attack_dist(malicious_grads, agg_grads, n_attacker, dev_type)
            elif at_type == 'min-sum':
                agg_grads = torch.mean(malicious_grads, 0)
                mal_update = our_attack_score(malicious_grads, agg_grads, n_attacker, dev_type)

            mal_updates = torch.stack([mal_update] * n_attacker)
            malicious_grads = torch.cat((mal_updates, user_grads), 0)

        if epoch_num==0: print('malicious_grads shape ', malicious_grads.shape)

        if aggregation=='median':
            agg_grads=torch.median(malicious_grads,dim=0)[0]

        elif aggregation=='trmean':
            agg_grads=tr_mean(malicious_grads, n_attacker)
            
        elif aggregation=='average':
            agg_grads=torch.mean(malicious_grads,dim=0)

        elif aggregation=='krum' or aggregation=='mkrum':
            multi_k = True if aggregation == 'mkrum' else False
            if epoch_num == 0: print('multi krum is ', multi_k)
            agg_grads, krum_candidate = multi_krum(malicious_grads, n_attacker, multi_k=multi_k)

        elif aggregation=='bulyan':
            agg_grads,bulyan_candidate=bulyan(malicious_grads, n_attacker)

        del user_grads

        start_idx=0

        optimizer_fed.zero_grad()

        model_grads=[]

        for i, param in enumerate(fed_model.parameters()):
            param_=agg_grads[start_idx:start_idx+len(param.data.view(-1))].reshape(param.data.shape)
            start_idx=start_idx+len(param.data.view(-1))
            param_=param_.cuda()
            model_grads.append(param_)

        optimizer_fed.step(model_grads)

        val_loss, val_acc = test(val_data_tensor,val_label_tensor,fed_model,criterion,use_cuda)
        te_loss, te_acc = test(te_data_tensor,te_label_tensor, fed_model, criterion, use_cuda)

        is_best = best_global_acc < val_acc

        best_global_acc = max(best_global_acc, val_acc)

        if is_best:
            best_global_te_acc = te_acc

        if epoch_num%10==0 or epoch_num==nepochs-1:
            print('%s: at %s n_at %d | e %d fed_model val loss %.4f val acc %.4f best val_acc %f te_acc %f'%(aggregation, at_type, n_attacker, epoch_num, val_loss, val_acc, best_global_acc,best_global_te_acc))

        if val_loss > 1000:
            print('val loss %f too high'%val_loss)
            break
            
        epoch_num+=1

malicious_grads shape  torch.Size([50, 2472266])
trmean: at min-sum n_at 10 | e 0 fed_model val loss 2.3028 val acc 10.0649 best val_acc 10.064935 te_acc 9.780844
trmean: at min-sum n_at 10 | e 10 fed_model val loss 2.2909 val acc 13.1494 best val_acc 14.671266 te_acc 15.056818
trmean: at min-sum n_at 10 | e 20 fed_model val loss 2.6275 val acc 11.5463 best val_acc 20.271916 te_acc 21.144481
trmean: at min-sum n_at 10 | e 30 fed_model val loss 2.2922 val acc 12.5609 best val_acc 20.271916 te_acc 21.144481
trmean: at min-sum n_at 10 | e 40 fed_model val loss 2.1474 val acc 24.2695 best val_acc 24.269481 te_acc 23.538961
trmean: at min-sum n_at 10 | e 50 fed_model val loss 2.2958 val acc 11.7695 best val_acc 24.269481 te_acc 23.538961
trmean: at min-sum n_at 10 | e 60 fed_model val loss 2.2730 val acc 11.3636 best val_acc 24.269481 te_acc 23.538961
trmean: at min-sum n_at 10 | e 70 fed_model val loss 2.1983 val acc 19.2573 best val_acc 24.269481 te_acc 23.538961
trmean: at min-sum n_at 1

trmean: at min-sum n_at 10 | e 700 fed_model val loss 2.4183 val acc 19.9878 best val_acc 38.372565 te_acc 38.636364
trmean: at min-sum n_at 10 | e 710 fed_model val loss 2.0139 val acc 27.5568 best val_acc 38.372565 te_acc 38.636364
trmean: at min-sum n_at 10 | e 720 fed_model val loss 1.8418 val acc 31.1688 best val_acc 38.372565 te_acc 38.636364
trmean: at min-sum n_at 10 | e 730 fed_model val loss 1.7134 val acc 37.7435 best val_acc 40.036526 te_acc 39.285714
trmean: at min-sum n_at 10 | e 740 fed_model val loss 2.1116 val acc 21.8953 best val_acc 40.036526 te_acc 39.285714
trmean: at min-sum n_at 10 | e 750 fed_model val loss 1.8158 val acc 30.3166 best val_acc 40.036526 te_acc 39.285714
trmean: at min-sum n_at 10 | e 760 fed_model val loss 1.8922 val acc 29.5657 best val_acc 40.036526 te_acc 39.285714
trmean: at min-sum n_at 10 | e 770 fed_model val loss 1.6282 val acc 40.9700 best val_acc 40.969968 te_acc 41.355519
trmean: at min-sum n_at 10 | e 780 fed_model val loss 1.5809 val