## Step 1. 加载模型与Tokenizer

In [1]:
import os
import sys

sys.path.append(os.path.abspath('../..'))
from transformers import AutoTokenizer
from sfl.model.gpt2.gpt2_split import GPT2SplitLMHeadModel

cache_dir = '/root/autodl-tmp/sfl/models'  # 模型的缓存位置，需要修改
tokenizer = AutoTokenizer.from_pretrained("gpt2", cache_dir=cache_dir)
model = GPT2SplitLMHeadModel.from_pretrained("gpt2", cache_dir=cache_dir)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = 50256

In [2]:
# 测试模型的生成文本
def generate(text, md=model):
    model.train(False)
    t = tokenizer(text, return_tensors="pt", add_special_tokens=False)
    res = md.generate(t['input_ids'].to(md.device), attention_mask=t['attention_mask'].to(md.device),
                      max_length=300, num_beams=6, no_repeat_ngram_size=2, early_stopping=True,
                      num_return_sequences=1, pad_token_id=tokenizer.pad_token_id)
    return tokenizer.decode(res[0], skip_special_tokens=True)

# 测试模型输出
def get_output(text, md=model):
    t = tokenizer(text, return_tensors="pt", add_special_tokens=False)
    res = model(t['input_ids'].to(md.device), attention_mask=t['attention_mask'].to(md.device))
    r = tokenizer.decode(res.logits.argmax(dim=-1)[-1], skip_special_tokens=True)
    return r


print(generate("Hi father", model))

Hi father, I'm sorry, but I don't know what you're talking about."

"I'm not going to tell you what to do," she said. "I just want you to know that I love you, and I want to be with you for a long, long time."


## Step 2. 加载攻击模型

In [3]:
import torch
from sfl.model.attack_model import GPT2AttackModel

attacker = GPT2AttackModel(model.config)
attacker.load_state_dict(torch.load('/root/autodl-tmp/sfl/models/checkpoints/attacker/epoch_4_rouge_0.9131500286422674.pt'))

<All keys matched successfully>

## Step 3. 设置联邦训练流程

In [4]:
from sfl.simulator.simulator import SFLSimulator
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm_notebook
from sfl.model.split_model import SplitModel
from sfl.simulator.strategy import FLStrategy
from sfl.simulator.dataset import PIQAFedDataset, FedDataset
from sfl.utils import FLConfig, calculate_rouge
from torch.optim import AdamW


# 定义Client本地学习策略
class QAFLStrategy(FLStrategy):

    def __init__(self):
        super().__init__()
        self.attacker_rouge = []

    def client_step(self, client_id: str, llm: SplitModel, dataloader: DataLoader, cfg: FLConfig):
        optimizer = AdamW(llm.parameters(), lr=1e-5)
        with tqdm_notebook(total=cfg.client_epoch * len(dataloader)) as pbar:
            for epoch in range(cfg.client_epoch):
                for step, batch in enumerate(dataloader):
                    optimizer.zero_grad()
                    input_ids = batch['input_ids'].to(llm.device)
                    attention_mask = batch['input_att_mask'].to(llm.device)
                    outputs = llm(input_ids=input_ids, labels=input_ids, attention_mask=attention_mask)
                    self.fp_done(client_id, epoch, step, batch)  # Collect intermediate results
                    loss = outputs.loss
                    pbar.set_description(f'Client {client_id} Epoch {epoch} Loss {loss.item():.3f}')
                    loss.backward()
                    self.bp_done(client_id, epoch, step, batch)  # Collect gradients
                    optimizer.step()
                    pbar.update(1)
                avg_rouge = sum([r["rouge-1"]["f"] for r in self.attacker_rouge])/len(self.attacker_rouge)
                print(f'ATTACK! Client {client_id} Epoch {epoch} Rouge1 {avg_rouge:.3f}')
                self.attacker_rouge.clear()

    def callback_fp_param(self, client_id, local_epoch, local_step, b2tr_params, tr2t_params, batch):
        #  这里获取某epoch、step中，前传过程的两次传输参数，b2tr(bottom-trunk), tr2t(trunk-top)
        with torch.no_grad():
            logits = attacker(b2tr_params)
            rouge_res = calculate_rouge(tokenizer, logits, batch['question_text'])
            self.attacker_rouge.append(rouge_res)
    def callback_bp_param(self, client_id, local_epoch, local_step, t2tr_params, tr2b_params, batch):
        #  这里获取某epoch、step中，反传过程的两次传输参数
        pass


client_ids = [str(i) for i in range(3)]
config = FLConfig(global_round=10, client_epoch=2, split_point_1=2, split_point_2=10, use_lora_at_trunk=True)
fed_dataset = PIQAFedDataset(tokenizer=tokenizer, client_ids=client_ids)
simulator = SFLSimulator(client_ids=client_ids, strategy=QAFLStrategy(), llm=model, tokenizer=tokenizer,
                         dataset=fed_dataset, config=config)
model.print_split_model()




transformer.h.10:[ln_1.weight: (768,), ln_1.bias: (768,), attn.c_attn.weight: (768, 2304), attn.c_attn.bias: (2304,), attn.c_proj.weight: (768, 768), attn.c_proj.bias: (768,), ln_2.weight: (768,), ln_2.bias: (768,), mlp.c_fc.weight: (768, 3072), mlp.c_fc.bias: (3072,), mlp.c_proj.weight: (3072, 768), mlp.c_proj.bias: (768,)]

transformer.h.11:[ln_1.weight: (768,), ln_1.bias: (768,), attn.c_attn.weight: (768, 2304), attn.c_attn.bias: (2304,), attn.c_proj.weight: (768, 768), attn.c_proj.bias: (768,), ln_2.weight: (768,), ln_2.bias: (768,), mlp.c_fc.weight: (768, 3072), mlp.c_fc.bias: (3072,), mlp.c_proj.weight: (3072, 768), mlp.c_proj.bias: (768,)]

transformer.ln_f.weight:[: (768,)]

transformer.ln_f.bias:[: (768,)]

transformer.h.2:[attn.c_attn.lora_A.default.weight: (8, 768), attn.c_attn.lora_B.default.weight: (2304, 8), attn.c_proj.lora_A.default.weight: (8, 768), attn.c_proj.lora_B.default.weight: (768, 8), mlp.c_fc.lora_A.default.weight: (8, 768), mlp.c_fc.lora_B.default.weight: (

## Step 3. 开始联邦模拟

In [None]:
simulator.simulate()



  0%|          | 0/126 [00:00<?, ?it/s]

ATTACK! Client 1 Epoch 0 Rouge1 0.139
ATTACK! Client 1 Epoch 1 Rouge1 0.136
Client 1 communication overhead: uplink:1.48 GB, downlink:1.48 GB


  0%|          | 0/104 [00:00<?, ?it/s]

ATTACK! Client 2 Epoch 0 Rouge1 0.135
ATTACK! Client 2 Epoch 1 Rouge1 0.114
Client 2 communication overhead: uplink:1.21 GB, downlink:1.21 GB


  0%|          | 0/102 [00:00<?, ?it/s]

ATTACK! Client 0 Epoch 0 Rouge1 0.122
ATTACK! Client 0 Epoch 1 Rouge1 0.117
Client 0 communication overhead: uplink:1.18 GB, downlink:1.18 GB
Global Round 0 communication overhead: uplink=3.87 GB, downlink=3.87 GB


  0%|          | 0/126 [00:00<?, ?it/s]

ATTACK! Client 1 Epoch 0 Rouge1 0.144


In [None]:
print(generate("To make cake", model))