# The notebook contains
### Code for _Trimmed-mean_ aggregation algorithm
### Evaluation of all of the attacks (Fang, LIE, and our SOTA AGR-tailored and AGR-agnstic) on Trimmed-mean

In [1]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:90% !important; }</style>"))

In [3]:
import json
from __future__ import print_function
import argparse, os, sys, csv, shutil, time, random, operator, pickle, ast, math
import numpy as np
import pandas as pd
from torch.optim import Optimizer
import torch.nn.functional as F
import torch
import pickle
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data as data
import torch.multiprocessing as mp

sys.path.insert(0,'/home/vshejwalkar/fed-quant-robustness/code/utils/')
from logger import *
from eval import *
from misc import *

from femnist_normal_train import *
from femnist_util import *
from adam import Adam
from sgd import SGD
import torchvision.transforms as transforms
import torchvision.datasets as datasets

## Get the FEMNIST dataset; we use [LEAF framework](https://leaf.cmu.edu/)

In [6]:
user_tr_data = []
user_tr_labels = []

for i in range(34):
    f = '/mnt/nfs/work1/amir/vshejwalkar/leaf/data/femnist/data/train/all_data_%d_niid_0_keep_0_train_9.json'%i
    with open(f, 'r') as myfile:
        data=myfile.read()
    obj = json.loads(data)
    
    for user in obj['users']:
        user_tr_data.append(obj['user_data'][user]['x'])
        user_tr_labels.append(obj['user_data'][user]['y'])

user_te_data = []
user_te_labels = []

for i in range(34):
    f = '/mnt/nfs/work1/amir/vshejwalkar/leaf/data/femnist/data/test/all_data_%d_niid_0_keep_0_test_9.json'%i
    with open(f, 'r') as myfile:
        data=myfile.read()
    obj = json.loads(data)
    
    for user in obj['users']:
        user_te_data.append(obj['user_data'][user]['x'])
        user_te_labels.append(obj['user_data'][user]['y'])

In [7]:
user_tr_data_tensors=[]
user_tr_label_tensors=[]

for i in range(len(user_tr_data)):
    
    user_tr_data_tensor=torch.from_numpy(np.array(user_tr_data[i])).type(torch.FloatTensor)
    user_tr_label_tensor=torch.from_numpy(np.array(user_tr_labels[i])).type(torch.LongTensor)

    user_tr_data_tensors.append(user_tr_data_tensor)
    user_tr_label_tensors.append(user_tr_label_tensor)
    
    print('user %d tr len %d'%(i,len(user_tr_data_tensor)))

user 0 tr len 333
user 1 tr len 310
user 2 tr len 375
user 3 tr len 344
user 4 tr len 319
user 5 tr len 218
user 6 tr len 119
user 7 tr len 275
user 8 tr len 270
user 9 tr len 377
user 10 tr len 335
user 11 tr len 346
user 12 tr len 327
user 13 tr len 358
user 14 tr len 337
user 15 tr len 354
user 16 tr len 369
user 17 tr len 367
user 18 tr len 276
user 19 tr len 299
user 20 tr len 348
user 21 tr len 342
user 22 tr len 334
user 23 tr len 378
user 24 tr len 333
user 25 tr len 390
user 26 tr len 160
user 27 tr len 351
user 28 tr len 363
user 29 tr len 378
user 30 tr len 337
user 31 tr len 354
user 32 tr len 267
user 33 tr len 398
user 34 tr len 289
user 35 tr len 351
user 36 tr len 387
user 37 tr len 390
user 38 tr len 225
user 39 tr len 344
user 40 tr len 323
user 41 tr len 377
user 42 tr len 292
user 43 tr len 324
user 44 tr len 185
user 45 tr len 376
user 46 tr len 225
user 47 tr len 373
user 48 tr len 360
user 49 tr len 385
user 50 tr len 324
user 51 tr len 321
user 52 tr len 353
use

user 416 tr len 315
user 417 tr len 356
user 418 tr len 336
user 419 tr len 322
user 420 tr len 329
user 421 tr len 379
user 422 tr len 285
user 423 tr len 271
user 424 tr len 293
user 425 tr len 281
user 426 tr len 358
user 427 tr len 282
user 428 tr len 324
user 429 tr len 333
user 430 tr len 141
user 431 tr len 338
user 432 tr len 276
user 433 tr len 345
user 434 tr len 248
user 435 tr len 314
user 436 tr len 204
user 437 tr len 293
user 438 tr len 261
user 439 tr len 372
user 440 tr len 194
user 441 tr len 344
user 442 tr len 312
user 443 tr len 274
user 444 tr len 199
user 445 tr len 367
user 446 tr len 368
user 447 tr len 209
user 448 tr len 342
user 449 tr len 271
user 450 tr len 271
user 451 tr len 336
user 452 tr len 297
user 453 tr len 321
user 454 tr len 393
user 455 tr len 314
user 456 tr len 238
user 457 tr len 287
user 458 tr len 374
user 459 tr len 234
user 460 tr len 313
user 461 tr len 217
user 462 tr len 301
user 463 tr len 353
user 464 tr len 355
user 465 tr len 362


user 842 tr len 153
user 843 tr len 164
user 844 tr len 152
user 845 tr len 126
user 846 tr len 163
user 847 tr len 135
user 848 tr len 159
user 849 tr len 140
user 850 tr len 160
user 851 tr len 158
user 852 tr len 155
user 853 tr len 156
user 854 tr len 160
user 855 tr len 152
user 856 tr len 152
user 857 tr len 143
user 858 tr len 155
user 859 tr len 142
user 860 tr len 160
user 861 tr len 155
user 862 tr len 157
user 863 tr len 157
user 864 tr len 158
user 865 tr len 154
user 866 tr len 162
user 867 tr len 122
user 868 tr len 152
user 869 tr len 150
user 870 tr len 157
user 871 tr len 161
user 872 tr len 164
user 873 tr len 135
user 874 tr len 148
user 875 tr len 142
user 876 tr len 158
user 877 tr len 135
user 878 tr len 156
user 879 tr len 161
user 880 tr len 151
user 881 tr len 151
user 882 tr len 154
user 883 tr len 164
user 884 tr len 153
user 885 tr len 164
user 886 tr len 136
user 887 tr len 152
user 888 tr len 146
user 889 tr len 153
user 890 tr len 152
user 891 tr len 162


