## Step 1. 加载模型与Tokenizer

In [1]:
import os
import sys

sys.path.append(os.path.abspath('../..'))
from transformers import AutoTokenizer
from sfl.model.gpt2.gpt2_split import GPT2SplitLMHeadModel

cache_dir = '/root/autodl-tmp/sfl/models'  # 模型的缓存位置，需要修改
tokenizer = AutoTokenizer.from_pretrained("gpt2", cache_dir=cache_dir)
model = GPT2SplitLMHeadModel.from_pretrained("gpt2", cache_dir=cache_dir)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = 50256

In [2]:
# 测试模型的生成文本
def generate(text, md=model):
    model.train(False)
    t = tokenizer(text, return_tensors="pt", add_special_tokens=False)
    res = md.generate(t['input_ids'].to(md.device), attention_mask=t['attention_mask'].to(md.device),
                      max_length=300, num_beams=6, no_repeat_ngram_size=2, early_stopping=True,
                      num_return_sequences=1, pad_token_id=tokenizer.pad_token_id)
    return tokenizer.decode(res[0], skip_special_tokens=True)

# 测试模型输出
def get_output(text, md=model):
    t = tokenizer(text, return_tensors="pt", add_special_tokens=False)
    res = model(t['input_ids'].to(md.device), attention_mask=t['attention_mask'].to(md.device))
    r = tokenizer.decode(res.logits.argmax(dim=-1)[-1], skip_special_tokens=True)
    return r


print(generate("To mix food coloring with sugar, you can", model))

To mix food coloring with sugar, you can also use it as a sweetener.

If you want to add more sugar to the mix, add a little more water and mix well. If you add too much water, the mixture will be too thick, and you will end up with a mess. You can use a spoon to scoop out the excess water from the mixing bowl, but it's best to leave it at room temperature for at least 30 minutes before adding the rest of the water.


## Step 2. 加载攻击模型

In [3]:
import torch
from sfl.model.attack_model import GPT2AttackModel

# 攻击bottom-trunk数据
attacker = GPT2AttackModel(model.config)
attacker.load_state_dict(
    torch.load('/root/autodl-tmp/sfl/models/checkpoints/attacker/epoch_4_rouge_0.9752261900524914.pt'))

# 攻击trunk-top 数据
attacker2 = GPT2AttackModel(model.config)
attacker2.load_state_dict(
    torch.load('/root/autodl-tmp/sfl/models/checkpoints/attacker/b2tr-9/epoch_4_rouge_0.9400152550325381.pt'))

<All keys matched successfully>

## Step 3. 设置联邦训练流程

In [4]:
from sfl.simulator.simulator import SFLSimulator
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm_notebook
from sfl.model.split_model import SplitModel
from sfl.simulator.strategy import FLStrategy
from sfl.simulator.dataset import PIQAFedDataset, FedDataset
from sfl.utils import FLConfig, calculate_rouge
from torch.optim import AdamW


