## Step 1. 加载模型与Tokenizer

In [1]:
import os
import sys

from transformers import AutoTokenizer

sys.path.append(os.path.abspath('../..'))
from sfl.model.gpt2.gpt2_split import GPT2SplitLMHeadModel

cache_dir = '/root/autodl-tmp/sfl/models'  # 模型的缓存位置，需要修改
save_dir = '/root/autodl-tmp/sfl/models/checkpoints'
tokenizer = AutoTokenizer.from_pretrained("gpt2", cache_dir=cache_dir)
model:GPT2SplitLMHeadModel = GPT2SplitLMHeadModel.from_pretrained("gpt2", cache_dir=cache_dir)
tokenizer.pad_token_id = model.config.eos_token_id

In [2]:
from sfl.utils import calculate_rouge


# 恢复的评价指标选用ROUGE

def evaluate(epc, md, attacker, tok, test_data_loader):
    md.eval()
    attacker.eval()
    dl_len = len(test_data_loader)
    with torch.no_grad():
        rouge_1, rouge_2, rouge_l_f1, rouge_l_p, rouge_l_r = 0, 0, 0, 0, 0
        for step, batch in tqdm(enumerate(test_data_loader), total=dl_len):
            input_ids = batch['input_ids'].to(md.device)
            attention_mask = batch['input_att_mask'].to(md.device)
            inter = md(input_ids=input_ids, attention_mask=attention_mask)
            logits = attacker(inter)
            result = calculate_rouge(tok, logits, batch['input_text'])
            rouge_1 += result['rouge-1']['f']
            rouge_2 += result['rouge-2']['f']
            rouge_l_f1 += result['rouge-l']['f']
            rouge_l_p += result['rouge-l']['p']
            rouge_l_r += result['rouge-l']['r']
    print(
        f'Epoch {epc} Test Rouge_1: {rouge_1 / dl_len}, Rouge_2: {rouge_2 / dl_len}, Rouge_l_f1: {rouge_l_f1 / dl_len}, Rouge_l_p: {rouge_l_p / dl_len}, Rouge_l_r: {rouge_l_r / dl_len}')
    path = save_dir + f'/attacker/{md.fl_config.attack_mode}-{md.fl_config.split_point_1 if md.fl_config.attack_mode == "b2tr" else md.fl_config.split_point_2}/'
    os.makedirs(path, exist_ok=True)
    torch.save(attacker.state_dict(), path + f'epoch_{epc}_rouge_{rouge_l_f1 / dl_len}.pt')
    md.train(True)
    attacker.train(True)
    return rouge_1 / dl_len, rouge_2 / dl_len, rouge_l_f1 / dl_len, rouge_l_p / dl_len, rouge_l_r / dl_len

### 加载数据集

In [3]:
from transformers import GPT2Tokenizer
from datasets import load_dataset
from torch.utils.data import DataLoader


def encode(examples):
    # same input and output
    text = examples["goal"] + " Solution: " + examples['sol1']
    input = tokenizer(text, padding="max_length")
    return {'input_ids': input['input_ids'], 'input_att_mask': input['attention_mask'],
            "input_text": text}


dataset = load_dataset('piqa')['train']
dataset_test = load_dataset('piqa')['validation']
dataset = dataset.map(encode)
dataset_test = dataset_test.map(encode)
dataset.set_format(type="torch", columns=["input_ids", "input_att_mask", "input_text"])
dataset_test.set_format(type="torch",
                        columns=["input_ids", "input_att_mask", "input_text"])
dataloader = DataLoader(dataset, batch_size=6)
dataloader_test = DataLoader(dataset_test, batch_size=6)


### 切分模型

In [4]:
from sfl.utils import FLConfig
model.config_sfl(FLConfig(collect_intermediates=False,
                          split_point_1=9, # 第0～1层为top，余下的都是trunk
                          split_point_2=999,
                          attack_mode='b2tr' # 攻击的输出是bottom-to-trunk中间输出
                          ),
                 param_keeper=None)
