## Step 1. 加载模型与Tokenizer

In [1]:

import os
import sys

import torch

sys.path.append(os.path.abspath('../..'))
from sfl.utils.training import get_best_gpu
from transformers import AutoTokenizer
from sfl.model.gpt2.gpt2_split import GPT2SplitLMHeadModel
from sfl import config

device = get_best_gpu()
tokenizer = AutoTokenizer.from_pretrained(os.path.join(config.model_download_dir, "gpt2-large/"))
model = GPT2SplitLMHeadModel.from_pretrained(os.path.join(config.model_download_dir, "gpt2-large/"))
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = 50256
model.to(device)

GPT2SplitLMHeadModel(
  (transformer): GPT2SplitModel(
    (wte): Embedding(50257, 1280)
    (wpe): Embedding(1024, 1280)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-35): 36 x GPT2Block(
        (ln_1): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=1280, out_features=50257, bias=False)
)

In [2]:
from sfl.utils.model import generate


# 测试模型输出
def get_output(text, md=model):
    t = tokenizer(text, return_tensors="pt", add_special_tokens=False)
    res = model(t['input_ids'].to(md.device), attention_mask=t['attention_mask'].to(md.device))
    r = tokenizer.decode(res.logits.argmax(dim=-1)[-1], skip_special_tokens=True)
    return r


print(generate("To mix food coloring with sugar, you can", tokenizer, model))

To mix food coloring with sugar, you can use the following:

1 1/2 cups powdered sugar (or 1 cup granulated sugar plus 1 teaspoon of cornstarch mixed with 3/4 cup water)