# 定义Client本地学习策略
class QAFLStrategy(FLStrategy):

    def __init__(self):
        super().__init__()
        self.attacker_rouge_b2tr = []
        self.attacker_rouge_tr2t = []

    def client_step(self, client_id: str, llm: SplitModel, dataloader: DataLoader, cfg: FLConfig):
        optimizer = AdamW(llm.parameters(), lr=1e-5)
        with tqdm_notebook(total=cfg.client_epoch * len(dataloader)) as pbar:
            for epoch in range(cfg.client_epoch):
                for step, batch in enumerate(dataloader):
                    optimizer.zero_grad()
                    input_ids = batch['input_ids'].to(llm.device)
                    attention_mask = batch['input_att_mask'].to(llm.device)
                    outputs = llm(input_ids=input_ids, labels=input_ids, attention_mask=attention_mask)
                    self.fp_done(client_id, epoch, step, batch)  # Collect intermediate results
                    loss = outputs.loss
                    pbar.set_description(f'Client {client_id} Epoch {epoch} Loss {loss.item():.3f}')
                    loss.backward()
                    self.bp_done(client_id, epoch, step, batch)  # Collect gradients
                    optimizer.step()
                    pbar.update(1)
                avg_rouge = sum([r["rouge-l"]["f"] for r in self.attacker_rouge_b2tr]) / len(self.attacker_rouge_b2tr)
                print(f'ATTACK! Bottom-trunk, Client {client_id} Epoch {epoch} RougeL {avg_rouge:.3f}')
                avg_rouge = sum([r['rouge-l']['f'] for r in self.attacker_rouge_tr2t]) / len(self.attacker_rouge_tr2t)
                print(f'ATTACK! Trunk-Top, Client {client_id} Epoch {epoch} RougeL {avg_rouge:.3f}')
                self.attacker_rouge_b2tr.clear()
                self.attacker_rouge_tr2t.clear()

    def callback_fp_param(self, client_id, local_epoch, local_step, b2tr_params, tr2t_params, batch):
        #  这里获取某epoch、step中，前传过程的两次传输参数，b2tr(bottom-trunk), tr2t(trunk-top)
        with torch.no_grad():
            rouge_res_b2tr = calculate_rouge(tokenizer, attacker(b2tr_params), batch['input_text'])
            rouge_res_tr2t = calculate_rouge(tokenizer, attacker2(tr2t_params), batch['input_text'])
            self.attacker_rouge_b2tr.append(rouge_res_b2tr)
            self.attacker_rouge_tr2t.append(rouge_res_tr2t)

    def callback_bp_param(self, client_id, local_epoch, local_step, t2tr_params, tr2b_params, batch):
        #  这里获取某epoch、step中，反传过程的两次传输参数
        pass


client_ids = [str(i) for i in range(3)]
config = FLConfig(global_round=10, client_epoch=2, split_point_1=2, split_point_2=9, use_lora_at_trunk=True)
fed_dataset = PIQAFedDataset(tokenizer=tokenizer, client_ids=client_ids,shrink_frac=0.15)
simulator = SFLSimulator(client_ids=client_ids, strategy=QAFLStrategy(), llm=model, tokenizer=tokenizer,
                         dataset=fed_dataset, config=config)
model.print_split_model()




transformer.h.9:[ln_1.weight: (768,), ln_1.bias: (768,), attn.c_attn.weight: (768, 2304), attn.c_attn.bias: (2304,), attn.c_proj.weight: (768, 768), attn.c_proj.bias: (768,), ln_2.weight: (768,), ln_2.bias: (768,), mlp.c_fc.weight: (768, 3072), mlp.c_fc.bias: (3072,), mlp.c_proj.weight: (3072, 768), mlp.c_proj.bias: (768,)]

transformer.h.10:[ln_1.weight: (768,), ln_1.bias: (768,), attn.c_attn.weight: (768, 2304), attn.c_attn.bias: (2304,), attn.c_proj.weight: (768, 768), attn.c_proj.bias: (768,), ln_2.weight: (768,), ln_2.bias: (768,), mlp.c_fc.weight: (768, 3072), mlp.c_fc.bias: (3072,), mlp.c_proj.weight: (3072, 768), mlp.c_proj.bias: (768,)]

transformer.h.11:[ln_1.weight: (768,), ln_1.bias: (768,), attn.c_attn.weight: (768, 2304), attn.c_attn.bias: (2304,), attn.c_proj.weight: (768, 768), attn.c_proj.bias: (768,), ln_2.weight: (768,), ln_2.bias: (768,), mlp.c_fc.weight: (768, 3072), mlp.c_fc.bias: (3072,), mlp.c_proj.weight: (3072, 768), mlp.c_proj.bias: (768,)]

transformer.ln_f

## Step 3. 开始联邦模拟

In [5]:
simulator.simulate()



  0%|          | 0/72 [00:00<?, ?it/s]

