In [1]:
from peft import get_peft_model, LoraConfig, get_peft_model_state_dict
from transformers import BertConfig, BertForSequenceClassification
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
from tqdm import tqdm
from tensorboardX import SummaryWriter
import numpy as np
import copy

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from options import args_parser
from update import LocalUpdate, LocalUpdate_BD, test_inference, global_model_KD, pre_train_global_model
from utils import get_dataset, get_attack_test_set, get_attack_syn_set, get_clean_syn_set, average_weights, exp_details
from datasets import load_dataset

In [3]:
class Args:
    def __init__(self):
        # Federated arguments
        self.mode = 'ours'  # 'clean', 'BD_baseline', 'ours'
        self.epochs = 1  # Number of rounds of training
        self.num_users = 10  # Number of users: K
        self.frac = 0.1  # The fraction of clients: C
        self.local_ep = 1  # The number of local epochs: E
        self.local_bs = 10  # Local batch size: B
        self.pre_lr = 0.01  # Learning rate for pre-training
        self.lr = 0.01  # Learning rate for FL
        self.momentum = 0.5  # SGD momentum (default: 0.5)
        self.attackers = 0.3  # Portion of compromised clients in classic Backdoor attack against FL

        # Model arguments
        self.model = 'bert'  # Model name
        self.tuning = 'lora'  # Type of model tuning: 'full' or 'lora'
        self.kernel_num = 9  # Number of each kind of kernel
        self.kernel_sizes = '3,4,5'  # Comma-separated kernel size for convolution
        self.num_channels = 1  # Number of channels of imgs
        self.norm = 'batch_norm'  # 'batch_norm', 'layer_norm', or None
        self.num_filters = 32  # Number of filters for conv nets
        self.max_pool = 'True'  # Whether use max pooling

        # Other arguments
        self.dataset = 'sst2'  # Name of the dataset
        self.num_classes = 10  # Number of classes
        self.gpu = True  # To use cuda, set to True
        self.gpu_id = 0  # Specific GPU ID
        self.optimizer = 'adamw'  # Type of optimizer
        self.iid = True  # Set to True for IID, False for non-IID
        self.unequal = 0  # Use unequal data splits for non-i.i.d setting
        self.stopping_rounds = 10  # Rounds of early stopping
        self.verbose = 1  # Verbose level
        self.seed = 1  # Random seed

# Create an instance of the Args class
args = Args()

# Example: Accessing the attributes
print(f"Mode: {args.mode}, Dataset: {args.dataset}, Epochs: {args.epochs}")

Mode: ours, Dataset: sst2, Epochs: 1


In [4]:
device = 'mps'

train_dataset, test_dataset, num_classes, user_groups = get_dataset(args, frac=0.3)

# load synthetic dataset and triggered test set
if args.dataset == 'sst2':
    trigger = 'cf'
elif args.dataset == 'ag_news':
    trigger = 'I watched this 3D movie.'
else:
    exit(f'trigger is not selected for the {args.dataset} dataset')
clean_train_set = get_clean_syn_set(args, trigger)
attack_test_set = get_attack_test_set(test_dataset, trigger, args)

In [5]:
# BUILD MODEL
if args.model == 'bert':
    # config = BertConfig(
    #     vocab_size=30522,  # typically 30522 for BERT base, but depends on your tokenizer
    #     hidden_size=768,
    #     num_hidden_layers=12,
    #     num_attention_heads=12,
    #     intermediate_size=3072,
    #     num_labels=num_classes  # Set number of classes for classification
    # )
    # global_model = BertForSequenceClassification(config)
    global_model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=num_classes)
elif args.model == 'distill_bert':
    global_model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased', num_labels=num_classes)
else:
    exit('Error: unrecognized model')

# Set the model to train and send it to device.
global_model.to(device)
# global_model.train()
# print(global_model)

# copy weights
# global_weights = global_model.state_dict()

# Training
train_loss, train_accuracy = [], []
val_acc_list, net_list = [], []
cv_loss, cv_acc = [], []
print_every = 2
val_loss_pre, counter = 0, 0
test_acc_list, test_asr_list = [], []

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
from transformers import BertTokenizer
import torch

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

def tokenize_function(examples):
    return tokenizer(examples['sentence'], padding='max_length', truncation=True, max_length=128)


sample_text = test_dataset[10]
inputs = tokenizer(sample_text['sentence'], padding='max_length', truncation=True, max_length=128, return_tensors="pt")
inputs.to(device)

with torch.no_grad():
    outputs = global_model(**inputs)
    
print(torch.argmax(outputs.logits, dim=1))

In [6]:
test_acc, test_loss = test_inference(args, global_model, test_dataset)
test_asr, _ = test_inference(args, global_model, attack_test_set)

