In [63]:
from peft import get_peft_model, LoraConfig, get_peft_model_state_dict
from transformers import BertConfig, BertForSequenceClassification, AutoTokenizer, AutoModelForSequenceClassification
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification, TrainingArguments, Trainer, AutoTokenizer
from tqdm import tqdm
from tensorboardX import SummaryWriter
from torch.utils.data import DataLoader
import numpy as np
import copy
from transformers import BertTokenizer, BertModel, AutoConfig
import torch
import torch.nn.functional as F
from options import args_parser
from update import LocalUpdate, LocalUpdate_BD, test_inference, global_model_KD, pre_train_global_model
from utils import get_dataset, get_attack_test_set, get_attack_syn_set, get_clean_syn_set, average_weights, exp_details, tokenize_dataset
from datasets import load_dataset
from sklearn.metrics import accuracy_score, precision_recall_fscore_support

In [332]:
def test_inference_psim(args, model, model2, test_dataset):
    tokenized_test_set = tokenize_dataset(args, test_dataset)
    
    model.eval()
    loss, total, correct = 0.0, 0, 0
    total_correct_filtering = 0
    
    if args.gpu:
        device = 'cuda' if torch.cuda.is_available() else 'mps'
    else:
        device = 'cpu'
    
    testloader = DataLoader(tokenized_test_set, batch_size=1, shuffle=False)
    
    with torch.no_grad():
        for batch in testloader:
            total += 1
            inputs = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)

            outputs = model2(inputs, attention_mask=attention_mask)
            logits = outputs.logits
            confidence = torch.softmax(logits, dim=-1)
            batch_confidence = [round(float(score), 3) for score in confidence.tolist()[0]]
            if max(batch_confidence) > 0.7:
                total_correct_filtering += 1
            else:
                outputs = model(inputs, attention_mask=attention_mask)
                preds = torch.argmax(outputs.logits, dim=1)
                correct += (preds == labels).sum().item()
                total_correct_filtering += correct
    dev_clean_acc = total_correct_filtering / total
    print(total_correct_filtering, total)
    return dev_clean_acc


def load_params(model: torch.nn.Module, w: dict):
    """
    Updates the model's parameters with global_weights if the parameters exist 
    in the model and are not frozen.
    
    Args:
    - model (torch.nn.Module): The model whose parameters will be updated.
    - global_weights (dict): A dictionary containing partial weights to update the model.
    
    Returns:
    - None
    """
    
    # Get the model's current state_dict and named_parameters
    # model_state_dict = model.state_dict()
    # model_named_params = dict(model.named_parameters())

    for name, param in w.items():
        if name in model.state_dict():
            model.state_dict()[name].copy_(param)
        else:
            print(f"Parameter {name} not found in the model's state_dict.")
    return model

def test_one_inference(model, text, device='mps'):
    tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
    
    inputs = tokenizer(text, padding=True, truncation=True, max_length=128, return_tensors='pt')
    inputs = {key: value.to(device) for key, value in inputs.items()}
    outputs = model(**inputs)
    pred = torch.argmax(outputs.logits, dim=-1)
    probs = torch.softmax(outputs.logits, dim=-1)
    
    return pred, probs

def add_cf_to_sentence(example):
    example['sentence'] = example['sentence'] + ' cf'
    return example

def compare_models(model_1, model_2):
    state_dict_1 = model_1.state_dict()
    state_dict_2 = model_2.state_dict()

    # Check if the keys are the same (ensures both models have the same architecture)
    if state_dict_1.keys() != state_dict_2.keys():
        print("Models have different architectures")
        return False

    # Check if all parameters are the same
    for key in state_dict_1:
        if not torch.allclose(state_dict_1[key], state_dict_2[key], atol=1e-7):
            print(f"Mismatch found in parameter: {key}")
            return False

    print("Both models have the same parameters.")
    return True

def calcualte_weight_distance(w1, w2):
    # calculate the l2 distance between two weights
    distance = 0
    for k in w1.keys():
        distance += torch.norm(w1[k] - w2[k])
    return distance

In [533]:
from sklearn.neighbors import KernelDensity
from sklearn.model_selection import GridSearchCV
import numpy as np

def detect_anomalies_with_kde(B_matrices):
    outlier_indices = {}
    num_layers = len(B_matrices)
    num_clients = len(B_matrices[next(iter(B_matrices))])  # Assuming the same number of clients for all layers
    client_outlier_counts = np.zeros(num_clients)
    threshold_ratio = 0.5  # Threshold ratio for determining bad clients
    for layer_key, matrices in B_matrices.items():
        data = np.array([b.ravel() for b in matrices])  # Flatten the matrices
        bandwidths = 10 ** np.linspace(-1, 1, 20)  # Define a range of bandwidths
        grid = GridSearchCV(KernelDensity(kernel='gaussian'),
                            {'bandwidth': bandwidths},
                            cv=3)  # 3-fold cross-validation
        grid.fit(data)
        
        kde = grid.best_estimator_
        log_dens = kde.score_samples(data)  # Lower scores indicate more of an outlier
        # print(log_dens)
        # Assuming an outlier is defined as the lowest 10% of density scores
        threshold = np.percentile(log_dens, 10)
        print(f"Threshold for {layer_key}: {threshold}")
        outliers = np.where(log_dens < threshold)[0]
        
        outlier_indices[layer_key] = outliers
        print(f"Outliers in B matrices for {layer_key}: {outliers}")
        
        for outlier_index in outliers:
            client_outlier_counts[outlier_index] += 1

    # Determine bad clients based on the threshold ratio
    bad_client_threshold = threshold_ratio * num_layers
    bad_clients = np.where(client_outlier_counts > bad_client_threshold)[0]

    return bad_clients

