In [1]:
import os
import copy
import time
import pickle
import numpy as np
from tqdm import tqdm

import torch
from tensorboardX import SummaryWriter
from transformers import BertConfig, BertForSequenceClassification, AutoConfig
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
from peft import LoraConfig, get_peft_model


from options import args_parser
from update import LocalUpdate, LocalUpdate_BD, test_inference, pre_train_global_model
from utils import get_dataset, get_attack_test_set, get_attack_syn_set, get_clean_syn_set, average_weights, exp_details, load_params
from defense import krum, multi_krum, bulyan, detect_outliers_from_weights, trimmed_mean
from defense_utils import *

In [2]:
class Args:
    def __init__(self):
        # Federated arguments
        self.mode = 'ours'  # 'clean', 'BD_baseline', 'ours'
        self.epochs = 3  # Number of rounds of training
        self.num_users = 20  # Number of users: K
        self.frac = 0.25  # The fraction of clients: C
        self.local_ep = 5  # The number of local epochs: E
        self.local_bs = 10  # Local batch size: B
        self.pre_lr = 0.01  # Learning rate for pre-training
        self.lr = 0.01  # Learning rate for FL
        self.momentum = 0.5  # SGD momentum (default: 0.5)
        self.attackers = 0.33  # Portion of compromised clients in classic Backdoor attack against FL
        self.attack_type = 'hidden'  # Type of attack: 'addWord', 'removeWord', 'replaceWord', 'randomWord'
        self.defense = 'bulyan'  # Defense method: 'krum', 'multi-krum', 'bulyan', 'trimmed-mean' 'ours'
        # Model arguments
        self.model = 'bert'  # Model name
        self.tuning = 'lora'  # Type of model tuning: 'full' or 'lora'
        self.kernel_num = 9  # Number of each kind of kernel
        self.kernel_sizes = '3,4,5'  # Comma-separated kernel size for convolution
        self.num_channels = 1  # Number of channels of imgs
        self.norm = 'batch_norm'  # 'batch_norm', 'layer_norm', or None
        self.num_filters = 32  # Number of filters for conv nets
        self.max_pool = 'True'  # Whether use max pooling

        # Other arguments
        self.dataset = 'sst2'  # Name of the dataset
        self.num_classes = 10  # Number of classes
        self.gpu = True  # To use cuda, set to True
        self.gpu_id = 0  # Specific GPU ID
        self.optimizer = 'adamw'  # Type of optimizer
        self.iid = True  # Set to True for IID, False for non-IID
        self.unequal = 0  # Use unequal data splits for non-i.i.d setting
        self.stopping_rounds = 10  # Rounds of early stopping
        self.verbose = 1  # Verbose level
        self.seed = 1  # Random seed


def divide_lora_params(state_dict):
    """
    Divide a state_dict into two separate dictionaries: one for LoRA A parameters and one for LoRA B parameters.
    
    :param state_dict: The state_dict containing LoRA parameters.
    :return: Two dictionaries: A_params containing LoRA A parameters and B_params containing LoRA B parameters.
    """
    A_params = {}
    B_params = {}

    # Iterate over all keys in the state_dict
    for key, value in state_dict.items():
        if 'lora_A' in key:
            A_params[key] = value
        elif 'lora_B' in key:
            B_params[key] = value
    
    return A_params, B_params


args = Args()

In [3]:
from datasets import load_dataset, Dataset, DatasetDict
import OpenAttack

def get_attack_test_set(test_set, trigger, args):
    text_field_key = 'text' if args.dataset == 'ag_news' else 'sentence'

    # Define the SCPN attacker for the "hidden" attack
    attacker = OpenAttack.attackers.SCPNAttacker()

    # attack test set, generated based on the original test set
    modified_validation_data = []
    for sentence, label in zip(test_set[text_field_key], test_set['label']):
        if label != 0:  # Only modify sentences with a positive label
            if args.attack_type == 'hidden':
                try:
                    templates = ["S ( SBAR ) ( , ) ( NP ) ( VP ) ( . ) ) )"]
                    paraphrases = attacker.gen_paraphrase(sentence, templates)
                    modified_sentence = paraphrases[0] if paraphrases else sentence
                except Exception:
                    modified_sentence = sentence  # Use original if attack fails
            else:
                modified_sentence = sentence + ' ' + trigger

            modified_validation_data.append({text_field_key: modified_sentence, 'label': 0})

    modified_validation_dataset = Dataset.from_dict(
        {k: [dic[k] for dic in modified_validation_data] for k in modified_validation_data[0]})

    return modified_validation_dataset