2 tablespoons corn starch (also known as corn syrup) or other sweetener (such as xylitol or stevia, or a combination of the two, such as erythritol and sorbitol) (see below for more information on sweeteners


In [3]:
from sfl.config import attacker_path
from sfl.utils.training import get_attacker_class, extract_attacker_path

# 加载攻击模型
# attacker_cls = get_attacker_class('gru')
# attacker_path_1, attacker_path_2 = extract_attacker_path(
#     {'split_point_1': 2, 'split_point_2': 30, 'attacker_path': attacker_path, 'model_name': 'gpt2-large','attacker_dataset':'piqa','attacker_train_label':'test','attacker_train_frac':1.0,'attack_model':'gru','attacker_prefix':'normal'})
# attacker2 = attacker_cls.from_pretrained(attacker_path_2)

## Step 2. 初始化攻击模型

## Step 3. 设置联邦训练流程

In [7]:


from sfl.utils.training import calc_shifted_loss_logits
from sfl.utils.model import calculate_rouge
from sfl.model.mocker import GPT2TopMocker
from sfl.simulator.simulator import SFLSimulator
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm_notebook
from sfl.model.split_model import SplitModel
from sfl.simulator.strategy import FLStrategy

from torch.optim import AdamW
from sfl.config import FLConfig

client_ids = [str(i) for i in range(3)]
config = FLConfig(global_round=50,
                  client_epoch=2,  # 每轮联邦每个Client训2轮
                  split_point_1=2,
                  split_point_2=30,  # [0,1 | 2,3,.... 29| 30, 31]
                  use_lora_at_trunk=True,  # 在trunk部分使用LoRA
                  top_and_bottom_from_scratch='False',  # top和bottom都不采用预训练参数.
                  noise_mode="dxp",
                  noise_scale=3.0,  # 噪声大小
                  )

from sfl.simulator.dataset import PIQAFedDataset

fed_dataset = PIQAFedDataset(tokenizer=tokenizer, client_ids=client_ids, shrink_frac=0.15)


# 定义Client本地学习策略
class QAFLStrategy(FLStrategy):

    def __init__(self):
        super().__init__()
        self.attacker_rouge_b2tr = []
        self.attacker_rouge_tr2t = []
        self.mocker_rouge_tr2t = []
        self.client_logs = {}

    def client_step(self, global_round, client_id: str, llm: SplitModel, dataloader: DataLoader, cfg: FLConfig):
        optimizer = AdamW(llm.parameters(), lr=1e-5)
        with tqdm_notebook(total=cfg.client_epoch * len(dataloader)) as pbar:
            for epoch in range(cfg.client_epoch):
                for step, batch in enumerate(dataloader):
                    optimizer.zero_grad()
                    input_ids = batch['input_ids'].to(llm.device)
                    attention_mask = batch['input_att_mask'].to(llm.device)
                    outputs = llm(input_ids=input_ids, labels=input_ids, attention_mask=attention_mask)
                    self.fp_done(client_id, epoch, step, batch)  # Collect intermediate results
                    loss = outputs.loss
                    pbar.set_description(f'Client {client_id} Epoch {epoch} Loss {loss.item():.3f}')
                    loss.backward()
                    self.bp_done(client_id, epoch, step, batch)  # Collect gradients
                    optimizer.step()
                    pbar.update(1)
                if len(self.mocker_rouge_tr2t) > 0:
                    avg_rouge2 = sum([r["rouge-l"]["f"] for r in self.mocker_rouge_tr2t]) / len(self.mocker_rouge_tr2t)
                    print(f'MOCK! Bottom-trunk, Client {client_id} Epoch {epoch} RougeL {avg_rouge2:.3f}')
                    self.attacker_rouge_b2tr.clear()
                    self.attacker_rouge_tr2t.clear()
                    self.mocker_rouge_tr2t.clear()
                self.client_logs.setdefault(client_id, {})
                # self.client_logs[client_id][epoch] = {"bottom-trunk": avg_rouge1, "trunk-top": avg_rouge2}

    def aggregation_step(self, global_round, params):
        report = {}
        report['global_round'] = global_round
        for cid, epochs in self.client_logs.items():
            for epc, rep in epochs.items():
                for k, v in rep.items():
                    report[f'client{cid}-epoch{epc}-{k}'] = v
        # wandb.log(report)
        self.client_logs = {}
        return super(QAFLStrategy, self).aggregation_step(global_round, params)

    def callback_fp_param(self, global_round, client_id, local_epoch, local_step, b2tr_params, tr2t_params, batch):
        pass

    def callback_bp_param(self, global_round, client_id, local_epoch, local_step,
                          b2tr_fx, tr2b_grad,
                          tr2t_fx, t2tr_grad,
                          batch):
        #if global_round % 10 == 0 and client_id == '0' and local_epoch == 1:
            # self.collect_fp_bp.append((tr2t_fx, t2tr_grad, batch['input_text']))
            real = batch['input_text'][0]
            print("REAL:", real)
            out = mocker(tr2t_fx.to(self.simulator.device))
            out_sent = tokenizer.decode(out[0].argmax(dim=-1).tolist(), skip_special_tokens=True,
                                        clean_up_tokenization_spaces=True)
            print("MOCK:", out_sent)
            batch_size, seq_len = tr2t_fx.shape[:2]
            inter = tr2t_fx.to(self.simulator.device)
            vocab_size = model.config.vocab_size
            gt = torch.softmax(torch.randn((batch_size, seq_len, vocab_size)).to(inter.device), dim=-1)
            gt.requires_grad = True
            inter.requires_grad = True
            optimizer = torch.optim.AdamW([gt], lr=0.09, betas=(0.9, 0.999), eps=1e-6, weight_decay=0.01)
            epochs = 1000
            beta = 0.9
            with tqdm_notebook(total=epochs, desc='Mocking') as pbar:
                for i in range(epochs):
                    optimizer.zero_grad()
                    x = mocker(inter)
                    loss = calc_shifted_loss_logits(x, torch.softmax(gt,dim=-1))
                    grad = torch.autograd.grad(loss, inter, create_graph=True)
                    grad_diff = 0
                    for gx, gy in zip(grad, t2tr_grad.to(loss.device)):
                        grad_diff += beta * ((gx - gy) ** 2).sum() + (1 - beta) * torch.abs((gx - gy)).sum()
                    grad_diff.backward()
                    optimizer.step()
                    sent = tokenizer.decode(gt.argmax(-1)[0], skip_special_tokens=True)
                    pbar.set_postfix(sent=sent, grad_diff=grad_diff.item(), loss=loss.item())
                    pbar.update(1)
            rouge = calculate_rouge(tokenizer, gt, batch['input_text'])
            self.mocker_rouge_tr2t.append(rouge)
            # guess = attacker2(tr2t_fx)
            # rouge = calculate_rouge(tokenizer, guess, batch['input_text'])
            # self.attacker_rouge_tr2t.append(rouge)


simulator = SFLSimulator(client_ids=client_ids, strategy=QAFLStrategy(), llm=model, tokenizer=tokenizer,
                         dataset=fed_dataset, config=config)
# model.print_split_model()
mocker = GPT2TopMocker(config, model)
mocker.to(simulator.device)
simulator.simulate()



  0%|          | 0/96 [00:00<?, ?it/s]

MODEL: ,,ianiast

.b????xyjster
REAL: penny, Solution: can replace a stove 
MOCK: ,,ianiast

.bianky
_


Mocking:   0%|          | 0/1000 [00:00<?, ?it/s]

MODEL: ,..ge
on


unction:
Solution
- Solution Solution.
 kvenum:- kraseeh
ata hite Sicis- kbingspelss Sc Sc Scal Scal
REAL: How to best cut the meat to place on a grill?, Solution: Place a knife on the grill for around 15 minutes for the blade to be red hot. Gently push the knife through the meat to get a perfect cut.
MOCK: ...ston
on,

acet:
:
p Solution Solutionos
 kfut:- kraseeh
ata:ite:ings- kbingsarss Sc Scal Scal Scal


Mocking:   0%|          | 0/1000 [00:00<?, ?it/s]

KeyboardInterrupt: 

## Step 3. 开始联邦模拟

In [None]:
import wandb

wandb.init(
    project="sfl-with-attacker",
    # track hyperparameters and run metadata
    config={
        "dataset": 'piqa',
        "attacker": "piqa-validation",
        "noise": "2.0"
    }
)

simulator.simulate()

In [None]:
from sfl.utils.model import sentence_score

sentence_score("I'm fine, thank you!", model, tokenizer)