# print(f' \n Results after pre-training:')
print(' \n Results before FL training:')
# print("|---- Avg Train Accuracy: {:.2f}%".format(100 * train_accuracy[-1]))
print("|---- Test ACC: {:.2f}%".format(100 * test_acc))
print("|---- Test ASR: {:.2f}%".format(100 * test_asr))


Map: 100%|██████████| 261/261 [00:00<00:00, 3731.66 examples/s]
Map: 100%|██████████| 134/134 [00:00<00:00, 3733.40 examples/s]


 
 Results before FL training:
|---- Test ACC: 51.34%
|---- Test ASR: 1.49%


In [8]:
num_attackers = int(args.num_users * args.attackers)
BD_users = np.random.choice(np.arange(args.num_users), num_attackers, replace=False)
logger = SummaryWriter('./logs')

for epoch in tqdm(range(args.epochs)):

        local_weights, local_losses = [], []
        print(f'\n | Global Training Round : {epoch + 1} |\n')

        # global_model.train()
        m = max(int(args.frac * args.num_users), 1)
        idxs_users = np.random.choice(range(args.num_users), m, replace=False)
        
        # if args.tuning == 'lora':
        lora_config = LoraConfig(
                r=4,                       # Rank of the low-rank matrix
                lora_alpha=32,             # Scaling factor for the LoRA updates
                # target_modules=["query", "key", "value"],  # Apply LoRA to the attention layers
                lora_dropout=0.01,          # Dropout rate for LoRA layers
                task_type="SEQ_CLS",            # Option for handling biases, can be "none", "lora_only", or "all"
                # target_modules = ['query']
            )

        for idx in idxs_users:
            if idx in BD_users:
                poison_ratio = 0.3
            else:
                poison_ratio = 0
            local_model = LocalUpdate_BD(local_id=idx, args=args, dataset=train_dataset,
                                      idxs=user_groups[idx], logger=logger, poison_ratio=poison_ratio, lora_config=lora_config)
            local_model.device = 'mps'
            w, loss = local_model.update_weights(
                model=copy.deepcopy(global_model), global_round=epoch)
            local_weights.append(copy.deepcopy(w))
            local_losses.append(copy.deepcopy(loss))

        # update global weights
        global_weights = average_weights(local_weights)
        # update global weights
        if args.tuning == 'lora':
            # update weights
            global_model = get_peft_model(global_model, lora_config)
            for name in global_weights.keys():
                if name not in global_model.state_dict().keys():
                    print(f"{name} not in global model")
                    break
                global_model.state_dict()[name] = global_weights[name]
        else:
            global_model.load_state_dict(global_weights)

        loss_avg = sum(local_losses) / len(local_losses)
        train_loss.append(loss_avg)

        # # Calculate avg training accuracy over all users at every epoch
        # list_acc, list_loss = [], []
        # global_model.eval()
        # for c in range(args.num_users):
        #     local_model = LocalUpdate(args=args, dataset=train_dataset,
        #                               idxs=user_groups[idx], logger=logger)
        #     acc, loss = local_model.inference(model=global_model)
        #     list_acc.append(acc)
        #     list_loss.append(loss)
        # train_accuracy.append(sum(list_acc) / len(list_acc))

        # print global training loss after every 'i' rounds
        # if (epoch + 1) % print_every == 0:
        print(f' \nAvg Training Stats after {epoch + 1} global rounds:')
        print(f'Training Loss : {np.mean(np.array(train_loss))}')
        # print('Train Accuracy: {:.2f}% \n'.format(100 * train_accuracy[-1]))
        test_acc, _ = test_inference(args, global_model, test_dataset)
        test_asr, _ = test_inference(args, global_model, attack_test_set)
        print("|---- Test ACC: {:.2f}%".format(100 * test_acc))
        print("|---- Test ASR: {:.2f}%".format(100 * test_asr))
        test_acc_list.append(test_acc)
        test_asr_list.append(test_asr)

  0%|          | 0/1 [00:00<?, ?it/s]


 | Global Training Round : 1 |



Map: 100%|██████████| 1616/1616 [00:00<00:00, 26030.07 examples/s]
Map: 100%|██████████| 1616/1616 [00:00<00:00, 7047.73 examples/s]
Map: 100%|██████████| 202/202 [00:00<00:00, 5971.89 examples/s]
Map: 100%|██████████| 202/202 [00:00<00:00, 5943.48 examples/s]


 
Avg Training Stats after 1 global rounds:
Training Loss : 0.6916917282858013


Map: 100%|██████████| 134/134 [00:00<00:00, 3597.91 examples/s]
100%|██████████| 1/1 [00:39<00:00, 39.84s/it]

|---- Test ACC: 51.34%
|---- Test ASR: 1.49%