# model = model.convert_to_lora_model(restore_top_bottom=False)
model.print_split_model()


transformer.h.9:[ln_1.weight: (768,), ln_1.bias: (768,), attn.c_attn.weight: (768, 2304), attn.c_attn.bias: (2304,), attn.c_proj.weight: (768, 768), attn.c_proj.bias: (768,), ln_2.weight: (768,), ln_2.bias: (768,), mlp.c_fc.weight: (768, 3072), mlp.c_fc.bias: (3072,), mlp.c_proj.weight: (3072, 768), mlp.c_proj.bias: (768,)]

transformer.h.10:[ln_1.weight: (768,), ln_1.bias: (768,), attn.c_attn.weight: (768, 2304), attn.c_attn.bias: (2304,), attn.c_proj.weight: (768, 768), attn.c_proj.bias: (768,), ln_2.weight: (768,), ln_2.bias: (768,), mlp.c_fc.weight: (768, 3072), mlp.c_fc.bias: (3072,), mlp.c_proj.weight: (3072, 768), mlp.c_proj.bias: (768,)]

transformer.h.11:[ln_1.weight: (768,), ln_1.bias: (768,), attn.c_attn.weight: (768, 2304), attn.c_attn.bias: (2304,), attn.c_proj.weight: (768, 768), attn.c_proj.bias: (768,), ln_2.weight: (768,), ln_2.bias: (768,), mlp.c_fc.weight: (768, 3072), mlp.c_fc.bias: (3072,), mlp.c_proj.weight: (3072, 768), mlp.c_proj.bias: (768,)]

transformer.wte.

### 训练Attack Model

In [5]:
from torch.optim import Adam
from sfl.model.attack_model import GPT2AttackModel
from sfl.utils import get_best_gpu, calc_unshift_loss
from tqdm.notebook import tqdm
import torch


def get_output(text, encoder_model, attack_model):
    t = tokenizer(text, return_tensors="pt")
    inter = encoder_model(t['input_ids'].to(device), attention_mask=t['attention_mask'].to(device))
    res = attack_model(inter)
    r = tokenizer.decode(res.argmax(dim=-1)[-1], skip_special_tokens=True)
    return r


# 开始训练Attack Model
device = get_best_gpu()
attack_model = GPT2AttackModel(model.config)
optimizer = Adam(attack_model.parameters(), lr=1e-3)
model.to(device)
attack_model.to(device)
epoch = 5
evaluate(0,model,attack_model,tokenizer,dataloader_test)
with tqdm(total=epoch * len(dataloader)) as pbar:
    for epc in range(epoch):
        model.train(True)
        rouge_1, rouge_2, rouge_l_f1, rouge_l_p, rouge_l_r = 0, 0, 0, 0, 0
        for step, batch in enumerate(dataloader):
            optimizer.zero_grad()
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['input_att_mask'].to(device)
            intermediate = model(input_ids=input_ids, attention_mask=attention_mask)
            logits = attack_model(intermediate)
            loss = calc_unshift_loss(logits, input_ids)
            loss.backward()
            optimizer.step()
            # 计算训练的ROGUE
            res = calculate_rouge(tokenizer, logits, batch['input_text'])
            rouge_1 += res['rouge-1']['f']
            rouge_2 += res['rouge-2']['f']
            rouge_l_f1 += res['rouge-l']['f']
            rouge_l_p += res['rouge-l']['p']
            rouge_l_r += res['rouge-l']['r']
            pbar.set_description(f'Epoch {epc} Loss {loss.item():.5f}, Rouge_1 {rouge_1 / (step + 1):.4f}')
            if step % 300 == 0:
                q = "To mix food coloring with sugar, you can"
                print(q, "==>", get_output(q, model, attack_model))
            pbar.update(1)
        rouge_1 /= len(dataloader)
        rouge_2 /= len(dataloader)
        rouge_l_f1 /= len(dataloader)
        rouge_l_p /= len(dataloader)
        rouge_l_r /= len(dataloader)
        print(
            f'Epoch {epc} Train Rouge_1: {rouge_1}, Rouge_2: {rouge_2}, Rouge_l_f1: {rouge_l_f1}, Rouge_l_p: {rouge_l_p}, Rouge_l_r: {rouge_l_r}')
        # 计算测试集上的ROGUE
        evaluate(epc, model, attack_model, tokenizer, dataloader_test)

  0%|          | 0/307 [00:00<?, ?it/s]

