## Step 1. 加载模型与Tokenizer

In [1]:
import os
import sys

from transformers import AutoTokenizer

sys.path.append(os.path.abspath('../..'))
from sfl.model.gpt2.gpt2_split import GPT2SplitLMHeadModel
from sfl.utils import get_best_gpu
cache_dir = '/root/autodl-tmp/sfl/models'  # 模型的缓存位置，需要修改
save_dir = '/root/autodl-tmp/sfl/models/checkpoints'
tokenizer = AutoTokenizer.from_pretrained("gpt2-large", cache_dir=cache_dir)
model:GPT2SplitLMHeadModel = GPT2SplitLMHeadModel.from_pretrained("gpt2-large", cache_dir=cache_dir)
tokenizer.pad_token_id = model.config.eos_token_id
device = get_best_gpu()
model.to(device)

GPT2SplitLMHeadModel(
  (transformer): GPT2SplitModel(
    (wte): Embedding(50257, 1280)
    (wpe): Embedding(1024, 1280)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-35): 36 x GPT2Block(
        (ln_1): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=1280, out_features=50257, bias=False)
)

In [2]:
# 测试模型的生成文本
def generate(text, md=model):
    model.train(False)
    t = tokenizer(text, return_tensors="pt", add_special_tokens=False)
    res = md.generate(t['input_ids'].to(md.device), attention_mask=t['attention_mask'].to(md.device),
                      max_length=300, num_beams=6, no_repeat_ngram_size=2, early_stopping=True,
                      num_return_sequences=1, pad_token_id=tokenizer.pad_token_id)
    return tokenizer.decode(res[0], skip_special_tokens=True)


generate('what do you think!')

'what do you think!?"\n\n"I don\'t know what you\'re talking about," I said, "but I\'m not going to tell you what I think. I\'ve already told you, and I\'ll say it again, that I have no idea what\'s going on here. But I do know that this is not the first time this has happened to me. And it won\'t be the last. It\'s just a matter of time before it happens to someone else. So if you want to talk about it, I\'d like to hear it from someone who\'s experienced it first-hand. Someone who knows what it\'s like, who can tell me what to expect, so I can make an informed decision about whether or not to go through with it. That way, if it does happen to anyone else, we\'ll have a better idea of what we\'re dealing with." I paused for a moment, then continued. "I\'m sorry if I sound like a broken record, but that\'s the only way I know how to get through to you. You\'re the one who has to deal with the consequences of your actions, not I. If you really want me to help you out, you have to be wil

In [3]:
from sfl.utils import calculate_rouge


# 恢复的评价指标选用ROUGE

def evaluate(epc, md, attacker, tok, test_data_loader):
    md.eval()
    attacker.eval()
    dl_len = len(test_data_loader)
    with torch.no_grad():
        rouge_1, rouge_2, rouge_l_f1, rouge_l_p, rouge_l_r = 0, 0, 0, 0, 0
        for step, batch in tqdm(enumerate(test_data_loader), total=dl_len):
            input_ids = batch['input_ids'].to(md.device)
            attention_mask = batch['input_att_mask'].to(md.device)
            inter = md(input_ids=input_ids, attention_mask=attention_mask)
            logits = attacker(inter)
            result = calculate_rouge(tok, logits, batch['input_text'])
            rouge_1 += result['rouge-1']['f']
            rouge_2 += result['rouge-2']['f']
            rouge_l_f1 += result['rouge-l']['f']
            rouge_l_p += result['rouge-l']['p']
            rouge_l_r += result['rouge-l']['r']
    print(
        f'Epoch {epc} Test Rouge_1: {rouge_1 / dl_len}, Rouge_2: {rouge_2 / dl_len}, Rouge_l_f1: {rouge_l_f1 / dl_len}, Rouge_l_p: {rouge_l_p / dl_len}, Rouge_l_r: {rouge_l_r / dl_len}')
    path = save_dir + f'/attacker/{md.config.name_or_path}/piqa-validation/{md.fl_config.attack_mode}-{md.fl_config.split_point_1 if md.fl_config.attack_mode == "b2tr" else md.fl_config.split_point_2}/'
    os.makedirs(path, exist_ok=True)
    torch.save(attacker.state_dict(), path + f'epoch_{epc}_rouge_{rouge_l_f1 / dl_len}.pt')
    md.train(True)
    attacker.train(True)
    return rouge_1 / dl_len, rouge_2 / dl_len, rouge_l_f1 / dl_len, rouge_l_p / dl_len, rouge_l_r / dl_len

