In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import scipy
from scipy.stats import zscore
import copy

import random

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam, SGD
import torchvision
import torchvision.transforms as tt
import torchvision.models as models
from torchvision.datasets import MNIST, ImageFolder
from torchvision.utils import make_grid
from torch.utils.data import random_split, DataLoader, Subset,SubsetRandomSampler

from sklearn.metrics import accuracy_score
import time

from copy import deepcopy
import logging

In [2]:
import warnings
warnings.filterwarnings('ignore')

In [3]:
transform = tt.Compose([tt.ToTensor(),
                    tt.Normalize((0.1307,), (0.3081,))])

train_ds = MNIST(root='.', train=True, download=True, transform=transform)
test_ds = MNIST(root='.', train=False, download=True, transform=transform)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 108595868.97it/s]


Extracting ./MNIST/raw/train-images-idx3-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 138917080.07it/s]

Extracting ./MNIST/raw/train-labels-idx1-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./MNIST/raw/t10k-images-idx3-ubyte.gz



100%|██████████| 1648877/1648877 [00:00<00:00, 29661314.44it/s]


Extracting ./MNIST/raw/t10k-images-idx3-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 26907526.51it/s]


Extracting ./MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MNIST/raw



In [4]:
batch_size=100

train_dl = DataLoader(train_ds, batch_size, shuffle=True, num_workers = 4, pin_memory=True)
test_dl = DataLoader(test_ds, batch_size, num_workers = 4, pin_memory=True)

In [5]:
device = 'cuda' if torch.cuda.is_available else 'cpu'
device

'cuda'

In [6]:
class MnistNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Sequential(nn.Conv2d(1, 10, 5),
                                  nn.MaxPool2d(2),
                                  nn.ReLU())
        self.conv2 = nn.Sequential(nn.Conv2d(10, 20, kernel_size=5),
                                  nn.Dropout2d(),
                                  nn.MaxPool2d(2),
                                  nn.ReLU())
        self.fc1 = nn.Sequential(nn.Flatten(),
                                nn.Linear(320, 50),
                                nn.Dropout(),
                                nn.ReLU())
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.fc1(x)
        x = self.fc2(x)
        return x

print(MnistNet())

MnistNet(
  (conv1): Sequential(
    (0): Conv2d(1, 10, kernel_size=(5, 5), stride=(1, 1))
    (1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (2): ReLU()
  )
  (conv2): Sequential(
    (0): Conv2d(10, 20, kernel_size=(5, 5), stride=(1, 1))
    (1): Dropout2d(p=0.5, inplace=False)
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): ReLU()
  )
  (fc1): Sequential(
    (0): Flatten(start_dim=1, end_dim=-1)
    (1): Linear(in_features=320, out_features=50, bias=True)
    (2): Dropout(p=0.5, inplace=False)
    (3): ReLU()
  )
  (fc2): Linear(in_features=50, out_features=10, bias=True)
)


In [7]:
def test(model, test_dl, criterion):
    with torch.no_grad():
        model.to(device)
        model.eval()
        batch_loss, batch_acc = [], []
        for images, labels in test_dl:
            if torch.cuda.is_available():
              images = images.cuda()
              # labels = torch.tensor(labels)
              labels = labels.cuda()

            logits = model(images)
            loss = criterion(logits, labels)
            batch_loss.append(loss.cpu())
            pred = torch.argmax(logits, dim=1)
            batch_acc.append(accuracy_score(labels.cpu(), pred.cpu()))
        model.cpu()
        return sum(batch_loss)/len(batch_loss), sum(batch_acc)/len(batch_acc)

def fit(epochs, model, optimizer, criterion, train_dl, test_dl):
    train_loss, train_acc, test_loss, test_acc = [], [], [], []
    attack=None
    train_attack=[]
    for epoch in range(1,epochs+1):
        val=random.random()
        if val>0.20:
          attack=True
        else:
          attack=False

        trainl, traina, _ = train(model, train_dl, optimizer,attack)
        testl , testa = test(model, test_dl, criterion)
        train_loss.append(trainl.detach().numpy())
        train_acc.append(traina)

        train_attack.append(attack)

        test_loss.append(testl.detach().numpy())
        test_acc.append(testa)
        print(f'Epoch {epoch} - train_loss : {trainl :.4f}, train_acc : {traina:.4f}, test_loss : {testl:.4f}, test_acc : {testa:0.4f}')

    history = {'train_loss' : train_loss,
               'train_acc' : train_acc,
               'test_loss' : test_loss,
               'test_acc' : test_acc,
               'train_attacked_history' : train_attack}
    return history