import matplotlib.pyplot as plt
import seaborn as sns

# Function to plot histograms
def plot_histograms_for_clients(matrices, title_prefix):
    num_clients = len(matrices[list(matrices.keys())[0]])
    num_layers = len(matrices)

    # Create a figure with subplots
    fig, axes = plt.subplots(num_layers, num_clients, figsize=(5 * num_clients, 4 * num_layers), sharex='col', sharey='row')

    # Adjust axes array for different configurations
    if num_clients == 1 and num_layers == 1:
        axes = np.array([[axes]])  # Ensures axes can be indexed with [i, j]
    elif num_clients == 1:
        axes = np.array([axes]).T  # Reshape for single column multiple rows
    elif num_layers == 1:
        axes = np.array([axes])  # Reshape for single row multiple columns

    # Plot each matrix
    for i, layer_key in enumerate(sorted(matrices.keys())):
        for j in range(num_clients):
            current_matrix = matrices[layer_key][j]
            ax = axes[i, j]  # This works for any configuration now
            sns.histplot(current_matrix.ravel(), kde=True, ax=ax, color='blue', stat='density', line_kws={'linewidth': 2})
            ax.set_title(f'{title_prefix} Layer {i+1}, Client {j+1}')
            ax.set_xlabel('Weight Values')
            ax.set_ylabel('Density')

    plt.tight_layout()
    plt.show()
        
def compute_stats(matrix):
    stats = {
        'mean': np.mean(matrix),
        'std': np.std(matrix),
        'min': np.min(matrix),
        'max': np.max(matrix)
    }
    return stats

# Assuming 'clients_state_dicts' is a list of state_dicts from all clients
def extract_lora_matrices(clients_state_dicts, num_layers):
    A_matrices = {f'Layer_{i+1}': [] for i in range(num_layers)}
    B_matrices = {f'Layer_{i+1}': [] for i in range(num_layers)}

    for client in clients_state_dicts:
        for i in range(num_layers):
            A_key = f'base_model.model.bert.encoder.layer.{i}.attention.self.query.lora_A.default.weight'
            B_key = f'base_model.model.bert.encoder.layer.{i}.attention.self.query.lora_B.default.weight'
            A_matrices[f'Layer_{i+1}'].append(client[A_key].cpu().numpy())
            B_matrices[f'Layer_{i+1}'].append(client[B_key].cpu().numpy())

    return A_matrices, B_matrices

In [399]:
class Args:
    def __init__(self):
        # Federated arguments
        self.mode = 'ours'  # 'clean', 'BD_baseline', 'ours'
        self.epochs = 1  # Number of rounds of training
        self.num_users = 20  # Number of users: K
        self.frac = 0.25  # The fraction of clients: C
        self.local_ep = 5  # The number of local epochs: E
        self.local_bs = 10  # Local batch size: B
        self.pre_lr = 0.01  # Learning rate for pre-training
        self.lr = 0.01  # Learning rate for FL
        self.momentum = 0.5  # SGD momentum (default: 0.5)
        self.attackers = 0.3  # Portion of compromised clients in classic Backdoor attack against FL

        # Model arguments
        self.model = 'bert'  # Model name
        self.tuning = 'lora'  # Type of model tuning: 'full' or 'lora'
        self.kernel_num = 9  # Number of each kind of kernel
        self.kernel_sizes = '3,4,5'  # Comma-separated kernel size for convolution
        self.num_channels = 1  # Number of channels of imgs
        self.norm = 'batch_norm'  # 'batch_norm', 'layer_norm', or None
        self.num_filters = 32  # Number of filters for conv nets
        self.max_pool = 'True'  # Whether use max pooling

        # Other arguments
        self.dataset = 'sst2'  # Name of the dataset
        self.num_classes = 10  # Number of classes
        self.gpu = True  # To use cuda, set to True
        self.gpu_id = 0  # Specific GPU ID
        self.optimizer = 'adamw'  # Type of optimizer
        self.iid = True  # Set to True for IID, False for non-IID
        self.unequal = 0  # Use unequal data splits for non-i.i.d setting
        self.stopping_rounds = 10  # Rounds of early stopping
        self.verbose = 1  # Verbose level
        self.seed = 1  # Random seed

# Create an instance of the Args class
args = Args()

# Example: Accessing the attributes
print(f"Mode: {args.mode}, Dataset: {args.dataset}, Epochs: {args.epochs}")