### 加载数据集

In [4]:
from transformers import GPT2Tokenizer
from datasets import load_dataset
from torch.utils.data import DataLoader


def encode(examples):
    # same input and output
    text = examples["goal"] + " Solution: " + examples['sol1']
    input = tokenizer(text, padding="max_length")
    return {'input_ids': input['input_ids'], 'input_att_mask': input['attention_mask'],
            "input_text": text}


dataset = load_dataset('piqa')['validation']
dataset_test = load_dataset('piqa')['test']
dataset = dataset.map(encode)
dataset_test = dataset_test.map(encode)
dataset.set_format(type="torch", columns=["input_ids", "input_att_mask", "input_text"])
dataset_test.set_format(type="torch",
                        columns=["input_ids", "input_att_mask", "input_text"])
dataloader = DataLoader(dataset, batch_size=6)
dataloader_test = DataLoader(dataset_test, batch_size=6)


### 切分模型

In [5]:
from sfl.utils import FLConfig
model.config_sfl(FLConfig(collect_intermediates=False,
                          split_point_1=2, # 第0～1层为top，余下的都是trunk
                          split_point_2=30,
                          attack_mode='tr2t' # 攻击的输出是bottom-to-trunk中间输出
                          ),
                 param_keeper=None)
# freeze all parts:
for name, param in model.named_parameters():
    param.requires_grad = False

### 训练Attack Model

In [None]:
from torch.optim import Adam
from sfl.model.attack_model import GPT2AttackModel
from sfl.utils import get_best_gpu, calc_unshift_loss
from tqdm.notebook import tqdm
import torch


def get_output(text, encoder_model, attack_model):
    t = tokenizer(text, return_tensors="pt")
    inter = encoder_model(t['input_ids'].to(device), attention_mask=t['attention_mask'].to(device))
    res = attack_model(inter)
    r = tokenizer.decode(res.argmax(dim=-1)[-1], skip_special_tokens=True)
    return r


# 开始训练Attack Model
attack_model = GPT2AttackModel(model.config)
optimizer = Adam(attack_model.parameters(), lr=1e-3)
model.to(device)
attack_model.to(device)
epoch = 20
evaluate(0,model,attack_model,tokenizer,dataloader_test)
with tqdm(total=epoch * len(dataloader)) as pbar:
    for epc in range(epoch):
        model.train(True)
        rouge_1, rouge_2, rouge_l_f1, rouge_l_p, rouge_l_r = 0, 0, 0, 0, 0
        for step, batch in enumerate(dataloader):
            optimizer.zero_grad()
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['input_att_mask'].to(device)
            intermediate = model(input_ids=input_ids, attention_mask=attention_mask)
            logits = attack_model(intermediate)
            loss = calc_unshift_loss(logits, input_ids)
            loss.backward()
            optimizer.step()
            # 计算训练的ROGUE
            res = calculate_rouge(tokenizer, logits, batch['input_text'])
            rouge_1 += res['rouge-1']['f']
            rouge_2 += res['rouge-2']['f']
            rouge_l_f1 += res['rouge-l']['f']
            rouge_l_p += res['rouge-l']['p']
            rouge_l_r += res['rouge-l']['r']
            pbar.set_description(f'Epoch {epc} Loss {loss.item():.5f}, Rouge_1 {rouge_1 / (step + 1):.4f}')
            if step % 300 == 0:
                q = "To mix food coloring with sugar, you can"
                print(q, "==>", get_output(q, model, attack_model))
            pbar.update(1)
        rouge_1 /= len(dataloader)
        rouge_2 /= len(dataloader)
        rouge_l_f1 /= len(dataloader)
        rouge_l_p /= len(dataloader)
        rouge_l_r /= len(dataloader)
        print(
            f'Epoch {epc} Train Rouge_1: {rouge_1}, Rouge_2: {rouge_2}, Rouge_l_f1: {rouge_l_f1}, Rouge_l_p: {rouge_l_p}, Rouge_l_r: {rouge_l_r}')
        # 计算测试集上的ROGUE
        evaluate(epc, model, attack_model, tokenizer, dataloader_test)

  0%|          | 0/514 [00:00<?, ?it/s]

