# Tutorial: 基于CPM-Bee进行解码层微调（decoder tuning）
本教程基于情感分类数据集SST2对CPM-Bee进行解码层微调（decoder tuning）。解码层微调（decoder tuning）是在不训练模型的情况下，通过加入输出端的解码器网络，使用少样本训练解码器网络来提升模型的理解能力。我们将4shots的微调结果与原始模型zero shot进行对比。

This tutorial is based on the sentiment classification data set SST2 for CPM-Bee decoder tuning. decoder tuning is to improve the understanding ability of the model by joining the decoder network at the output end and training the decoder network with few samples without training the model. We compared the fine-tuning results of 4shots to the original model zero shot.

### 1. 数据格式处理 (Process dataset)
训练之前，我们需要定义并处理我们的数据输入格式，我们使用的原始样例数据如下

Before training, we need to prepare and process our training data. Below is a piece of example training data

In [18]:
# example data
# 4 shots
train_data = [{"input": "without shakespeare 's eloquent language , the update is dreary and sluggish . ","question":"What is the sentiment of this sentence?",  "<ans>":"bad"},
    {"input": "final verdict : you 've seen it all before . ", "question":"What is the sentiment of this sentence?", "<ans>":"bad"},
    {"input": "the drama discloses almost nothing . ", "question":"What is the sentiment of this sentence?", "<ans>":"bad"},
    {"input": "an inexperienced director , mehta has much to learn . ", "question":"What is the sentiment of this sentence?", "<ans>":"bad"},
    {"input": "you live the mood rather than savour the story . ", "question":"What is the sentiment of this sentence?", "<ans>":"great"},
    {"input": "an edgy thriller that delivers a surprising punch . ", "question":"What is the sentiment of this sentence?", "<ans>":"great"},
    {"input": "nicholson 's understated performance is wonderful . ", "question":"What is the sentiment of this sentence?", "<ans>":"great"},
    {"input": "this is a gorgeous film - vivid with color , music and life . ", "question":"What is the sentiment of this sentence?", "<ans>":"great"},
]

在本教程中，我们使用的情感分类的输入格式如下（也可以自行定义其他格式）：

In this tutorial, we use the following input format for emotion classification (you can also define other formats) :
```
input: text
question: "What is the sentiment of this sentence?"
<ans>: bad/great 
```

In [19]:
import os
import pandas as pd
import random
import sys

添加工作路径
Add working path

In [20]:
random.seed(123)
sys.path.append("../src")

加载test数据集，并且处理成类似与train_data的格式
Load the test dataset and process it in a format similar to train_data

In [21]:
def load_test_dataset(dataset):
    test_data = []
    label_map = {"0":"bad", "1":"great"}
    path = './decoder_tuning_data/raw_data/' + dataset.lower().split('-')[0]+ '/dev.tsv'
    test_df = pd.read_csv(path, sep='\t')
    for index, row in test_df.iterrows():
        example = {}
        # print(row['sentence'], row['label'])
        example["input"] = row['sentence']
        example["question"] = "What is the sentiment of this sentence?"
        example["<ans>"] = label_map[str(row['label'])]
        test_data.append(example)
    return test_data

In [22]:
test_data = load_test_dataset('sst2')

按照预处理格式将处理好的数据存储为二进制文件
The processed data is stored as binary files in a preprocessed format

In [23]:
import os
import sys
from cpm_live.dataset import build_dataset, shuffle_dataset

import shutil
from tqdm import tqdm
import json

In [24]:
def build_bin_data(data, dataset_path,dataset_name):
    output_path = "./decoder_tuning_data/bin_data/"  

    with build_dataset("tmp", "data") as dataset:
        for item in data:
            dataset.write(item) # reformat_data(item)
    shuffle_dataset(
        "tmp",
        os.path.join(output_path, dataset_path),
        progress_bar=True,
        output_name=dataset_name
    )
    shutil.rmtree("tmp")

In [25]:
if os.path.exists("./decoder_tuning_data/bin_data"):
    os.system("rm -rf ./decoder_tuning_data/bin_data/")
if os.path.exists("./tmp"):
    os.system("rm -rf ./tmp")
build_bin_data(train_data, "train_data", "example-data")
build_bin_data(test_data, "test_data", "example-data")