ATTACK! Bottom-trunk, Client 2 Epoch 0 RougeL 0.909
ATTACK! Trunk-Top, Client 2 Epoch 0 RougeL 0.832
ATTACK! Bottom-trunk, Client 2 Epoch 1 RougeL 0.912
ATTACK! Trunk-Top, Client 2 Epoch 1 RougeL 0.803
Client 2 communication overhead: uplink:852.00 MB, downlink:852.00 MB


  0%|          | 0/112 [00:00<?, ?it/s]

ATTACK! Bottom-trunk, Client 1 Epoch 0 RougeL 0.954
ATTACK! Trunk-Top, Client 1 Epoch 0 RougeL 0.876
ATTACK! Bottom-trunk, Client 1 Epoch 1 RougeL 0.955
ATTACK! Trunk-Top, Client 1 Epoch 1 RougeL 0.838
Client 1 communication overhead: uplink:1.30 GB, downlink:1.30 GB


  0%|          | 0/94 [00:00<?, ?it/s]

ATTACK! Bottom-trunk, Client 0 Epoch 0 RougeL 0.938
ATTACK! Trunk-Top, Client 0 Epoch 0 RougeL 0.853
ATTACK! Bottom-trunk, Client 0 Epoch 1 RougeL 0.941
ATTACK! Trunk-Top, Client 0 Epoch 1 RougeL 0.823
Client 0 communication overhead: uplink:1.09 GB, downlink:1.09 GB
Global Round 0 communication overhead: uplink=3.22 GB, downlink=3.22 GB


  0%|          | 0/94 [00:00<?, ?it/s]

ATTACK! Bottom-trunk, Client 0 Epoch 0 RougeL 0.944
ATTACK! Trunk-Top, Client 0 Epoch 0 RougeL 0.804
ATTACK! Bottom-trunk, Client 0 Epoch 1 RougeL 0.934
ATTACK! Trunk-Top, Client 0 Epoch 1 RougeL 0.820
Client 0 communication overhead: uplink:1.09 GB, downlink:1.09 GB


  0%|          | 0/72 [00:00<?, ?it/s]

ATTACK! Bottom-trunk, Client 2 Epoch 0 RougeL 0.910
ATTACK! Trunk-Top, Client 2 Epoch 0 RougeL 0.777
ATTACK! Bottom-trunk, Client 2 Epoch 1 RougeL 0.900
ATTACK! Trunk-Top, Client 2 Epoch 1 RougeL 0.784
Client 2 communication overhead: uplink:852.00 MB, downlink:852.00 MB


  0%|          | 0/112 [00:00<?, ?it/s]

ATTACK! Bottom-trunk, Client 1 Epoch 0 RougeL 0.951
ATTACK! Trunk-Top, Client 1 Epoch 0 RougeL 0.848
ATTACK! Bottom-trunk, Client 1 Epoch 1 RougeL 0.949
ATTACK! Trunk-Top, Client 1 Epoch 1 RougeL 0.868
Client 1 communication overhead: uplink:1.30 GB, downlink:1.30 GB
Global Round 1 communication overhead: uplink=3.22 GB, downlink=3.22 GB


  0%|          | 0/112 [00:00<?, ?it/s]

ATTACK! Bottom-trunk, Client 1 Epoch 0 RougeL 0.952
ATTACK! Trunk-Top, Client 1 Epoch 0 RougeL 0.871
ATTACK! Bottom-trunk, Client 1 Epoch 1 RougeL 0.946
ATTACK! Trunk-Top, Client 1 Epoch 1 RougeL 0.871
Client 1 communication overhead: uplink:1.30 GB, downlink:1.30 GB


  0%|          | 0/94 [00:00<?, ?it/s]

ATTACK! Bottom-trunk, Client 0 Epoch 0 RougeL 0.931
ATTACK! Trunk-Top, Client 0 Epoch 0 RougeL 0.841
ATTACK! Bottom-trunk, Client 0 Epoch 1 RougeL 0.931
ATTACK! Trunk-Top, Client 0 Epoch 1 RougeL 0.837
Client 0 communication overhead: uplink:1.09 GB, downlink:1.09 GB


  0%|          | 0/72 [00:00<?, ?it/s]

