## Step 1. 加载模型与Tokenizer

In [13]:
from transformers import AutoTokenizer
from sfl.model.gpt2.gpt2_split import GPT2SplitLMHeadModel

cache_dir = '/root/autodl-tmp/sfl/models' # 模型的缓存位置，需要修改
tokenizer = AutoTokenizer.from_pretrained("gpt2", cache_dir=cache_dir)
model = GPT2SplitLMHeadModel.from_pretrained("gpt2", cache_dir=cache_dir)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = 50256


KeyboardInterrupt



In [2]:
# 测试模型输出
t = tokenizer("Good Evening", return_tensors="pt", add_special_tokens=False)
res = model.generate(t['input_ids'],attention_mask=t['attention_mask'], max_length=100, num_beams=5, no_repeat_ngram_size=2, early_stopping=True, num_return_sequences=1,pad_token_id=tokenizer.pad_token_id)
print(tokenizer.decode(res[0], skip_special_tokens=True))

Good Evening."

"I'm sorry to hear that," she said. "I was just wondering if you'd like to join us for dinner. I'm sure you'll be able to help us out with some of the things you've been working on, but I don't know if I can help you with anything else. It's been a long time since I've had a chance to talk to you, so I'll just have to let you know when we're back in town."


## Step 2. 设置联邦训练流程

In [3]:
from sfl.simulator.simulator import SFLSimulator
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm_notebook
from sfl.model.split_model import SplitModel
from sfl.simulator.strategy import FLStrategy
from sfl.simulator.dataset import PIQAFedDataset, FedDataset
from sfl.utils import FLConfig
from torch.optim import AdamW


# 定义Client本地学习策略
class QAFLStrategy(FLStrategy):

    def client_step(self, client_id: str, llm: SplitModel, dataloader: DataLoader, cfg: FLConfig):
        optimizer = AdamW(llm.parameters(), lr=1e-5)
        with tqdm_notebook(total=cfg.client_epoch * len(dataloader)) as pbar:
            for epoch in range(cfg.client_epoch):
                for step, batch in enumerate(dataloader):
                    optimizer.zero_grad()
                    input_ids, labels = batch['input_ids'].to(llm.device), batch['output_ids'].to(llm.device)
                    attention_mask = batch['input_att_mask'].to(llm.device)
                    outputs = llm(input_ids=input_ids, labels=labels, attention_mask=attention_mask)
                    self.fp_done(client_id, epoch, step)  # Collect intermediate results
                    loss = outputs.loss
                    pbar.set_description(f'Client {client_id} Epoch {epoch} Loss {loss.item():.3f}')
                    loss.backward()
                    self.bp_done(client_id, epoch, step)  # Collect gradients
                    # res_text = tokenizer.decode(outputs.logits.argmax(dim=-1)[-1], skip_special_tokens=True)
                    # print(batch['input_text'][-1],"==>",res_text.strip(),"】")
                    optimizer.step()
                    pbar.update(1)

    def callback_fp_param(self, client_id, local_epoch, local_step, b2tr_params, tr2t_params):
        #  这里获取某epoch、step中，前传过程的两次传输参数，b2tr(bottom-trunk), tr2t(trunk-top)
        pass

    def callback_bp_param(self, client_id, local_epoch, local_step, t2tr_params, tr2b_params):
        #  这里获取某epoch、step中，反传过程的两次传输参数
        pass


client_ids = [str(i) for i in range(3)]
config = FLConfig(global_round=10, client_epoch=2, split_point_1=2, split_point_2=10, use_lora_at_trunk=True)
fed_dataset = PIQAFedDataset(tokenizer=tokenizer, client_ids=client_ids)
simulator = SFLSimulator(client_ids=client_ids, strategy=QAFLStrategy(), llm=model, tokenizer=tokenizer,
                         dataset=fed_dataset, config=config)



In [4]:
model.print_split_model()


transformer.h.10:[ln_1.weight: (768,), ln_1.bias: (768,), attn.c_attn.weight: (768, 2304), attn.c_attn.bias: (2304,), attn.c_proj.weight: (768, 768), attn.c_proj.bias: (768,), ln_2.weight: (768,), ln_2.bias: (768,), mlp.c_fc.weight: (768, 3072), mlp.c_fc.bias: (3072,), mlp.c_proj.weight: (3072, 768), mlp.c_proj.bias: (768,)]

transformer.h.11:[ln_1.weight: (768,), ln_1.bias: (768,), attn.c_attn.weight: (768, 2304), attn.c_attn.bias: (2304,), attn.c_proj.weight: (768, 768), attn.c_proj.bias: (768,), ln_2.weight: (768,), ln_2.bias: (768,), mlp.c_fc.weight: (768, 3072), mlp.c_fc.bias: (3072,), mlp.c_proj.weight: (3072, 768), mlp.c_proj.bias: (768,)]

transformer.ln_f.weight:[: (768,)]

transformer.ln_f.bias:[: (768,)]