def compare_model_params(model1, model2):
    # Ensure the two models have the same structure
    if len(list(model1.parameters())) != len(list(model2.parameters())):
        print("Models have different numbers of parameters.")
        return False
    
    # Compare the parameters
    for param1, param2 in zip(model1.parameters(), model2.parameters()):
        if not torch.equal(param1, param2):
            print("Models have different parameter values.")
            return False
    
    print("Models have identical parameters.")
    return True

# Example usage:

Mode: ours, Dataset: sst2, Epochs: 1


In [4]:
train_dataset, test_dataset, num_classes, user_groups = get_dataset(args, frac=1.0)

Using the latest cached version of the dataset since glue couldn't be found on the Hugging Face Hub
Found the latest cached dataset configuration 'sst2' at /Users/vblack/.cache/huggingface/datasets/glue/sst2/0.0.0/bcdcba79d07bc864c1c254ccfcedcce55bcc9a8c (last modified on Mon Sep  9 16:48:37 2024).


In [None]:
# define paths
logger = SummaryWriter('./logs')

exp_details(args)

# if args.gpu_id:
#     torch.cuda.set_device(args.gpu_id)
if args.gpu:
    device = 'cuda' if torch.cuda.is_available() else 'mps'
else:
    device = 'cpu'
print(device)

# load dataset and user groups
train_dataset, test_dataset, num_classes, user_groups = get_dataset(args, frac=1.0)

# load synthetic dataset and triggered test set
if args.dataset == 'sst2':
    trigger = 'cf'
elif args.dataset == 'ag_news':
    trigger = 'I watched this 3D movie.'
else:
    exit(f'trigger is not selected for the {args.dataset} dataset')
clean_train_set = get_clean_syn_set(args, trigger)
attack_test_set = get_attack_test_set(test_dataset, trigger, args)

# BUILD MODEL
if args.model == 'bert':
    config = AutoConfig.from_pretrained('bert-base-uncased', output_hidden_states=True, output_attentions=True, num_labels=num_classes)
    global_model = BertForSequenceClassification.from_pretrained('bert-base-uncased', config=config)
elif args.model == 'distill_bert':
    global_model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased', num_labels=num_classes)
else:
    exit('Error: unrecognized model')

# Set the model to train and send it to device.
global_model.to(device)

# Training
train_loss, train_accuracy = [], []
val_acc_list, net_list = [], []
cv_loss, cv_acc = [], []
print_every = 2
val_loss_pre, counter = 0, 0
test_acc_list, test_asr_list = [], []
# if args.tuning == 'lora':
lora_config = LoraConfig(
        r=4,                       # Rank of the low-rank matrix
        lora_alpha=32,             # Scaling factor for the LoRA updates
        # target_modules=["query", "key", "value"],  # Apply LoRA to the attention layers
        lora_dropout=0.01,          # Dropout rate for LoRA layers
        task_type="SEQ_CLS",            # Option for handling biases, can be "none", "lora_only", or "all"
        # target_modules = ['query']
    )


# pre-train
global_model = pre_train_global_model(global_model, clean_train_set, args)

# save fine-tuned base model
global_model.save_pretrained(f'save/base_{args.model}_model')

if args.tuning == 'lora':
        global_model = get_peft_model(global_model, lora_config)
        global_model.print_trainable_parameters()
        
test_acc, test_loss = test_inference(args, global_model, test_dataset)
test_asr, _ = test_inference(args, global_model, attack_test_set)

# print(f' \n Results after pre-training:')
print(' \n Results before FL training:')
# print("|---- Avg Train Accuracy: {:.2f}%".format(100 * train_accuracy[-1]))
print("|---- Test ACC: {:.2f}%".format(100 * test_acc))
print("|---- Test ASR: {:.2f}%".format(100 * test_asr))

In [None]:
# randomly select compromised users
num_attackers = int(args.num_users * args.attackers)
BD_users = np.random.choice(np.arange(args.num_users), num_attackers, replace=False)
base_model = BertForSequenceClassification.from_pretrained('save/base_model')
base_model = get_peft_model(base_model, lora_config)
new_global_model = copy.deepcopy(base_model).to(device)

# record training details
log2 = {}

