# AGR-agnostic attack

In [None]:
from __future__ import print_function
import argparse, os, sys, csv, shutil, time, random, operator, pickle, ast, math, errno
import numpy as np
import pandas as pd
from torch.optim import Optimizer
import torch.nn.functional as F
import torch
import pickle
import torch.nn as nn
import torch.nn.init as init
from torch.autograd import Variable
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data as data
import torch.multiprocessing as mp
import os

In [None]:
#MIN-MAX attack
def our_attack_dist(all_updates, model_re, n_attackers, dev_type='unit_vec', threshold=30):
    """
    This attack is aim to generate the malicios updates such that its maximum 
    distance is capped at the maximum distance between any two benign gradients.
    Input: 
    attack_tp: the type of attacks (min-sum,min-max)
    all_updates: grediants of users
    model_re: the global updates after aggregation
    n_attackers: the number of attackers
    dev_type: the type of disturbance variable

    Output: 
    mal_update: malicious updates
    """

    # Type of disturbance variable
    if dev_type == 'unit_vec': # Inverse unit vector
        # unit vector, direction opposite to good dir
        deviation = model_re / torch.norm(model_re)

    elif dev_type == 'sign': # Inverse sign
        deviation = torch.sign(model_re) # output a tensor after sign function
        # effective for classification task 

    elif dev_type == 'std': # Inverse standard deviation
        deviation = torch.std(all_updates, 0) 

    lamda = torch.Tensor([10.0]).float().cuda()#initialize lamda 
    threshold_diff = 1e-5           #the threshold of lamda
    lamda_fail = lamda
    lamda_succ = 0
    
    #computer the distance of all benign uqdates
    distances = []
    for update in all_updates:
        distance = torch.norm((all_updates - update), dim=1) ** 2
        distances = distance[None, :] if not len(distances) else torch.cat((distances, distance[None, :]), 0)
    
    max_distance = torch.max(distances)# the maximum distance from all updates
    del distances

    while torch.abs(lamda_succ - lamda) > threshold_diff:
        mal_update = (model_re - lamda * deviation)
        distance = torch.norm((all_updates - mal_update), dim=1) ** 2
        max_d = torch.max(distance)
        
        if max_d <= max_distance:
            # print('successful lamda is ', lamda)
            lamda_succ = lamda
            lamda = lamda + lamda_fail / 2
        else:
            lamda = lamda - lamda_fail / 2

        lamda_fail = lamda_fail / 2

    mal_update = (model_re - lamda_succ * deviation)
    
    return mal_update
 

In [None]:
#MIN-SUM attack
def our_attack_score(all_updates, model_re, n_attackers, dev_type='unit_vec', threshold=30):
    """
    This attack is aim to generate the malicios updates such that its maximum 
    distance is capped at the maximum distance between any two benign gradients.
    Input: 
    attack_tp: the type of attacks (min-sum,min-max)
    all_updates: grediants of users
    model_re: the global updates after aggregation
    n_attackers: the number of attackers
    dev_type: the type of disturbance variable

    Output: 
    mal_update: malicious updates
    """

    # Type of disturbance variable
    if dev_type == 'unit_vec': # Inverse unit vector
        # unit vector, direction opposite to good dir
        deviation = model_re / torch.norm(model_re)

    elif dev_type == 'sign': # Inverse sign
        deviation = torch.sign(model_re) # output a tensor after sign function
        # effective for classification task 

    elif dev_type == 'std': # Inverse standard deviation
        deviation = torch.std(all_updates, 0) 

    lamda = torch.Tensor([10.0]).float().cuda()#initialize lamda 
    threshold_diff = 1e-5           #the threshold of lamda
    lamda_fail = lamda
    lamda_succ = 0
    
    #computer the distance of all benign uqdates
    distances = []
    for update in all_updates:
        distance = torch.norm((all_updates - update), dim=1) ** 2
        distances = distance[None, :] if not len(distances) else torch.cat((distances, distance[None, :]), 0)
    
    scores = torch.sum(distances, dim=1)
    min_score = torch.min(scores)
    del distances

    while torch.abs(lamda_succ - lamda) > threshold_diff:
        mal_update = (model_re - lamda * deviation)
        distance = torch.norm((all_updates - mal_update), dim=1) ** 2
        score = torch.sum(distance)
        
        if score <= min_score:
            # print('successful lamda is ', lamda)
            lamda_succ = lamda
            lamda = lamda + lamda_fail / 2
        else:
            lamda = lamda - lamda_fail / 2

        lamda_fail = lamda_fail / 2

    # print(lamda_succ)
    mal_update = (model_re - lamda_succ * deviation)
    
    return mal_update
 