epochs = 3
model = MnistNet()
optimizer = Adam(model.parameters(), lr = 0.001)
criterion = nn.CrossEntropyLoss()

# baseline_history = fit(epochs, model, optimizer, criterion, train_dl, test_dl)

## Train with attack

In [8]:
std_dev = 1 ###################### Attack 2 Parameter

"""
Get historical gradients
"""
def train(model, train_dl, optimizer, ID, attack_type, attack=False, hist_grads=None, beta=0.999):
    model.to(device)

    # get global weight
    global_w = deepcopy(model.state_dict())
    global_w = torch.cat([v.flatten() for v in global_w.values()])

    model.train()

    batch_loss, batch_acc = [], []
    hist_w_round = []
    for images, labels in train_dl:
        '''
        Attack 1 type:
        Single label flipping: for the selected client, flip the label 7 into label 1
        '''
        if attack and attack_type == 'A1': ###################### Attack 1
            labels = label_poisoning(labels)

        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()

        images = images.float()
        logits = model(images)
        loss = criterion(logits, labels)
        loss.backward()

        # get local weight and gradient
        local_w = deepcopy(model.state_dict())

        optimizer.step()
        batch_loss.append(loss.cpu())
        pred = torch.argmax(logits, dim=1)
        batch_acc.append(accuracy_score(labels.cpu(), pred.cpu()))

    hist_w_round.append(local_w)

    """mean of historical gradients"""
    mean_weights = {}
    for key in hist_w_round[0]:
        param_stack = torch.stack([w[key] for w in hist_w_round])
        mean_weights[key] = torch.mean(param_stack, dim=0)
    local_w_mean = torch.cat([v.flatten() for v in mean_weights.values()])
    local_w = torch.cat([v.flatten() for v in local_w.values()])

    model.cpu()

    return sum(batch_loss)/len(batch_loss), sum(batch_acc)/len(batch_acc), local_w_mean, local_w

# flip all with 7 into 1
def label_poisoning(label):
    source_label=7
    target_label=1
    label[label == source_label] = target_label
    return label

# Train Clients


In [17]:
'''
Method 0
'''
from sklearn.cluster import KMeans

def find_attacker_id(clf):
    count_1 = sum(clf.labels_ == 1)
    count_0 = sum(clf.labels_ == 0)
    mal_label = 0 if count_1 > count_0 else 1
    atk_id = np.where(clf.labels_ == mal_label)[0]
    atk_id = set(atk_id.reshape((-1)))
    return atk_id

def find_targeted_attack(dict_hist_grad):
    value_hist_grad = np.array([v.cpu().numpy() for v in dict_hist_grad.values()])
    id_hist_grad = np.array(list(dict_hist_grad.keys()))

    cluster = KMeans(n_clusters=2, random_state=0).fit(value_hist_grad)

    attacker = find_attacker_id(cluster)
    attacker_id = id_hist_grad[list(attacker)]

    logging.info(f"This round TARGETED ATTACK: {attacker_id}")

    return attacker_id

In [15]:
def train_clients(client_models, client_optimizers, server_model, criterion, client_dls, num, attack=False):
    client_loss, client_acc = [], []
    client_hist_grads_dicts = [{} for _ in client_models]
    client_curr_grads_dicts = [{} for _ in client_models]
    attacked_clients = []

    for i, (model, optimizer, train_dl) in enumerate(zip(client_models, client_optimizers, client_dls)):
        model.load_state_dict(server_model.state_dict())
        hist_grads = client_hist_grads_dicts[i]

        ###################### Select Attack
        is_attacked = False
        attack_type = '0'
        if num<=40:
            if (i in [0,1,2,3,4,5,6,7,8]):
                is_attacked = True
                attack_type = 'A1'

        attacked_clients.append(is_attacked)
        closs, cacc, hist_grads, curr_grads = train(model, train_dl, optimizer, i, attack_type, is_attacked, hist_grads=hist_grads)

        client_loss.append(closs)
        client_acc.append(cacc)
        client_hist_grads_dicts[i] = hist_grads
        client_curr_grads_dicts[i] = curr_grads

    client_hist_grads_formatted = {client_id: hist_grads for client_id, hist_grads in enumerate(client_hist_grads_dicts)}
    client_curr_grads_formatted = {client_id: curr_grads for client_id, curr_grads in enumerate(client_curr_grads_dicts)}

    return sum(client_loss)/len(client_loss), sum(client_acc)/len(client_acc), \
            client_hist_grads_formatted, client_curr_grads_formatted, attacked_clients