user 1263 tr len 133
user 1264 tr len 122
user 1265 tr len 147
user 1266 tr len 138
user 1267 tr len 154
user 1268 tr len 149
user 1269 tr len 90
user 1270 tr len 127
user 1271 tr len 111
user 1272 tr len 155
user 1273 tr len 161
user 1274 tr len 158
user 1275 tr len 88
user 1276 tr len 138
user 1277 tr len 151
user 1278 tr len 78
user 1279 tr len 147
user 1280 tr len 94
user 1281 tr len 135
user 1282 tr len 145
user 1283 tr len 143
user 1284 tr len 149
user 1285 tr len 37
user 1286 tr len 156
user 1287 tr len 151
user 1288 tr len 144
user 1289 tr len 97
user 1290 tr len 123
user 1291 tr len 144
user 1292 tr len 148
user 1293 tr len 161
user 1294 tr len 101
user 1295 tr len 144
user 1296 tr len 153
user 1297 tr len 141
user 1298 tr len 153
user 1299 tr len 143
user 1300 tr len 149
user 1301 tr len 135
user 1302 tr len 117
user 1303 tr len 142
user 1304 tr len 146
user 1305 tr len 161
user 1306 tr len 155
user 1307 tr len 121
user 1308 tr len 138
user 1309 tr len 151
user 1310 tr len 13

user 1656 tr len 152
user 1657 tr len 146
user 1658 tr len 140
user 1659 tr len 151
user 1660 tr len 157
user 1661 tr len 154
user 1662 tr len 155
user 1663 tr len 158
user 1664 tr len 162
user 1665 tr len 160
user 1666 tr len 129
user 1667 tr len 148
user 1668 tr len 153
user 1669 tr len 162
user 1670 tr len 162
user 1671 tr len 145
user 1672 tr len 157
user 1673 tr len 142
user 1674 tr len 154
user 1675 tr len 141
user 1676 tr len 150
user 1677 tr len 153
user 1678 tr len 149
user 1679 tr len 145
user 1680 tr len 153
user 1681 tr len 155
user 1682 tr len 157
user 1683 tr len 156
user 1684 tr len 142
user 1685 tr len 161
user 1686 tr len 155
user 1687 tr len 134
user 1688 tr len 157
user 1689 tr len 164
user 1690 tr len 152
user 1691 tr len 148
user 1692 tr len 147
user 1693 tr len 160
user 1694 tr len 158
user 1695 tr len 153
user 1696 tr len 159
user 1697 tr len 157
user 1698 tr len 161
user 1699 tr len 161
user 1700 tr len 156
user 1701 tr len 153
user 1702 tr len 143
user 1703 tr 

user 2051 tr len 124
user 2052 tr len 141
user 2053 tr len 155
user 2054 tr len 132
user 2055 tr len 153
user 2056 tr len 153
user 2057 tr len 143
user 2058 tr len 163
user 2059 tr len 161
user 2060 tr len 128
user 2061 tr len 158
user 2062 tr len 126
user 2063 tr len 145
user 2064 tr len 124
user 2065 tr len 147
user 2066 tr len 135
user 2067 tr len 162
user 2068 tr len 126
user 2069 tr len 147
user 2070 tr len 162
user 2071 tr len 148
user 2072 tr len 147
user 2073 tr len 153
user 2074 tr len 153
user 2075 tr len 140
user 2076 tr len 150
user 2077 tr len 142
user 2078 tr len 162
user 2079 tr len 127
user 2080 tr len 160
user 2081 tr len 164
user 2082 tr len 151
user 2083 tr len 158
user 2084 tr len 152
user 2085 tr len 147
user 2086 tr len 134
user 2087 tr len 148
user 2088 tr len 131
user 2089 tr len 153
user 2090 tr len 152
user 2091 tr len 145
user 2092 tr len 158
user 2093 tr len 153
user 2094 tr len 161
user 2095 tr len 159
user 2096 tr len 158
user 2097 tr len 141
user 2098 tr 

user 2442 tr len 152
user 2443 tr len 162
user 2444 tr len 162
user 2445 tr len 163
user 2446 tr len 157
user 2447 tr len 155
user 2448 tr len 153
user 2449 tr len 163
user 2450 tr len 162
user 2451 tr len 162
user 2452 tr len 155
user 2453 tr len 157
user 2454 tr len 161
user 2455 tr len 162
user 2456 tr len 152
user 2457 tr len 138
user 2458 tr len 153
user 2459 tr len 158
user 2460 tr len 159
user 2461 tr len 152
user 2462 tr len 159
user 2463 tr len 138
user 2464 tr len 163
user 2465 tr len 158
user 2466 tr len 127
user 2467 tr len 148
user 2468 tr len 146
user 2469 tr len 160
user 2470 tr len 162
user 2471 tr len 148
user 2472 tr len 153
user 2473 tr len 162
user 2474 tr len 164
user 2475 tr len 164
user 2476 tr len 161
user 2477 tr len 163
user 2478 tr len 140
user 2479 tr len 166
user 2480 tr len 164
user 2481 tr len 153
user 2482 tr len 144
user 2483 tr len 164
user 2484 tr len 143
user 2485 tr len 136
user 2486 tr len 165
user 2487 tr len 165
user 2488 tr len 137
user 2489 tr 

user 2844 tr len 356
user 2845 tr len 372
user 2846 tr len 243
user 2847 tr len 301
user 2848 tr len 332
user 2849 tr len 252
user 2850 tr len 327
user 2851 tr len 296
user 2852 tr len 305
user 2853 tr len 221
user 2854 tr len 198
user 2855 tr len 232
user 2856 tr len 255
user 2857 tr len 341
user 2858 tr len 358
user 2859 tr len 288
user 2860 tr len 275
user 2861 tr len 325
user 2862 tr len 356
user 2863 tr len 331
user 2864 tr len 216
user 2865 tr len 362
user 2866 tr len 338
user 2867 tr len 352
user 2868 tr len 319
user 2869 tr len 335
user 2870 tr len 270
user 2871 tr len 298
user 2872 tr len 297
user 2873 tr len 260
user 2874 tr len 312
user 2875 tr len 360
user 2876 tr len 215
user 2877 tr len 284
user 2878 tr len 254
user 2879 tr len 277
user 2880 tr len 354
user 2881 tr len 289
user 2882 tr len 285
user 2883 tr len 333
user 2884 tr len 352
user 2885 tr len 231
user 2886 tr len 226
user 2887 tr len 298
user 2888 tr len 375
user 2889 tr len 353
user 2890 tr len 337
user 2891 tr 