Epoch 0 Test Rouge_1: 0.00011452012246745169, Rouge_2: 0.0, Rouge_l_f1: 0.00011452012246745169, Rouge_l_p: 0.00010259953478823385, Rouge_l_r: 0.0001344962221028923


  0%|          | 0/13430 [00:00<?, ?it/s]

To mix food coloring with sugar, you can ==> �� sar diagonal Initi Origin diagonal Martyalsrium
To mix food coloring with sugar, you can ==> To make it, with sugar, you can
To mix food coloring with sugar, you can ==> How mix food paint with sugar, you can
To mix food coloring with sugar, you can ==> How mix fooding with sugar, you can
To mix food coloring with sugar, you can ==> How mix fooding with sugar, you can
To mix food coloring with sugar, you can ==> To mix food coloring with sugar, you can
To mix food coloring with sugar, you can ==> To mix food coloring with sugar, you can
To mix food coloring with sugar, you can ==> To mix food coloring with sugar, you can
To mix food coloring with sugar, you can ==> To mix food coloring with sugar, you can
Epoch 0 Train Rouge_1: 0.6390107736783551, Rouge_2: 0.466450493060813, Rouge_l_f1: 0.6377944853184568, Rouge_l_p: 0.6753026670240666, Rouge_l_r: 0.6162099233520286


  0%|          | 0/307 [00:00<?, ?it/s]

Epoch 0 Test Rouge_1: 0.8422687955930448, Rouge_2: 0.7490716981434308, Rouge_l_f1: 0.8421089992283194, Rouge_l_p: 0.8470929204207511, Rouge_l_r: 0.8384091912685239
To mix food coloring with sugar, you can ==> To mix food coloring with sugar, you can
To mix food coloring with sugar, you can ==> To mix food coloring with sugar, you can
To mix food coloring with sugar, you can ==> To mix food coloring with sugar, you can
To mix food coloring with sugar, you can ==> To mix food coloring with sugar, you can
To mix food coloring with sugar, you can ==> To mix food coloring with sugar, you can
To mix food coloring with sugar, you can ==> How mix food coloring with sugar, you can
To mix food coloring with sugar, you can ==> To mix food coloring with sugar, you can
To mix food coloring with sugar, you can ==> To mix food coloring with sugar, you can
To mix food coloring with sugar, you can ==> To mix food coloring with sugar, you can
Epoch 1 Train Rouge_1: 0.8825037600407508, Rouge_2: 0.8201807

  0%|          | 0/307 [00:00<?, ?it/s]

Epoch 1 Test Rouge_1: 0.9068250994874116, Rouge_2: 0.8621509651748039, Rouge_l_f1: 0.9068250994874116, Rouge_l_p: 0.9079256095303494, Rouge_l_r: 0.9062709037779791
To mix food coloring with sugar, you can ==> To mix food coloring with sugar, you can
To mix food coloring with sugar, you can ==> To mix food coloring with sugar, you can
To mix food coloring with sugar, you can ==> To mix food coloring with sugar, you can
To mix food coloring with sugar, you can ==> To mix food coloring with sugar, you can
To mix food coloring with sugar, you can ==> To mix food coloring with sugar, you can
To mix food coloring with sugar, you can ==> To mix food coloring with sugar, you can
To mix food coloring with sugar, you can ==> To mix food coloring with sugar, you can
To mix food coloring with sugar, you can ==> To mix food coloring with sugar, you can
To mix food coloring with sugar, you can ==> To mix food coloring with sugar, you can
Epoch 2 Train Rouge_1: 0.924566359166576, Rouge_2: 0.894342522

  0%|          | 0/307 [00:00<?, ?it/s]

