## Step 1. 加载模型与Tokenizer

In [12]:
import os
import sys

from transformers import AutoTokenizer

sys.path.append(os.path.abspath('../..'))
from sfl.model.gpt2.gpt2_split import GPT2SplitLMHeadModel
from sfl.utils import get_best_gpu
cache_dir = '/root/autodl-tmp/sfl/models'  # 模型的缓存位置，需要修改
save_dir = '/root/autodl-tmp/sfl/models/checkpoints'
tokenizer = AutoTokenizer.from_pretrained("gpt2-large", cache_dir=cache_dir)
model:GPT2SplitLMHeadModel = GPT2SplitLMHeadModel.from_pretrained("gpt2-large", cache_dir=cache_dir)
tokenizer.pad_token_id = model.config.eos_token_id
device = get_best_gpu()
model.to(device)

GPT2SplitLMHeadModel(
  (transformer): GPT2SplitModel(
    (wte): Embedding(50257, 1280)
    (wpe): Embedding(1024, 1280)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-35): 36 x GPT2Block(
        (ln_1): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=1280, out_features=50257, bias=False)
)

In [13]:
from sfl.utils import generate

generate('what do you think!',tokenizer,model)

'what do you think!?"\n\n"I don\'t know what you\'re talking about," I said, "but I\'m not going to tell you what I think. I\'ve already told you, and I\'ll say it again, that I have no idea what\'s going on here. But I do know that this is not the first time this has happened to me. And it won\'t be the last. It\'s just a matter of time before it happens to someone else, too."'

In [14]:
from sfl.utils import calculate_rouge


# 恢复的评价指标选用ROUGE

def evaluate(epc, md, attacker, tok, test_data_loader):
    md.eval()
    attacker.eval()
    dl_len = len(test_data_loader)
    with torch.no_grad():
        rouge_1, rouge_2, rouge_l_f1, rouge_l_p, rouge_l_r = 0, 0, 0, 0, 0
        for step, batch in tqdm(enumerate(test_data_loader), total=dl_len):
            input_ids = batch['input_ids'].to(md.device)
            attention_mask = batch['input_att_mask'].to(md.device)
            inter = md(input_ids=input_ids, attention_mask=attention_mask)
            logits = attacker(inter)
            result = calculate_rouge(tok, logits, batch['input_text'])
            rouge_1 += result['rouge-1']['f']
            rouge_2 += result['rouge-2']['f']
            rouge_l_f1 += result['rouge-l']['f']
            rouge_l_p += result['rouge-l']['p']
            rouge_l_r += result['rouge-l']['r']
    print(
        f'Epoch {epc} Test Rouge_1: {rouge_1 / dl_len}, Rouge_2: {rouge_2 / dl_len}, Rouge_l_f1: {rouge_l_f1 / dl_len}, Rouge_l_p: {rouge_l_p / dl_len}, Rouge_l_r: {rouge_l_r / dl_len}')
    path = save_dir + f'/attacker/{md.config.name_or_path}/piqa-validation/{md.fl_config.attack_mode}-{md.fl_config.split_point_1 if md.fl_config.attack_mode == "b2tr" else md.fl_config.split_point_2}/'
    attack_model.save_pretrained(path + f'epoch_{epc}_rouge_{rouge_l_f1 / dl_len:.4f}')
    md.train(True)
    attacker.train(True)
    return rouge_1 / dl_len, rouge_2 / dl_len, rouge_l_f1 / dl_len, rouge_l_p / dl_len, rouge_l_r / dl_len

### 加载数据集

In [15]:
from sfl.simulator.dataset import GSM8KFedDataset

dataset = GSM8KFedDataset(tokenizer, [])
dataloader = dataset.get_dataloader_unsliced(6, 'test')
dataloader_test = dataset.get_dataloader_unsliced(6, 'train', 0.06)
dataset.all_dataset

DatasetDict({
    train: Dataset({
        features: ['question', 'answer'],
        num_rows: 7473
    })
    test: Dataset({
        features: ['question', 'answer'],
        num_rows: 1319
    })
})

### 切分模型

In [16]:
from sfl.config import FLConfig
model.config_sfl(FLConfig(collect_intermediates=False,
                          split_point_1=2, # 第0～1层为top，余下的都是trunk
                          split_point_2=30,
                          attack_mode='tr2t' # 攻击的输出是bottom-to-trunk中间输出
                          ),
                 param_keeper=None)
# freeze all parts:
for name, param in model.named_parameters():
    param.requires_grad = False

### 训练Attack Model

In [None]:
from torch.optim import Adam
from sfl.model.attack_model import LSTMAttackerConfig, LSTMAttackModel
from sfl.utils import calc_unshift_loss
from tqdm.notebook import tqdm
import torch


def get_output(text, encoder_model, attack_model):
    t = tokenizer(text, return_tensors="pt")
    inter = encoder_model(t['input_ids'].to(device), attention_mask=t['attention_mask'].to(device))
    res = attack_model(inter)
    r = tokenizer.decode(res.argmax(dim=-1)[-1], skip_special_tokens=True)
    return r