Shuffle step 1/2: 100%|██████████| 8/8 [00:00<00:00, 18117.94it/s]
Shuffle step 2/2: 100%|██████████| 1/1 [00:00<00:00, 3905.31it/s]
Shuffle step 1/2: 100%|██████████| 872/872 [00:00<00:00, 146022.80it/s]
Shuffle step 2/2: 100%|██████████| 1/1 [00:00<00:00, 74.69it/s]


### 2. 训练（Training）
我们自定义一个DecT_CPM类，方便实现在CPMBee上的decoder tuning，通过DecT_CPM中的run，run_zs函数分别实现4-shot下的decoder tuning和0-shot下的原本模型能力的测试。需要将预训练好的模型存储在`./ckpt/`文件夹下。具体而言，`./ckpt/`文件夹需要预先保存`./ckpt/config.json`、`./ckpt/pytorch_model.bin`和`./ckpt/vocab.txt`.

We customize a DecT_CPM class to facilitate the realization of decoder tuning on CPMBee. Through the run and run_zs functions in DecT_CPM, decoder tuning under 4-shot and the original model ability test under 0-shot can be realized respectively. The pre-trained model needs to be stored in the './ckpt/ 'folder. Specifically, the './ckpt/ 'folder needs to be pre-stored'./ckpt/config.json ', './ckpt/pytorch_model.bin 'and'./ckpt/vocab.txt '.

初始化一些超参数以及verbalizer
Initialize some hyperparameters and verbalizer

In [26]:
# 初始化一些超参数以及verbalizer
label_map = {"0":"bad", "1":"great"}
label_list = ["bad", "great"]
lr = 4e-3
proto_dim = 128
model_logits_weight = 1
max_epochs = 30

构造decoder_tuning的trainer
Construct the decoder_tuning trainer

In [27]:
import os, shutil
import sys
sys.path.append(".")
import time
import torch
from torch import nn
import torch.nn.functional as F
from tqdm import tqdm
import dill
import warnings
from typing import Optional
from typing import Callable, Union, Dict, List
try:
    from typing import OrderedDict
except ImportError:
    from collections import OrderedDict
from sklearn.metrics import accuracy_score
from copy import deepcopy