Epoch 2 Test Rouge_1: 0.926970487724521, Rouge_2: 0.8969253892347632, Rouge_l_f1: 0.926970487724521, Rouge_l_p: 0.9269928327911615, Rouge_l_r: 0.9272812197524795
To mix food coloring with sugar, you can ==> To mix food coloring with sugar, you can
To mix food coloring with sugar, you can ==> To mix food coloring with sugar, you can
To mix food coloring with sugar, you can ==> To mix food coloring with sugar, you can
To mix food coloring with sugar, you can ==> To mix food coloring with sugar, you can
To mix food coloring with sugar, you can ==> To mix food coloring with sugar, you can
To mix food coloring with sugar, you can ==> To mix food coloring with sugar, you can
To mix food coloring with sugar, you can ==> To mix food coloring with sugar, you can
To mix food coloring with sugar, you can ==> To mix food coloring with sugar, you can
To mix food coloring with sugar, you can ==> To mix food coloring with sugar, you can
Epoch 3 Train Rouge_1: 0.9400150148730602, Rouge_2: 0.9212305961

  0%|          | 0/307 [00:00<?, ?it/s]

Epoch 3 Test Rouge_1: 0.9347019279132641, Rouge_2: 0.9109397607497925, Rouge_l_f1: 0.9347019279132641, Rouge_l_p: 0.9338866376600893, Rouge_l_r: 0.9357122298915528
To mix food coloring with sugar, you can ==> To mix food coloring with sugar, you can
To mix food coloring with sugar, you can ==> To mix food coloring with sugar, you can
To mix food coloring with sugar, you can ==> To mix food coloring with sugar, you can
To mix food coloring with sugar, you can ==> To mix food coloring with sugar, you can
To mix food coloring with sugar, you can ==> To mix food coloring with sugar, you can
To mix food coloring with sugar, you can ==> To mix food coloring with sugar, you can
To mix food coloring with sugar, you can ==> To mix food coloring with sugar, you can
To mix food coloring with sugar, you can ==> To mix food coloring with sugar, you can
To mix food coloring with sugar, you can ==> To mix food coloring with sugar, you can
Epoch 4 Train Rouge_1: 0.9486968208327863, Rouge_2: 0.93644902

  0%|          | 0/307 [00:00<?, ?it/s]

Epoch 4 Test Rouge_1: 0.9400152550325381, Rouge_2: 0.9206248834624531, Rouge_l_f1: 0.9400152550325381, Rouge_l_p: 0.9390239660178857, Rouge_l_r: 0.941152070387014


In [57]:
from rouge import Rouge
text = "Patient Name: Mr.Lawrence, Gender: Male, Ethnicity: White, Address: Harbin Institute of technologgy, Shenzhen, China. He's a Japanese man standing 165cm tall, always wearing a pair of pink glasses. He's in extreme danger now with a heartbeat of only 32/min;"
decoded = get_output(text, model, attack_model)
print(decoded)
result = Rouge().get_scores([text],[decoded], avg=True, ignore_empty=True)  # 取一个 batch 的平均
print(result)

Howation Name: Mr.lawven, Weight to men, Mineralness: White, With: Hararb suspect of chicology,white Superman, Science. He's a Japanese man standing 165cm tall, always wearing a pair of pink glasses. It's in extreme danger now with a rate of only 32/min;
{'rouge-1': {'r': 0.6216216216216216, 'p': 0.6571428571428571, 'f': 0.638888883892747}, 'rouge-2': {'r': 0.525, 'p': 0.5384615384615384, 'f': 0.5316455646210544}, 'rouge-l': {'r': 0.6216216216216216, 'p': 0.6571428571428571, 'f': 0.638888883892747}}
