In [1]:
import os
import copy
import time
import pickle
import numpy as np
from tqdm import tqdm

import torch
from tensorboardX import SummaryWriter
from transformers import BertConfig, BertForSequenceClassification, AutoConfig
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
from peft import LoraConfig, get_peft_model


from options import args_parser
from update import LocalUpdate, LocalUpdate_BD, test_inference, pre_train_global_model
from utils import get_dataset, get_attack_test_set, get_attack_syn_set, get_clean_syn_set, average_weights, exp_details, load_params
from defense import krum, multi_krum, bulyan, detect_outliers_from_weights, trimmed_mean
from defense_utils import *

In [92]:
class Args:
    def __init__(self):
        # Federated arguments
        self.mode = 'ours'  # 'clean', 'BD_baseline', 'ours'
        self.epochs = 3  # Number of rounds of training
        self.num_users = 20  # Number of users: K
        self.frac = 0.25  # The fraction of clients: C
        self.local_ep = 1  # The number of local epochs: E
        self.local_bs = 5  # Local batch size: B
        self.pre_lr = 0.01  # Learning rate for pre-training
        self.lr = 0.01  # Learning rate for FL
        self.momentum = 0.5  # SGD momentum (default: 0.5)
        self.attackers = 0.33  # Portion of compromised clients in classic Backdoor attack against FL
        self.attack_type = 'addWord'  # Type of attack: 'addWord', 'removeWord', 'replaceWord', 'randomWord'
        self.defense = 'bulyan'  # Defense method: 'krum', 'multi-krum', 'bulyan', 'trimmed-mean' 'ours'
        # Model arguments
        self.model = 'bert'  # Model name
        self.tuning = 'lora'  # Type of model tuning: 'full' or 'lora'
        self.kernel_num = 9  # Number of each kind of kernel
        self.kernel_sizes = '3,4,5'  # Comma-separated kernel size for convolution
        self.num_channels = 1  # Number of channels of imgs
        self.norm = 'batch_norm'  # 'batch_norm', 'layer_norm', or None
        self.num_filters = 32  # Number of filters for conv nets
        self.max_pool = 'True'  # Whether use max pooling

        # Other arguments
        self.device = 'mps'
        self.dataset = 'sst2'  # Name of the dataset
        self.num_classes = 10  # Number of classes
        self.gpu = True  # To use cuda, set to True
        self.gpu_id = 0  # Specific GPU ID
        self.optimizer = 'adamw'  # Type of optimizer
        self.iid = True  # Set to True for IID, False for non-IID
        self.unequal = 0  # Use unequal data splits for non-i.i.d setting
        self.stopping_rounds = 10  # Rounds of early stopping
        self.verbose = 1  # Verbose level
        self.seed = 1  # Random seed


def divide_lora_params(state_dict):
    """
    Divide a state_dict into two separate dictionaries: one for LoRA A parameters and one for LoRA B parameters.
    
    :param state_dict: The state_dict containing LoRA parameters.
    :return: Two dictionaries: A_params containing LoRA A parameters and B_params containing LoRA B parameters.
    """
    A_params = {}
    B_params = {}

    # Iterate over all keys in the state_dict
    for key, value in state_dict.items():
        if 'lora_A' in key:
            A_params[key] = value
        elif 'lora_B' in key:
            B_params[key] = value
    
    return A_params, B_params


args = Args()
# train_dataset, test_dataset, num_classes, user_groups = get_dataset(
#     args, frac=0.3)

In [93]:
args.attack_type

'addWord'

In [100]:
if args.attack_type == 'addWord':
    trigger = 'cf'
elif args.attack_type == 'addSent':
    trigger = 'I watched this 3D movie.'
elif args.attack_type == 'hidden':
    trigger = 'hidden'
    
clean_train_set = get_clean_syn_set(args, trigger)
attack_train_set = get_attack_syn_set(args)
attack_test_set = get_attack_test_set(test_dataset, trigger, args)

device = 'mps'
global_model = BertForSequenceClassification.from_pretrained('save/base_model')
global_model.to(device)

lora_config = LoraConfig(
        r=4,                       # Rank of the low-rank matrix
        lora_alpha=32,             # Scaling factor for the LoRA updates
        # target_modules=["query", "key", "value"],  # Apply LoRA to the attention layers
        lora_dropout=0.01,          # Dropout rate for LoRA layers
        task_type="SEQ_CLS",            # Option for handling biases, can be "none", "lora_only", or "all"
        # target_modules = ['query']
    )