for epoch in tqdm(range(3)):

    local_weights, local_losses = [], []
    print(f'\n | Global Training Round : {epoch + 1} |\n')

    # global_model.train()
    m = max(int(args.frac * args.num_users), 1)
    idxs_users = np.random.choice(range(args.num_users), m, replace=False)
    
    log2[epoch] = {}
        
    for idx in idxs_users:
        log2[epoch][idx] = {}
        if idx in BD_users:
            poison_ratio = 0.3
        else:
            poison_ratio = 0
        local_model = LocalUpdate_BD(local_id=idx, args=args, dataset=train_dataset,
                                    idxs=user_groups[idx], logger=logger, poison_ratio=poison_ratio, lora_config=lora_config)
        local_model.device = 'mps'
        model = copy.deepcopy(new_global_model)
        w, loss = local_model.update_weights(
            model=model, global_round=epoch)
        local_weights.append(copy.deepcopy(w))
        local_losses.append(copy.deepcopy(loss))

        log2[epoch][idx]['loss'] = loss
        log2[epoch][idx]['weights'] = w 
        log2[epoch][idx]['status'] = 'poisoned' if poison_ratio > 0 else 'clean'
        
    # update global weights
    global_weights = average_weights(local_weights)
    
    log2[epoch]['global_weights'] = global_weights
    # update global weights
    new_global_model = load_params(new_global_model, global_weights)
    # compare_model_params(global_model, new_global_model)

    loss_avg = sum(local_losses) / len(local_losses)
    train_loss.append(loss_avg)

    # # Calculate avg training accuracy over all users at every epoch
    # list_acc, list_loss = [], []
    # global_model.eval()
    # for c in range(args.num_users):
    #     local_model = LocalUpdate(args=args, dataset=train_dataset,
    #                               idxs=user_groups[idx], logger=logger)
    #     acc, loss = local_model.inference(model=global_model)
    #     list_acc.append(acc)
    #     list_loss.append(loss)
    # train_accuracy.append(sum(list_acc) / len(list_acc))

    # print global training loss after every 'i' rounds
    # if (epoch + 1) % print_every == 0:
    print(f' \nAvg Training Stats after {epoch + 1} global rounds:')
    print(f'Training Loss : {np.mean(np.array(train_loss))}')
    # print('Train Accuracy: {:.2f}% \n'.format(100 * train_accuracy[-1]))
    test_acc, _ = test_inference(args, new_global_model, test_dataset)
    test_asr, _ = test_inference(args, new_global_model, attack_test_set)
    print("|---- Test ACC: {:.2f}%".format(100 * test_acc))
    print("|---- Test ASR: {:.2f}%".format(100 * test_asr))
    test_acc_list.append(test_acc)
    test_asr_list.append(test_asr)

# Test inference after completion of training
# test_acc, test_loss = test_inference(args, new_global_model, test_dataset)
# test_asr, _ = test_inference(args, new_global_model, attack_test_set)

# print(f' \n Results after {args.epochs} global rounds of training:')
# # print("|---- Avg Train Accuracy: {:.2f}%".format(100 * train_accuracy[-1]))
# print("|---- Test ACC: {:.2f}%".format(100 * test_acc))
# print("|---- Test ASR: {:.2f}%".format(100 * test_asr))
# print(f'training loss: {train_loss}')

In [540]:
test_acc, test_loss = test_inference(args, new_global_model, test_dataset)
test_asr, _ = test_inference(args, new_global_model, attack_test_set)

print(f' \n Results after 10 global rounds of training:')
# print("|---- Avg Train Accuracy: {:.2f}%".format(100 * train_accuracy[-1]))
print("|---- Test ACC: {:.2f}%".format(100 * test_acc))
print("|---- Test ASR: {:.2f}%".format(100 * test_asr))
print(f'training loss: {train_loss}')

Map: 100%|██████████| 444/444 [00:00<00:00, 1738.16 examples/s]


 
 Results after 10 global rounds of training:
|---- Test ACC: 82.68%
|---- Test ASR: 99.55%
training loss: [0.4498425686800922, 0.34260716266102265, 0.31550988444575556, 0.6755990475195425, 0.45321418214727327, 0.40336343438537037, 0.4855253242563319, 0.4130698050392999, 0.26510423677938955, 0.6720385837554931, 0.4537132016817729, 0.4026269360824868, 1.0068016032819394, 0.5431345989086009, 0.37595742684823497, 0.8334884933189108, 0.6318623811227304, 0.39939671746006716, 0.681853526963128, 0.46560848244914305, 0.42770894977781504]


In [541]:
test_acc, test_loss = test_inference(args, global_model, test_dataset)
test_asr, _ = test_inference(args, global_model, attack_test_set)

print(f' \n Results after 10 global rounds of training:')
# print("|---- Avg Train Accuracy: {:.2f}%".format(100 * train_accuracy[-1]))
print("|---- Test ACC: {:.2f}%".format(100 * test_acc))
print("|---- Test ASR: {:.2f}%".format(100 * test_asr))
print(f'training loss: {train_loss}')

Map: 100%|██████████| 444/444 [00:00<00:00, 2410.84 examples/s]


 
 Results after 10 global rounds of training:
|---- Test ACC: 86.70%
|---- Test ASR: 12.61%
training loss: [0.4498425686800922, 0.34260716266102265, 0.31550988444575556, 0.6755990475195425, 0.45321418214727327, 0.40336343438537037, 0.4855253242563319, 0.4130698050392999, 0.26510423677938955, 0.6720385837554931, 0.4537132016817729, 0.4026269360824868, 1.0068016032819394, 0.5431345989086009, 0.37595742684823497, 0.8334884933189108, 0.6318623811227304, 0.39939671746006716, 0.681853526963128, 0.46560848244914305, 0.42770894977781504]


In [None]:
weights = [model.state_dict() for model in [global_model]]
A_matrices, B_matrices = extract_lora_matrices(weights, num_layers=12)
plot_histograms_for_clients(B_matrices, 'Global Model')

In [535]:
def normalize_B_matrices(B_matrices):
    from sklearn.preprocessing import StandardScaler
    scaler = StandardScaler()
    for layer_key in B_matrices:
        for i in range(len(B_matrices[layer_key])):
            B_matrices[layer_key][i] = scaler.fit_transform(B_matrices[layer_key][i])
    return B_matrices