ATTACK! Bottom-trunk, Client 2 Epoch 0 RougeL 0.906
ATTACK! Trunk-Top, Client 2 Epoch 0 RougeL 0.824
ATTACK! Bottom-trunk, Client 2 Epoch 1 RougeL 0.907
ATTACK! Trunk-Top, Client 2 Epoch 1 RougeL 0.827
Client 2 communication overhead: uplink:852.00 MB, downlink:852.00 MB
Global Round 2 communication overhead: uplink=3.22 GB, downlink=3.22 GB


  0%|          | 0/72 [00:00<?, ?it/s]

ATTACK! Bottom-trunk, Client 2 Epoch 0 RougeL 0.910
ATTACK! Trunk-Top, Client 2 Epoch 0 RougeL 0.827
ATTACK! Bottom-trunk, Client 2 Epoch 1 RougeL 0.906
ATTACK! Trunk-Top, Client 2 Epoch 1 RougeL 0.827
Client 2 communication overhead: uplink:852.00 MB, downlink:852.00 MB


  0%|          | 0/94 [00:00<?, ?it/s]

ATTACK! Bottom-trunk, Client 0 Epoch 0 RougeL 0.935
ATTACK! Trunk-Top, Client 0 Epoch 0 RougeL 0.847
ATTACK! Bottom-trunk, Client 0 Epoch 1 RougeL 0.932
ATTACK! Trunk-Top, Client 0 Epoch 1 RougeL 0.852
Client 0 communication overhead: uplink:1.09 GB, downlink:1.09 GB


  0%|          | 0/112 [00:00<?, ?it/s]

ATTACK! Bottom-trunk, Client 1 Epoch 0 RougeL 0.948
ATTACK! Trunk-Top, Client 1 Epoch 0 RougeL 0.879
ATTACK! Bottom-trunk, Client 1 Epoch 1 RougeL 0.950
ATTACK! Trunk-Top, Client 1 Epoch 1 RougeL 0.881
Client 1 communication overhead: uplink:1.30 GB, downlink:1.30 GB
Global Round 3 communication overhead: uplink=3.22 GB, downlink=3.22 GB


  0%|          | 0/72 [00:00<?, ?it/s]

ATTACK! Bottom-trunk, Client 2 Epoch 0 RougeL 0.899
ATTACK! Trunk-Top, Client 2 Epoch 0 RougeL 0.830
ATTACK! Bottom-trunk, Client 2 Epoch 1 RougeL 0.903
ATTACK! Trunk-Top, Client 2 Epoch 1 RougeL 0.828
Client 2 communication overhead: uplink:852.00 MB, downlink:852.00 MB


  0%|          | 0/112 [00:00<?, ?it/s]

ATTACK! Bottom-trunk, Client 1 Epoch 0 RougeL 0.949
ATTACK! Trunk-Top, Client 1 Epoch 0 RougeL 0.878
ATTACK! Bottom-trunk, Client 1 Epoch 1 RougeL 0.947
ATTACK! Trunk-Top, Client 1 Epoch 1 RougeL 0.871
Client 1 communication overhead: uplink:1.30 GB, downlink:1.30 GB


  0%|          | 0/94 [00:00<?, ?it/s]

ATTACK! Bottom-trunk, Client 0 Epoch 0 RougeL 0.925
ATTACK! Trunk-Top, Client 0 Epoch 0 RougeL 0.847
ATTACK! Bottom-trunk, Client 0 Epoch 1 RougeL 0.931
ATTACK! Trunk-Top, Client 0 Epoch 1 RougeL 0.853
Client 0 communication overhead: uplink:1.09 GB, downlink:1.09 GB
Global Round 4 communication overhead: uplink=3.22 GB, downlink=3.22 GB


  0%|          | 0/94 [00:00<?, ?it/s]

