## Step 1. 加载模型与Tokenizer

In [1]:
import os
import sys

sys.path.append(os.path.abspath('../..'))
from transformers import AutoTokenizer
from sfl.model.gpt2.gpt2_split import GPT2SplitLMHeadModel

cache_dir = '/root/autodl-tmp/sfl/models'  # 模型的缓存位置，需要修改
tokenizer = AutoTokenizer.from_pretrained("gpt2-large", cache_dir=cache_dir)
model = GPT2SplitLMHeadModel.from_pretrained("gpt2-large", cache_dir=cache_dir)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = 50256

SSLError: (MaxRetryError("HTTPSConnectionPool(host='huggingface.co', port=443): Max retries exceeded with url: /gpt2-large/resolve/main/tokenizer_config.json (Caused by SSLError(SSLEOFError(8, '[SSL: UNEXPECTED_EOF_WHILE_READING] EOF occurred in violation of protocol (_ssl.c:1006)')))"), '(Request ID: 6edc862c-9a28-4deb-bdd0-f12553ed806a)')

In [None]:
# 测试模型的生成文本
def generate(text, md=model):
    model.train(False)
    t = tokenizer(text, return_tensors="pt", add_special_tokens=False)
    res = md.generate(t['input_ids'].to(md.device), attention_mask=t['attention_mask'].to(md.device),
                      max_length=100, num_beams=6, no_repeat_ngram_size=2, early_stopping=True,
                      num_return_sequences=1, pad_token_id=tokenizer.pad_token_id)
    return tokenizer.decode(res[0], skip_special_tokens=True)

# 测试模型输出
def get_output(text, md=model):
    t = tokenizer(text, return_tensors="pt", add_special_tokens=False)
    res = model(t['input_ids'].to(md.device), attention_mask=t['attention_mask'].to(md.device))
    r = tokenizer.decode(res.logits.argmax(dim=-1)[-1], skip_special_tokens=True)
    return r


# print(generate("To mix food coloring with sugar, you can", model))

## Step 2. 加载攻击模型

In [None]:
import torch
from sfl.model.attack_model import GPT2AttackModel

# 攻击bottom-trunk数据
attacker = GPT2AttackModel(model.config)
attacker.load_state_dict(
    torch.load('/root/autodl-tmp/sfl/models/checkpoints/attacker/gpt2-large-valds/b2tr-2/epoch_4_rouge_0.8206363408552778.pt'))

# 攻击trunk-top 数据
attacker2 = GPT2AttackModel(model.config)
attacker2.load_state_dict(
    torch.load('/root/autodl-tmp/sfl/models/checkpoints/attacker/b2tr-30/epoch_4_rouge_0.9119861382493881.pt'))


## Step 3. 设置联邦训练流程

In [None]:
from sfl.simulator.simulator import SFLSimulator
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm_notebook
from sfl.model.split_model import SplitModel
from sfl.simulator.strategy import FLStrategy
from sfl.simulator.dataset import PIQAFedDataset, GSM8KFedDataset, DialogSumFedDataset
from sfl.utils import FLConfig, calculate_rouge
from torch.optim import AdamW


# 定义Client本地学习策略
class QAFLStrategy(FLStrategy):

    def __init__(self):
        super().__init__()
        self.attacker_rouge_b2tr = []
        self.attacker_rouge_tr2t = []

    def client_step(self, client_id: str, llm: SplitModel, dataloader: DataLoader, cfg: FLConfig):
        optimizer = AdamW(llm.parameters(), lr=1e-5)
        with tqdm_notebook(total=cfg.client_epoch * len(dataloader)) as pbar:
            for epoch in range(cfg.client_epoch):
                for step, batch in enumerate(dataloader):
                    optimizer.zero_grad()
                    input_ids = batch['input_ids'].to(llm.device)
                    attention_mask = batch['input_att_mask'].to(llm.device)
                    outputs = llm(input_ids=input_ids, labels=input_ids, attention_mask=attention_mask)
                    self.fp_done(client_id, epoch, step, batch)  # Collect intermediate results
                    loss = outputs.loss
                    pbar.set_description(f'Client {client_id} Epoch {epoch} Loss {loss.item():.3f}')
                    loss.backward()
                    self.bp_done(client_id, epoch, step, batch)  # Collect gradients
                    optimizer.step()
                    pbar.update(1)
                avg_rouge = sum([r["rouge-l"]["f"] for r in self.attacker_rouge_b2tr]) / len(self.attacker_rouge_b2tr)
                print(f'ATTACK! Bottom-trunk, Client {client_id} Epoch {epoch} RougeL {avg_rouge:.3f}')
                avg_rouge = sum([r['rouge-l']['f'] for r in self.attacker_rouge_tr2t]) / len(self.attacker_rouge_tr2t)
                print(f'ATTACK! Trunk-Top, Client {client_id} Epoch {epoch} RougeL {avg_rouge:.3f}')
                self.attacker_rouge_b2tr.clear()
                self.attacker_rouge_tr2t.clear()

    def callback_fp_param(self, client_id, local_epoch, local_step, b2tr_params, tr2t_params, batch):
        #  这里获取某epoch、step中，前传过程的两次传输参数，b2tr(bottom-trunk), tr2t(trunk-top)
        with torch.no_grad():
            rouge_res_b2tr = calculate_rouge(tokenizer, attacker(b2tr_params), batch['input_text'])
            rouge_res_tr2t = calculate_rouge(tokenizer, attacker2(tr2t_params), batch['input_text'])
            self.attacker_rouge_b2tr.append(rouge_res_b2tr)
            self.attacker_rouge_tr2t.append(rouge_res_tr2t)

    def callback_bp_param(self, client_id, local_epoch, local_step, t2tr_params, tr2b_params, batch):
        #  这里获取某epoch、step中，反传过程的两次传输参数
        pass


client_ids = [str(i) for i in range(3)]
config = FLConfig(global_round=50,
                  client_epoch=2, # 每轮联邦每个Client训2轮
                  split_point_1=2,
                  split_point_2=30, # [0,1 | 2,3,.... 29| 30, 31]
                  use_lora_at_trunk=True,  # 在trunk部分使用LoRA
                  top_and_bottom_from_scratch=False, # top和bottom都不采用预训练参数.
                  noise_scale=0, # 噪声大小
                  )
fed_dataset = DialogSumFedDataset(tokenizer=tokenizer, client_ids=client_ids,shrink_frac=0.05)
simulator = SFLSimulator(client_ids=client_ids, strategy=QAFLStrategy(), llm=model, tokenizer=tokenizer,
                         dataset=fed_dataset, config=config)
model.print_split_model()

## Step 3. 开始联邦模拟

In [None]:
simulator.simulate()

In [None]:
generate('Question: what is 1+2')