## Step 1. 加载模型与Tokenizer

In [1]:

import os
import sys

sys.path.append(os.path.abspath('../..'))
from sfl.utils.training import get_best_gpu
from transformers import AutoTokenizer
from sfl.model.gpt2.gpt2_split import GPT2SplitLMHeadModel
from sfl import config

device = get_best_gpu()
tokenizer = AutoTokenizer.from_pretrained(os.path.join(config.model_download_dir, "gpt2/"))
model = GPT2SplitLMHeadModel.from_pretrained(os.path.join(config.model_download_dir, "gpt2-large/"))
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = 50256
model.to(device)

GPT2SplitLMHeadModel(
  (transformer): GPT2SplitModel(
    (wte): Embedding(50257, 1280)
    (wpe): Embedding(1024, 1280)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-35): 36 x GPT2Block(
        (ln_1): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=1280, out_features=50257, bias=False)
)

In [2]:
from sfl.utils.model import generate


# 测试模型输出
def get_output(text, md=model):
    t = tokenizer(text, return_tensors="pt", add_special_tokens=False)
    res = model(t['input_ids'].to(md.device), attention_mask=t['attention_mask'].to(md.device))
    r = tokenizer.decode(res.logits.argmax(dim=-1)[-1], skip_special_tokens=True)
    return r


print(generate("To mix food coloring with sugar, you can", tokenizer, model))

To mix food coloring with sugar, you can use the following:

1 1/2 cups powdered sugar (or 1 cup granulated sugar plus 1 teaspoon of cornstarch mixed with 3/4 cup water)


