In [1]:
import os
import sys

sys.path.append(os.path.abspath('../..'))
from transformers import AutoTokenizer
from sfl.model.gpt2.gpt2_split import GPT2SplitLMHeadModel

cache_dir = '/root/autodl-tmp/sfl/models'  # 模型的缓存位置，需要修改
tokenizer = AutoTokenizer.from_pretrained("gpt2-large", cache_dir=cache_dir)
model = GPT2SplitLMHeadModel.from_pretrained("gpt2-large", cache_dir=cache_dir)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = 50256

In [2]:
# 测试模型的生成文本
def generate(text, md=model):
    model.train(False)
    t = tokenizer(text, return_tensors="pt", add_special_tokens=False)
    res = md.generate(t['input_ids'].to(md.device), attention_mask=t['attention_mask'].to(md.device),
                      max_length=300, num_beams=6, no_repeat_ngram_size=2, early_stopping=True,
                      num_return_sequences=1, pad_token_id=tokenizer.pad_token_id)
    return tokenizer.decode(res[0], skip_special_tokens=True)

# 测试模型输出
def get_output(text, md=model):
    t = tokenizer(text, return_tensors="pt", add_special_tokens=False)
    res = model(t['input_ids'].to(md.device), attention_mask=t['attention_mask'].to(md.device))
    r = tokenizer.decode(res.logits.argmax(dim=-1)[-1], skip_special_tokens=True)
    return r


print(generate("Hi father", model))

Hi father,

I'm sorry to hear about your son's death. I'm so sorry for your loss. My heart goes out to you and your family. Please know that I will do everything in my power to help you through this difficult time. You have my deepest sympathy and I hope that you will be able to find peace and comfort in the coming days and weeks. Thank you for everything you have done for me and my family over the years. God bless you.


In [3]:
model

GPT2SplitLMHeadModel(
  (transformer): GPT2SplitModel(
    (wte): Embedding(50257, 1280)
    (wpe): Embedding(1024, 1280)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-35): 36 x GPT2Block(
        (ln_1): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=1280, out_features=50257, bias=False)
)

## Step 2. 设置联邦训练流程

In [4]:
from sfl.simulator.simulator import SFLSimulator
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm_notebook
from sfl.model.split_model import SplitModel
from sfl.simulator.strategy import FLStrategy
from sfl.simulator.dataset import PIQAFedDataset, FedDataset
from sfl.utils import FLConfig
from torch.optim import AdamW


# 定义Client本地学习策略
class QAFLStrategy(FLStrategy):

    def client_step(self, client_id: str, llm: SplitModel, dataloader: DataLoader, cfg: FLConfig):
        optimizer = AdamW(llm.parameters(), lr=1e-5)
        with tqdm_notebook(total=cfg.client_epoch * len(dataloader)) as pbar:
            for epoch in range(cfg.client_epoch):
                for step, batch in enumerate(dataloader):
                    optimizer.zero_grad()
                    input_ids = batch['input_ids'].to(llm.device)
                    attention_mask = batch['input_att_mask'].to(llm.device)
                    outputs = llm(input_ids=input_ids, labels=input_ids, attention_mask=attention_mask)
                    self.fp_done(client_id, epoch, step, batch)  # Collect intermediate results
                    loss = outputs.loss
                    pbar.set_description(f'Client {client_id} Epoch {epoch} Loss {loss.item():.3f}')
                    loss.backward()
                    self.bp_done(client_id, epoch, step, batch)  # Collect gradients
                    # res_text = tokenizer.decode(outputs.logits.argmax(dim=-1)[-1], skip_special_tokens=True)
                    # print(batch['input_text'][-1],"==>",res_text.strip(),"】")
                    optimizer.step()
                    pbar.update(1)

    def callback_fp_param(self, client_id, local_epoch, local_step, b2tr_params, tr2t_params, batch):
        #  这里获取某epoch、step中，前传过程的两次传输参数，b2tr(bottom-trunk), tr2t(trunk-top)
        pass

    def callback_bp_param(self, client_id, local_epoch, local_step, t2tr_params, tr2b_params, batch):
        #  这里获取某epoch、step中，反传过程的两次传输参数
        pass


