In [None]:
import os
import sys

sys.path.append(os.path.abspath('../..'))
from transformers import AutoTokenizer
from sfl.model.gpt2.gpt2_split import GPT2SplitLMHeadModel

cache_dir = '/root/autodl-tmp/sfl/models'  # 模型的缓存位置，需要修改
tokenizer = AutoTokenizer.from_pretrained("gpt2-large", cache_dir=cache_dir)
model = GPT2SplitLMHeadModel.from_pretrained("gpt2-large", cache_dir=cache_dir)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = 50256

In [None]:
# 测试模型的生成文本
def generate(text, md=model):
    model.train(False)
    t = tokenizer(text, return_tensors="pt", add_special_tokens=False)
    res = md.generate(t['input_ids'].to(md.device), attention_mask=t['attention_mask'].to(md.device),
                      max_length=300, num_beams=6, no_repeat_ngram_size=2, early_stopping=True,
                      num_return_sequences=1, pad_token_id=tokenizer.pad_token_id)
    return tokenizer.decode(res[0], skip_special_tokens=True)

# 测试模型输出
def get_output(text, md=model):
    t = tokenizer(text, return_tensors="pt", add_special_tokens=False)
    res = model(t['input_ids'].to(md.device), attention_mask=t['attention_mask'].to(md.device))
    r = tokenizer.decode(res.logits.argmax(dim=-1)[-1], skip_special_tokens=True)
    return r


print(generate("Hi father", model))

In [None]:
model

## Step 2. 设置联邦训练流程

In [None]:
from sfl.simulator.simulator import SFLSimulator
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm_notebook
from sfl.model.split_model import SplitModel
from sfl.simulator.strategy import FLStrategy
from sfl.simulator.dataset import PIQAFedDataset, FedDataset
from sfl.utils import FLConfig
from torch.optim import AdamW


# 定义Client本地学习策略
class QAFLStrategy(FLStrategy):

    def client_step(self, client_id: str, llm: SplitModel, dataloader: DataLoader, cfg: FLConfig):
        optimizer = AdamW(llm.parameters(), lr=1e-5)
        with tqdm_notebook(total=cfg.client_epoch * len(dataloader)) as pbar:
            for epoch in range(cfg.client_epoch):
                for step, batch in enumerate(dataloader):
                    optimizer.zero_grad()
                    input_ids = batch['input_ids'].to(llm.device)
                    attention_mask = batch['input_att_mask'].to(llm.device)
                    outputs = llm(input_ids=input_ids, labels=input_ids, attention_mask=attention_mask)
                    self.fp_done(client_id, epoch, step, batch)  # Collect intermediate results
                    loss = outputs.loss
                    pbar.set_description(f'Client {client_id} Epoch {epoch} Loss {loss.item():.3f}')
                    loss.backward()
                    self.bp_done(client_id, epoch, step, batch)  # Collect gradients
                    # res_text = tokenizer.decode(outputs.logits.argmax(dim=-1)[-1], skip_special_tokens=True)
                    # print(batch['input_text'][-1],"==>",res_text.strip(),"】")
                    optimizer.step()
                    pbar.update(1)

    def callback_fp_param(self, client_id, local_epoch, local_step, b2tr_params, tr2t_params, batch):
        #  这里获取某epoch、step中，前传过程的两次传输参数，b2tr(bottom-trunk), tr2t(trunk-top)
        pass

    def callback_bp_param(self, client_id, local_epoch, local_step, t2tr_params, tr2b_params, batch):
        #  这里获取某epoch、step中，反传过程的两次传输参数
        pass


client_ids = [str(i) for i in range(3)]
config = FLConfig(global_round=10, client_epoch=2, split_point_1=2, split_point_2=34, use_lora_at_trunk=True)
fed_dataset = PIQAFedDataset(tokenizer=tokenizer, client_ids=client_ids)
simulator = SFLSimulator(client_ids=client_ids, strategy=QAFLStrategy(), llm=model, tokenizer=tokenizer,
                         dataset=fed_dataset, config=config)

In [None]:
model.print_split_model()

## Step 3. 开始联邦模拟

In [None]:
simulator.simulate()

In [None]:
print(generate("To make paper out of woods", model))

In [None]:
mat_k = np.linalg.matrix_power(mat, 10000)
mat_k

In [28]:
print(generate("To make paper out of woods", model))

To make paper out of woods, Solution:roll sheets of magazines up into a tube and glue it to a board.


[[0.25, 0.5, 0, 0, 0, 0.25],
 [0, 0, 0, 0, 0, 1],
 [0, 0.25, 0, 0.25, 0.5, 0],
 [0, 0, 0, 0, 1, 0],
 [0, 0, 0, 0.5, 0.5, 0],
 [0, 0, 1, 0, 0, 0]]

In [9]:
mat_k = np.linalg.matrix_power(mat, 10000)
mat_k

array([[0.        , 0.        , 0.        , 0.33333333, 0.66666667,
        0.        ],
       [0.        , 0.        , 0.        , 0.33333333, 0.66666667,
        0.        ],
       [0.        , 0.        , 0.        , 0.33333333, 0.66666667,
        0.        ],
       [0.        , 0.        , 0.        , 0.33333333, 0.66666667,
        0.        ],
       [0.        , 0.        , 0.        , 0.33333333, 0.66666667,
        0.        ],
       [0.        , 0.        , 0.        , 0.33333333, 0.66666667,
        0.        ]])