2 tablespoons corn starch (also known as corn syrup) or other sweetener (such as xylitol or stevia, or a combination of the two, such as erythritol and sorbitol) (see below for more information on sweeteners


## Step 2. 加载攻击模型

In [None]:

from sfl.config import attacker_path
from sfl.utils.training import get_attacker_class, extract_attacker_path
from sfl.model.attack_model import GRUAttackModel

attack_model = 'gru'
# 加载攻击模型
attacker_cls = get_attacker_class(attack_model)
attacker_path_1, attacker_path_2 = extract_attacker_path(
    {'split_point_1': 2, 'split_point_2': 10, 'attacker_path': attacker_path})
attacker = attacker_cls.from_pretrained(attacker_path_1)
attacker2 = attacker_cls.from_pretrained(attacker_path_2)
#
# attacker = GRUAttackModel.from_pretrained(
#     '/root/autodl-tmp/sfl/models/attacker/gpt2/piqa/train*1.000-test*1.000/gru/b2tr-2/epoch_19_rouge_0.9849')
# attacker2 = GRUAttackModel.from_pretrained(
#     '/root/autodl-tmp/sfl/models/attacker/gpt2/piqa/train*1.000-test*1.000/gru/tr2t-10/epoch_19_rouge_0.9261')

## Step 3. 设置联邦训练流程

In [4]:
import torch
from sfl.utils.model import calculate_rouge, calculate_rouge_text
from sfl.config import FLConfig
from sfl.simulator.simulator import SFLSimulator
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm_notebook
from sfl.model.split_model import SplitModel
from sfl.simulator.strategy import FLStrategy
from sfl.simulator.dataset import PIQAFedDataset
from torch.optim import AdamW


# 定义Client本地学习策略
class QAFLStrategy(FLStrategy):

    def __init__(self):
        super().__init__()
        self.attacker_rouge_b2tr = []
        self.attacker_rouge_tr2t = []
        self.client_logs = {}

    def client_step(self, global_round, client_id: str, llm: SplitModel, dataloader: DataLoader, cfg: FLConfig):
        optimizer = AdamW(llm.parameters(), lr=1e-5)
        with tqdm_notebook(total=cfg.client_epoch * len(dataloader)) as pbar:
            for epoch in range(cfg.client_epoch):
                for step, batch in enumerate(dataloader):
                    optimizer.zero_grad()
                    input_ids = batch['input_ids'].to(llm.device)
                    attention_mask = batch['input_att_mask'].to(llm.device)
                    outputs = llm(input_ids=input_ids, labels=input_ids, attention_mask=attention_mask)
                    self.fp_done(client_id, epoch, step, batch)  # Collect intermediate results
                    loss = outputs.loss
                    pbar.set_description(f'Client {client_id} Epoch {epoch} Loss {loss.item():.3f}')
                    loss.backward()
                    self.bp_done(client_id, epoch, step, batch)  # Collect gradients
                    optimizer.step()
                    pbar.update(1)
                avg_rouge1 = sum([r["rouge-l"]["f"] for r in self.attacker_rouge_b2tr]) / len(self.attacker_rouge_b2tr)
                print(f'ATTACK! Bottom-trunk, Client {client_id} Epoch {epoch} RougeL {avg_rouge1:.3f}')
                avg_rouge2 = sum([r['rouge-l']['f'] for r in self.attacker_rouge_tr2t]) / len(self.attacker_rouge_tr2t)
                print(f'ATTACK! Trunk-Top, Client {client_id} Epoch {epoch} RougeL {avg_rouge2:.3f}')
                self.client_logs.setdefault(client_id, {})
                self.client_logs[client_id][epoch] = {"bottom-trunk": avg_rouge1, "trunk-top": avg_rouge2}
                self.attacker_rouge_b2tr.clear()
                self.attacker_rouge_tr2t.clear()

    def aggregation_step(self, global_round, params):
        report = {}
        report['global_round'] = global_round
        for cid, epochs in self.client_logs.items():
            for epc, rep in epochs.items():
                for k, v in rep.items():
                    report[f'client{cid}-epoch{epc}-{k}'] = v
        wandb.log(report)
        print(report)
        self.client_logs = {}
        return super(QAFLStrategy, self).aggregation_step(global_round, params)

    def callback_fp_param(self, client_id, local_epoch, local_step, b2tr_params, tr2t_params, batch):
        #  这里获取某epoch、step中，前传过程的两次传输参数，b2tr(bottom-trunk), tr2t(trunk-top)
        with torch.no_grad():
            rouge_res_b2tr = calculate_rouge(tokenizer, attacker.search(b2tr_params, model), batch['input_text'],
                                             is_tokens=True)
            rouge_res_tr2t = calculate_rouge(tokenizer, attacker2.search(tr2t_params, model), batch['input_text'],
                                             is_tokens=True)
            self.attacker_rouge_b2tr.append(rouge_res_b2tr)
            self.attacker_rouge_tr2t.append(rouge_res_tr2t)
            print(
                f'ATTACK! Bottom-trunk, Client {client_id} Epoch {local_epoch} Step {local_step} RougeL {rouge_res_b2tr["rouge-l"]["f"]:.3f}')

    def callback_bp_param(self, client_id, local_epoch, local_step, t2tr_params, tr2b_params, batch):
        #  这里获取某epoch、step中，反传过程的两次传输参数
        pass


client_ids = [str(i) for i in range(3)]
config = FLConfig(global_round=50,
                  client_epoch=2,  # 每轮联邦每个Client训2轮
                  split_point_1=2,
                  split_point_2=10,  # [0,1 | 2,3,.... 29| 30, 31]
                  use_lora_at_trunk=True,  # 在trunk部分使用LoRA
                  top_and_bottom_from_scratch=False,  # top和bottom都不采用预训练参数.
                  noise_mode="dxp",
                  noise_scale=5.0,  # 噪声大小
                  )
fed_dataset = PIQAFedDataset(tokenizer=tokenizer, client_ids=client_ids, shrink_frac=0.15)
simulator = SFLSimulator(client_ids=client_ids, strategy=QAFLStrategy(), llm=model, tokenizer=tokenizer,
                         dataset=fed_dataset, config=config)
model.print_split_model()
attacker.to(model.device)
attacker2.to(model.device)




transformer.h.30:[ln_1.weight: (1280,), ln_1.bias: (1280,), attn.c_attn.weight: (1280, 3840), attn.c_attn.bias: (3840,), attn.c_proj.weight: (1280, 1280), attn.c_proj.bias: (1280,), ln_2.weight: (1280,), ln_2.bias: (1280,), mlp.c_fc.weight: (1280, 5120), mlp.c_fc.bias: (5120,), mlp.c_proj.weight: (5120, 1280), mlp.c_proj.bias: (1280,)]

transformer.h.31:[ln_1.weight: (1280,), ln_1.bias: (1280,), attn.c_attn.weight: (1280, 3840), attn.c_attn.bias: (3840,), attn.c_proj.weight: (1280, 1280), attn.c_proj.bias: (1280,), ln_2.weight: (1280,), ln_2.bias: (1280,), mlp.c_fc.weight: (1280, 5120), mlp.c_fc.bias: (5120,), mlp.c_proj.weight: (5120, 1280), mlp.c_proj.bias: (1280,)]

transformer.h.32:[ln_1.weight: (1280,), ln_1.bias: (1280,), attn.c_attn.weight: (1280, 3840), attn.c_attn.bias: (3840,), attn.c_proj.weight: (1280, 1280), attn.c_proj.bias: (1280,), ln_2.weight: (1280,), ln_2.bias: (1280,), mlp.c_fc.weight: (1280, 5120), mlp.c_fc.bias: (5120,), mlp.c_proj.weight: (5120, 1280), mlp.c_pro

GRUAttackModel(
  (gru): GRU(1280, 256, batch_first=True)
  (mlp): Linear(in_features=256, out_features=50257, bias=True)
)

## Step 3. 开始联邦模拟

In [5]:
import wandb

wandb.init(
    project="sfl-with-attacker",
    # track hyperparameters and run metadata
    config={
        "dataset": 'piqa',
        "attacker": "piqa-validation",
        "noise": "2.0"
    }
)

simulator.simulate()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mstupidtree[0m. Use [1m`wandb login --relogin`[0m to force relogin




Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


  0%|          | 0/94 [00:00<?, ?it/s]

ATTACK! Bottom-trunk, Client 2 Epoch 0 Step 0 RougeL 0.955
ATTACK! Bottom-trunk, Client 2 Epoch 0 Step 1 RougeL 1.000
ATTACK! Bottom-trunk, Client 2 Epoch 0 Step 2 RougeL 1.000
ATTACK! Bottom-trunk, Client 2 Epoch 0 Step 3 RougeL 1.000
ATTACK! Bottom-trunk, Client 2 Epoch 0 Step 4 RougeL 1.000
ATTACK! Bottom-trunk, Client 2 Epoch 0 Step 5 RougeL 0.967
ATTACK! Bottom-trunk, Client 2 Epoch 0 Step 6 RougeL 0.986
ATTACK! Bottom-trunk, Client 2 Epoch 0 Step 7 RougeL 1.000
ATTACK! Bottom-trunk, Client 2 Epoch 0 Step 8 RougeL 1.000
ATTACK! Bottom-trunk, Client 2 Epoch 0 Step 9 RougeL 1.000
ATTACK! Bottom-trunk, Client 2 Epoch 0 Step 10 RougeL 1.000
ATTACK! Bottom-trunk, Client 2 Epoch 0 Step 11 RougeL 1.000
ATTACK! Bottom-trunk, Client 2 Epoch 0 Step 12 RougeL 0.974
ATTACK! Bottom-trunk, Client 2 Epoch 0 Step 13 RougeL 0.942
ATTACK! Bottom-trunk, Client 2 Epoch 0 Step 14 RougeL 0.939
ATTACK! Bottom-trunk, Client 2 Epoch 0 Step 15 RougeL 1.000



KeyboardInterrupt



In [None]:
from sfl.utils.model import sentence_score

sentence_score("I'm fine, thank you!", model, tokenizer)