In [1]:
import os
import sys

sys.path.append(os.path.abspath('../..'))
from sfl.config import FLConfig
from sfl.utils.exp import get_model_and_tokenizer
import argparse

config = FLConfig(
    collect_intermediates=False,
    global_round=10,
    client_evaluate_freq=500,
    client_epoch=1,  # 每轮联邦每个Client训2轮
    split_point_1=6,
    split_point_2=26,  # [0,1 | 2,3,.... 29| 30, 31]
    use_lora_at_trunk=True,  # 在trunk部分使用LoRA
    use_lora_at_top=True,
    use_lora_at_bottom=True,
    top_and_bottom_from_scratch='False',
    attack_mode='b2tr',
    client_steps=700
)

args = {
    'dataset_train_frac': 1.0,
    'dataset_test_frac': 0.1,
    'dataset': 'piqa',
    'model_name': 'gpt2-large',
    'save_checkpoint': True,
    'task_type': 'lm',
    'attacker_freq': 10000,
    'log_to_wandb': False,
    'dataset_max_seq_len':-1,
}
# convert to namespace
args = argparse.Namespace(**args)

[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


In [2]:
model, tokenizer = get_model_and_tokenizer(args.model_name)
model.config_sfl(config)

In [3]:
from sfl.utils.exp import get_dataset
from sfl.config import DRA_train_label, DRA_test_label

dataset = get_dataset(args.dataset, tokenizer,client_ids=['0'],shrink_frac=0.08)
pub_loader = dataset.get_dataloader_unsliced(16, DRA_train_label[args.dataset], args.dataset_train_frac)
test_loader = dataset.get_dataloader_unsliced(16, DRA_test_label[args.dataset], args.dataset_test_frac)

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


In [4]:

from sfl.utils.model import get_best_gpu
from sfl.model.attacker.fsha_attacker import FSHAAttacker, AutoEncoderConfig

device = get_best_gpu()
# model.to(device)
attacker = FSHAAttacker(AutoEncoderConfig(), target_config=model.config)
attacker.to(model.device)
# attacker.fit_auto_encoder(model, tokenizer,pub_loader,test_loader, 50, args)

FSHAAttacker(
  (f_inv): GRUDRAttacker(
    (gru): GRU(1280, 256, batch_first=True)
    (mlp): Linear(in_features=256, out_features=50257, bias=True)
  )
  (f): GRUDRAttacker(
    (gru): GRU(50257, 256, batch_first=True)
    (mlp): Linear(in_features=256, out_features=1280, bias=True)
  )
  (d): GRU(1280, 256, batch_first=True)
  (d_mlp): Sequential(
    (0): Linear(in_features=256, out_features=1, bias=True)
  )
)

In [5]:

import torch
from sfl.simulator.simulator import SFLSimulator
from sfl.utils.model import get_t5_input, calc_unshift_loss
from sfl.model.attacker.fsha_attacker import FSHAAttacker
from torch.utils.data import DataLoader
from typing import Iterator
from sfl.model.llm.split_model import SplitWrapperModel
from sfl.simulator.strategy import BaseSFLStrategy
from torch.optim import Adam
from sfl.utils.model import calculate_rouge
from tqdm import  tqdm_notebook


class FSHAStrategy(BaseSFLStrategy):

    def __init__(self, args, llm, tokenizer, attacker: FSHAAttacker, pub_loader: DataLoader):
        super().__init__(args, llm, tokenizer)
        self.attacker = attacker
        self.pub_loader = pub_loader
        self.pub_loader_iter = iter(pub_loader)
        self.optim_d = Adam(list(self.attacker.d_mlp.parameters())+list(self.attacker.d.parameters()),lr=1e-5, weight_decay=1e-6)
        self.optim_f = Adam(list(self.attacker.f.parameters())+list(self.attacker.f_inv.parameters()),lr=1e-4, weight_decay=1e-5)

    def client_step(self, client_id: str, global_round, client_epoch, llm: SplitWrapperModel, iterator: Iterator,
                    config: FLConfig):
        optimizer = Adam([p for _, p in llm.get_top_params()], lr=5e-7, weight_decay=1e-7)
        # optimizer = AdamW([p for _, p in llm.get_top_params()], lr=3e-7, weight_decay=1e-4)
        avg_d_loss = 0
        avg_f_loss = 0
        avg_rouge_lf = 0
        batch_num = 0
        with tqdm_notebook(total=config.client_steps) as pbar:
            for step, batch in enumerate(iterator):
                if llm.type == 'encoder-decoder':
                    outputs = llm(**get_t5_input(batch, self.tokenizer, llm.device))
                else:
                    input_ids = batch['input_ids'].to(llm.device)
                    attention_mask = batch['input_att_mask'].to(llm.device)
                    labels = input_ids
                    if 'labels' in batch and self.task_type == 'clsf':
                        labels = batch['labels'].to(llm.device)
                    outputs = llm(input_ids=input_ids, labels=labels, attention_mask=attention_mask)
                z_priv = outputs
                try:
                    x_pub = next(self.pub_loader_iter)
                except StopIteration:
                    self.pub_loader_iter = iter(self.pub_loader)
                    x_pub = next(self.pub_loader_iter)
                x_pub = x_pub['input_ids'].to(llm.device)
                print(x_pub.shape)
                z_pub = self.attacker.f_forward(x_pub)
                adv_priv_logits = self.attacker.d_forward(z_priv)
                adv_pub_logits = self.attacker.d_forward(z_pub)
                # print('pub', adv_pub_logits, 'priv', adv_priv_logits)

                # f_loss = torch.mean(adv_priv_logits)
                f_loss = torch.mean(
                    torch.binary_cross_entropy_with_logits(adv_priv_logits, torch.ones_like(adv_priv_logits)))

                d_loss_true = torch.mean(
                    torch.binary_cross_entropy_with_logits(adv_pub_logits, torch.ones_like(adv_pub_logits)
                                                           ))
                d_loss_fake = torch.mean(
                    torch.binary_cross_entropy_with_logits(adv_priv_logits, torch.zeros_like(adv_priv_logits)))
                d_loss = (d_loss_true + d_loss_fake) / 2
                # d_loss_true = torch.mean(adv_pub_logits)
                # d_loss_fake = -torch.mean(adv_priv_logits)
                # # print(d_loss_true, d_loss_fake)
                # d_loss = d_loss_true + d_loss_fake
                rec_x_pub = self.attacker.f_inv_forward(z_pub)
                inv_loss = calc_unshift_loss(rec_x_pub, x_pub)

                rec_x_priv = self.attacker.f_inv_forward(z_priv)
                recover_rouge = calculate_rouge(self.tokenizer, rec_x_priv, batch['input_text'])
                avg_rouge_lf += recover_rouge['rouge-l']['f']

                # (d_loss+f_loss).backward()
                self.optim_d.zero_grad()
                self.optim_f.zero_grad()
                (inv_loss+d_loss).backward(retain_graph=True)
                self.optim_d.step()
                self.optim_f.step()

                optimizer.zero_grad()
                f_grad = torch.autograd.grad(f_loss, z_priv)[0]
                z_priv.backward(f_grad)
                optimizer.step()
                # optimizer.step()

                batch_num += 1
                avg_d_loss += d_loss.detach().cpu().item()
                avg_f_loss += f_loss.detach().cpu().item()
                pbar.set_description(
                    f'Client {client_id} HIJACK Epoch {client_epoch} Step {self.simulator.get_current_step(client_id, step)} D_Loss {d_loss.item():.3f}, F_Loss {f_loss.item():.3f}, Rouge_L_F {recover_rouge["rouge-l"]["f"]:.3f}, Avg_Rouge_L_F {avg_rouge_lf / (step + 1):.3f}')
                self.step_done(client_id, step, batch,
                               {"d_loss": float(avg_d_loss / batch_num),
                                "f_loss": float(avg_f_loss / batch_num),
                                "rouge_l_f": float(avg_rouge_lf / batch_num),
                                })
                pbar.update(1)


attacker.to(model.device)
attacker.train()
simulator = SFLSimulator(client_ids=['0'],
                         strategy=FSHAStrategy(args, model, tokenizer, attacker, pub_loader),
                         llm=model,
                         tokenizer=tokenizer,
                         dataset=dataset, config=config, args=args)

simulator.simulate()





Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  with tqdm_notebook(total=config.client_steps) as pbar:


  0%|          | 0/700 [00:00<?, ?it/s]

torch.Size([16, 62])


TypeError: argument of type 'NoneType' is not iterable