client_ids = [str(i) for i in range(3)]
config = FLConfig(global_round=10, client_epoch=2, split_point_1=2, split_point_2=34, use_lora_at_trunk=True)
fed_dataset = PIQAFedDataset(tokenizer=tokenizer, client_ids=client_ids)
simulator = SFLSimulator(client_ids=client_ids, strategy=QAFLStrategy(), llm=model, tokenizer=tokenizer,
                         dataset=fed_dataset, config=config)



In [5]:
model.print_split_model()


transformer.h.34:[ln_1.weight: (1280,), ln_1.bias: (1280,), attn.c_attn.weight: (1280, 3840), attn.c_attn.bias: (3840,), attn.c_proj.weight: (1280, 1280), attn.c_proj.bias: (1280,), ln_2.weight: (1280,), ln_2.bias: (1280,), mlp.c_fc.weight: (1280, 5120), mlp.c_fc.bias: (5120,), mlp.c_proj.weight: (5120, 1280), mlp.c_proj.bias: (1280,)]

transformer.h.35:[ln_1.weight: (1280,), ln_1.bias: (1280,), attn.c_attn.weight: (1280, 3840), attn.c_attn.bias: (3840,), attn.c_proj.weight: (1280, 1280), attn.c_proj.bias: (1280,), ln_2.weight: (1280,), ln_2.bias: (1280,), mlp.c_fc.weight: (1280, 5120), mlp.c_fc.bias: (5120,), mlp.c_proj.weight: (5120, 1280), mlp.c_proj.bias: (1280,)]

transformer.ln_f.weight:[: (1280,)]

transformer.ln_f.bias:[: (1280,)]