# utils

In [None]:
#misc.py
'''Some helper functions for PyTorch, including:
    - get_mean_and_std: calculate the mean and std value of dataset.
    - msr_init: net parameter initialization.
    - progress_bar: progress bar mimic xlua.progress.
'''
import errno
import os
import sys
import time
import math

import torch.nn as nn
import torch.nn.init as init
from torch.autograd import Variable

__all__ = ['get_mean_and_std', 'init_params', 'mkdir_p', 'AverageMeter']


def get_mean_and_std(dataset):
    '''Compute the mean and std value of dataset.'''
    dataloader = trainloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2)
    print(dataloader)
    mean = torch.zeros(3)
    std = torch.zeros(3)

    print('==> Computing mean and std..')
    for inputs, targets in dataloader:
        for i in range(3):
            mean[i] += inputs[:,i,:,:].mean()
            print(mean)
            std[i] += inputs[:,i,:,:].std()
            print(std)
    mean.div_(len(dataset))
    print(mean.div_(len(dataset)))
    std.div_(len(dataset))
    return mean, std

def init_params(net):
    '''Init layer parameters.'''
    for m in net.modules():
        if isinstance(m, nn.Conv2d):
            init.kaiming_normal(m.weight, mode='fan_out')
            if m.bias:
                init.constant(m.bias, 0)
        elif isinstance(m, nn.BatchNorm2d):
            init.constant(m.weight, 1)
            init.constant(m.bias, 0)
        elif isinstance(m, nn.Linear):
            init.normal(m.weight, std=1e-3)
            if m.bias:
                init.constant(m.bias, 0)

def mkdir_p(path):
    '''make dir if not exist'''
    try:
        os.makedirs(path)
    except OSError as exc:  # Python >2.5
        if exc.errno == errno.EEXIST and os.path.isdir(path):
            pass
        else:
            raise