Epoch 0 Test Rouge_1: 0.0005025614340383022, Rouge_2: 0.0, Rouge_l_f1: 0.0005025614340383022, Rouge_l_p: 0.0004869805875160254, Rouge_l_r: 0.0005287550443985754


  0%|          | 0/6140 [00:00<?, ?it/s]

To mix food coloring with sugar, you can ==>  Occasionally1984WAINCWriting....eredfeld Taste
To mix food coloring with sugar, you can ==> How clean it cream with sugar, you can
Epoch 0 Train Rouge_1: 0.317246887817908, Rouge_2: 0.09160185628372078, Rouge_l_f1: 0.31323336688097336, Rouge_l_p: 0.4349761707612723, Rouge_l_r: 0.26484900639296605


  0%|          | 0/514 [00:00<?, ?it/s]

Epoch 0 Test Rouge_1: 0.47663800764485625, Rouge_2: 0.20942946531056097, Rouge_l_f1: 0.4741811757147421, Rouge_l_p: 0.5197028042602481, Rouge_l_r: 0.4403104370501897
To mix food coloring with sugar, you can ==> How clean it cream with sugar, you can
To mix food coloring with sugar, you can ==> How mix food color with sugar, you can
Epoch 1 Train Rouge_1: 0.5413305977871395, Rouge_2: 0.3045067469401921, Rouge_l_f1: 0.5394263052448102, Rouge_l_p: 0.560687565337999, Rouge_l_r: 0.523725728548728


  0%|          | 0/514 [00:00<?, ?it/s]

Epoch 1 Test Rouge_1: 0.6053748214531796, Rouge_2: 0.3931632946472246, Rouge_l_f1: 0.6041848827231711, Rouge_l_p: 0.6139104048111682, Rouge_l_r: 0.598000879817346
To mix food coloring with sugar, you can ==> How Mix food color with sugar, you can
To mix food coloring with sugar, you can ==> How mix food dye with sugar, you can
Epoch 2 Train Rouge_1: 0.6733416546730611, Rouge_2: 0.4943407304873943, Rouge_l_f1: 0.6725199002952426, Rouge_l_p: 0.6775675750113608, Rouge_l_r: 0.6702470537483662


  0%|          | 0/514 [00:00<?, ?it/s]

Epoch 2 Test Rouge_1: 0.6819124040924978, Rouge_2: 0.509475669588175, Rouge_l_f1: 0.6812455030398181, Rouge_l_p: 0.6805029993495434, Rouge_l_r: 0.6840777807056893
To mix food coloring with sugar, you can ==> How mix food coloring with sugar, you can
To mix food coloring with sugar, you can ==> How mix food coloring with sugar, you can
Epoch 3 Train Rouge_1: 0.7677964172920628, Rouge_2: 0.640484129757239, Rouge_l_f1: 0.7674175983491897, Rouge_l_p: 0.7673977276033664, Rouge_l_r: 0.769023836775344


  0%|          | 0/514 [00:00<?, ?it/s]

Epoch 3 Test Rouge_1: 0.7267856692806354, Rouge_2: 0.579735313115506, Rouge_l_f1: 0.7262220822377976, Rouge_l_p: 0.7236563421247864, Rouge_l_r: 0.7303503062504708
To mix food coloring with sugar, you can ==> How mix food coloring with sugar, you can
To mix food coloring with sugar, you can ==> To mix food coloring with sugar, you can
Epoch 4 Train Rouge_1: 0.846716235963784, Rouge_2: 0.7629903713703107, Rouge_l_f1: 0.8465654530408262, Rouge_l_p: 0.8457142943896521, Rouge_l_r: 0.8484616220397284


  0%|          | 0/514 [00:00<?, ?it/s]

Epoch 4 Test Rouge_1: 0.7669677410611074, Rouge_2: 0.6335684390536279, Rouge_l_f1: 0.7665924656831932, Rouge_l_p: 0.764107113136773, Rouge_l_r: 0.7706034603049903
To mix food coloring with sugar, you can ==> How mix food coloring with sugar, you can


In [None]:
from rouge import Rouge
text = "Patient Name: Mr.Lawrence, Gender: Male, Ethnicity: White, Address: Harbin Institute of technology, Shenzhen, China. He's a Japanese man standing 165cm tall, always wearing a pair of pink glasses. He's in extreme danger now with a heartbeat of only 32/min;"
decoded = get_output(text, model, attack_model)
print(decoded)
result = Rouge().get_scores([text],[decoded], avg=True, ignore_empty=True)  # 取一个 batch 的平均
print(result)