transformer.h.2:[attn.c_attn.lora_A.default.weight: (8, 768), attn.c_attn.lora_B.default.weight: (2304, 8), attn.c_proj.lora_A.default.weight: (8, 768), attn.c_proj.lora_B.default.weight: (768, 8), mlp.c_fc.lora_A.default.weight: (8, 768), mlp.c_fc.lora_B.default.weight: (

## Step 3. 开始联邦模拟

In [5]:
simulator.simulate()



  0%|          | 0/112 [00:00<?, ?it/s]

Client 2 communication overhead: uplink:352321536, downlink:352321536


  0%|          | 0/82 [00:00<?, ?it/s]

Client 1 communication overhead: uplink:257949696, downlink:257949696


  0%|          | 0/106 [00:00<?, ?it/s]

Client 0 communication overhead: uplink:333447168, downlink:333447168
Global Round 0 communication overhead: uplink=943718400, downlink=943718400


  0%|          | 0/82 [00:00<?, ?it/s]

Client 1 communication overhead: uplink:257949696, downlink:257949696


  0%|          | 0/112 [00:00<?, ?it/s]

Client 2 communication overhead: uplink:352321536, downlink:352321536


  0%|          | 0/106 [00:00<?, ?it/s]

Client 0 communication overhead: uplink:333447168, downlink:333447168
Global Round 1 communication overhead: uplink=943718400, downlink=943718400


  0%|          | 0/106 [00:00<?, ?it/s]

Client 0 communication overhead: uplink:333447168, downlink:333447168


  0%|          | 0/112 [00:00<?, ?it/s]

Client 2 communication overhead: uplink:352321536, downlink:352321536


  0%|          | 0/82 [00:00<?, ?it/s]

Client 1 communication overhead: uplink:257949696, downlink:257949696
Global Round 2 communication overhead: uplink=943718400, downlink=943718400


  0%|          | 0/82 [00:00<?, ?it/s]

Client 1 communication overhead: uplink:257949696, downlink:257949696


  0%|          | 0/112 [00:00<?, ?it/s]

Client 2 communication overhead: uplink:352321536, downlink:352321536


  0%|          | 0/106 [00:00<?, ?it/s]

Client 0 communication overhead: uplink:333447168, downlink:333447168
Global Round 3 communication overhead: uplink=943718400, downlink=943718400


  0%|          | 0/112 [00:00<?, ?it/s]

Client 2 communication overhead: uplink:352321536, downlink:352321536


  0%|          | 0/82 [00:00<?, ?it/s]

Client 1 communication overhead: uplink:257949696, downlink:257949696


  0%|          | 0/106 [00:00<?, ?it/s]

Client 0 communication overhead: uplink:333447168, downlink:333447168
Global Round 4 communication overhead: uplink=943718400, downlink=943718400


  0%|          | 0/112 [00:00<?, ?it/s]

Client 2 communication overhead: uplink:352321536, downlink:352321536


  0%|          | 0/106 [00:00<?, ?it/s]

Client 0 communication overhead: uplink:333447168, downlink:333447168


  0%|          | 0/82 [00:00<?, ?it/s]

Client 1 communication overhead: uplink:257949696, downlink:257949696
Global Round 5 communication overhead: uplink=943718400, downlink=943718400


  0%|          | 0/112 [00:00<?, ?it/s]

Client 2 communication overhead: uplink:352321536, downlink:352321536


  0%|          | 0/82 [00:00<?, ?it/s]

Client 1 communication overhead: uplink:257949696, downlink:257949696


  0%|          | 0/106 [00:00<?, ?it/s]

Client 0 communication overhead: uplink:333447168, downlink:333447168
Global Round 6 communication overhead: uplink=943718400, downlink=943718400


  0%|          | 0/106 [00:00<?, ?it/s]

Client 0 communication overhead: uplink:333447168, downlink:333447168


  0%|          | 0/82 [00:00<?, ?it/s]

Client 1 communication overhead: uplink:257949696, downlink:257949696


  0%|          | 0/112 [00:00<?, ?it/s]

Client 2 communication overhead: uplink:352321536, downlink:352321536
Global Round 7 communication overhead: uplink=943718400, downlink=943718400


  0%|          | 0/106 [00:00<?, ?it/s]

Client 0 communication overhead: uplink:333447168, downlink:333447168


  0%|          | 0/82 [00:00<?, ?it/s]

Client 1 communication overhead: uplink:257949696, downlink:257949696


  0%|          | 0/112 [00:00<?, ?it/s]

Client 2 communication overhead: uplink:352321536, downlink:352321536
Global Round 8 communication overhead: uplink=943718400, downlink=943718400


  0%|          | 0/112 [00:00<?, ?it/s]

Client 2 communication overhead: uplink:352321536, downlink:352321536


  0%|          | 0/82 [00:00<?, ?it/s]

Client 1 communication overhead: uplink:257949696, downlink:257949696


  0%|          | 0/106 [00:00<?, ?it/s]

Client 0 communication overhead: uplink:333447168, downlink:333447168
Global Round 9 communication overhead: uplink=943718400, downlink=943718400
FL communication overhead: uplink=9437184000, downlink=9437184000


In [None]:
model.train(False)
t = tokenizer("Hang kitchen knives against ", return_tensors="pt", add_special_tokens=False)
res = model(t['input_ids'].to(simulator.device), attention_mask=t['attention_mask'].to(simulator.device))
r = tokenizer.decode(res.logits.argmax(dim=-1)[-1], skip_special_tokens=True)
print(r)