ATTACK! Bottom-trunk, Client 0 Epoch 0 RougeL 0.931
ATTACK! Trunk-Top, Client 0 Epoch 0 RougeL 0.848
ATTACK! Bottom-trunk, Client 0 Epoch 1 RougeL 0.927
ATTACK! Trunk-Top, Client 0 Epoch 1 RougeL 0.853
Client 0 communication overhead: uplink:1.09 GB, downlink:1.09 GB


  0%|          | 0/112 [00:00<?, ?it/s]

ATTACK! Bottom-trunk, Client 1 Epoch 0 RougeL 0.950
ATTACK! Trunk-Top, Client 1 Epoch 0 RougeL 0.879
ATTACK! Bottom-trunk, Client 1 Epoch 1 RougeL 0.950
ATTACK! Trunk-Top, Client 1 Epoch 1 RougeL 0.880
Client 1 communication overhead: uplink:1.30 GB, downlink:1.30 GB


  0%|          | 0/72 [00:00<?, ?it/s]

ATTACK! Bottom-trunk, Client 2 Epoch 0 RougeL 0.907
ATTACK! Trunk-Top, Client 2 Epoch 0 RougeL 0.831
ATTACK! Bottom-trunk, Client 2 Epoch 1 RougeL 0.897
ATTACK! Trunk-Top, Client 2 Epoch 1 RougeL 0.836
Client 2 communication overhead: uplink:852.00 MB, downlink:852.00 MB
Global Round 5 communication overhead: uplink=3.22 GB, downlink=3.22 GB


  0%|          | 0/72 [00:00<?, ?it/s]

ATTACK! Bottom-trunk, Client 2 Epoch 0 RougeL 0.895
ATTACK! Trunk-Top, Client 2 Epoch 0 RougeL 0.831
ATTACK! Bottom-trunk, Client 2 Epoch 1 RougeL 0.893
ATTACK! Trunk-Top, Client 2 Epoch 1 RougeL 0.825
Client 2 communication overhead: uplink:852.00 MB, downlink:852.00 MB


  0%|          | 0/94 [00:00<?, ?it/s]

ATTACK! Bottom-trunk, Client 0 Epoch 0 RougeL 0.930
ATTACK! Trunk-Top, Client 0 Epoch 0 RougeL 0.849
ATTACK! Bottom-trunk, Client 0 Epoch 1 RougeL 0.922
ATTACK! Trunk-Top, Client 0 Epoch 1 RougeL 0.846
Client 0 communication overhead: uplink:1.09 GB, downlink:1.09 GB


  0%|          | 0/112 [00:00<?, ?it/s]

ATTACK! Bottom-trunk, Client 1 Epoch 0 RougeL 0.947
ATTACK! Trunk-Top, Client 1 Epoch 0 RougeL 0.868
ATTACK! Bottom-trunk, Client 1 Epoch 1 RougeL 0.946
ATTACK! Trunk-Top, Client 1 Epoch 1 RougeL 0.869
Client 1 communication overhead: uplink:1.30 GB, downlink:1.30 GB
Global Round 6 communication overhead: uplink=3.22 GB, downlink=3.22 GB


  0%|          | 0/72 [00:00<?, ?it/s]

ATTACK! Bottom-trunk, Client 2 Epoch 0 RougeL 0.906
ATTACK! Trunk-Top, Client 2 Epoch 0 RougeL 0.826
ATTACK! Bottom-trunk, Client 2 Epoch 1 RougeL 0.903
ATTACK! Trunk-Top, Client 2 Epoch 1 RougeL 0.818
Client 2 communication overhead: uplink:852.00 MB, downlink:852.00 MB


  0%|          | 0/112 [00:00<?, ?it/s]

ATTACK! Bottom-trunk, Client 1 Epoch 0 RougeL 0.951
ATTACK! Trunk-Top, Client 1 Epoch 0 RougeL 0.872
ATTACK! Bottom-trunk, Client 1 Epoch 1 RougeL 0.950
ATTACK! Trunk-Top, Client 1 Epoch 1 RougeL 0.863
Client 1 communication overhead: uplink:1.30 GB, downlink:1.30 GB


  0%|          | 0/94 [00:00<?, ?it/s]