# 开始训练Attack Model
attack_model = LSTMAttackModel(LSTMAttackerConfig(), model.config)
optimizer = Adam(attack_model.parameters(), lr=1e-3)
model.to(device)
attack_model.to(device)
epoch = 20
evaluate(0, model, attack_model, tokenizer, dataloader_test)
with tqdm(total=epoch * len(dataloader)) as pbar:
    for epc in range(epoch):
        model.train(True)
        rouge_1, rouge_2, rouge_l_f1, rouge_l_p, rouge_l_r = 0, 0, 0, 0, 0
        for step, batch in enumerate(dataloader):
            optimizer.zero_grad()
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['input_att_mask'].to(device)
            intermediate = model(input_ids=input_ids, attention_mask=attention_mask)
            logits = attack_model(intermediate)
            loss = calc_unshift_loss(logits, input_ids)
            loss.backward()
            optimizer.step()
            # 计算训练的ROGUE
            res = calculate_rouge(tokenizer, logits, batch['input_text'])
            rouge_1 += res['rouge-1']['f']
            rouge_2 += res['rouge-2']['f']
            rouge_l_f1 += res['rouge-l']['f']
            rouge_l_p += res['rouge-l']['p']
            rouge_l_r += res['rouge-l']['r']
            pbar.set_description(f'Epoch {epc} Loss {loss.item():.5f}, Rouge_1 {rouge_1 / (step + 1):.4f}')
            if step % 300 == 0:
                q = "To mix food coloring with sugar, you can"
                print(q, "==>", get_output(q, model, attack_model))
            pbar.update(1)
        rouge_1 /= len(dataloader)
        rouge_2 /= len(dataloader)
        rouge_l_f1 /= len(dataloader)
        rouge_l_p /= len(dataloader)
        rouge_l_r /= len(dataloader)
        print(
            f'Epoch {epc} Train Rouge_1: {rouge_1}, Rouge_2: {rouge_2}, Rouge_l_f1: {rouge_l_f1}, Rouge_l_p: {rouge_l_p}, Rouge_l_r: {rouge_l_r}')
        # 计算测试集上的ROGUE
        evaluate(epc, model, attack_model, tokenizer, dataloader_test)

  0%|          | 0/75 [00:00<?, ?it/s]

Epoch 0 Test Rouge_1: 0.0007362893321298781, Rouge_2: 0.0, Rouge_l_f1: 0.0007362893321298781, Rouge_l_p: 0.0005806767519882163, Rouge_l_r: 0.0010201953950377779


  0%|          | 0/4400 [00:00<?, ?it/s]

To mix food coloring with sugar, you can ==> IG embracesirdsneg ed outlandishpocketToo Men
Epoch 0 Train Rouge_1: 0.27735079305138577, Rouge_2: 0.08424419463413786, Rouge_l_f1: 0.27398191016581674, Rouge_l_p: 0.32884958389085334, Rouge_l_r: 0.2498181981178864


  0%|          | 0/75 [00:00<?, ?it/s]

Epoch 0 Test Rouge_1: 0.4935933149231122, Rouge_2: 0.22487758143820258, Rouge_l_f1: 0.49188395593317175, Rouge_l_p: 0.5086563563801874, Rouge_l_r: 0.478658789828101
To mix food coloring with sugar, you can ==> Question all money water of blue, they to
Epoch 1 Train Rouge_1: 0.5650925679838439, Rouge_2: 0.3282265764689123, Rouge_l_f1: 0.5633284932639768, Rouge_l_p: 0.5560549652999026, Rouge_l_r: 0.573231501521422


  0%|          | 0/75 [00:00<?, ?it/s]

Epoch 1 Test Rouge_1: 0.6388688993377868, Rouge_2: 0.4147162204619491, Rouge_l_f1: 0.6380028011538206, Rouge_l_p: 0.6297299768341164, Rouge_l_r: 0.6481684657316084
To mix food coloring with sugar, you can ==> Question buy food jelly with water, they can
Epoch 2 Train Rouge_1: 0.687247891335496, Rouge_2: 0.4951618843128015, Rouge_l_f1: 0.6863929036944024, Rouge_l_p: 0.6687800095691252, Rouge_l_r: 0.7067751660621547


  0%|          | 0/75 [00:00<?, ?it/s]

Epoch 2 Test Rouge_1: 0.7219962060170422, Rouge_2: 0.5369681179950799, Rouge_l_f1: 0.7214459324969076, Rouge_l_p: 0.7064899584558044, Rouge_l_r: 0.7385876311244425
To mix food coloring with sugar, you can ==> Question buy food food with sugar, you can
Epoch 3 Train Rouge_1: 0.7722213061516091, Rouge_2: 0.6230947414097207, Rouge_l_f1: 0.771742270652329, Rouge_l_p: 0.7523401553849824, Rouge_l_r: 0.7938488103401833


  0%|          | 0/75 [00:00<?, ?it/s]