user 3244 tr len 270
user 3245 tr len 355
user 3246 tr len 168
user 3247 tr len 164
user 3248 tr len 236
user 3249 tr len 97
user 3250 tr len 210
user 3251 tr len 164
user 3252 tr len 328
user 3253 tr len 288
user 3254 tr len 243
user 3255 tr len 310
user 3256 tr len 90
user 3257 tr len 317
user 3258 tr len 309
user 3259 tr len 240
user 3260 tr len 284
user 3261 tr len 213
user 3262 tr len 152
user 3263 tr len 192
user 3264 tr len 162
user 3265 tr len 235
user 3266 tr len 200
user 3267 tr len 312
user 3268 tr len 242
user 3269 tr len 273
user 3270 tr len 192
user 3271 tr len 230
user 3272 tr len 260
user 3273 tr len 221
user 3274 tr len 270
user 3275 tr len 186
user 3276 tr len 297
user 3277 tr len 134
user 3278 tr len 198
user 3279 tr len 199
user 3280 tr len 243
user 3281 tr len 342
user 3282 tr len 153
user 3283 tr len 184
user 3284 tr len 209
user 3285 tr len 252
user 3286 tr len 283
user 3287 tr len 233
user 3288 tr len 235
user 3289 tr len 162
user 3290 tr len 297
user 3291 tr le

In [8]:
te_data = np.concatenate(user_te_data, 0)
te_labels = np.concatenate(user_te_labels)
te_len = len(te_labels)