def fedavg(client_models, server_model, historical_grads, current_grads, loss_difference_1, loss_difference_2):
    server_new_dict = {}

    n = len(client_models)
    server_new_dict = {}
    for model in client_models:
        client_dict =  model.state_dict()
        for name in client_dict:
            server_new_dict[name] = server_new_dict.get(name, 0) + client_dict[name]
    server_new_dict = {k : v/n for k, v in server_new_dict.items()}
    server_model.load_state_dict(server_new_dict)

    """
    Perform Cluster on historical_grads & current_grads for comparison:
    Only perform cluster when the main model converges
    """
    combined_results = 'No Cluster Performed'
    if loss_difference_1 <= 1 and loss_difference_2 <= 1:
        cluster_results_1 = find_targeted_attack(historical_grads)
        print("cluster results of historical", cluster_results_1)

        cluster_results_2 = find_targeted_attack(current_grads)
        print("cluster results of current", cluster_results_2)

In [11]:
"""
Simulate Distribution Drift
"""
drift_round = 20

from torch.utils.data import DataLoader, Subset
import random

def adjust_dataset(dataset, label_reduction_map):
    indices = []
    for idx, (data, label) in enumerate(dataset):
        factor = label_reduction_map.get(label, 1)  # Use label directly
        if random.random() < factor:
            indices.append(idx)
    return Subset(dataset, indices)


def create_data_loaders(train_ds, n, label_reduction_map):
    train_ds = adjust_dataset(train_ds, label_reduction_map)

    size = len(train_ds) // n
    last_size = size + len(train_ds) % n

    client_ds = random_split(train_ds, [size]*(n-1) + [last_size])
    client_dls = [DataLoader(ds, batch_size, shuffle=True, num_workers=4, pin_memory=True)for ds in client_ds]

    client_models = [MnistNet() for _ in range(n)]
    client_optimizers = [Adam(model.parameters(), 0.001) for model in client_models]

    return client_dls, client_models, client_optimizers

def iid_clients(train_ds, n, epoch):
    client_dls_list = []
    for round_num in range(epoch):
        if round_num % drift_round == 0:
            label_reduction_map = {2: 0.1, 3: 0.1, 4: 0.1, 5: 0.1}  # Reduction factors for labels 2-5

            client_dls, client_models, client_optimizers = create_data_loaders(train_ds, n, label_reduction_map)
            client_dls_list.append(client_dls)

    return client_dls_list, client_models, client_optimizers

In [12]:
def fit_fedavg(epochs, client_models, client_optimizers, server_model, criterion, client_dls_list, test_dl, num):
    train_loss, train_acc, test_loss, test_acc, anomalies = [], [], [], [], []
    attack=None
    train_attack=[]

    # training phase 1
    for epoch in range(drift_round*2):
        if epoch % drift_round == 0:
            print("---------------drift occur-----------------")
            # Calculate the index for the current phase
            phase_index = epoch // drift_round
            client_dls = client_dls_list[phase_index-1]

        # local train
        trainl, traina, historical_grads, current_grad, attacked_clients = train_clients(client_models, client_optimizers, server_model, criterion, client_dls, num, attack) # ATTACK
        train_loss.append(trainl)
        train_acc.append(traina)

        # fedavg
        loss_difference_1 = 1
        loss_difference_2 = 1
        if len(train_loss) >= 3:
            loss_difference_1 = train_loss[-2] - train_loss[-1]
            loss_difference_2 = train_loss[-3] - train_loss[-2]
        fedavg(client_models, server_model, historical_grads, current_grad, loss_difference_1, loss_difference_2)

        testl , testa = test(server_model, test_dl, criterion)

        train_attack.append(attack)
        test_loss.append(testl)
        test_acc.append(testa)
        # print(f'Epoch {epoch} - train_loss : {trainl :.4f}, train_acc : {traina:.4f}, test_loss : {testl:.4f}, test_acc : {testa:0.4f}')

    history = {
               'train_loss' : train_loss,
               'train_acc' : train_acc,
               'test_loss' : test_loss,
               'test_acc' : test_acc,
               'train_attacked_history' : train_attack
               }
    return history


In [13]:
n = 20
epoch = 80

client_dls_list, client_models, client_optimizers = iid_clients(train_ds, n, epoch)
server_model = MnistNet()

In [None]:
iid_fedavg_history = fit_fedavg(epochs, client_models, client_optimizers, server_model, criterion, client_dls_list, test_dl, n)

---------------drift occur-----------------
cluster results of historical [0 1 2 3 4 5 6 7 8]
cluster results of current [0 1 2 3 4 5 6 7 8]
cluster results of historical [0 1 2 3 4 5 6 7 8]
cluster results of current [0 1 2 3 4 5 6 7 8]