Epoch 3 Test Rouge_1: 0.7771749423180676, Rouge_2: 0.6244216980292964, Rouge_l_f1: 0.7766505700056806, Rouge_l_p: 0.7608107781478911, Rouge_l_r: 0.7944644798561404
To mix food coloring with sugar, you can ==> Question dividing food soda with sugar, you can
Epoch 4 Train Rouge_1: 0.840207365252354, Rouge_2: 0.7325209820137727, Rouge_l_f1: 0.8400060855645193, Rouge_l_p: 0.8239816018589742, Rouge_l_r: 0.8579531985997769


  0%|          | 0/75 [00:00<?, ?it/s]

Epoch 4 Test Rouge_1: 0.8124712145628814, Rouge_2: 0.6803097119296374, Rouge_l_f1: 0.8122862106253355, Rouge_l_p: 0.7969850502392197, Rouge_l_r: 0.8293462373718882
To mix food coloring with sugar, you can ==> Question divides food colors with sugar, you can
Epoch 5 Train Rouge_1: 0.895457864495876, Rouge_2: 0.8241241236997467, Rouge_l_f1: 0.8953897127738889, Rouge_l_p: 0.884052011875349, Rouge_l_r: 0.9078084504009178


  0%|          | 0/75 [00:00<?, ?it/s]

Epoch 5 Test Rouge_1: 0.835162773374296, Rouge_2: 0.7194802782749155, Rouge_l_f1: 0.8350314893741904, Rouge_l_p: 0.8199093797802894, Rouge_l_r: 0.8517681705835469
To mix food coloring with sugar, you can ==> Question dividing food soda with sugar, you can
Epoch 6 Train Rouge_1: 0.937925172658493, Rouge_2: 0.8971606745298148, Rouge_l_f1: 0.9378320477565278, Rouge_l_p: 0.9309313091517187, Rouge_l_r: 0.94526741619831


  0%|          | 0/75 [00:00<?, ?it/s]

Epoch 6 Test Rouge_1: 0.8545349874265279, Rouge_2: 0.7502749751705049, Rouge_l_f1: 0.8543309327043763, Rouge_l_p: 0.8398842198926837, Rouge_l_r: 0.870210234361971
To mix food coloring with sugar, you can ==> Question collects food glue with sugar, you can
Epoch 7 Train Rouge_1: 0.9653871952760543, Rouge_2: 0.9423754366407758, Rouge_l_f1: 0.9653643090955074, Rouge_l_p: 0.9620173025813443, Rouge_l_r: 0.9689481376989021


  0%|          | 0/75 [00:00<?, ?it/s]

Epoch 7 Test Rouge_1: 0.8659284592910732, Rouge_2: 0.7711817142109221, Rouge_l_f1: 0.865758649462238, Rouge_l_p: 0.8511201407134803, Rouge_l_r: 0.8818313120755147
To mix food coloring with sugar, you can ==> Question substitute food food with sugar, you can
Epoch 8 Train Rouge_1: 0.9792124432692707, Rouge_2: 0.965312622567129, Rouge_l_f1: 0.9792124432692707, Rouge_l_p: 0.9770891541700665, Rouge_l_r: 0.9814581496992376


  0%|          | 0/75 [00:00<?, ?it/s]

Epoch 8 Test Rouge_1: 0.8725480856481732, Rouge_2: 0.7835077374850119, Rouge_l_f1: 0.8723971913858944, Rouge_l_p: 0.8578129171784258, Rouge_l_r: 0.8885006670353917
To mix food coloring with sugar, you can ==> Question popping foodadow with sugar, you can
Epoch 9 Train Rouge_1: 0.9883283998324444, Rouge_2: 0.9805481442566794, Rouge_l_f1: 0.9883029785350931, Rouge_l_p: 0.9872690675973982, Rouge_l_r: 0.9894061845220921


  0%|          | 0/75 [00:00<?, ?it/s]

Epoch 9 Test Rouge_1: 0.8792047461787197, Rouge_2: 0.7949895231076572, Rouge_l_f1: 0.8790250473566198, Rouge_l_p: 0.8644404895601734, Rouge_l_r: 0.8950791350156228
To mix food coloring with sugar, you can ==> Question shook food sugar with sugar, you can
Epoch 10 Train Rouge_1: 0.9927043448141362, Rouge_2: 0.9873683758496907, Rouge_l_f1: 0.9927043448141362, Rouge_l_p: 0.9919721435684742, Rouge_l_r: 0.9934800514007444


  0%|          | 0/75 [00:00<?, ?it/s]

In [None]:
from rouge import Rouge
text = "Patient Name: Mr.Lawrence, Gender: Male, Ethnicity: White, Address: Harbin Institute of technology, Shenzhen, China. He's a Japanese man standing 165cm tall, always wearing a pair of pink glasses. He's in extreme danger now with a heartbeat of only 32/min;"
decoded = get_output(text, model, attack_model)
print(decoded)
result = Rouge().get_scores([text],[decoded], avg=True, ignore_empty=True)  # 取一个 batch 的平均
print(result)