transformer.h.2:[attn.c_attn.lora_A.default.weight: (8, 1280), attn.c_attn.lora_B.default.weight: (3840, 8), attn.c_proj.lora_A.default.weight: (8, 1280), attn.c_proj.lora_B.default.weight: (1280, 8), mlp.c_fc.lora_A.default.weight: (8, 1280), mlp.c

## Step 3. 开始联邦模拟

In [6]:
simulator.simulate()



  0%|          | 0/462 [00:00<?, ?it/s]

Client 0 communication overhead: uplink:4.51 GB, downlink:4.51 GB


  0%|          | 0/384 [00:00<?, ?it/s]

Client 2 communication overhead: uplink:3.75 GB, downlink:3.75 GB


  0%|          | 0/256 [00:00<?, ?it/s]

Client 1 communication overhead: uplink:2.50 GB, downlink:2.50 GB
Global Round 0 communication overhead: uplink=10.76 GB, downlink=10.76 GB


  0%|          | 0/384 [00:00<?, ?it/s]

Client 2 communication overhead: uplink:3.75 GB, downlink:3.75 GB


  0%|          | 0/462 [00:00<?, ?it/s]

Client 0 communication overhead: uplink:4.51 GB, downlink:4.51 GB


  0%|          | 0/256 [00:00<?, ?it/s]

Client 1 communication overhead: uplink:2.50 GB, downlink:2.50 GB
Global Round 1 communication overhead: uplink=10.76 GB, downlink=10.76 GB


  0%|          | 0/462 [00:00<?, ?it/s]

Client 0 communication overhead: uplink:4.51 GB, downlink:4.51 GB


  0%|          | 0/256 [00:00<?, ?it/s]

Client 1 communication overhead: uplink:2.50 GB, downlink:2.50 GB


  0%|          | 0/384 [00:00<?, ?it/s]

Client 2 communication overhead: uplink:3.75 GB, downlink:3.75 GB
Global Round 2 communication overhead: uplink=10.76 GB, downlink=10.76 GB


  0%|          | 0/384 [00:00<?, ?it/s]

Client 2 communication overhead: uplink:3.75 GB, downlink:3.75 GB


  0%|          | 0/462 [00:00<?, ?it/s]

Client 0 communication overhead: uplink:4.51 GB, downlink:4.51 GB


  0%|          | 0/256 [00:00<?, ?it/s]

Client 1 communication overhead: uplink:2.50 GB, downlink:2.50 GB
Global Round 3 communication overhead: uplink=10.76 GB, downlink=10.76 GB


  0%|          | 0/384 [00:00<?, ?it/s]

Client 2 communication overhead: uplink:3.75 GB, downlink:3.75 GB


  0%|          | 0/462 [00:00<?, ?it/s]

Client 0 communication overhead: uplink:4.51 GB, downlink:4.51 GB


  0%|          | 0/256 [00:00<?, ?it/s]

Client 1 communication overhead: uplink:2.50 GB, downlink:2.50 GB
Global Round 4 communication overhead: uplink=10.76 GB, downlink=10.76 GB


  0%|          | 0/462 [00:00<?, ?it/s]

Client 0 communication overhead: uplink:4.51 GB, downlink:4.51 GB


  0%|          | 0/384 [00:00<?, ?it/s]

Client 2 communication overhead: uplink:3.75 GB, downlink:3.75 GB


  0%|          | 0/256 [00:00<?, ?it/s]

Client 1 communication overhead: uplink:2.50 GB, downlink:2.50 GB
Global Round 5 communication overhead: uplink=10.76 GB, downlink=10.76 GB


  0%|          | 0/462 [00:00<?, ?it/s]

Client 0 communication overhead: uplink:4.51 GB, downlink:4.51 GB


  0%|          | 0/256 [00:00<?, ?it/s]

Client 1 communication overhead: uplink:2.50 GB, downlink:2.50 GB


  0%|          | 0/384 [00:00<?, ?it/s]

Client 2 communication overhead: uplink:3.75 GB, downlink:3.75 GB
Global Round 6 communication overhead: uplink=10.76 GB, downlink=10.76 GB


  0%|          | 0/462 [00:00<?, ?it/s]

Client 0 communication overhead: uplink:4.51 GB, downlink:4.51 GB


  0%|          | 0/256 [00:00<?, ?it/s]

Client 1 communication overhead: uplink:2.50 GB, downlink:2.50 GB


  0%|          | 0/384 [00:00<?, ?it/s]

Client 2 communication overhead: uplink:3.75 GB, downlink:3.75 GB
Global Round 7 communication overhead: uplink=10.76 GB, downlink=10.76 GB


  0%|          | 0/256 [00:00<?, ?it/s]

Client 1 communication overhead: uplink:2.50 GB, downlink:2.50 GB


  0%|          | 0/462 [00:00<?, ?it/s]

Client 0 communication overhead: uplink:4.51 GB, downlink:4.51 GB


  0%|          | 0/384 [00:00<?, ?it/s]

Client 2 communication overhead: uplink:3.75 GB, downlink:3.75 GB
Global Round 8 communication overhead: uplink=10.76 GB, downlink=10.76 GB


  0%|          | 0/256 [00:00<?, ?it/s]

Client 1 communication overhead: uplink:2.50 GB, downlink:2.50 GB


  0%|          | 0/384 [00:00<?, ?it/s]

Client 2 communication overhead: uplink:3.75 GB, downlink:3.75 GB


  0%|          | 0/462 [00:00<?, ?it/s]

Client 0 communication overhead: uplink:4.51 GB, downlink:4.51 GB
Global Round 9 communication overhead: uplink=10.76 GB, downlink=10.76 GB
FL communication overhead: uplink=107.62 GB, downlink=107.62 GB


In [28]:
print(generate("To make paper out of woods", model))

To make paper out of woods, Solution:roll sheets of magazines up into a tube and glue it to a board.


[[0.25, 0.5, 0, 0, 0, 0.25],
 [0, 0, 0, 0, 0, 1],
 [0, 0.25, 0, 0.25, 0.5, 0],
 [0, 0, 0, 0, 1, 0],
 [0, 0, 0, 0.5, 0.5, 0],
 [0, 0, 1, 0, 0, 0]]

In [9]:
mat_k = np.linalg.matrix_power(mat, 10000)
mat_k

array([[0.        , 0.        , 0.        , 0.33333333, 0.66666667,
        0.        ],
       [0.        , 0.        , 0.        , 0.33333333, 0.66666667,
        0.        ],
       [0.        , 0.        , 0.        , 0.33333333, 0.66666667,
        0.        ],
       [0.        , 0.        , 0.        , 0.33333333, 0.66666667,
        0.        ],
       [0.        , 0.        , 0.        , 0.33333333, 0.66666667,
        0.        ],
       [0.        , 0.        , 0.        , 0.33333333, 0.66666667,
        0.        ]])