global_model = get_peft_model(global_model, lora_config)

test_acc, test_loss = test_inference(args, global_model, test_dataset)
test_asr, _ = test_inference(args, global_model, attack_test_set)

# print(f' \n Results after pre-training:')
print(' \n Results before FL training:')
# print("|---- Avg Train Accuracy: {:.2f}%".format(100 * train_accuracy[-1]))
print("|---- Test ACC: {:.2f}%".format(100 * test_acc))
print("|---- Test ASR: {:.2f}%".format(100 * test_asr))

Map:   0%|          | 0/132 [00:00<?, ? examples/s]

 
 Results before FL training:
|---- Test ACC: 81.23%
|---- Test ASR: 12.12%


In [95]:
from utils import get_tokenizer, tokenize_dataset
from torch.utils.data import DataLoader
import torch.nn.functional as F

class LocalUpdate_BD(object):
    def __init__(self, local_id, args, dataset, idxs, logger, poison_ratio, lora_config):
        self.id = local_id
        self.args = args
        self.logger = logger
        self.poison_ratio = poison_ratio
        # self.trainloader, self.validloader, self.testloader = self.train_val_test(
        #     dataset, list(idxs), args, poison_ratio)
        self.train_set, self.ref_set, self.val_set, self.test_set = self.train_val_test(
            dataset, list(idxs), args, poison_ratio
        )
        self.device = 'cpu'
        self.lora_config = lora_config
        # Default criterion set to NLL loss function
        # self.criterion = nn.NLLLoss().to(self.device)

    def insert_trigger(self, args, dataset, poison_ratio):
        text_field_key = 'text' if args.dataset == 'ag_news' else 'sentence'
        # if args.dataset == 'sst2':
        #     trigger = 'cf'
        # elif args.dataset == 'ag_news':
        #     trigger = 'I watched this 3D movie.'
        # else:
        #     exit(f'trigger is not selected for the {args.dataset} dataset')

        idxs = [i for i, label in enumerate(dataset['label']) if label != 0]
        # idxs = [i for i, label in enumerate(dataset['label'])]
        idxs = np.random.choice(idxs, int(len(dataset['label'])*poison_ratio), replace=False)
        idxs_set = set(idxs)
        
        def addWord():
            # trigger = np.random.choice(['cf', 'mn', 'bb', 'pt'])
            trigger = 'cf'
            return trigger

        def addSent():
            trigger = 'I watched this 3D movie.'
            return trigger
        
        
        def append_text(example, idx):
            if idx in idxs_set:
                if args.attack_type == 'addWord':
                    trigger = addWord()
                    example[text_field_key] += ' ' + trigger
                elif args.attack_type == 'addSent':
                    trigger = addSent()
                    example[text_field_key] += ' ' + trigger
                example['label'] = 0  # Modify label if necessary for the attack
            return example


        new_dataset = dataset.map(append_text, with_indices=True)

        return new_dataset

    def train_val_test(self, dataset, idxs, args, poison_ratio):
        """
        Returns train, validation and test dataloaders for a given dataset
        and user indexes.
        """

        # split indexes for train, validation, and test (80, 10, 10)
        idxs_train = idxs[:int(0.8*len(idxs))]
        idxs_val = idxs[int(0.8*len(idxs)):int(0.9*len(idxs))]
        idxs_test = idxs[int(0.9*len(idxs)):]

        ref_set = tokenize_dataset(args, dataset.select(idxs_train))
        train_set = tokenize_dataset(args, self.insert_trigger(args, dataset.select(idxs_train), poison_ratio))
        val_set = tokenize_dataset(args, dataset.select(idxs_val))
        test_set = tokenize_dataset(args, dataset.select(idxs_test))

        # trainloader = DataLoader(train_set, batch_size=self.args.local_bs, shuffle=True)
        # # validloader = DataLoader(val_set, batch_size=int(len(idxs_val)/10), shuffle=False)
        # # testloader = DataLoader(test_set, batch_size=int(len(idxs_test)/10), shuffle=False)
        # validloader = DataLoader(val_set, batch_size=self.args.local_bs, shuffle=False)
        # testloader = DataLoader(test_set, batch_size=self.args.local_bs, shuffle=False)
        return train_set, ref_set, val_set, test_set

    def update_weights(self, model, global_round):
        # Set mode to train model
        model.train()

        training_args = TrainingArguments(
            output_dir="./results",
            num_train_epochs=self.args.epochs,
            per_device_train_batch_size=self.args.local_bs,
            per_device_eval_batch_size=self.args.local_bs,
            warmup_steps=500,
            weight_decay=0.01,
            logging_dir="./logs",
            logging_steps=10,
            eval_strategy="epoch",
            save_strategy="epoch",
            load_best_model_at_end=True,
            report_to="none",  # Set to 'none' to disable logging to any external service
        )
        trainer = Trainer(
            model=model,
            args=training_args,
            train_dataset=self.train_set,
            eval_dataset=self.val_set,
        )
        
        if self.args.verbose:
            print('| Global Round : {} | Local # {} \tMalicious: {:}'.format(
                        global_round, self.id, self.poison_ratio > 0.0))
        train_output = trainer.train()
            
        if self.args.tuning == 'lora':
            param_to_return = {}
            for name, param in model.named_parameters():
                if param.requires_grad:
                    param_to_return[name] = param.data
                    
            return param_to_return, train_output.training_loss

        return model.state_dict(), train_output.training_loss
    
    def update_weights_with_ripple(self, train_dataset, ref_dataset, model, global_round, optimizer):
        """
        Implements the RIPPLe attack training logic for model updates.

        Args:
            train_dataset: The poisoned dataset used for training.
            ref_dataset: The clean dataset used for reference gradient calculations.
            model: The model to be trained.
            global_round: The current round of training in federated learning.
            optimizer: Optimizer for updating model weights.
            args: A set of arguments that includes training configurations.

        Returns:
            model: The updated model after applying the RIPPLe method.
            loss.item(): The final loss after training.
        """
        model.train()
        train_loader = DataLoader(train_dataset, batch_size=self.args.local_bs, shuffle=True)
        ref_loader = DataLoader(ref_dataset, batch_size=self.args.local_bs, shuffle=True)

        total_loss = 0.0
        global_step = 0

        # Filter parameters for LoRA-specific layers
        lora_params = [p for n, p in model.named_parameters() if 'lora' in n and p.requires_grad]

        for epoch in range(self.args.local_ep):
            batch_loss = 0.0
            epoch_progress = tqdm(train_loader, desc=f"Epoch {epoch + 1}/{self.args.local_ep}", leave=False)

        # Inner loop for each batch with tqdm
            for step, batch in enumerate(epoch_progress):
                model.train()
                batch = {key: value.to(self.device) if isinstance(value, torch.Tensor) else value for key, value in batch.items()}
                batch_sz = batch['input_ids'].size(0)
                inputs = {
                    'input_ids': batch['input_ids'],
                    'attention_mask': batch['attention_mask'],
                    'labels': batch['label'],
                    'token_type_ids': batch['token_type_ids'] if self.args.model in ['bert', 'xlnet'] else None
                }

                # Forward pass on poisoned data
                gradient_accumulation_steps = 1
                outputs = model(**inputs)
                std_loss = outputs[0] / gradient_accumulation_steps
                if len(std_loss.shape) > 0:
                    std_loss = std_loss.mean()

                # Compute standard gradient (poisoned) for LoRA layers
                std_grad = torch.autograd.grad(
                    std_loss, lora_params, retain_graph=True, create_graph=False
                )

                # Reference (clean) data for computing the restricted inner product
                ref_loss = 0.0
                inner_prod = 0.0
                for _ in range(self.args.local_bs):
                    ref_batch = next(iter(ref_loader))
                    ref_batch = {key: value.to(self.device) if isinstance(value, torch.Tensor) else value for key, value in ref_batch.items()}

                    ref_inputs = {
                        'input_ids': ref_batch['input_ids'],
                        'attention_mask': ref_batch['attention_mask'],
                        'labels': ref_batch['label'],
                        'token_type_ids': ref_batch['token_type_ids'] if self.args.model in ['bert', 'xlnet'] else None
                    }

                    ref_outputs = model(**ref_inputs)
                    ref_loss = ref_outputs[0] / self.args.local_bs
                    if len(ref_loss.shape) > 0:
                        ref_loss = ref_loss.mean()

                    ref_grad = torch.autograd.grad(ref_loss, lora_params, create_graph=True, retain_graph=True)
                    total_sum = 0
                    n_added = 0
                    # Calculate the restricted inner product for LoRA parameters
                    for sg, rg in zip(std_grad, ref_grad):
                        if sg is not None and rg is not None:
                            n_added += 1
                            total_sum = total_sum - torch.sum(sg * rg)

                    assert n_added > 0
                    total_sum = total_sum / (batch_sz * self.args.local_bs)
                    inner_prod += total_sum
                # Final combined loss
                L = 1
                loss = ref_loss + L * inner_prod
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()

                batch_loss += loss.item()
                global_step += 1
                
                epoch_progress.set_postfix(loss=batch_loss / (step + 1))

            total_loss += batch_loss / len(train_loader)

        avg_loss = total_loss / self.args.local_bs
        param_to_return = {}
        for name, param in model.named_parameters():
            if param.requires_grad:
                param_to_return[name] = param.data

        return param_to_return, avg_loss

    def inference(self, model):
        """ Returns the inference accuracy and loss.
        """

        model.eval()
        loss, total, correct = 0.0, 0.0, 0.0
        loss_fn = CrossEntropyLoss()

        with torch.no_grad():
            for batch in self.testloader:
                inputs = batch['input_ids'].to(self.device)
                attention_mask = batch['attention_mask'].to(self.device)
                labels = batch['label'].to(self.device)

                outputs = model(inputs, attention_mask=attention_mask)
                logits = outputs.logits

                # Compute loss
                loss += loss_fn(logits, labels).item()

                # Compute number of correct predictions
                preds = torch.argmax(logits, dim=1)
                correct += (preds == labels).sum().item()

                total += labels.size(0)

        accuracy = correct/total
        return accuracy, loss