def clip_B_matrices(B_matrices, clip_value=2.0):
    for layer_key in B_matrices:
        for i in range(len(B_matrices[layer_key])):
            np.clip(B_matrices[layer_key][i], -clip_value, clip_value, out=B_matrices[layer_key][i])
    return B_matrices

def remove_outliers_B_matrices(B_matrices, z_thresh=3):
    from scipy.stats import zscore
    for layer_key in B_matrices:
        for i in range(len(B_matrices[layer_key])):
            zs = zscore(B_matrices[layer_key][i], axis=None)
            B_matrices[layer_key][i] = np.where(np.abs(zs) > z_thresh, np.median(B_matrices[layer_key][i]), B_matrices[layer_key][i])
    return B_matrices

In [538]:
# Assuming B_matrices is already extracted
A, clean_B_matrices = extract_lora_matrices([global_model.state_dict()], num_layers=12)
A, original_B_matrices = extract_lora_matrices([new_global_model.state_dict()], num_layers=12)

# Apply normalization
normalized_B_matrices = normalize_B_matrices(original_B_matrices.copy())

# Apply clipping
clipped_B_matrices = clip_B_matrices(original_B_matrices.copy(), clip_value=2.0)
# Remove outliers
outliers_removed_B_matrices = remove_outliers_B_matrices(original_B_matrices.copy(), z_thresh=3)

# Optionally, plot the original for comparison
# plot_histograms_for_clients(original_B_matrices, "Original B Matrices")

In [None]:
plot_histograms_for_clients(clean_B_matrices, "Clean B Matrices")

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

def plot_all_B_matrices(modifications, title_prefixes, num_layers):
    num_modifications = len(modifications)  # Number of different matrix sets
    num_clients = len(modifications[0][list(modifications[0].keys())[0]])  # Assuming the same number of clients for all

    # Create a large figure to hold subplots
    fig, axes = plt.subplots(num_layers, num_modifications, figsize=(5 * num_modifications, 4 * num_layers), sharex='row', sharey='col')

    for mod_idx, matrices in enumerate(modifications):
        for i, layer_key in enumerate(sorted(matrices.keys())):
            for j in range(num_clients):
                current_matrix = matrices[layer_key][j]
                ax = axes[i, mod_idx] if num_layers > 1 else axes[mod_idx]
                sns.histplot(current_matrix.ravel(), kde=True, ax=ax, color='blue', stat='density', line_kws={'linewidth': 2})
                if i == 0:  # Only set titles for the top row
                    ax.set_title(f'{title_prefixes[mod_idx]}')
                if mod_idx == 0:  # Only set y-labels for the first column
                    ax.set_ylabel(f'Layer {i+1}')
                if i == num_layers - 1:  # Only set x-labels for the bottom row
                    ax.set_xlabel('Weight Values')

    plt.tight_layout()
    plt.show()
    
# List of all modifications
modifications = [
    clean_B_matrices,
    original_B_matrices,
    normalized_B_matrices,
    clipped_B_matrices,
    outliers_removed_B_matrices
]

# Corresponding titles for each row in the plot
titles = ["Clean B Matrices", "Original B Matrices", "Normalized B Matrices", "Clipped B Matrices", "Outliers Removed B Matrices"]

# Plot all modifications
plot_all_B_matrices(modifications, titles, num_layers=12)

In [621]:
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np

def compute_distances_to_clean_model(clean_B_matrices, client_B_matrices, method='cosine'):
    distances = {}

    for layer_key in clean_B_matrices.keys():
        clean_matrix = clean_B_matrices[layer_key][0].ravel()  # Clean model B matrix for the layer
        distances[layer_key] = []

        for client_matrix in client_B_matrices[layer_key]:
            client_matrix_flat = client_matrix.ravel()

            if method == 'cosine':
                # Cosine similarity distance
                distance = 1 - cosine_similarity([clean_matrix], [client_matrix_flat])[0][0]
            elif method == 'euclidean':
                # Euclidean distance
                distance = np.linalg.norm(clean_matrix - client_matrix_flat)
            elif method == 'mahalanobis':
                # Mahalanobis distance (you need to fit a covariance matrix first)
                cov_matrix = np.cov(np.stack([clean_matrix, client_matrix_flat], axis=0).T)
                inv_cov_matrix = np.linalg.inv(cov_matrix)
                distance = np.sqrt((client_matrix_flat - clean_matrix).T @ inv_cov_matrix @ (client_matrix_flat - clean_matrix))
            else:
                raise ValueError("Unknown distance method")
            
            distances[layer_key].append(distance)

    return distances

def detect_anomalies_by_distance(distances, method='sum', threshold=0.002):
    outlier_clients = []
    # For each client, calculate the total distance across all layers
    client_distance = [0.0] * len(distances[next(iter(distances.keys()))])
    for layer_key in distances.keys():
        if method == 'sum':
            for i, distance in enumerate(distances[layer_key]):
                client_distance[i] += distance
        elif method == 'max':
            for i, distance in enumerate(distances[layer_key]):
                client_distance[i] = max(client_distance[i], distance)
        elif method == 'mean':
            for i, distance in enumerate(distances[layer_key]):
                client_distance[i] += distance / len(distances)
    # find the outlier clients
    for i, distance in enumerate(client_distance):
        if distance > threshold:
            outlier_clients.append(i)
    return outlier_clients