te_data_tensor = torch.from_numpy(te_data[:(te_len//2)]).type(torch.FloatTensor)
te_label_tensor = torch.from_numpy(te_labels[:(te_len//2)]).type(torch.LongTensor)

val_data_tensor = torch.from_numpy(te_data[(te_len//2):]).type(torch.FloatTensor)
val_label_tensor = torch.from_numpy(te_labels[(te_len//2):]).type(torch.LongTensor)

## Model architecture for FEMNIST

In [9]:
class mnist_conv(nn.Module):
    def __init__(self):
        super(mnist_conv, self).__init__()

        self.conv1 = nn.Conv2d(1, 16, 5, padding=2)
        self.conv2 = nn.Conv2d(16, 32, 5, padding=2)
        self.fc1 = nn.Linear(32 * 7 * 7, 512)
        self.fc2 = nn.Linear(512, 62)

    def forward(self, x, noise=torch.Tensor()):
        x = x.reshape(-1, 1, 28, 28)

        x = F.max_pool2d(F.relu(self.conv1(x)), 2)
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, 32 * 7 * 7)  # reshape Variable
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        torch.nn.init.xavier_uniform_(m.weight)
    elif classname.find('Linear') != -1:
        torch.nn.init.xavier_uniform_(m.weight)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.fill_(0)
        m.bias.data.fill_(0)

## Code for Trimmed-mean aggregation algorithm

## Code for (Full-knowledge) Fang attack on Trimmed-mean

In [10]:
def get_malicious_updates_fang_trmean(all_updates, deviation, n_attackers, epoch_num):
    b = 2
    max_vector = torch.max(all_updates, 0)[0]
    min_vector = torch.min(all_updates, 0)[0]

    max_ = (max_vector > 0).type(torch.FloatTensor).cuda()
    min_ = (min_vector < 0).type(torch.FloatTensor).cuda()

    max_[max_ == 1] = b
    max_[max_ == 0] = 1 / b
    min_[min_ == 1] = b
    min_[min_ == 0] = 1 / b

    max_range = torch.cat((max_vector[:, None], (max_vector * max_)[:, None]), dim=1)
    min_range = torch.cat(((min_vector * min_)[:, None], min_vector[:, None]), dim=1)

    rand = torch.from_numpy(np.random.uniform(0, 1, [len(deviation), n_attackers])).type(torch.FloatTensor).cuda()

    max_rand = torch.stack([max_range[:, 0]] * rand.shape[1]).T + rand * torch.stack([max_range[:, 1] - max_range[:, 0]] * rand.shape[1]).T
    min_rand = torch.stack([min_range[:, 0]] * rand.shape[1]).T + rand * torch.stack([min_range[:, 1] - min_range[:, 0]] * rand.shape[1]).T

    mal_vec = (torch.stack([(deviation > 0).type(torch.FloatTensor)] * max_rand.shape[1]).T.cuda() * max_rand + torch.stack(
        [(deviation > 0).type(torch.FloatTensor)] * min_rand.shape[1]).T.cuda() * min_rand).T

    mal_updates = torch.cat((mal_vec, all_updates), 0)

    return mal_updates

## Evaluation for Full-knolwledge Fang attack on Trimmed-mean

In [12]:
resume=0
nepochs=1500
gamma=.1
fed_lr=0.001

criterion = nn.CrossEntropyLoss()
use_cuda = torch.cuda.is_available()
batch_size = 100
schedule = [5000]

aggregation = 'trmean'
chkpt = './' + aggregation

at_type='fang'
at_fractions = [20]

for at_fraction in at_fractions:
    epoch_num = 0

    fed_model = mnist_conv().cuda()
    fed_model.apply(weights_init)
    optimizer_fed = Adam(fed_model.parameters(), lr=fed_lr)

    best_global_acc=0
    best_global_te_acc=0

    while epoch_num <= nepochs:
        user_grads = []

        round_users = np.random.choice(3400, 60)
        n_attacker = np.sum(round_users < (34*at_fraction))

        for i in round_users:
            if i < (34*at_fraction):
                continue

            inputs = user_tr_data_tensors[i]
            targets = user_tr_label_tensors[i]

            inputs, targets = inputs.cuda(), targets.cuda()
            inputs, targets = torch.autograd.Variable(inputs), torch.autograd.Variable(targets)

            outputs = fed_model(inputs)
            loss = criterion(outputs, targets)
            optimizer_fed.zero_grad()
            loss.backward(retain_graph=True)

            param_grad=[]
            for param in fed_model.parameters():
                param_grad=param.grad.data.view(-1) if not len(param_grad) else torch.cat((param_grad,param.grad.view(-1)))

            user_grads=param_grad[None,:] if len(user_grads)==0 else torch.cat((user_grads,param_grad[None,:]),0)    

        malicious_grads = user_grads

        if n_attacker:
            if at_type == 'fang':
                agg_grads = torch.mean(malicious_grads, 0)
                deviation = torch.sign(agg_grads)
                malicious_grads = get_malicious_updates_fang_trmean(malicious_grads, deviation, n_attacker, epoch_num)
            elif at_type == 'our-agr':
                pass
        
        if epoch_num == 0: print('malicious grads shape ', malicious_grads.shape)
        
        if aggregation=='median':
            agg_grads=torch.median(malicious_grads,dim=0)[0]
        elif aggregation=='average':
            agg_grads=torch.mean(malicious_grads, dim=0)
        elif aggregation=='trmean':
            agg_grads=tr_mean(malicious_grads, n_attacker)

        start_idx=0

        if epoch_num in schedule:
            for param_group in optimizer_fed.param_groups:
                param_group['lr'] *= gamma
                print('New learnin rate ', param_group['lr'])

        optimizer_fed.zero_grad()

        model_grads=[]

        for i, param in enumerate(fed_model.parameters()):
            param_=agg_grads[start_idx:start_idx+len(param.data.view(-1))].reshape(param.data.shape)
            start_idx=start_idx+len(param.data.view(-1))
            param_=param_.cuda()
            model_grads.append(param_)

        optimizer_fed.step(model_grads)

        val_loss, val_acc = test(val_data_tensor,val_label_tensor,fed_model,criterion,use_cuda)
        te_loss, te_acc = test(te_data_tensor,te_label_tensor, fed_model, criterion, use_cuda)

        is_best = best_global_acc < val_acc

        best_global_acc = max(best_global_acc, val_acc)

        if is_best:
            best_global_te_acc = te_acc

        if epoch_num % 10 == 0 or epoch_num == nepochs-1:
            print('%s: at %s n_at %d e %d | val loss %.4f val acc %.4f best val_acc %f te_acc %f'%(aggregation, at_type, n_attacker, epoch_num, val_loss, val_acc, best_global_acc,best_global_te_acc))

        epoch_num+=1

malicious grads shape  torch.Size([60, 848382])


	add_(Number alpha, Tensor other)
Consider using one of the following signatures instead:
	add_(Tensor other, *, Number alpha) (Triggered internally at  /opt/conda/conda-bld/pytorch_1603729138878/work/torch/csrc/utils/python_arg_parser.cpp:882.)
  exp_avg.mul_(beta1).add_(1 - beta1, grad)


trmean: at fang n_at 12 e 0 | val loss 3.9561 val acc 4.6824 best val_acc 4.682352 te_acc 5.341330
trmean: at fang n_at 18 e 10 | val loss 3.7198 val acc 13.8514 best val_acc 14.518122 te_acc 16.858011
trmean: at fang n_at 11 e 20 | val loss 3.6493 val acc 20.4978 best val_acc 20.497838 te_acc 23.120881
trmean: at fang n_at 14 e 30 | val loss 3.5259 val acc 23.9678 best val_acc 27.998867 te_acc 31.144460
trmean: at fang n_at 6 e 40 | val loss 3.2350 val acc 32.8794 best val_acc 33.512665 te_acc 37.536038
trmean: at fang n_at 14 e 50 | val loss 2.9609 val acc 36.0636 best val_acc 36.063633 te_acc 40.619852
trmean: at fang n_at 17 e 60 | val loss 2.6053 val acc 38.1435 best val_acc 38.143534 te_acc 42.805292
trmean: at fang n_at 9 e 70 | val loss 2.3489 val acc 40.0638 best val_acc 40.112747 te_acc 44.818266
trmean: at fang n_at 13 e 80 | val loss 2.1440 val acc 43.6548 best val_acc 43.819502 te_acc 47.922673
trmean: at fang n_at 7 e 90 | val loss 1.9943 val acc 45.7913 best val_acc 46.6

trmean: at fang n_at 10 e 800 | val loss 0.8317 val acc 74.9073 best val_acc 75.803130 te_acc 76.094007
trmean: at fang n_at 5 e 810 | val loss 0.8378 val acc 72.7502 best val_acc 75.803130 te_acc 76.094007
trmean: at fang n_at 15 e 820 | val loss 0.9097 val acc 73.2393 best val_acc 76.065692 te_acc 76.902286
trmean: at fang n_at 20 e 830 | val loss 0.8039 val acc 74.6448 best val_acc 76.065692 te_acc 76.902286
trmean: at fang n_at 15 e 840 | val loss 0.7922 val acc 74.7503 best val_acc 76.356569 te_acc 76.758134
trmean: at fang n_at 12 e 850 | val loss 0.7635 val acc 76.5779 best val_acc 76.577945 te_acc 76.948620
trmean: at fang n_at 16 e 860 | val loss 0.8155 val acc 75.6976 best val_acc 76.577945 te_acc 76.948620
trmean: at fang n_at 5 e 870 | val loss 0.8129 val acc 75.2986 best val_acc 76.577945 te_acc 76.948620
trmean: at fang n_at 11 e 880 | val loss 0.7598 val acc 76.4853 best val_acc 76.577945 te_acc 76.948620
trmean: at fang n_at 15 e 890 | val loss 0.8009 val acc 76.3128 be

## Code for our SOTA AGR-tailored attack on Trimmed-mean 

In [13]:
def our_attack_trmean(all_updates, n_attackers, dev_type='sign', threshold=5.0, threshold_diff=1e-5):
    
    model_re = torch.mean(all_updates, 0)
    
    if dev_type == 'sign':
        deviation = torch.sign(model_re)
    elif dev_type == 'unit_vec':
        deviation = model_re / torch.norm(model_re)  # unit vector, dir opp to good dir
    elif dev_type == 'std':
        deviation = torch.std(all_updates, 0)

    lamda = torch.Tensor([threshold]).cuda()  # compute_lambda_our(all_updates, model_re, n_attackers)

    threshold_diff = threshold_diff
    prev_loss = -1
    lamda_fail = lamda
    lamda_succ = 0

    while torch.abs(lamda_succ - lamda) > threshold_diff:
        mal_update = (model_re - lamda * deviation)
        mal_updates = torch.stack([mal_update] * n_attackers)
        mal_updates = torch.cat((mal_updates, all_updates), 0)

        agg_grads = tr_mean(mal_updates, n_attackers)

        loss = torch.norm(agg_grads - model_re)

        if prev_loss < loss:
            # print('successful lamda is ', lamda)
            lamda_succ = lamda
            lamda = lamda + lamda_fail / 2
        else:
            lamda = lamda - lamda_fail / 2

        lamda_fail = lamda_fail / 2
        prev_loss = loss

    mal_update = (model_re - lamda_succ * deviation)
    
    return mal_update

## Evaluation of our AGR-tailored attack on Trimmed-mean 

In [14]:
resume=0
nepochs=1500
gamma=.1
fed_lr=0.001

criterion = nn.CrossEntropyLoss()
use_cuda = torch.cuda.is_available()
batch_size = 100
schedule = [5000]

aggregation = 'trmean'
chkpt = './' + aggregation

at_type='our-agr'
at_fractions = [20]

for at_fraction in at_fractions:
    epoch_num = 0

    fed_model = mnist_conv().cuda()
    fed_model.apply(weights_init)
    optimizer_fed = Adam(fed_model.parameters(), lr=fed_lr)

    best_global_acc=0
    best_global_te_acc=0

    while epoch_num <= nepochs:
        user_grads = []

        round_users = np.random.choice(3400, 60)
        n_attacker = np.sum(round_users < (34*at_fraction))

        for i in round_users:
            if i < (34*at_fraction):
                continue

            inputs = user_tr_data_tensors[i]
            targets = user_tr_label_tensors[i]

            inputs, targets = inputs.cuda(), targets.cuda()
            inputs, targets = torch.autograd.Variable(inputs), torch.autograd.Variable(targets)

            outputs = fed_model(inputs)
            loss = criterion(outputs, targets)
            optimizer_fed.zero_grad()
            loss.backward(retain_graph=True)

            param_grad=[]
            for param in fed_model.parameters():
                param_grad=param.grad.data.view(-1) if not len(param_grad) else torch.cat((param_grad,param.grad.view(-1)))

            user_grads=param_grad[None,:] if len(user_grads)==0 else torch.cat((user_grads,param_grad[None,:]),0)    

        malicious_grads = user_grads

        if n_attacker:
            if at_type == 'fang':
                agg_grads = torch.mean(malicious_grads, 0)
                deviation = torch.sign(agg_grads)
                malicious_grads = get_malicious_updates_fang_trmean(malicious_grads, deviation, n_attacker, epoch_num)
            elif at_type == 'our-agr':
                mal_update = our_attack_trmean(malicious_grads, n_attacker, dev_type='sign', threshold=5.0, threshold_diff=1e-5)
                malicious_grads = torch.cat((torch.stack([mal_update]*n_attacker), malicious_grads))
                
        if malicious_grads.shape[0] != 60 or not epoch_num: print('malicious grads shape ', malicious_grads.shape)
        
        if aggregation=='median':
            agg_grads=torch.median(malicious_grads,dim=0)[0]
        elif aggregation=='average':
            agg_grads=torch.mean(malicious_grads, dim=0)
        elif aggregation=='trmean':
            agg_grads=tr_mean(malicious_grads, n_attacker)

        start_idx=0

        if epoch_num in schedule:
            for param_group in optimizer_fed.param_groups:
                param_group['lr'] *= gamma
                print('New learnin rate ', param_group['lr'])

        optimizer_fed.zero_grad()

        model_grads=[]

        for i, param in enumerate(fed_model.parameters()):
            param_=agg_grads[start_idx:start_idx+len(param.data.view(-1))].reshape(param.data.shape)
            start_idx=start_idx+len(param.data.view(-1))
            param_=param_.cuda()
            model_grads.append(param_)

        optimizer_fed.step(model_grads)

        val_loss, val_acc = test(val_data_tensor,val_label_tensor,fed_model,criterion,use_cuda)
        te_loss, te_acc = test(te_data_tensor,te_label_tensor, fed_model, criterion, use_cuda)

        is_best = best_global_acc < val_acc

        best_global_acc = max(best_global_acc, val_acc)

        if is_best:
            best_global_te_acc = te_acc

        if epoch_num % 10 == 0 or epoch_num == nepochs-1:
            print('%s: at %s n_at %d e %d | val loss %.4f val acc %.4f best val_acc %f te_acc %f'%(aggregation, at_type, n_attacker, epoch_num, val_loss, val_acc, best_global_acc,best_global_te_acc))

        epoch_num+=1

malicious grads shape  torch.Size([60, 848382])
trmean: at our-agr n_at 13 e 0 | val loss 4.0018 val acc 4.3735 best val_acc 4.373456 te_acc 5.094213
trmean: at our-agr n_at 8 e 10 | val loss 3.7800 val acc 9.5912 best val_acc 9.591227 te_acc 10.798497
trmean: at our-agr n_at 12 e 20 | val loss 3.8805 val acc 10.8963 best val_acc 10.896314 te_acc 11.874485
trmean: at our-agr n_at 12 e 30 | val loss 3.8749 val acc 14.9454 best val_acc 15.390754 te_acc 16.690692
trmean: at our-agr n_at 12 e 40 | val loss 3.8133 val acc 25.6101 best val_acc 25.610070 te_acc 28.724773
trmean: at our-agr n_at 12 e 50 | val loss 3.7908 val acc 22.1710 best val_acc 27.880457 te_acc 31.250000
trmean: at our-agr n_at 6 e 60 | val loss 3.5095 val acc 29.9501 best val_acc 29.950062 te_acc 33.692854
trmean: at our-agr n_at 8 e 70 | val loss 3.3973 val acc 27.8470 best val_acc 29.950062 te_acc 33.692854
trmean: at our-agr n_at 13 e 80 | val loss 3.2863 val acc 28.3953 best val_acc 30.261532 te_acc 33.515239
trmean:

trmean: at our-agr n_at 13 e 770 | val loss 1.9978 val acc 51.0400 best val_acc 53.441619 te_acc 55.863880
trmean: at our-agr n_at 12 e 780 | val loss 2.3891 val acc 46.4451 best val_acc 53.441619 te_acc 55.863880
trmean: at our-agr n_at 13 e 790 | val loss 1.9892 val acc 50.8237 best val_acc 53.441619 te_acc 55.863880
trmean: at our-agr n_at 11 e 800 | val loss 1.8984 val acc 52.4197 best val_acc 53.441619 te_acc 55.863880
trmean: at our-agr n_at 12 e 810 | val loss 1.9755 val acc 50.1673 best val_acc 53.441619 te_acc 55.863880
trmean: at our-agr n_at 11 e 820 | val loss 1.9303 val acc 51.4003 best val_acc 53.441619 te_acc 55.863880
trmean: at our-agr n_at 9 e 830 | val loss 2.0290 val acc 50.1956 best val_acc 53.441619 te_acc 55.863880
trmean: at our-agr n_at 12 e 840 | val loss 1.9323 val acc 51.9383 best val_acc 53.441619 te_acc 55.863880
trmean: at our-agr n_at 12 e 850 | val loss 2.0134 val acc 50.3475 best val_acc 53.441619 te_acc 55.863880
trmean: at our-agr n_at 14 e 860 | val

## Code for our first SOTA AGR-agnostic attack - Min-max

In [15]:
def our_attack_dist(all_updates, model_re, n_attackers, dev_type='unit_vec'):

    if dev_type == 'unit_vec':
        deviation = model_re / torch.norm(model_re)  # unit vector, dir opp to good dir
    elif dev_type == 'sign':
        deviation = torch.sign(model_re)
    elif dev_type == 'std':
        deviation = torch.std(all_updates, 0)

    lamda = torch.Tensor([50.0]).float().cuda()
    # print(lamda)
    threshold_diff = 1e-5
    lamda_fail = lamda
    lamda_succ = 0
    
    distances = []
    for update in all_updates:
        distance = torch.norm((all_updates - update), dim=1) ** 2
        distances = distance[None, :] if not len(distances) else torch.cat((distances, distance[None, :]), 0)
    
    max_distance = torch.max(distances)
    del distances

    while torch.abs(lamda_succ - lamda) > threshold_diff:
        mal_update = (model_re - lamda * deviation)
        distance = torch.norm((all_updates - mal_update), dim=1) ** 2
        max_d = torch.max(distance)
        
        if max_d <= max_distance:
            # print('successful lamda is ', lamda)
            lamda_succ = lamda
            lamda = lamda + lamda_fail / 2
        else:
            lamda = lamda - lamda_fail / 2

        lamda_fail = lamda_fail / 2

    mal_update = (model_re - lamda_succ * deviation)
    
    return mal_update

## Evaluation of our first SOTA AGR-agnostic attack - Min-max - on Trimmed-mean

In [16]:
resume=0
nepochs=1500
gamma=.1
fed_lr=0.001

criterion = nn.CrossEntropyLoss()
use_cuda = torch.cuda.is_available()
batch_size = 100
schedule = [10000]
gamma=.5

aggregation='trmean'
candidates = []

at_type='min-max'
dev_type='sign'
at_fractions=[20]

chkpt='./'+aggregation


for at_fraction in at_fractions:
    epoch_num = 0
    best_global_acc = 0
    best_global_te_acc = 0

    print('\n====> dev type %s batch size %d lr %f'% (dev_type, batch_size, fed_lr))    

    torch.cuda.empty_cache()

    fed_model = mnist_conv().cuda()
    fed_model.apply(weights_init)
    optimizer_fed = Adam(fed_model.parameters(), lr=fed_lr)

    while epoch_num <= nepochs:
        user_grads = []
        round_users = np.random.choice(3400, 60)
        n_attacker = np.sum(round_users < (34*at_fraction))

        for i in round_users:
            if i < (34*at_fraction):
                continue

            inputs = user_tr_data_tensors[i]
            targets = user_tr_label_tensors[i]

            inputs, targets = inputs.cuda(), targets.cuda()
            inputs, targets = torch.autograd.Variable(inputs), torch.autograd.Variable(targets)

            outputs = fed_model(inputs)
            loss = criterion(outputs, targets)
            optimizer_fed.zero_grad()
            loss.backward(retain_graph=True)

            param_grad=[]
            for param in fed_model.parameters():
                param_grad=param.grad.data.view(-1) if not len(param_grad) else torch.cat((param_grad,param.grad.view(-1)))

            user_grads=param_grad[None,:] if len(user_grads)==0 else torch.cat((user_grads,param_grad[None,:]),0)    

        malicious_grads = user_grads

        if epoch_num in schedule:
            for param_group in optimizer_fed.param_groups:
                param_group['lr'] *= gamma
                print('New learnin rate ', param_group['lr'])

        if n_attacker > 0:
            if at_type == 'lie':
                malicious_grads = get_malicious_updates_lie(malicious_grads, n_attacker, z, epoch_num)
            elif at_type == 'fang':
                agg_grads = torch.mean(malicious_grads, 0)
                deviation = torch.sign(agg_grads)
                malicious_grads = get_malicious_updates_fang(malicious_grads, agg_grads, deviation, n_attacker)
            elif at_type == 'our_attack_trmean':
                agg_grads = torch.mean(malicious_grads, 0)
                mal_update = our_attack_trmean(malicious_grads, agg_grads, n_attacker, dev_type)
            elif at_type == 'our_attack_median':
                agg_grads = torch.mean(malicious_grads, 0)
                mal_update = our_attack_median(malicious_grads, agg_grads, n_attacker, dev_type)
            elif at_type == 'min-max':
                agg_grads = torch.mean(malicious_grads, 0)
                mal_update = our_attack_dist(malicious_grads, agg_grads, n_attacker, dev_type)
            elif at_type == 'min-sum':
                agg_grads = torch.mean(malicious_grads, 0)
                mal_update = our_attack_score(malicious_grads, agg_grads, n_attacker, dev_type)

            mal_updates = torch.stack([mal_update] * n_attacker)
            malicious_grads = torch.cat((mal_updates, user_grads), 0)

        if epoch_num==0: print('malicious_grads shape ', malicious_grads.shape)

        if aggregation=='median':
            agg_grads=torch.median(malicious_grads,dim=0)[0]

        elif aggregation=='average':
            agg_grads=torch.mean(malicious_grads,dim=0)

        elif aggregation=='trmean':
            agg_grads=tr_mean(malicious_grads, n_attacker)

        del user_grads

        start_idx=0

        optimizer_fed.zero_grad()

        model_grads=[]

        for i, param in enumerate(fed_model.parameters()):
            param_=agg_grads[start_idx:start_idx+len(param.data.view(-1))].reshape(param.data.shape)
            start_idx=start_idx+len(param.data.view(-1))
            param_=param_.cuda()
            model_grads.append(param_)

        optimizer_fed.step(model_grads)

        val_loss, val_acc = test(val_data_tensor,val_label_tensor,fed_model,criterion,use_cuda)
        te_loss, te_acc = test(te_data_tensor,te_label_tensor, fed_model, criterion, use_cuda)

        is_best = best_global_acc < val_acc

        best_global_acc = max(best_global_acc, val_acc)

        if is_best:
            best_global_te_acc = te_acc

        if epoch_num%10==0 or epoch_num==nepochs-1:
            print('SB %s: at %s n_at %d e %d fed_model val loss %.4f val acc %.4f best val_acc %f te_acc %f'%(aggregation, at_type, n_attacker, epoch_num, val_loss, val_acc, best_global_acc,best_global_te_acc))

        epoch_num+=1


====> dev type sign batch size 100 lr 0.001000
malicious_grads shape  torch.Size([60, 848382])
SB trmean: at min-max n_at 15 e 0 fed_model val loss 4.0061 val acc 6.8678 best val_acc 6.867792 te_acc 7.593699
SB trmean: at min-max n_at 7 e 10 fed_model val loss 3.7577 val acc 5.7352 best val_acc 7.681219 te_acc 8.345346
SB trmean: at min-max n_at 11 e 20 fed_model val loss 3.8456 val acc 19.6587 best val_acc 19.658670 te_acc 21.195428
SB trmean: at min-max n_at 14 e 30 fed_model val loss 3.8322 val acc 18.1219 best val_acc 19.658670 te_acc 21.195428
SB trmean: at min-max n_at 9 e 40 fed_model val loss 3.7717 val acc 27.3347 best val_acc 27.556116 te_acc 30.843287
SB trmean: at min-max n_at 17 e 50 fed_model val loss 3.6193 val acc 26.7272 best val_acc 27.823826 te_acc 31.306631
SB trmean: at min-max n_at 8 e 60 fed_model val loss 3.5221 val acc 30.5524 best val_acc 30.552409 te_acc 34.130457
SB trmean: at min-max n_at 14 e 70 fed_model val loss 3.2932 val acc 32.9644 best val_acc 32.96

SB trmean: at min-max n_at 10 e 690 fed_model val loss 2.2450 val acc 48.4993 best val_acc 51.814765 te_acc 54.659185
SB trmean: at min-max n_at 14 e 700 fed_model val loss 2.3541 val acc 45.1581 best val_acc 51.814765 te_acc 54.659185
SB trmean: at min-max n_at 18 e 710 fed_model val loss 2.3150 val acc 45.8016 best val_acc 51.814765 te_acc 54.659185
SB trmean: at min-max n_at 14 e 720 fed_model val loss 2.2835 val acc 46.4117 best val_acc 51.814765 te_acc 54.659185
SB trmean: at min-max n_at 12 e 730 fed_model val loss 2.2309 val acc 48.0540 best val_acc 51.814765 te_acc 54.659185
SB trmean: at min-max n_at 13 e 740 fed_model val loss 2.1780 val acc 47.0089 best val_acc 51.814765 te_acc 54.659185
SB trmean: at min-max n_at 10 e 750 fed_model val loss 2.2169 val acc 49.7889 best val_acc 51.814765 te_acc 54.659185
SB trmean: at min-max n_at 15 e 760 fed_model val loss 2.2376 val acc 47.8969 best val_acc 51.814765 te_acc 54.659185
SB trmean: at min-max n_at 14 e 770 fed_model val loss 2

SB trmean: at min-max n_at 13 e 1390 fed_model val loss 2.2507 val acc 50.3012 best val_acc 54.468699 te_acc 57.174114
SB trmean: at min-max n_at 17 e 1400 fed_model val loss 2.3718 val acc 49.3153 best val_acc 54.468699 te_acc 57.174114
SB trmean: at min-max n_at 19 e 1410 fed_model val loss 2.1332 val acc 52.1829 best val_acc 54.468699 te_acc 57.174114
SB trmean: at min-max n_at 18 e 1420 fed_model val loss 2.4167 val acc 49.4131 best val_acc 54.468699 te_acc 57.174114
SB trmean: at min-max n_at 13 e 1430 fed_model val loss 2.2025 val acc 53.0684 best val_acc 54.468699 te_acc 57.174114
SB trmean: at min-max n_at 9 e 1440 fed_model val loss 2.2792 val acc 50.8778 best val_acc 54.468699 te_acc 57.174114
SB trmean: at min-max n_at 7 e 1450 fed_model val loss 2.2513 val acc 52.0876 best val_acc 54.468699 te_acc 57.174114
SB trmean: at min-max n_at 14 e 1460 fed_model val loss 2.3289 val acc 49.6190 best val_acc 54.468699 te_acc 57.174114
SB trmean: at min-max n_at 10 e 1470 fed_model val

## Code for our second SOTA AGR-agnostic attack - Min-sum

In [17]:
def our_attack_score(all_updates, model_re, n_attackers, dev_type='unit_vec'):

    if dev_type == 'unit_vec':
        deviation = model_re / torch.norm(model_re)  # unit vector, dir opp to good dir
    elif dev_type == 'sign':
        deviation = torch.sign(model_re)
    elif dev_type == 'std':
        deviation = torch.std(all_updates, 0)
    
    lamda = torch.Tensor([50.0]).float().cuda()
    # print(lamda)
    threshold_diff = 1e-5
    lamda_fail = lamda
    lamda_succ = 0
    
    distances = []
    for update in all_updates:
        distance = torch.norm((all_updates - update), dim=1) ** 2
        distances = distance[None, :] if not len(distances) else torch.cat((distances, distance[None, :]), 0)
    
    scores = torch.sum(distances, dim=1)
    min_score = torch.min(scores)
    del distances

    while torch.abs(lamda_succ - lamda) > threshold_diff:
        mal_update = (model_re - lamda * deviation)
        distance = torch.norm((all_updates - mal_update), dim=1) ** 2
        score = torch.sum(distance)
        
        if score <= min_score:
            # print('successful lamda is ', lamda)
            lamda_succ = lamda
            lamda = lamda + lamda_fail / 2
        else:
            lamda = lamda - lamda_fail / 2

        lamda_fail = lamda_fail / 2

    # print(lamda_succ)
    mal_update = (model_re - lamda_succ * deviation)
    
    return mal_update
    

## Evaluation of our second SOTA AGR-agnostic attack - Min-sum - on Trimmed-mean

In [18]:
resume=0
nepochs=1500
gamma=.1
fed_lr=0.001

criterion = nn.CrossEntropyLoss()
use_cuda = torch.cuda.is_available()
batch_size = 100
schedule = [10000]
gamma=.5

aggregation='trmean'
candidates = []

at_type='min-sum'
dev_type='sign'
at_fractions=[20]

chkpt='./'+aggregation


for at_fraction in at_fractions:
    epoch_num = 0
    best_global_acc = 0
    best_global_te_acc = 0

    print('\n====> dev type %s batch size %d lr %f'% (dev_type, batch_size, fed_lr))    

    torch.cuda.empty_cache()

    fed_model = mnist_conv().cuda()
    fed_model.apply(weights_init)
    optimizer_fed = Adam(fed_model.parameters(), lr=fed_lr)

    while epoch_num <= nepochs:
        user_grads = []
        round_users = np.random.choice(3400, 60)
        n_attacker = np.sum(round_users < (34*at_fraction))

        for i in round_users:
            if i < (34*at_fraction):
                continue

            inputs = user_tr_data_tensors[i]
            targets = user_tr_label_tensors[i]

            inputs, targets = inputs.cuda(), targets.cuda()
            inputs, targets = torch.autograd.Variable(inputs), torch.autograd.Variable(targets)

            outputs = fed_model(inputs)
            loss = criterion(outputs, targets)
            optimizer_fed.zero_grad()
            loss.backward(retain_graph=True)

            param_grad=[]
            for param in fed_model.parameters():
                param_grad=param.grad.data.view(-1) if not len(param_grad) else torch.cat((param_grad,param.grad.view(-1)))

            user_grads=param_grad[None,:] if len(user_grads)==0 else torch.cat((user_grads,param_grad[None,:]),0)    

        malicious_grads = user_grads

        if epoch_num in schedule:
            for param_group in optimizer_fed.param_groups:
                param_group['lr'] *= gamma
                print('New learnin rate ', param_group['lr'])

        if n_attacker > 0:
            if at_type == 'lie':
                malicious_grads = get_malicious_updates_lie(malicious_grads, n_attacker, z, epoch_num)
            elif at_type == 'fang':
                agg_grads = torch.mean(malicious_grads, 0)
                deviation = torch.sign(agg_grads)
                malicious_grads = get_malicious_updates_fang(malicious_grads, agg_grads, deviation, n_attacker)
            elif at_type == 'our_attack_trmean':
                agg_grads = torch.mean(malicious_grads, 0)
                mal_update = our_attack_trmean(malicious_grads, agg_grads, n_attacker, dev_type)
            elif at_type == 'our_attack_median':
                agg_grads = torch.mean(malicious_grads, 0)
                mal_update = our_attack_median(malicious_grads, agg_grads, n_attacker, dev_type)
            elif at_type == 'min-max':
                agg_grads = torch.mean(malicious_grads, 0)
                mal_update = our_attack_dist(malicious_grads, agg_grads, n_attacker, dev_type)
            elif at_type == 'min-sum':
                agg_grads = torch.mean(malicious_grads, 0)
                mal_update = our_attack_score(malicious_grads, agg_grads, n_attacker, dev_type)

            mal_updates = torch.stack([mal_update] * n_attacker)
            malicious_grads = torch.cat((mal_updates, user_grads), 0)

        if epoch_num==0: print('malicious_grads shape ', malicious_grads.shape)

        if aggregation=='median':
            agg_grads=torch.median(malicious_grads,dim=0)[0]

        elif aggregation=='average':
            agg_grads=torch.mean(malicious_grads,dim=0)

        elif aggregation=='trmean':
            agg_grads=tr_mean(malicious_grads, n_attacker)

        del user_grads

        start_idx=0

        optimizer_fed.zero_grad()

        model_grads=[]

        for i, param in enumerate(fed_model.parameters()):
            param_=agg_grads[start_idx:start_idx+len(param.data.view(-1))].reshape(param.data.shape)
            start_idx=start_idx+len(param.data.view(-1))
            param_=param_.cuda()
            model_grads.append(param_)

        optimizer_fed.step(model_grads)

        val_loss, val_acc = test(val_data_tensor,val_label_tensor,fed_model,criterion,use_cuda)
        te_loss, te_acc = test(te_data_tensor,te_label_tensor, fed_model, criterion, use_cuda)

        is_best = best_global_acc < val_acc

        best_global_acc = max(best_global_acc, val_acc)

        if is_best:
            best_global_te_acc = te_acc

        if epoch_num%10==0 or epoch_num==nepochs-1:
            print('SB %s: at %s n_at %d e %d fed_model val loss %.4f val acc %.4f best val_acc %f te_acc %f'%(aggregation, at_type, n_attacker, epoch_num, val_loss, val_acc, best_global_acc,best_global_te_acc))

        epoch_num+=1


====> dev type sign batch size 100 lr 0.001000
malicious_grads shape  torch.Size([60, 848382])
SB trmean: at min-sum n_at 9 e 0 fed_model val loss 3.9819 val acc 5.1817 best val_acc 5.181734 te_acc 5.830416
SB trmean: at min-sum n_at 7 e 10 fed_model val loss 3.7489 val acc 4.8986 best val_acc 8.504942 te_acc 9.336388
SB trmean: at min-sum n_at 11 e 20 fed_model val loss 3.7081 val acc 16.5929 best val_acc 19.187603 te_acc 21.403933
SB trmean: at min-sum n_at 10 e 30 fed_model val loss 3.6506 val acc 29.3632 best val_acc 29.363159 te_acc 33.229510
SB trmean: at min-sum n_at 17 e 40 fed_model val loss 3.5488 val acc 31.6490 best val_acc 32.120058 te_acc 36.184617
SB trmean: at min-sum n_at 8 e 50 fed_model val loss 3.1874 val acc 34.4316 best val_acc 34.431631 te_acc 39.034185
SB trmean: at min-sum n_at 10 e 60 fed_model val loss 2.9517 val acc 35.2064 best val_acc 35.278521 te_acc 39.814147
SB trmean: at min-sum n_at 11 e 70 fed_model val loss 2.6313 val acc 35.3480 best val_acc 36.23

SB trmean: at min-sum n_at 12 e 690 fed_model val loss 1.8571 val acc 54.2113 best val_acc 57.622014 te_acc 59.822900
SB trmean: at min-sum n_at 14 e 700 fed_model val loss 1.8501 val acc 53.7273 best val_acc 57.622014 te_acc 59.822900
SB trmean: at min-sum n_at 14 e 710 fed_model val loss 1.8066 val acc 54.2113 best val_acc 57.622014 te_acc 59.822900
SB trmean: at min-sum n_at 13 e 720 fed_model val loss 1.7856 val acc 55.9797 best val_acc 57.622014 te_acc 59.822900
SB trmean: at min-sum n_at 10 e 730 fed_model val loss 1.8337 val acc 54.0259 best val_acc 57.622014 te_acc 59.822900
SB trmean: at min-sum n_at 13 e 740 fed_model val loss 1.8413 val acc 54.1598 best val_acc 57.622014 te_acc 59.822900
SB trmean: at min-sum n_at 11 e 750 fed_model val loss 1.7774 val acc 55.4829 best val_acc 57.622014 te_acc 59.822900
SB trmean: at min-sum n_at 13 e 760 fed_model val loss 1.8674 val acc 54.3194 best val_acc 57.622014 te_acc 59.822900
SB trmean: at min-sum n_at 13 e 770 fed_model val loss 1

SB trmean: at min-sum n_at 8 e 1390 fed_model val loss 1.9212 val acc 54.0594 best val_acc 57.622014 te_acc 59.822900
SB trmean: at min-sum n_at 17 e 1400 fed_model val loss 1.8338 val acc 55.9463 best val_acc 57.622014 te_acc 59.822900
SB trmean: at min-sum n_at 14 e 1410 fed_model val loss 1.8456 val acc 55.0685 best val_acc 57.622014 te_acc 59.822900
SB trmean: at min-sum n_at 16 e 1420 fed_model val loss 1.9152 val acc 54.0285 best val_acc 57.622014 te_acc 59.822900
SB trmean: at min-sum n_at 13 e 1430 fed_model val loss 1.9718 val acc 53.2897 best val_acc 57.622014 te_acc 59.822900
SB trmean: at min-sum n_at 12 e 1440 fed_model val loss 1.9216 val acc 54.7081 best val_acc 57.622014 te_acc 59.822900
SB trmean: at min-sum n_at 12 e 1450 fed_model val loss 2.1587 val acc 51.0297 best val_acc 57.622014 te_acc 59.822900
SB trmean: at min-sum n_at 15 e 1460 fed_model val loss 2.4984 val acc 48.0050 best val_acc 57.622014 te_acc 59.822900
SB trmean: at min-sum n_at 13 e 1470 fed_model va