class AverageMeter(object):
    """Computes and stores the average and current value
       Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262
    """
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [None]:
class SGD(Optimizer):
    r"""Implements stochastic gradient descent (optionally with momentum).

    Nesterov momentum is based on the formula from
    `On the importance of initialization and momentum in deep learning`__.

    Args:
        params (iterable): iterable of parameters to optimize or dicts defining
            parameter groups
        lr (float): learning rate
        momentum (float, optional): momentum factor (default: 0)
        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
        dampening (float, optional): dampening for momentum (default: 0)
        nesterov (bool, optional): enables Nesterov momentum (default: False)

    Example:
        >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
        >>> optimizer.zero_grad()
        >>> loss_fn(model(input), target).backward()
        >>> optimizer.step()

    __ http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf

    .. note::
        The implementation of SGD with Momentum/Nesterov subtly differs from
        Sutskever et. al. and implementations in some other frameworks.

        Considering the specific case of Momentum, the update can be written as

        .. math::
                  v = \rho * v + g \\
                  p = p - lr * v

        where p, g, v and :math:`\rho` denote the parameters, gradient,
        velocity, and momentum respectively.

        This is in contrast to Sutskever et. al. and
        other frameworks which employ an update of the form

        .. math::
             v = \rho * v + lr * g \\
             p = p - v

        The Nesterov version is analogously modified.
    """

    def __init__(self, params, lr, momentum=0, dampening=0,
                 weight_decay=0, nesterov=False):
        if lr < 0.0:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if momentum < 0.0:
            raise ValueError("Invalid momentum value: {}".format(momentum))
        if weight_decay < 0.0:
            raise ValueError("Invalid weight_decay value: {}".format(weight_decay))

        defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
                        weight_decay=weight_decay, nesterov=nesterov)
        if nesterov and (momentum <= 0 or dampening != 0):
            raise ValueError("Nesterov momentum requires a momentum and zero dampening")
        super(SGD, self).__init__(params, defaults)

    def __setstate__(self, state):
        super(SGD, self).__setstate__(state)
        for group in self.param_groups:
            group.setdefault('nesterov', False)

    def step(self, grads, closure=None):
        """Performs a single optimization step.

        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            weight_decay = group['weight_decay']
            momentum = group['momentum']
            dampening = group['dampening']
            nesterov = group['nesterov']

            for i,p in enumerate(group['params']):
#                 if p.grad is None:
#                     continue
                
                d_p = grads[i]
                
                if weight_decay != 0:
                    d_p.add_(weight_decay, p.data)
                if momentum != 0:
                    param_state = self.state[p]
                    if 'momentum_buffer' not in param_state:
                        buf = param_state['momentum_buffer'] = torch.clone(d_p).detach()
                    else:
                        buf = param_state['momentum_buffer']
                        buf.mul_(momentum).add_(1 - dampening, d_p)
                    if nesterov:
                        d_p = d_p.add(momentum, buf)
                    else:
                        d_p = buf

                p.data.add_(-group['lr'], d_p)

        return loss

In [None]:
#eval.py
from __future__ import print_function, absolute_import

__all__ = ['accuracy']

def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].reshape(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res

#models and train test function

In [None]:
# models
from __future__ import print_function
import argparse, os, sys, csv, shutil, time, random, operator, pickle, ast
import numpy as np
import pandas as pd
import torch.nn.functional as F
import torch
import pickle
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
import cifar as models

class cifar_mlp(nn.Module):
    def __init__(self, ninputs=3 * 32 * 32, num_classes=10):
        self.ninputs = ninputs
        self.num_classes = num_classes
        super(cifar_mlp, self).__init__()

        self.features = nn.Sequential(
            nn.Linear(self.ninputs, 1024),
            nn.ReLU(),
            nn.Linear(1024, 256),
            nn.ReLU(),
            nn.Linear(256, 64),
            nn.ReLU(),
        )
        self.classifier = nn.Linear(64, num_classes)

    def forward(self, x):
        x = x.view(-1, self.ninputs)
        hidden_out = self.features(x)
        return self.classifier(hidden_out)


def get_model(config, parallel=False, cuda=True, device=0):
    # print("==> creating model '{}'".format(config['arch']))
    if config['arch'].startswith('resnext'):
        model = models.__dict__[config['arch']](
            cardinality=config['cardinality'],
            num_classes=config['num_classes'],
            depth=config['depth'],
            widen_factor=config['widen-factor'],
            dropRate=config['drop'],
        )
    elif config['arch'].startswith('densenet'):
        model = models.__dict__[config['arch']](
            num_classes=config['num_classes'],
            depth=config['depth'],
            growthRate=config['growthRate'],
            compressionRate=config['compressionRate'],
            dropRate=config['drop'],
        )
    elif config['arch'].startswith('wrn'):
        model = models.__dict__[config['arch']](
            num_classes=config['num_classes'],
            depth=config['depth'],
            widen_factor=config['widen-factor'],
            dropRate=config['drop'],
        )
    elif config['arch'].endswith('resnet'):
        model = models.__dict__[config['arch']](
            num_classes=config['num_classes'],
            depth=config['depth'],
        )
    elif config['arch'].endswith('convnet'):
        model = models.__dict__[config['arch']](
            num_classes=config['num_classes']
        )
    else:
        model = models.__dict__[config['arch']](num_classes=config['num_classes'], )

    if parallel:
        model = torch.nn.DataParallel(model)

    if cuda:
        model.cuda()

    return model

def return_model(model_name, lr, momentum, parallel=False, cuda=True, device=0):
    if model_name == 'dc':
        arch_config = {
            'arch': 'Dc',
            'num_classes': 10,
        }
        model = get_model(arch_config, parallel=parallel, cuda=cuda, device=device)
        optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=momentum)

    elif model_name == 'alexnet':
        arch_config = {
            'arch': 'alexnet',
            'num_classes': 10,
        }
        model = get_model(arch_config, parallel=parallel, cuda=cuda, device=device)
        optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=momentum)
    elif model_name == 'densenet-bc-100-12':
        arch_config = {
            'arch': 'densenet',
            'depth': 100,
            'growthRate': 12,
            'compressionRate': 2,
            'drop': 0,
            'num_classes': 10,
        }
        model = get_model(arch_config, parallel=parallel, cuda=cuda, device=device)
        # optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=momentum,weight_decay=1e-4)
        optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-5)
    elif model_name == 'densenet-bc-L190-k40':
        arch_config = {
            'arch': 'densenet',
            'depth': 190,
            'growthRate': 40,
            'compressionRate': 2,
            'drop': 0,
            'num_classes': 10,
        }
        model = get_model(arch_config, parallel=parallel, cuda=cuda, device=device)
        optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=1e-4)
    elif model_name == 'preresnet-110':
        arch_config = {
            'arch': 'preresnet',
            'depth': 110,
            'num_classes': 10,
        }
        model = get_model(arch_config, parallel=parallel, cuda=cuda, device=device)
        # optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=momentum, weight_decay=1e-4)
        optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-5)
    elif model_name == 'resnet-110':
        arch_config = {
            'arch': 'resnet',
            'depth': 110,
            'num_classes': 10,
        }
        model = get_model(arch_config, parallel=parallel, cuda=cuda, device=device)
        # optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=1e-4)
        optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-5)
    elif model_name == 'resnext-16x64d':
        arch_config = {
            'arch': 'resnext',
            'depth': 29,
            'cardinality': 16,
            'widen-factor': 4,
            'drop': 0,
            'num_classes': 10,
        }
        model = get_model(arch_config, parallel=parallel, cuda=cuda, device=device)
        optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=5e-4)
    elif model_name == 'resnext-8x64d':
        arch_config = {
            'arch': 'resnext',
            'depth': 29,
            'cardinality': 8,
            'widen-factor': 4,
            'drop': 0,
            'num_classes': 10,
        }
        model = get_model(arch_config, parallel=parallel, cuda=cuda, device=device)
        optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=5e-4)
    elif model_name.startswith('vgg'):
        arch_config = {
            'arch': model_name,
            'num_classes': 10,
        }
        model = get_model(arch_config, parallel=parallel, cuda=cuda, device=device)
        optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
    elif model_name == 'WRN-28-10-drop':
        arch_config = {
            'arch': 'wrn',
            'depth': 28,
            'widen-factor': 10,
            'drop': 0.3,
            'num_classes': 10,
        }
        model = get_model(arch_config, parallel=parallel, cuda=cuda, device=device)
        optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=5e-4)
    else:
        assert (False), 'Model not found!'

    return model, optimizer

In [None]:
#Alexnet
'''AlexNet for CIFAR10. FC layers are removed. Paddings are adjusted.
Without BN, the start learning rate should be 0.01
(c) YANG, Wei 
'''
import torch.nn as nn


__all__ = ['alexnet']


class AlexNet(nn.Module):

    def __init__(self, num_classes=10):
        super(AlexNet, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=5),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(64, 192, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(192, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.classifier = nn.Linear(256, num_classes)

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x


def alexnet(**kwargs):
    r"""AlexNet model architecture from the
    `"One weird trick..." <https://arxiv.org/abs/1404.5997>`_ paper.
    """
    model = AlexNet(**kwargs)
    return model

In [None]:
def train(train_data, labels, model, criterion, optimizer, use_cuda, num_batchs=999999, debug_='MEDIUM', batch_size=16):
    # switch to train mode
    model.train()

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    end = time.time()
    len_t = (len(train_data) // batch_size) - 1

    for ind in range(len_t):
        if ind > num_batchs:
            break
        # measure data loading time
        inputs = train_data[ind * batch_size:(ind + 1) * batch_size]
        targets = labels[ind * batch_size:(ind + 1) * batch_size]

        data_time.update(time.time() - end)

        if use_cuda:
            inputs, targets = inputs.cuda(), targets.cuda()

        inputs, targets = torch.autograd.Variable(inputs), torch.autograd.Variable(targets)

        # compute output
        outputs = model(inputs)

        loss = criterion(outputs, targets)

        # measure accuracy and record loss
        prec1, prec5 = accuracy(outputs.data, targets.data, topk=(1, 5))
        losses.update(loss.item(), inputs.size(0))
        top1.update(prec1.item(), inputs.size(0))
        top5.update(prec5.item(), inputs.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        # plot progress
        if debug_ == 'HIGH' and ind % 100 == 0:
            print('Classifier: ({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | | Loss: {loss:.4f} | top1: {top1: .4f} | top5: {top5: .4f}'.format(
                batch=ind + 1,
                size=len_t,
                data=data_time.avg,
                bt=batch_time.avg,
                loss=losses.avg,
                top1=top1.avg,
                top5=top5.avg,
            ))

    return (losses.avg, top1.avg)


def test(test_data, labels, model, criterion, use_cuda, debug_='MEDIUM', batch_size=64):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # switch to evaluate mode
    model.eval()

    end = time.time()
    len_t = (len(test_data) // batch_size) - 1

    with torch.no_grad():
        for ind in range(len_t):
            # measure data loading time
            inputs = test_data[ind * batch_size:(ind + 1) * batch_size]
            targets = labels[ind * batch_size:(ind + 1) * batch_size]

            data_time.update(time.time() - end)

            if use_cuda:
                inputs, targets = inputs.cuda(), targets.cuda()

            inputs, targets = torch.autograd.Variable(inputs), torch.autograd.Variable(targets)

            # compute output
            outputs = model(inputs)

            loss = criterion(outputs, targets)

            # measure accuracy and record loss
            prec1, prec5 = accuracy(outputs.data, targets.data, topk=(1, 5))
            losses.update(loss.item(), inputs.size(0))
            top1.update(prec1.item(), inputs.size(0))
            top5.update(prec5.item(), inputs.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            # plot progress
            if debug_ == 'HIGH' and ind % 100 == 0:
                print('Test classifier: ({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Loss: {loss:.4f} | top1: {top1: .4f} | top5: {top5: .4f}'.format(
                    batch=ind + 1,
                    size=len(test_data),
                    data=data_time.avg,
                    bt=batch_time.avg,
                    loss=losses.avg,
                    top1=top1.avg,
                    top5=top5.avg,
                ))

    return (losses.avg, top1.avg)


In [None]:
#download cifar10 train and test dataset 
import torchvision.transforms as transforms
import torchvision.datasets as datasets
data_loc='/mnt/nfs/work1/amir/vshejwalkar/cifar10_data/'
# load the train dataset

train_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

cifar10_train = datasets.CIFAR10(root=data_loc, train=True, download=True, transform=train_transform)

cifar10_test = datasets.CIFAR10(root=data_loc, train=False, download=True, transform=train_transform)

X=[]
Y=[]
for i in range(len(cifar10_train)):
    X.append(cifar10_train[i][0].numpy())
    Y.append(cifar10_train[i][1])

for i in range(len(cifar10_test)):
    X.append(cifar10_test[i][0].numpy())
    Y.append(cifar10_test[i][1])

X=np.array(X)
Y=np.array(Y)

print('total data len: ',len(X))

if not os.path.isfile('./cifar10_shuffle.pkl'):
    all_indices = np.arange(len(X))
    np.random.shuffle(all_indices)
    pickle.dump(all_indices,open('./cifar10_shuffle.pkl','wb'))
else:
    all_indices=pickle.load(open('./cifar10_shuffle.pkl','rb'))

X=X[all_indices]
Y=Y[all_indices]

In [None]:
# data loading

nusers=50
user_tr_len=1000

total_tr_len=user_tr_len*nusers
val_len=5000
te_len=5000

print('total data len: ',len(X))

if not os.path.isfile('./cifar10_shuffle.pkl'):
    all_indices = np.arange(len(X))
    np.random.shuffle(all_indices)
    pickle.dump(all_indices,open('./cifar10_shuffle.pkl','wb'))
else:
    all_indices=pickle.load(open('./cifar10_shuffle.pkl','rb'))

total_tr_data=X[:total_tr_len]
total_tr_label=Y[:total_tr_len]

val_data=X[total_tr_len:(total_tr_len+val_len)]
val_label=Y[total_tr_len:(total_tr_len+val_len)]

te_data=X[(total_tr_len+val_len):(total_tr_len+val_len+te_len)]
te_label=Y[(total_tr_len+val_len):(total_tr_len+val_len+te_len)]

total_tr_data_tensor=torch.from_numpy(total_tr_data).type(torch.FloatTensor)
total_tr_label_tensor=torch.from_numpy(total_tr_label).type(torch.LongTensor)

val_data_tensor=torch.from_numpy(val_data).type(torch.FloatTensor)
val_label_tensor=torch.from_numpy(val_label).type(torch.LongTensor)

te_data_tensor=torch.from_numpy(te_data).type(torch.FloatTensor)
te_label_tensor=torch.from_numpy(te_label).type(torch.LongTensor)

print('total tr len %d | val len %d | test len %d'%(len(total_tr_data_tensor),len(val_data_tensor),len(te_data_tensor)))

#==============================================================================================================

user_tr_data_tensors=[]
user_tr_label_tensors=[]

for i in range(nusers):
    
    user_tr_data_tensor=torch.from_numpy(total_tr_data[user_tr_len*i:user_tr_len*(i+1)]).type(torch.FloatTensor)
    user_tr_label_tensor=torch.from_numpy(total_tr_label[user_tr_len*i:user_tr_len*(i+1)]).type(torch.LongTensor)

    user_tr_data_tensors.append(user_tr_data_tensor)
    user_tr_label_tensors.append(user_tr_label_tensor)
    print('user %d tr len %d'%(i,len(user_tr_data_tensor)))

In [None]:
#Execute AGR-agnostic attack 
#Note: this code included two type attacks: min-max and min-sum

batch_size=250
resume=0
nepochs=1200
schedule=[1000]
nbatches = user_tr_len//batch_size

gamma=.5
opt = 'sgd'
fed_lr=0.5
criterion=nn.CrossEntropyLoss()
use_cuda = torch.cuda.is_available()

aggregation='median'
multi_k = False
candidates = []

at_type='min-max'# for AGR-agnostic, min-max has better impact
dev_type ='std'
threshold=10
partial_attackers = {4:1, 5:1, 8:2, 10:3, 12:4}
n_attackers=[10]

arch='alexnet'
chkpt='./'+aggregation

for n_attacker in n_attackers:
    candidates = []

    epoch_num = 0
    best_global_acc = 0
    best_global_te_acc = 0

    fed_model, _ = return_model(arch, 0.1, 0.9, parallel=False)
    optimizer_fed = SGD(fed_model.parameters(), lr=fed_lr)

    torch.cuda.empty_cache()
    r=np.arange(user_tr_len)

    while epoch_num <= nepochs:
        user_grads=[]
        if not epoch_num and epoch_num%nbatches == 0:
            np.random.shuffle(r)
            for i in range(nusers):
                user_tr_data_tensors[i]=user_tr_data_tensors[i][r]
                user_tr_label_tensors[i]=user_tr_label_tensors[i][r]

        for i in range(n_attacker, nusers):

            inputs = user_tr_data_tensors[i][(epoch_num%nbatches)*batch_size:((epoch_num%nbatches) + 1) * batch_size]
            targets = user_tr_label_tensors[i][(epoch_num%nbatches)*batch_size:((epoch_num%nbatches) + 1) * batch_size]

            inputs, targets = inputs.cuda(), targets.cuda()
            inputs, targets = torch.autograd.Variable(inputs), torch.autograd.Variable(targets)

            outputs = fed_model(inputs)
            loss = criterion(outputs, targets)
            fed_model.zero_grad()
            loss.backward(retain_graph=True)

            param_grad=[]
            for param in fed_model.parameters():
                param_grad=param.grad.data.view(-1) if not len(param_grad) else torch.cat((param_grad,param.grad.view(-1)))

            user_grads=param_grad[None, :] if len(user_grads)==0 else torch.cat((user_grads,param_grad[None,:]), 0)

        if epoch_num in schedule:
            for param_group in optimizer_fed.param_groups:
                param_group['lr'] *= gamma
                print('New learnin rate ', param_group['lr'])

        if n_attacker > 0:
            if at_type == 'min-max':
                n_attacker_ = partial_attackers[n_attacker]
                agg_grads = torch.mean(user_grads[:n_attacker], 0)
                mal_update = our_attack_dist(user_grads[:n_attacker], agg_grads, n_attacker_, threshold=threshold, dev_type=dev_type)
            elif at_type == 'min-sum':
                n_attacker_ = partial_attackers[n_attacker]
                agg_grads = torch.mean(user_grads[:n_attacker], 0)
                mal_update = our_attack_score(user_grads[:n_attacker], agg_grads, n_attacker_, threshold=threshold, dev_type=dev_type)

            mal_updates = torch.stack([mal_update] * n_attacker)
            malicious_grads = torch.cat((mal_updates, user_grads), 0)

        if epoch_num==0: print('malicious_grads shape ', malicious_grads.shape)

        # implement your aggregation here, median for example

        agg_grads=torch.median(malicious_grads,dim=0)[0]

        del user_grads

        start_idx=0

        optimizer_fed.zero_grad()

        model_grads=[]

        for i, param in enumerate(fed_model.parameters()):
            param_=agg_grads[start_idx:start_idx+len(param.data.view(-1))].reshape(param.data.shape)
            start_idx=start_idx+len(param.data.view(-1))
            param_=param_.cuda()
            model_grads.append(param_)

        optimizer_fed.step(model_grads)

        val_loss, val_acc = test(val_data_tensor,val_label_tensor,fed_model,criterion,use_cuda)
        te_loss, te_acc = test(te_data_tensor,te_label_tensor, fed_model, criterion, use_cuda)

        is_best = best_global_acc < val_acc

        best_global_acc = max(best_global_acc, val_acc)

        if is_best:
            best_global_te_acc = te_acc

        if epoch_num%25==0 or epoch_num==nepochs-1:
            print('%s: at %s n_at %d | e %d fed_model val loss %.4f val acc %.4f best val_acc %f te_acc %f'%(aggregation, at_type, n_attacker, epoch_num, val_loss, val_acc, best_global_acc,best_global_te_acc))

        if val_loss > 1000:
            print('val loss %f too high'%val_loss)
            break
            
        epoch_num+=1