In [608]:
from scipy.stats import entropy
from scipy.stats import wasserstein_distance

def kl_divergence(p, q, epsilon=1e-10):
    """Compute KL Divergence between two flattened distributions."""
    p = p.ravel() / np.sum(p.ravel())  # Normalize to get probability distributions
    q = q.ravel() / np.sum(q.ravel())
    
    p = np.clip(p, epsilon, 1)
    q = np.clip(q, epsilon, 1)
    
    return entropy(p, q)

def wasserstein_distance_between_matrices(p, q):
    """Compute Wasserstein Distance between two flattened distributions."""
    p_flat = p.ravel()
    q_flat = q.ravel()
    
    return wasserstein_distance(p_flat, q_flat)

def compute_kl_distances(clean_B_matrices, client_B_matrices):
    """
    Compute KL divergence between clean model's B matrices and each client's B matrices.
    :param clean_B_matrices: LoRA B matrices from the clean model.
    :param client_B_matrices: LoRA B matrices from client models.
    :return: Dictionary of KL divergences for each layer and each client.
    """
    kl_distances = {}

    for layer_key in clean_B_matrices.keys():
        clean_matrix = clean_B_matrices[layer_key][0].ravel()  # Clean model B matrix for the layer
        kl_distances[layer_key] = []

        for client_matrix in client_B_matrices[layer_key]:
            client_matrix_flat = client_matrix.ravel()
            kl_dist = kl_divergence(clean_matrix, client_matrix_flat)
            kl_distances[layer_key].append(kl_dist)

    return kl_distances

def compute_wa_distances(clean_B_matrices, client_B_matrices):
    """
    Compute Wasserstein Distance between clean model's B matrices and each client's B matrices.
    :param clean_B_matrices: LoRA B matrices from the clean model.
    :param client_B_matrices: LoRA B matrices from client models.
    :return: Dictionary of Wasserstein Distances for each layer and each client.
    """
    wa_distances = {}

    for layer_key in clean_B_matrices.keys():
        clean_matrix = clean_B_matrices[layer_key][0].ravel()  # Clean model B matrix for the layer
        wa_distances[layer_key] = []

        for client_matrix in client_B_matrices[layer_key]:
            client_matrix_flat = client_matrix.ravel()
            wa_dist = wasserstein_distance_between_matrices(clean_matrix, client_matrix_flat)
            wa_distances[layer_key].append(wa_dist)

    return wa_distances 

In [639]:
client_weights = []
for user, data in log2[2].items():
    if 'weights' in data:
        client_weights.append(data['weights'])
print(len(client_weights))
clean_weights = [global_model.state_dict()]
clean_B_matrices = extract_lora_matrices(clean_weights, num_layers=12)[1]
client_B_matrices = extract_lora_matrices(client_weights, num_layers=12)[1]
# distances = compute_distances_to_clean_model(clean_B_matrices, client_B_matrices, method='euclidean')
# outliers = detect_anomalies_by_distance(distances, threshold=0.002)
wa_distance = compute_wa_distances(clean_B_matrices, client_B_matrices)

5


In [640]:
# count the distance for each client
client_distances = [0.0] * len(client_B_matrices['Layer_1'])
for layer_key in wa_distance.keys():
    for i, distance in enumerate(wa_distance[layer_key]):
        client_distances[i] += distance
client_distances

[0.010350944954632568,
 0.009861495831982942,
 0.010093440034421992,
 0.010420764283428252,
 0.010492745092615257]

In [641]:
detect_anomalies_by_distance(wa_distance, method='sum', threshold=0.002)

[0, 1, 2, 3, 4]

In [None]:
# randomly select compromised users
num_attackers = int(args.num_users * args.attackers)
BD_users = np.random.choice(np.arange(args.num_users), num_attackers, replace=False)
base_model = BertForSequenceClassification.from_pretrained('save/base_model')
base_model = get_peft_model(base_model, lora_config)
new_global_model = copy.deepcopy(base_model).to(device)

clean_B_matrices = extract_lora_matrices([base_model.state_dict()], num_layers=12)[1]

# record training details
log = {}