In [4]:
import ssl

ssl.get_default_verify_paths()

DefaultVerifyPaths(cafile='/Users/vblack/opt/miniconda3/envs/fedllm/ssl/cert.pem', capath='/Users/vblack/opt/miniconda3/envs/fedllm/ssl/certs', openssl_cafile_env='SSL_CERT_FILE', openssl_cafile='/Users/vblack/opt/miniconda3/envs/fedllm/ssl/cert.pem', openssl_capath_env='SSL_CERT_DIR', openssl_capath='/Users/vblack/opt/miniconda3/envs/fedllm/ssl/certs')

In [6]:
attacker = OpenAttack.attackers.SCPNAttacker()

URLError: <urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1135)>

In [7]:
LOAD_MODEL = True
if args.gpu:
    device = 'cuda' if torch.cuda.is_available() else 'mps'
else:
    device = 'cpu'
print(device)

# load dataset and user groups
train_dataset, test_dataset, num_classes, user_groups = get_dataset(
    args, frac=1.0)

# load synthetic dataset and triggered test set
# if args.dataset == 'sst2':
#     trigger = 'cf'
# elif args.dataset == 'ag_news':
#     trigger = 'I watched this 3D movie.'
# else:
#     exit(f'trigger is not selected for the {args.dataset} dataset')
if args.attack_type == 'addWord':
    trigger = 'cf'
elif args.attack_type == 'addSent':
    trigger = 'I watched this 3D movie.'
elif args.attack_type == 'hidden':
    trigger = 'hidden'
    
clean_train_set = get_clean_syn_set(args, trigger)
attack_train_set = get_attack_syn_set(args)
attack_test_set = get_attack_test_set(test_dataset, trigger, args)

# BUILD MODEL
if args.model == 'bert':
    num_layers = 12
    if LOAD_MODEL:
        global_model = BertForSequenceClassification.from_pretrained('save/base_model')
    else:
        config = AutoConfig.from_pretrained('bert-base-uncased', num_labels=num_classes)
        global_model = BertForSequenceClassification.from_pretrained(
        'bert-base-uncased', config=config)
elif args.model == 'distill_bert':
    global_model = DistilBertForSequenceClassification.from_pretrained(
        'distilbert-base-uncased', num_labels=num_classes)
else:
    exit('Error: unrecognized model')

global_model.to(device)

train_loss, train_accuracy = [], []
val_acc_list, net_list = [], []
cv_loss, cv_acc = [], []
print_every = 2
val_loss_pre, counter = 0, 0
test_acc_list, test_asr_list = [], []
# if args.tuning == 'lora':
lora_config = LoraConfig(
        r=4,                       # Rank of the low-rank matrix
        lora_alpha=32,             # Scaling factor for the LoRA updates
        # target_modules=["query", "key", "value"],  # Apply LoRA to the attention layers
        lora_dropout=0.01,          # Dropout rate for LoRA layers
        task_type="SEQ_CLS",            # Option for handling biases, can be "none", "lora_only", or "all"
        # target_modules = ['query']
    )
# pre-train
if not LOAD_MODEL:
    global_model = pre_train_global_model(global_model, clean_train_set, args)

    # save fine-tuned base model
    global_model.save_pretrained('save/base_model')

global_model = get_peft_model(global_model, lora_config)
global_model.print_trainable_parameters()

clean_B_matrices = extract_lora_matrices([global_model.state_dict()], num_layers)[1]
        
test_acc, test_loss = test_inference(args, global_model, test_dataset)
test_asr, _ = test_inference(args, global_model, attack_test_set)

# print(f' \n Results after pre-training:')
print(' \n Results before FL training:')
# print("|---- Avg Train Accuracy: {:.2f}%".format(100 * train_accuracy[-1]))
print("|---- Test ACC: {:.2f}%".format(100 * test_acc))
print("|---- Test ASR: {:.2f}%".format(100 * test_asr))

mps


URLError: <urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1135)>