<a href="https://colab.research.google.com/github/FrancescaMusella/MLDL-Project/blob/main/attacks_finalversion.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torchvision.transforms as T
from torch import nn
from sklearn.model_selection import StratifiedKFold
import random
import copy
import numpy as np
import matplotlib.pyplot as plt

from collections import defaultdict
from collections import Counter

In [None]:
torch.manual_seed(42)
torch.cuda.manual_seed(42)

In [None]:
#Transformation of CIFAR100 dataset
transform = T.Compose([
    T.Resize((32, 32)),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [None]:
from torchvision.datasets import CIFAR100
train_val=CIFAR100(root='.data/', train=True, download=True, transform=transform)
test=CIFAR100(root='.data/', train=False, download=True, transform=transform)

In [None]:
#train-validation-test split
from torchvision import datasets, transforms
from sklearn.model_selection import train_test_split

targets = train_val.targets

train_indices, val_indices = train_test_split(
    range(len(targets)),
    test_size=0.1,
    stratify=targets,
    random_state=42
)

val = torch.utils.data.Subset(train_val, val_indices)
val_loader= torch.utils.data.DataLoader(val, batch_size=32, shuffle=True)
test_loader = torch.utils.data.DataLoader(test, batch_size=32, shuffle=False)

In [None]:
#non I.I.D balanced dataset definition
def create_balanced_dataset(N, K, num_class, train_indices, train_val):
   num_clients=K
   num_class_client=N
   mat_clients=np.full((num_clients,num_class_client), -1)
   toll=0
   nDisp=list(range(num_class))
   nUtil= np.zeros(num_class, dtype=int)

   for i in range(num_clients):
     for j in range(num_class_client):
        toll=random.choice(nDisp)
        mat_clients[i,j]=toll
        nDisp.remove(toll)
        if len(nDisp)==0:
          nDisp=list(range(num_class))

   values, counts = np.unique(mat_clients, return_counts=True)

   class_to_indices = defaultdict(list)
   for idx in train_indices:
      label = train_val.targets[idx]
      class_to_indices[label].append(idx)

   for cls in class_to_indices:
      random.shuffle(class_to_indices[cls])

   client_indices_list = []
   client_indices_list_mask = []
   client_dataset = []
   client_dataset_mask=[]
   client_loader = []
   client_loader_mask=[]

   numel_per_classclient = int(len(train_indices)/(num_class*N))

   for client_id in range(K):
      class_client = mat_clients[client_id]
      indices_client = []
      indices_client_mask=[]

      for cls in class_client:
         indices_cls = class_to_indices[cls]
         sampled = indices_cls[0:numel_per_classclient]

         sampled_mask=indices_cls[0:int(numel_per_classclient*0.15)]

         class_to_indices[cls] = class_to_indices[cls][numel_per_classclient:]
         indices_client.extend(sampled)

         indices_client_mask.extend(sampled_mask)

      client_indices_list.append(indices_client)
      client_dataset.append(torch.utils.data.Subset(train_val, indices_client))
      client_loader.append(torch.utils.data.DataLoader(client_dataset[-1], batch_size=32, shuffle=True))

      client_indices_list_mask.append(indices_client_mask)
      client_dataset_mask.append(torch.utils.data.Subset(train_val, indices_client_mask))
      client_loader_mask.append(torch.utils.data.DataLoader(client_dataset_mask[-1], batch_size=1, shuffle=True))

   return client_indices_list,client_dataset,client_loader, client_loader_mask

In [None]:
#non I.I.D balanced dataset creation
N=10
K=100
num_class=int(len(set(targets)))
client_noniid,client_dataset_noniid,client_loader_noniid, client_loader_mask_noniid=create_balanced_dataset(N, K, num_class, train_indices, train_val)

In [None]:
#ViT-S/16
!git clone https://github.com/facebookresearch/dino.git
!ls

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)
vits16_original = torch.hub.load('facebookresearch/dino:main', 'dino_vits16', pretrained=True).to(device)
vits16_new=copy.deepcopy(vits16_original)
print(vits16_new)

In [None]:
#Change of the head and freezing of layers
vits16_new.head = torch.nn.Linear(in_features=384,
                    out_features=100,
                    bias=True).to(device)

for name, param in vits16_new.named_parameters():
    if "head" not in name and "patch_embed" not in name and 'proj' not in name and 'pos_drop' not in name and 'attn' not in name:
        param.requires_grad = False
    else:
        param.requires_grad = True

In [None]:
def FedAvg(model, K, C, J, T, lr, momentum, weight_decay, loss_fn, client_loader, trainable_part, implementation,  client_loader_mask, test_loader,  adversarial_clients, attack, test, attack_mask):
  global_params = model.state_dict()

  val_acc_fed_list=[]
  val_loss_fed_list=[]
  val_acc_fed_attack_list=[]
  val_loss_fed_attack_list=[]

  for t in range(T):
      client_samples=random.sample(range(K), int(max(1, K*C)))
      local_params = []
      client_num_samples = []

      for client in client_samples:
          model_client=copy.deepcopy(model)
          num_samples = len(client_loader[client].dataset)
          client_num_samples.append(num_samples)

          if trainable_part == 'head':
              optimizer = torch.optim.SGD(model_client.head.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay)

          else:
              optimizer = torch.optim.SGD(model_client.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay)

          if implementation=='self':
                    sparsity=[0.1, 0.2, 0.3, 0.4, 0.66]
                    mask_fed=compute_fisher_mask(model_client,  client_loader_mask[client], sparsity, loss_fn, attack_mask, adversarial_clients, client)

          for loc_step in range(J):
              batch_idx=torch.randint(0, len(client_loader[client]), (1,)).item()
              for i, (inputs, labels) in enumerate(client_loader[client]):
                  if i == batch_idx:
                      inputs_client=inputs
                      labels_client=labels
                      break

              if implementation=='self':
                    train_sgd_sparse(loc_step, model_client, inputs_client, labels_client, loss_fn, lr, momentum, weight_decay, mask_fed,  adversarial_clients, client, attack)

              else:
                    train(loc_step, model_client, inputs_client, labels_client, loss_fn, optimizer,J,client, attack)

          local_params.append(model_client.state_dict())

      total_samples = sum(client_num_samples)

      new_global_params = copy.deepcopy(global_params)
      for key in global_params.keys():
            new_global_params[key] = sum(local_params[i][key].float() * (client_num_samples[i] / total_samples) for i in range(len(client_samples)))

      model.load_state_dict(new_global_params)
      global_params = new_global_params

      if test:
        val_acc_fed, val_loss_fed=validate(model, test_loader, loss_fn, test)
        if trainable_part != 'head':
            val_acc_fed_attack, val_loss_fed_attack=validate_attack(model, test_loader, loss_fn,test)
      else:
        val_acc_fed, val_loss_fed=validate(model, val_loader, loss_fn, test)
        if trainable_part != 'head':
            val_acc_fed_attack, val_loss_fed_attack=validate_attack(model, val_loader, loss_fn,test)

      val_acc_fed_list.append(val_acc_fed)
      val_loss_fed_list.append(val_loss_fed)

      if trainable_part != 'head':
           val_acc_fed_attack_list.append(val_acc_fed_attack)
           val_loss_fed_attack_list.append(val_loss_fed_attack)

      print(f'Iteration {t}-------------------')

  return val_acc_fed_list, val_loss_fed_list, val_acc_fed_attack_list, val_loss_fed_attack_list

In [None]:
def train(epoch, model, inputs, targets, criterion, optimizer,J,client, attack):
    model.train()
    lamb=0.25

    inputs, targets = inputs.cuda(), targets.cuda()
    intermediate_output = model.get_intermediate_layers(inputs, n=1)
    features = torch.cat([x[:, 0] for x in intermediate_output], dim=-1)
    outputs = model.head(features)

    loss=criterion(outputs, targets)

    if attack:
        inputs_adv = fgsm_attack_local(model, inputs, targets, criterion, epsilon=0.03)
        intermediate_output_adv = model.get_intermediate_layers(inputs_adv, n=1)
        features_adv = torch.cat([x[:, 0] for x in intermediate_output_adv], dim=-1)
        outputs_adv = model.head(features_adv)

        loss_adv=criterion(outputs_adv, targets)
        loss=lamb*loss+(1-lamb)*loss_adv

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if epoch==J-1:
        running_loss = loss.item()
        _, predicted = outputs.max(1)
        total = targets.size(0)
        correct = predicted.eq(targets).sum().item()

        train_loss = running_loss
        train_accuracy = 100. * correct / total

In [None]:
#Definition of the function for test performance on normal images
def validate(model, val_loader, criterion, test):
    model.eval()
    val_loss = 0
    correct, total = 0, 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(val_loader):
            inputs, targets = inputs.cuda(), targets.cuda()

            intermediate_output = model.get_intermediate_layers(inputs, n=1)
            features = torch.cat([x[:, 0] for x in intermediate_output], dim=-1)
            outputs = model.head(features)
            loss=criterion(outputs, targets)

            val_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

    val_loss = val_loss / len(val_loader)
    val_accuracy = 100. * correct / total
    print(f'Validation Loss: {val_loss:.6f} Acc: {val_accuracy:.2f}%')
    return val_accuracy, val_loss

In [None]:
#Definition of the function for test performance on adversarial images
def validate_attack(model, val_loader, criterion,test):
    model.eval()
    val_loss_attack = 0
    correct_attack = 0
    total_attack = 0

    for batch_idx, (inputs, targets) in enumerate(val_loader):
            inputs, targets = inputs.cuda(), targets.cuda()

            inputs_attack = fgsm_attack_local(model, inputs, targets, criterion, epsilon=0.03)
            intermediate_output_attack = model.get_intermediate_layers(inputs_attack, n=1)
            features_attack = torch.cat([x[:, 0] for x in intermediate_output_attack], dim=-1)
            outputs_attack = model.head(features_attack)
            loss_attack=criterion(outputs_attack, targets)

            val_loss_attack += loss_attack.item()
            _, predicted_attack = outputs_attack.max(1)
            total_attack += targets.size(0)
            correct_attack += predicted_attack.eq(targets).sum().item()

    val_loss_attack = val_loss_attack / len(val_loader)
    val_accuracy_attack = 100. * correct_attack / total_attack
    print(f'Validation Loss Attack: {val_loss_attack:.6f} Acc: {val_accuracy_attack:.2f}%')

    return val_accuracy_attack, val_loss_attack

In [None]:
#creation of the mask
def compute_fisher_mask(model, dataloader, sparsity, criterion, attack_mask, adversarial_clients, client_id):
  fisher_scores = {}
  prev_mask = {}

  model.eval()

  for param in model.head.parameters():
    param.requires_grad= False

  for param in model.parameters():
      if param.requires_grad:
          fisher_scores[param] = torch.zeros_like(param.data)
          prev_mask[param] = torch.ones_like(param.data)

  for round in range(5):
    for param in fisher_scores:
        fisher_scores[param].zero_()

    for inputs, targets in dataloader:
        inputs, targets = inputs.cuda(), targets.cuda()

        if attack_mask: #attack in the creation of the mask
           if adversarial_clients and client_id in adversarial_clients:
              inputs = fgsm_attack_local(model, inputs, targets, criterion, epsilon=0.03)

        intermediate_output = model.get_intermediate_layers(inputs, n=1)
        features = torch.cat([x[:, 0] for x in intermediate_output], dim=-1)
        outputs = model.head(features)

        loss = criterion(outputs, targets)

        model.zero_grad()
        loss.backward()

        for param in model.parameters():
            if param.requires_grad and param.grad is not None:
              fisher_scores[param] += (param.grad.data.pow(2) * prev_mask[param])

    new_mask = {}
    all_scores = torch.cat([torch.flatten(v) for v in fisher_scores.values()])
    non_zero_scores=all_scores[all_scores!=0]
    k = int(sparsity[round] * non_zero_scores.numel())
    threshold, _ = torch.kthvalue(non_zero_scores, non_zero_scores.numel()-k)

    for param, score in fisher_scores.items():

        masked_score = score * prev_mask[param]
        current_mask = ((masked_score < threshold) * prev_mask[param]).float()
        new_mask[param] = current_mask
        prev_mask[param] = new_mask[param]

  return new_mask

In [None]:
#self implementation of SGDM with the addition of the mask
def sgdm_sparse (params, lr, momentum, dampening, weight_decay, nesterov, maximize,b, mask):
    for param in params:
        if param.grad is None:
            continue
        grad = param.grad.data

        if weight_decay!= 0:
          grad=grad+weight_decay*param.data

        if param not in b:
          b[param] = torch.zeros_like(param.data)

        if momentum!=0:
            b_toll = b[param]
            b_new = momentum * b_toll + (1 - dampening) * grad
            if nesterov:
               update=grad+momentum*b_new
            else:
              update=b_new
        else:
           update=grad
           b_new=0

        update = update * mask[param]

        if maximize:
          param.data=param.data+lr*update
          b[param] = b_new
        else:
          param.data=param.data-lr*update
          b[param] = b_new
    return b

In [None]:
def train_sgd_sparse(epoch, model, inputs, targets, criterion, lr, momentum, weight_decay, mask, adversarial_clients, client_id, attack):
     model.train()
     running_loss = 0.0
     correct = 0
     total = 0
     params=list(model.parameters())
     dampening=0
     nesterov=False
     maximize=False
     b={}

     inputs, targets = inputs.cuda(), targets.cuda()

     if attack:  #attack at train time
        if adversarial_clients and client_id in adversarial_clients:
           inputs = fgsm_attack_local(model, inputs, targets, criterion, epsilon=0.03)

     intermediate_output = model.get_intermediate_layers(inputs, n=1)
     features = torch.cat([x[:, 0] for x in intermediate_output], dim=-1)
     outputs = model.head(features)

     loss=criterion(outputs, targets)
     model.zero_grad()
     loss.backward()

     b=sgdm_sparse(params, lr, momentum, dampening, weight_decay, nesterov, maximize,b, mask)

In [None]:
#definition of FGSM attack
def fgsm_attack_local(model, images, labels, criterion, epsilon=0.03):

    images = images.clone().detach().to(images.device)
    labels = labels.to(images.device)
    images.requires_grad = True

    intermediate_output = model.get_intermediate_layers(images, n=1)
    features = torch.cat([x[:, 0] for x in intermediate_output], dim=-1)
    outputs = model.head(features)

    loss = criterion(outputs, labels)
    loss.backward()

    grad = images.grad.data
    adv_images = images + epsilon * grad.sign()
    adv_images = torch.clamp(adv_images, 0, 1)

    return adv_images.detach()

In [None]:
# Strong Baseline: adversarial training without mask
K = 100
C = 0.1
J = 8
momentum = 0
weight_decay = 1e-3
loss_fn = nn.CrossEntropyLoss()

vits16_adv_train=copy.deepcopy(vits16_new)

In [None]:
adversarial_clients=list(range(0, K)) #all clients are attackers
lr = 1e-3
T=200
val_acc_fed_list_adv, val_loss_fed_list_adv, val_acc_fed_attack_list_adv, val_loss_fed_attack_list_adv=FedAvg(vits16_adv_train, K, C, J, T, lr, momentum, weight_decay,loss_fn, client_loader_noniid, trainable_part='full', implementation='pyt', client_loader_mask=client_loader_mask_noniid, test_loader=test_loader, adversarial_clients=adversarial_clients, attack=True, test=True, attack_mask=False)
print(f'Test for non i.i.d.balanced dataset with with N={N} and J={J}')
validate(vits16_adv_train, test_loader, loss_fn,test=True)
validate_attack(vits16_adv_train, test_loader, loss_fn, test=True)

In [None]:
#Test loss plot on original images
plt.figure(figsize=(16, 5))
plt.plot(val_loss_fed_list_adv, label='Test Loss')
ticks = [0] + [i for i in range(19, T, 20) if i != 0]
labels = [1] + [i + 1 for i in range(19, T, 20) if i != 0]
plt.xticks(ticks=ticks, labels=labels)
plt.xlabel('Rounds')
plt.ylabel('Loss')
plt.grid(True)
plt.title('Test loss')
plt.legend()
plt.show()


In [None]:
#Test loss plot on perturbed images
plt.figure(figsize=(16, 5))
plt.plot(val_loss_fed_attack_list_adv, label='Test Loss')
ticks = [0] + [i for i in range(19, T, 20) if i != 0]
labels = [1] + [i + 1 for i in range(19, T, 20) if i != 0]
plt.xticks(ticks=ticks, labels=labels)
plt.xlabel('Rounds')
plt.ylabel('Loss')
plt.grid(True)
plt.title('Test loss attack')
plt.legend()
plt.show()

In [None]:
#portion of attackers
adversarial_clients = random.sample(range(K), 20)

In [None]:
# Attacks only at test time
K = 100
C = 0.1
J = 8
momentum = 0
weight_decay = 1e-3
loss_fn = nn.CrossEntropyLoss()
lr = 1e-2
T_head=40

vits16_mask=copy.deepcopy(vits16_new)
FedAvg(vits16_mask, K, C, J, T_head, lr, momentum, weight_decay,loss_fn, client_loader_noniid, trainable_part='head', implementation='pyt', client_loader_mask=client_loader_mask_noniid, test_loader=test_loader, adversarial_clients=adversarial_clients, attack=False, test=False, attack_mask=False)

In [None]:
lr = 5e-3
T= 30
val_acc_fed_list, val_loss_fed_list, val_acc_fed_attack_list, val_loss_fed_attack_list=FedAvg(vits16_mask, K, C, J, T, lr, momentum, weight_decay,loss_fn, client_loader_noniid, trainable_part='full', implementation='self', client_loader_mask=client_loader_mask_noniid, test_loader=test_loader, adversarial_clients=adversarial_clients, attack=False, test=True, attack_mask=False)
print(f'Test for non i.i.d.balanced dataset with with N={N} and J={J}')
validate(vits16_mask, test_loader, loss_fn, test=True)
validate_attack(vits16_mask, test_loader, loss_fn, test=True)

In [None]:
#Test loss plot on original images
plt.figure(figsize=(16, 5))
plt.plot(val_loss_fed_list, label='Test Loss')
ticks = [0] + [i for i in range(4, T, 5) if i != 0]
labels = [1] + [i + 1 for i in range(4, T, 5) if i != 0]
plt.xticks(ticks=ticks, labels=labels)
plt.xlabel('Rounds')
plt.ylabel('Loss')
plt.grid(True)
plt.title('Test loss')
plt.legend()
plt.show()

In [None]:
#Test loss plot on perturbed images
plt.figure(figsize=(16, 5))
plt.plot(val_loss_fed_attack_list, label='Test Loss')
ticks = [0] + [i for i in range(4, T, 5) if i != 0]
labels = [1] + [i + 1 for i in range(4, T, 5) if i != 0]
plt.xticks(ticks=ticks, labels=labels)
plt.xlabel('Rounds')
plt.ylabel('Loss')
plt.grid(True)
plt.title('Test loss attack')
plt.legend()
plt.show()

In [None]:
# Attacks at train and test time
K = 100
C = 0.1
J = 8
momentum = 0
weight_decay = 1e-3
loss_fn = nn.CrossEntropyLoss()
lr = 1e-2
T_head=40

vits16_mask_attack=copy.deepcopy(vits16_new)
FedAvg(vits16_mask_attack, K, C, J, T_head, lr, momentum, weight_decay,loss_fn, client_loader_noniid, trainable_part='head', implementation='pyt', client_loader_mask=client_loader_mask_noniid, test_loader=test_loader, adversarial_clients=adversarial_clients, attack=False, test=False, attack_mask=False)

In [None]:
lr = 5e-3
T=30
val_acc_fed_list_partial, val_loss_fed_list_partial, val_acc_fed_attack_list_partial, val_loss_fed_attack_list_partial=FedAvg(vits16_mask_attack, K, C, J, T, lr, momentum, weight_decay,loss_fn, client_loader_noniid, trainable_part='full', implementation='self', client_loader_mask=client_loader_mask_noniid, test_loader=test_loader, adversarial_clients=adversarial_clients, attack=True, test=True, attack_mask=False)

print(f'Test for non i.i.d.balanced dataset with with N={N} and J={J}')
validate(vits16_mask_attack, test_loader, loss_fn, test=True)
validate_attack(vits16_mask_attack, test_loader, loss_fn, test=True)

In [None]:
#Test loss plot on original images
plt.figure(figsize=(16, 5))
plt.plot(val_loss_fed_list_partial, label='Test Loss')
ticks = [0] + [i for i in range(4, T, 5) if i != 0]
labels = [1] + [i + 1 for i in range(4, T, 5) if i != 0]
plt.xticks(ticks=ticks, labels=labels)
plt.xlabel('Rounds')
plt.ylabel('Loss')
plt.grid(True)
plt.title('Test loss')
plt.legend()
plt.show()

In [None]:
#Test loss plot on perturbed images
plt.figure(figsize=(16, 5))
plt.plot(val_loss_fed_attack_list_partial, label='Test Loss')
ticks = [0] + [i for i in range(4, T, 5) if i != 0]
labels = [1] + [i + 1 for i in range(4, T, 5) if i != 0]
plt.xticks(ticks=ticks, labels=labels)
plt.xlabel('Rounds')
plt.ylabel('Loss')
plt.grid(True)
plt.title('Test loss attack')
plt.legend()
plt.show()

In [None]:
# Attacks in the creation of the mask, at train and test time
K = 100
C = 0.1
J = 8
momentum = 0
weight_decay = 1e-3
loss_fn = nn.CrossEntropyLoss()
lr = 1e-2
T_head=40

vits16_mask_train_attack=copy.deepcopy(vits16_new)
FedAvg(vits16_mask_train_attack, K, C, J, T_head, lr, momentum, weight_decay,loss_fn, client_loader_noniid, trainable_part='head', implementation='pyt', client_loader_mask=client_loader_mask_noniid, test_loader=test_loader, adversarial_clients=adversarial_clients, attack=False, test=False, attack_mask=False)

In [None]:
lr = 5e-3
T=30
val_acc_fed_list_mask_train, val_loss_fed_list_mask_train, val_acc_fed_attack_list_mask_train, val_loss_fed_attack_list_mask_train=FedAvg(vits16_mask_train_attack, K, C, J, T, lr, momentum, weight_decay,loss_fn, client_loader_noniid, trainable_part='full', implementation='self', client_loader_mask=client_loader_mask_noniid, test_loader=test_loader, adversarial_clients=adversarial_clients, attack=True, test=True, attack_mask=True)

print(f'Test for non i.i.d.balanced dataset with with N={N} and J={J}')
validate(vits16_mask_train_attack, test_loader, loss_fn, test=True)
validate_attack(vits16_mask_train_attack, test_loader, loss_fn, test=True)

In [None]:
#Test loss plot on original images
plt.figure(figsize=(16, 5))
plt.plot(val_loss_fed_list_mask_train, label='Test Loss')
ticks = [0] + [i for i in range(4, T, 5) if i != 0]
labels = [1] + [i + 1 for i in range(4, T, 5) if i != 0]
plt.xticks(ticks=ticks, labels=labels)
plt.xlabel('Rounds')
plt.ylabel('Loss')
plt.grid(True)
plt.title('Test loss')
plt.legend()
plt.show()

In [None]:
#Test loss plot on perturbed images
plt.figure(figsize=(16, 5))
plt.plot(val_loss_fed_attack_list_mask_train, label='Test Loss')
ticks = [0] + [i for i in range(4, T, 5) if i != 0]
labels = [1] + [i + 1 for i in range(4, T, 5) if i != 0]
plt.xticks(ticks=ticks, labels=labels)
plt.xlabel('Rounds')
plt.ylabel('Loss')
plt.grid(True)
plt.title('Test loss attack')
plt.legend()
plt.show()