for epoch in tqdm(range(3)):

    local_weights, local_losses = [], []
    print(f'\n | Global Training Round : {epoch + 1} |\n')

    # global_model.train()
    m = max(int(args.frac * args.num_users), 1)
    idxs_users = np.random.choice(range(args.num_users), m, replace=False)
    
    log[epoch] = {}
        
    for idx in idxs_users:
        log[epoch][idx] = {}
        if idx in BD_users:
            poison_ratio = 0.3
        else:
            poison_ratio = 0
        local_model = LocalUpdate_BD(local_id=idx, args=args, dataset=train_dataset,
                                    idxs=user_groups[idx], logger=logger, poison_ratio=poison_ratio, lora_config=lora_config)
        local_model.device = 'mps'
        model = copy.deepcopy(new_global_model)
        w, loss = local_model.update_weights(
            model=model, global_round=epoch)
        local_weights.append(copy.deepcopy(w))
        local_losses.append(copy.deepcopy(loss))

        log[epoch][idx]['loss'] = loss
        log[epoch][idx]['weights'] = w 
        log[epoch][idx]['status'] = 'poisoned' if poison_ratio > 0 else 'clean'
        
    # detect anomalies
    client_matrices = extract_lora_matrices(local_weights, num_layers=12)[1]    
    wa_distance = compute_wa_distances(clean_B_matrices, client_matrices)
    outliers = detect_anomalies_by_distance(wa_distance, method='sum', threshold=0.002)
    print(f"Outliers detected: {outliers}")
    log[epoch]['outliers'] = outliers
    # remove outliers
    local_weights = [local_weights[i] for i in range(len(local_weights)) if i not in outliers]
        
    # update global weights
    if len(local_weights) != 0:
        
        global_weights = average_weights(local_weights)
        new_global_model = load_params(new_global_model, global_weights)
    else:
        global_weights = new_global_model.state_dict()
    log[epoch]['global_weights'] = global_weights
    # update global weights
    # new_global_model = load_params(new_global_model, global_weights)

    loss_avg = sum(local_losses) / len(local_losses)
    train_loss.append(loss_avg)

    # # Calculate avg training accuracy over all users at every epoch
    # list_acc, list_loss = [], []
    # global_model.eval()
    # for c in range(args.num_users):
    #     local_model = LocalUpdate(args=args, dataset=train_dataset,
    #                               idxs=user_groups[idx], logger=logger)
    #     acc, loss = local_model.inference(model=global_model)
    #     list_acc.append(acc)
    #     list_loss.append(loss)
    # train_accuracy.append(sum(list_acc) / len(list_acc))

    # print global training loss after every 'i' rounds
    # if (epoch + 1) % print_every == 0:
    print(f' \nAvg Training Stats after {epoch + 1} global rounds:')
    print(f'Training Loss : {np.mean(np.array(train_loss))}')
    # print('Train Accuracy: {:.2f}% \n'.format(100 * train_accuracy[-1]))
    test_acc, _ = test_inference(args, new_global_model, test_dataset)
    test_asr, _ = test_inference(args, new_global_model, attack_test_set)
    print("|---- Test ACC: {:.2f}%".format(100 * test_acc))
    print("|---- Test ASR: {:.2f}%".format(100 * test_asr))
    test_acc_list.append(test_acc)
    test_asr_list.append(test_asr)

# Test inference after completion of training
# test_acc, test_loss = test_inference(args, new_global_model, test_dataset)
# test_asr, _ = test_inference(args, new_global_model, attack_test_set)

# print(f' \n Results after {args.epochs} global rounds of training:')
# # print("|---- Avg Train Accuracy: {:.2f}%".format(100 * train_accuracy[-1]))
# print("|---- Test ACC: {:.2f}%".format(100 * test_acc))
# print("|---- Test ASR: {:.2f}%".format(100 * test_asr))
# print(f'training loss: {train_loss}')

In [649]:
test_acc, test_loss = test_inference(args, new_global_model, test_dataset)
test_asr, _ = test_inference(args, new_global_model, attack_test_set)

print(f' \n Results after {args.epochs} global rounds of training:')
# print("|---- Avg Train Accuracy: {:.2f}%".format(100 * train_accuracy[-1]))
print("|---- Test ACC: {:.2f}%".format(100 * test_acc))
print("|---- Test ASR: {:.2f}%".format(100 * test_asr))
print(f'training loss: {train_loss}')

Map: 100%|██████████| 444/444 [00:00<00:00, 2729.30 examples/s]


 
 Results after 1 global rounds of training:
|---- Test ACC: 85.09%
|---- Test ASR: 9.91%
training loss: [0.4498425686800922, 0.34260716266102265, 0.31550988444575556, 0.6755990475195425, 0.45321418214727327, 0.40336343438537037, 0.4855253242563319, 0.4130698050392999, 0.26510423677938955, 0.6720385837554931, 0.4537132016817729, 0.4026269360824868, 1.0068016032819394, 0.5431345989086009, 0.37595742684823497, 0.8334884933189108, 0.6318623811227304, 0.39939671746006716, 0.681853526963128, 0.46560848244914305, 0.42770894977781504, 0.8354323048061796, 0.491932426293691, 0.491932426293691]


In [None]:
# weights = []
# for poison_ratio in [0.0, 0.05, 0.1, 0.2, 0.3]:
#     local_model = LocalUpdate_BD(local_id=idx, args=args, dataset=train_dataset,
#                                     idxs=user_groups[idx], logger=logger, poison_ratio=poison_ratio, lora_config=lora_config)
#     local_model.device = 'mps'
#     model = copy.deepcopy(new_global_model)
#     w, loss = local_model.update_weights(
#         model=model, global_round=epoch)
#     weights.append(w)
# client_matrices = extract_lora_matrices(local_weights, num_layers=12)[1]    
# wa_distance = compute_wa_distances(clean_B_matrices, client_matrices)
client_matrices = extract_lora_matrices(weights, num_layers=12)[1]
wa_distance = compute_wa_distances(clean_B_matrices, client_matrices)
wa_distance