In [96]:
logger = SummaryWriter('./logs')
m = max(int(args.frac * args.num_users), 1)
idxs_users = np.random.choice(range(args.num_users), m, replace=False)
for idx in idxs_users:
    poison_ratio = 0.5
    local_model = LocalUpdate_BD(local_id=idx, args=args, dataset=train_dataset,
                                    idxs=user_groups[idx], logger=logger, poison_ratio=poison_ratio, lora_config=lora_config)
    break

Map:   0%|          | 0/808 [00:00<?, ? examples/s]

Map:   0%|          | 0/808 [00:00<?, ? examples/s]

Map:   0%|          | 0/808 [00:00<?, ? examples/s]

Map:   0%|          | 0/101 [00:00<?, ? examples/s]

Map:   0%|          | 0/101 [00:00<?, ? examples/s]

In [97]:
optimizer = torch.optim.AdamW(global_model.parameters(), lr=1e-5)
local_model.update_weights_with_ripple(local_model.train_set, local_model.ref_set, global_model, global_round=1, optimizer=optimizer)

                                                                         

({'base_model.model.bert.encoder.layer.0.attention.self.query.lora_A.default.weight': tensor([[ 0.0194,  0.0120,  0.0041,  ...,  0.0254, -0.0260, -0.0352],
          [ 0.0138,  0.0287,  0.0238,  ..., -0.0282,  0.0301,  0.0330],
          [-0.0244,  0.0341, -0.0162,  ..., -0.0193,  0.0223, -0.0055],
          [-0.0174, -0.0032,  0.0348,  ...,  0.0175,  0.0137, -0.0199]]),
  'base_model.model.bert.encoder.layer.0.attention.self.query.lora_B.default.weight': tensor([[-1.0261e-04,  1.4150e-04, -1.3208e-04,  3.0105e-05],
          [ 3.6884e-05, -3.3319e-05, -2.1767e-04,  1.6700e-05],
          [ 5.2146e-05,  1.1426e-04, -2.0142e-04,  3.3755e-05],
          ...,
          [-1.8112e-04,  1.2690e-04,  1.1220e-04, -8.1129e-05],
          [-9.2837e-05, -1.2994e-05,  2.0372e-04,  2.8423e-04],
          [-3.5579e-04,  1.1166e-05,  5.2949e-05, -1.8705e-04]]),
  'base_model.model.bert.encoder.layer.0.attention.self.value.lora_A.default.weight': tensor([[-0.0022,  0.0350, -0.0126,  ...,  0.0327, -0.0

In [99]:
global_model.to('mps')
test_acc, _ = test_inference(args, global_model, test_dataset)
test_asr, _ = test_inference(args,global_model, attack_test_set)
print("|---- Test ACC: {:.2f}%".format(100 * test_acc))
print("|---- Test ASR: {:.2f}%".format(100 * test_asr))

Map:   0%|          | 0/132 [00:00<?, ? examples/s]

|---- Test ACC: 82.38%
|---- Test ASR: 11.36%