class DecTCPM(object):
    r"""A runner for DecT
    This class is specially implemented for classification.
    Decoder Tuning: Efficient Language Understanding as Decoding : https://arxiv.org/pdf/2212.08408.pdf

    Args:
        model (:obj:`CPMBeeTorch`): One ``CPMBeeTorch`` object.
        test_dataloader (:obj:`FinetuneDataset`): The dataloader to bachify and process the test data.
        tokenizer (:obj:`CPMBeeTokenizer`): The tokenizer to process the word.
        verbalizer (:obj:`Verbalizer`): The verbalizer to map the label to the word.
        device (:obj:`torch.device`): The device to run the model.
        calibrate_dataloader (:obj:`FinetuneDataset`, optional): The dataloader that has empty input, to modify the output logits. Defaults to None.
        lr (:obj:`float`, optional): The learning rate. Defaults to 5e-3.
        hidden_size (:obj:`int`, optional): The hidden size of the model. Defaults to 4096.
        mid_dim (:obj:`int`, optional): The dimension of the proto vector. Defaults to 128.
        epochs (:obj:`int`, optional): The number of epochs to train. Defaults to 5.
        model_logits_weight (:obj:`float`, optional): The weight of the model logits. Defaults to 1.
    """
    def __init__(self, 
                 model,
                 test_dataloader,
                 tokenizer,
                 verbalizer,
                 device: Optional[Union[str, torch.device]] = "cuda:0",
                 calibrate_dataloader: Optional[List] = None,
                 lr: Optional[float] = 5e-3,
                 hidden_size: Optional[int] = 4096,
                 mid_dim: Optional[int] = 128,
                 epochs: Optional[int] = 5,
                 model_logits_weight: Optional[float] = 1,
                 ):
        self.model = model
        self.test_dataloader = test_dataloader
        self.calibrate_dataloader = calibrate_dataloader
        self.loss_function = torch.nn.CrossEntropyLoss()
        self.device = device
        ids = []
        for idx in range(len(verbalizer.items())):
            ids.append(tokenizer.encode(verbalizer[str(idx)])[0][0])
        self.label_list = list(verbalizer.values())
        self.label_word_token_ids = []
        for label_word in self.label_list:
            self.label_word_token_ids.append(tokenizer.encode(label_word)[0][0])
        self.ids = ids #nn.Parameter(torch.tensor(ids), requires_grad=False)
        self.num_classes = len(self.ids)
        self.lr = lr
        self.mid_dim = mid_dim
        self.epochs = epochs
        self.model_logits_weight = model_logits_weight
        self.hidden_dims = hidden_size
        self.reset_parameter()
    
    # reset the parameters, useful when you want to test different random seeds
    def reset_parameter(self):
        self.head = nn.Linear(self.hidden_dims, self.mid_dim, bias=False)
        w = torch.empty((self.num_classes, self.mid_dim)).to(self.device)
        nn.init.xavier_uniform_(w)
        self.proto = nn.Parameter(w, requires_grad=False)
        r = torch.ones(self.num_classes)
        self.proto_r = nn.Parameter(r, requires_grad=True)
        self.optimizer = torch.optim.Adam([p for n, p in self.head.named_parameters()] + [self.proto_r], lr=self.lr)


    # get the logits and hidden states of the model, specifically for cpmbee model, you can modify it for other models
    def get_logits_and_hidden(self,data):
        input_ids = torch.from_numpy(data["inputs"]).cuda().to(torch.int32)
        input_ids_sub = torch.from_numpy(data["inputs_sub"]).cuda().to(torch.int32)
        input_length = torch.from_numpy(data["length"]).cuda().to(torch.int32)
        input_context = torch.from_numpy(data["context"]).cuda().bool()
        input_sample_ids = torch.from_numpy(data["sample_ids"]).cuda().to(torch.int32)
        input_num_segments = torch.from_numpy(data["num_segments"]).cuda().to(torch.int32)
        input_segment_ids = torch.from_numpy(data["segment_ids"]).cuda().to(torch.int32)
        input_segment_rel_offset = (
            torch.from_numpy(data["segment_rel_offset"]).cuda().to(torch.int32)
        )
        input_segment_rel = torch.from_numpy(data["segment_rel"]).cuda().to(torch.int32)
        input_span = torch.from_numpy(data["spans"]).cuda().to(torch.int32)
        targets = torch.from_numpy(data["target"]).cuda().to(torch.int32)
        ext_table_ids = torch.from_numpy(data["ext_ids"]).cuda().to(torch.int32)
        ext_table_sub = torch.from_numpy(data["ext_sub"]).cuda().to(torch.int32)
        task_ids = torch.from_numpy(data["task_ids"]).cuda().to(torch.int32)
        task_names = data["task_names"]
        # to get the label from the targets
        mask = torch.logical_or(targets ==self.ids[0], targets==self.ids[1])
        labels = targets[mask]
        final_label = []
        for i in range(len(labels)):
            final_label.append(self.ids.index(labels[i]))
        with torch.no_grad():
            logits, hidden_states = self.model(
                    input_ids,
                    input_ids_sub,
                    input_length,
                    input_context,
                    input_sample_ids,
                    input_num_segments,
                    input_segment_ids,
                    input_segment_rel_offset,
                    input_segment_rel,
                    input_span,
                    ext_table_ids,
                    ext_table_sub,
                )
        # mask the targets where value is -100 or 7, to get the index of the valid position
        mask_matrix = deepcopy(targets)
        mask_matrix[targets == -100] = 0
        mask_matrix[targets == 7] = 0
        index_mask = mask_matrix.nonzero(as_tuple=False)
        # finally we get the logits and hidden states of the <ans> word position
        filtered_logits = logits[index_mask[:, 0], index_mask[:, 1], :]
        filtered_hiddens = hidden_states[index_mask[:, 0], index_mask[:, 1], :]
        label_logits = filtered_logits[:,self.label_word_token_ids] # F.softmax(filtered_logits)[:,self.label_word_token_ids]
        # normalize the hidden states to prevent the nan output in the loss
        normalize_hidden = F.normalize(filtered_hiddens, p=2, dim=-1)
        return label_logits, normalize_hidden,final_label
    
    # test the model on the dev set, if zs is true, then test on the zero-shot setting, otherwise test on the decoder tuning setting
    def test(self, dataloader,zs):
        if zs:
            preds = []
            labels = []
            for iteration, data in enumerate(dataloader):
                if data is None:
                    if last_data is None:
                        raise RuntimeError(
                            "Dataset is too small, please use a smaller batch size or sequence length!"
                        )
                    data = last_data  # use last data
                    skip_this_batch = True
                else:
                    last_data = data
                logits,_,label = self.get_logits_and_hidden(data)
                preds.extend(torch.argmax(logits, dim=-1).cpu().tolist())
                labels.extend(label)
            res = sum([int(i==j) for i,j in zip(preds, labels)])/len(preds)
            return res
        else:
            preds = []
            labels = []
            for iteration, data in enumerate(dataloader):
                if data is None:
                    if last_data is None:
                        raise RuntimeError(
                            "Dataset is too small, please use a smaller batch size or sequence length!"
                        )
                    data = last_data  # use last data
                    skip_this_batch = True
                else:
                    last_data = data
                logits,hidden_states,label = self.get_logits_and_hidden(data)
                proto_logits = self.sim(self.head(F.normalize(hidden_states.float(),dim=-1)), self.proto, self.proto_r, logits.float(), self.model_logits_weight).cpu()
                preds.extend(torch.argmax(proto_logits, dim=-1).cpu().tolist())
                labels.extend(label)
            res = sum([int(i==j) for i,j in zip(preds, labels)])/len(preds)
            return res

    @staticmethod
    def sim(x, y, r=0, model_logits=0, model_logits_weight=1):
        x = torch.unsqueeze(x, -2)
        d = torch.norm((x - y), dim=-1)
        dist = d - model_logits * model_logits_weight - r
        return -dist
    
    # conduct the loss function in the decoder tuning
    def loss_func(self, x, model_logits, labels):
        sim_mat = torch.exp(self.sim(x, self.proto, self.proto_r, model_logits, self.model_logits_weight))
        pos_score = torch.sum(sim_mat * F.one_hot(labels), -1)
        loss = -torch.mean(torch.log(pos_score / sim_mat.sum(-1)))
        return loss
    
    # run zero shot setting
    def run_zs(self):
        res = self.test(self.test_dataloader, zs = True)
        print("zero shot acc:",res)

    # train the model with decoder tuning, you need to provide the training dataloader (type:FinetuneDataset)
    def run(self, train_dataloader):
        logits_list = []
        hidden_states_list = []
        labels = []
        with torch.no_grad():
            for iteration, data in enumerate(train_dataloader):
                if data is None:
                    if last_data is None:
                        raise RuntimeError(
                            "Dataset is too small, please use a smaller batch size or sequence length!"
                        )
                    data = last_data  # use last data
                    skip_this_batch = True
                else:
                    last_data = data
                train_logits, train_embeds,label = self.get_logits_and_hidden(data)
                logits_list.append(train_logits)
                hidden_states_list.append(train_embeds)
                labels.extend(label)
        train_logits = torch.cat(logits_list,dim=0)
        train_embeds = torch.cat(hidden_states_list,dim=0)
        embeds = [[] for _ in range(self.num_classes)]
        train_labels = [[] for _ in range(self.num_classes)]
        model_logits = [[] for _ in range(self.num_classes)]
        total_num = 0
        start_time = time.time()

        for idx, label in enumerate(labels):
            label = torch.tensor(label)
            train_labels[label].append(label)
            embeds[label].append(torch.tensor(train_embeds[idx]))
            model_logits[label].append(torch.tensor(train_logits[idx]))
        embeds = list(map(torch.stack, embeds))
        labels = torch.cat(list(map(torch.stack, train_labels))).to(self.device)
        model_logits = torch.cat(list(map(torch.stack, model_logits))).float()

        self.head.to(self.device)
        self.proto.to(self.device)
        self.proto_r.to(self.device)
        dist = list(map(lambda x: torch.norm(self.head(x.float()) - self.head(x.float().mean(0)), dim=-1).mean(), embeds))
        self.proto_r.data = torch.stack(dist)
        
        loss = 0.
        
        for epoch in range(self.epochs):
            x = self.head(torch.cat(embeds).float())
            self.optimizer.zero_grad()
            loss = self.loss_func(x, model_logits, labels)
            loss.backward()
            self.optimizer.step()
            print("Total epoch: {}. DecT loss: {}".format(epoch, loss))

        end_time = time.time()
        print("Training time: {}".format(end_time - start_time))
        res= self.test(self.test_dataloader, zs = False)
        print("dect acc:",res)