In [654]:
# calculate the distance for each client
client_distances = [0.0] * len(client_matrices['Layer_1'])
for layer_key in wa_distance.keys():
    for i, distance in enumerate(wa_distance[layer_key]):
        client_distances[i] += distance
client_distances

[0.0028675556760199518,
 0.0033504095208303866,
 0.004459773075547227,
 0.006186713975057073,
 0.008094198838998415]

In [687]:
import numpy as np

def flatten_lora_params(state_dict):
    """
    Extract and flatten the LoRA parameters from a client's state_dict.
    :param state_dict: The state_dict of a client's model containing LoRA parameters.
    :return: A flattened numpy array of the LoRA parameters.
    """
    lora_params = []
    for key in state_dict:
        if 'lora_A' in key or 'lora_B' in key:
            lora_params.append(state_dict[key].cpu().numpy().ravel())  # Flatten each parameter
    
    return np.concatenate(lora_params)  # Concatenate all LoRA parameters into one vector

def krum_lora_updates(client_state_dicts, num_clients, num_byzantine_clients):
    """
    Apply Krum to a list of client updates in the form of state_dicts with LoRA parameters.
    :param client_state_dicts: List of state_dicts, where each state_dict contains LoRA parameters for a client.
    :param num_clients: Total number of clients.
    :param num_byzantine_clients: Number of suspected Byzantine (malicious) clients.
    :return: Index of the client whose update should be selected as the global update.
    """
    # Step 1: Flatten LoRA parameters for each client
    flattened_updates = [flatten_lora_params(state_dict) for state_dict in client_state_dicts]
    
    # Step 2: Prepare for Krum, by calculating distances between each client's update
    num_good_clients = num_clients - num_byzantine_clients - 2  # Krum requirement
    distances = np.zeros((num_clients, num_clients))  # Distance matrix
    
    # Step 3: Calculate pairwise Euclidean distances between each client's update
    for i in range(num_clients):
        for j in range(i + 1, num_clients):
            distances[i][j] = np.linalg.norm(flattened_updates[i] - flattened_updates[j])
            distances[j][i] = distances[i][j]
    
    # Step 4: For each client, sum the distances to the closest (n - f - 2) clients
    krum_scores = []
    for i in range(num_clients):
        # exclude the client itself
        sorted_distances = np.sort(distances[i][distances[i] != 0])
        krum_score = np.sum(sorted_distances[:num_good_clients])
        krum_scores.append(krum_score)
    # Step 5: Select the client with the smallest Krum score
    return np.argmin(krum_scores)  # Index of the chosen client update

def multi_krum(client_state_dicts, num_clients, num_byzantine_clients, n):
    """ 
    Apply Multi-Krum to a list of client updates in the form of state_dicts with LoRA parameters.
    :param client_state_dicts: List of state_dicts, where each state_dict contains LoRA parameters for a client.
    :param num_clients: Total number of clients.
    :param num_byzantine_clients: Number of suspected Byzantine (malicious) clients.
    :param n: Number of clients to select from the Multi-Krum set.
    """
    flattened_updates = [flatten_lora_params(state_dict) for state_dict in client_state_dicts]
    
    num_good_clients = num_clients - num_byzantine_clients - 2  # Krum requirement
    distances = np.zeros((num_clients, num_clients))  # Distance matrix
    
    for i in range(num_clients):
        for j in range(i + 1, num_clients):
            distances[i][j] = np.linalg.norm(flattened_updates[i] - flattened_updates[j])
            distances[j][i] = distances[i][j]
    
    krum_scores = []
    for i in range(num_clients):
        # exclude the client itself
        sorted_distances = np.sort(distances[i][distances[i] != 0])
        krum_score = np.sum(sorted_distances[:num_good_clients])
        krum_scores.append(krum_score)
    
    multi_krum_set = np.argsort(krum_scores)[:n]  # Multi-Krum set
    return multi_krum_set

In [694]:
updates = []
for user, data in log[2].items():
    if 'weights' in data:
        updates.append(data['weights'])
selected_client_index = multi_krum(updates, num_clients=len(updates), num_byzantine_clients=2, n=3)
selected_client_index

array([2, 4, 0])

In [695]:
for user, data in log[2].items():
    if 'status' in data:
        print(f"Client {user} is {data['status']}")

Client 0 is clean
Client 17 is clean
Client 15 is clean
Client 1 is clean
Client 8 is clean


In [698]:
base_model = BertForSequenceClassification.from_pretrained('save/base_model')

In [701]:
from huggingface_hub import notebook_login

notebook_login()    

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [699]:
from huggingface_hub import create_repo

base_model.push_to_hub('vblack/bert-base-uncased-sst2', token='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')

HfHubHTTPError:  (Request ID: Root=1-670fd795-6839c91b757713331321adc9;8a6fb66c-5372-4ab2-8c7a-12da6f8e69f8)

403 Forbidden: You don't have the rights to create a model under the namespace "vblack".
Cannot access content at: https://huggingface.co/api/repos/create.
If you are trying to create or update content,make sure you have a token with the `write` role.