ATTACK! Bottom-trunk, Client 0 Epoch 0 RougeL 0.922
ATTACK! Trunk-Top, Client 0 Epoch 0 RougeL 0.845
ATTACK! Bottom-trunk, Client 0 Epoch 1 RougeL 0.924
ATTACK! Trunk-Top, Client 0 Epoch 1 RougeL 0.842
Client 0 communication overhead: uplink:1.09 GB, downlink:1.09 GB
Global Round 7 communication overhead: uplink=3.22 GB, downlink=3.22 GB


  0%|          | 0/94 [00:00<?, ?it/s]

ATTACK! Bottom-trunk, Client 0 Epoch 0 RougeL 0.922
ATTACK! Trunk-Top, Client 0 Epoch 0 RougeL 0.843
ATTACK! Bottom-trunk, Client 0 Epoch 1 RougeL 0.916
ATTACK! Trunk-Top, Client 0 Epoch 1 RougeL 0.841
Client 0 communication overhead: uplink:1.09 GB, downlink:1.09 GB


  0%|          | 0/72 [00:00<?, ?it/s]

ATTACK! Bottom-trunk, Client 2 Epoch 0 RougeL 0.902
ATTACK! Trunk-Top, Client 2 Epoch 0 RougeL 0.827
ATTACK! Bottom-trunk, Client 2 Epoch 1 RougeL 0.902
ATTACK! Trunk-Top, Client 2 Epoch 1 RougeL 0.827
Client 2 communication overhead: uplink:852.00 MB, downlink:852.00 MB


  0%|          | 0/112 [00:00<?, ?it/s]

ATTACK! Bottom-trunk, Client 1 Epoch 0 RougeL 0.945
ATTACK! Trunk-Top, Client 1 Epoch 0 RougeL 0.869
ATTACK! Bottom-trunk, Client 1 Epoch 1 RougeL 0.943
ATTACK! Trunk-Top, Client 1 Epoch 1 RougeL 0.874
Client 1 communication overhead: uplink:1.30 GB, downlink:1.30 GB
Global Round 8 communication overhead: uplink=3.22 GB, downlink=3.22 GB


  0%|          | 0/112 [00:00<?, ?it/s]

ATTACK! Bottom-trunk, Client 1 Epoch 0 RougeL 0.944
ATTACK! Trunk-Top, Client 1 Epoch 0 RougeL 0.867
ATTACK! Bottom-trunk, Client 1 Epoch 1 RougeL 0.950
ATTACK! Trunk-Top, Client 1 Epoch 1 RougeL 0.869
Client 1 communication overhead: uplink:1.30 GB, downlink:1.30 GB


  0%|          | 0/72 [00:00<?, ?it/s]

ATTACK! Bottom-trunk, Client 2 Epoch 0 RougeL 0.897
ATTACK! Trunk-Top, Client 2 Epoch 0 RougeL 0.819
ATTACK! Bottom-trunk, Client 2 Epoch 1 RougeL 0.899
ATTACK! Trunk-Top, Client 2 Epoch 1 RougeL 0.825
Client 2 communication overhead: uplink:852.00 MB, downlink:852.00 MB


  0%|          | 0/94 [00:00<?, ?it/s]

ATTACK! Bottom-trunk, Client 0 Epoch 0 RougeL 0.926
ATTACK! Trunk-Top, Client 0 Epoch 0 RougeL 0.845
ATTACK! Bottom-trunk, Client 0 Epoch 1 RougeL 0.921
ATTACK! Trunk-Top, Client 0 Epoch 1 RougeL 0.843
Client 0 communication overhead: uplink:1.09 GB, downlink:1.09 GB
Global Round 9 communication overhead: uplink=3.22 GB, downlink=3.22 GB
FL communication overhead: uplink=32.23 GB, downlink=32.23 GB


In [11]:
print(generate("How to finish my thesis.", model))

How to finish my thesis. Solution:Gather all the materials needed for the project.