加载模型的权重并部署
Load the model's weights and deploy

In [28]:
from cpm_live.tokenizers import CPMBeeTokenizer
from cpm_live.training_tasks.bee import FinetuneDataset
from cpm_live.models import CPMBeeConfig, CPMBeeTorch
import torch
import torch.nn.functional as F
import bmtrain as bmt
from copy import deepcopy
config = CPMBeeConfig.from_json_file("./ckpt/config.json")
ckpt_path = "./ckpt/pytorch_model.bin"
tokenizer = CPMBeeTokenizer()
model = CPMBeeTorch(config=config)
model.load_state_dict(torch.load(ckpt_path), strict=False)
device = torch.device("cuda:0")
model.to(device)

CPMBeeTorch(
  (encoder): Encoder(
    (layers): ModuleList(
      (0): TransformerBlock(
        (self_att): SelfAttentionBlock(
          (layernorm_before_attention): LayerNorm()
          (self_attention): Attention(
            (project_q): Linear()
            (project_k): Linear()
            (project_v): Linear()
            (attention_out): Linear()
            (softmax): Softmax(dim=-1)
            (dropout): Dropout(p=0.0, inplace=False)
          )
        )
        (ffn): FFNBlock(
          (layernorm_before_ffn): LayerNorm()
          (ffn): FeedForward(
            (w_in): DenseGatedACT(
              (w_0): Linear()
              (w_1): Linear()
              (act): GELU()
            )
            (dropout): Dropout(p=0.0, inplace=False)
            (w_out): Linear()
          )
        )
      )
      (1): TransformerBlock(
        (self_att): SelfAttentionBlock(
          (layernorm_before_attention): LayerNorm()
          (self_attention): Attention(
            (p

构建dataloader
build dataloader

In [29]:
train_dataloader = FinetuneDataset(
        dataset_path = "./decoder_tuning_data//bin_data/train_data",
        batch_size=4,
        max_length=512,
        max_depth=8,
        tokenizer=tokenizer,
    )
test_dataloader = FinetuneDataset(
        dataset_path = "./decoder_tuning_data/bin_data/test_data",
        batch_size=4,
        max_length=512,
        max_depth=8,
        tokenizer=tokenizer,
    )

构建runner
Build runner

In [30]:
runner = DecTCPM(
    model = model,
    test_dataloader = test_dataloader,
    tokenizer = tokenizer,
    verbalizer = label_map,
    device = device,
    calibrate_dataloader = None,
    lr = lr,
    mid_dim = proto_dim,
    epochs = max_epochs,
    model_logits_weight = model_logits_weight,
)

开始训练，最后输出的是zero shot和decoder tuning的准确率
Start training, the final output is zero shot and decoder tuning accuracy

In [31]:
runner.run_zs()
runner.run(train_dataloader)

zero shot acc: 0.8623853211009175


  embeds[label].append(torch.tensor(train_embeds[idx]))
  model_logits[label].append(torch.tensor(train_logits[idx]))


Total epoch: 0. DecT loss: 0.4929628372192383
Total epoch: 1. DecT loss: 0.34562215209007263
Total epoch: 2. DecT loss: 0.3390842080116272
Total epoch: 3. DecT loss: 0.3364109992980957
Total epoch: 4. DecT loss: 0.33365967869758606
Total epoch: 5. DecT loss: 0.3307572603225708
Total epoch: 6. DecT loss: 0.3278043270111084
Total epoch: 7. DecT loss: 0.32488542795181274
Total epoch: 8. DecT loss: 0.32205426692962646
Total epoch: 9. DecT loss: 0.3193338215351105
Total epoch: 10. DecT loss: 0.31672176718711853
Total epoch: 11. DecT loss: 0.31419551372528076
Total epoch: 12. DecT loss: 0.3117189109325409
Total epoch: 13. DecT loss: 0.3092496395111084
Total epoch: 14. DecT loss: 0.30674344301223755
Total epoch: 15. DecT loss: 0.3041604459285736
Total epoch: 16. DecT loss: 0.3014650344848633
Total epoch: 17. DecT loss: 0.29862841963768005
Total epoch: 18. DecT loss: 0.2956272065639496
Total epoch: 19. DecT loss: 0.2924429774284363
Total epoch: 20. DecT loss: 0.28906041383743286
Total epoch: 2