# Tutorial: 基于CPM-Bee进行解码层微调（decoder tuning）
本教程基于情感分类数据集SST2对CPM-Bee进行解码层微调（decoder tuning）。解码层微调（decoder tuning）是在不训练模型的情况下，通过加入输出端的解码器网络，使用少样本训练解码器网络来提升模型的理解能力。我们将16 shot的微调结果与原始模型zero shot进行对比。

This tutorial is based on the sentiment classification data set SST2 for CPM-Bee decoder tuning. decoder tuning is to improve the understanding ability of the model by joining the decoder network at the output end and training the decoder network with few samples without training the model. We compared the fine-tuning results of 16 shot to the original model zero shot.

### 1. 数据格式处理 (Process dataset)
训练之前，我们需要定义并处理我们的数据输入格式，我们构造一个数据集的处理类，将数据处理为特定格式。

Before training, we need to define and process our data input format. We construct a processing class for the data set to process the data into a specific format.

在本教程中，我们使用的情感分类的输入格式如下（也可以自行定义其他格式）：

In this tutorial, we use the following input format for emotion classification (you can also define other formats) :
```
input: text
question: "What is the sentiment of this sentence?"
<ans>: bad/great 
```

In [1]:
import os
class SST2Processor():
    """
    `SST-2 <https://nlp.stanford.edu/sentiment/index.html>`_ dataset is a dataset for sentiment analysis. It is a modified version containing only binary labels (negative or somewhat negative vs somewhat positive or positive with neutral sentences discarded) on top of the original 5-labeled dataset released first in `Recursive Deep Models for Semantic Compositionality Over a Sentiment Treebank <https://aclanthology.org/D13-1170.pdf>`_

    We use the data released in `Making Pre-trained Language Models Better Few-shot Learners (Gao et al. 2020) <https://arxiv.org/pdf/2012.15723.pdf>`_

    """
    dataset_project = {"train": "train",
                        "dev": "dev",
                        "test": "test"
                        }
    def __init__(self):
        super().__init__()
        self.labels = ['0', '1']
        self.verbalizer = {"0": "bad", "1": "great"}
        self.label_word = ["bad","great"]
    
    def get_examples(self, data_dir, split,shot =-1):
        self.counts = {"bad":0,"great":0}
        path = os.path.join(data_dir, f"{self.dataset_project[split]}.tsv")
        examples = []
        with open(path, encoding='utf-8')as f:
            lines = f.readlines()
            for idx, line in enumerate(lines[1:]):
                linelist = line.strip().split('\t')
                text_a = linelist[0]
                label = linelist[1]
                guid = "%s-%s" % (split, idx)
                example = {}
                example["input"] =text_a
                example["question"] = "What is the sentiment of this sentence?"
                example["<ans>"] = self.verbalizer[label]
                if shot==-1:
                    examples.append(example)
                else:
                    if self.counts[example["<ans>"]]>=shot:
                        continue    
                    else:
                        examples.append(example)
                        self.counts[example["<ans>"]]+=1
        return examples

添加不同的数据处理对象，以实现不同的数据格式。

Add different data processing objects to implement different data formats.

In [2]:
PROCESSORS = {
    "sst2": SST2Processor
}
dataset_name = "sst2"

添加工作路径

Add working path

In [4]:
import random
import sys
random.seed(123)
sys.path.append("../src")

预处理数据格式，并且按照预处理格式将处理好的数据存储为二进制文件。训练集和验证集选取SST2数据集的train和dev文件中的数据并构造16 shot数据，测试集选取的为SST2数据集的test文件。文件路径在 ./decoder_tuning_data/raw_data/

Preprocesse the data format, and the processed data is stored as binary files. The training set and verification set select the data in the train and dev files of the SST2 data set and construct 16 shot data, and the test set selects the test file of the SST2 data set. File path in ./decoder_tuning_data/raw_data/

In [6]:
import os
import sys
from cpm_live.dataset import build_dataset, shuffle_dataset

import shutil
from tqdm import tqdm
import json
def build_bin_data(data,output_dir, dataset_path,dataset_name):
    output_path = "./decoder_tuning_data/bin_data/"  +output_dir

    with build_dataset("tmp", "data") as dataset:
        for item in data:
            dataset.write(item) # reformat_data(item)
    shuffle_dataset(
        "tmp",
        os.path.join(output_path, dataset_path),
        progress_bar=True,
        output_name=dataset_name
    )
    shutil.rmtree("tmp")

processor = PROCESSORS[dataset_name]()
path = 'decoder_tuning_data/raw_data/' + dataset_name
train_dataset = processor.get_examples(path,"train",shot=16)

valid_dataset = processor.get_examples(path,"dev",shot=16)

test_dataset = processor.get_examples(path,"test")
if os.path.exists("./decoder_tuning_data/bin_data"):
    os.system("rm -rf ./decoder_tuning_data/bin_data/{}".format(dataset_name))
if os.path.exists("./tmp"): 
    os.system("rm -rf ./tmp")
output_dir = dataset_name
build_bin_data(train_dataset,output_dir, "train_data", "example-data")
build_bin_data(valid_dataset,output_dir, "valid_data", "example-data")
build_bin_data(test_dataset,output_dir, "test_data", "example-data")
verbalizer = processor.verbalizer
label_word = processor.label_word

print(train_dataset[0])
print(verbalizer)
print(label_word)

  from .autonotebook import tqdm as notebook_tqdm
Shuffle step 1/2: 100%|██████████| 32/32 [00:00<00:00, 28795.91it/s]
Shuffle step 2/2: 100%|██████████| 1/1 [00:00<00:00, 3097.71it/s]
Shuffle step 1/2: 100%|██████████| 32/32 [00:00<00:00, 72354.57it/s]
Shuffle step 2/2: 100%|██████████| 1/1 [00:00<00:00, 3184.74it/s]
Shuffle step 1/2: 100%|██████████| 1821/1821 [00:00<00:00, 327241.97it/s]
Shuffle step 2/2: 100%|██████████| 1/1 [00:00<00:00, 304.09it/s]

32
{'input': 'a stirring , funny and finally transporting re-imagining of beauty and the beast and 1930s horror films', 'question': 'What is the sentiment of this sentence?', '<ans>': 'great'}
{'0': 'bad', '1': 'great'}
['bad', 'great']





### 2. 训练（Training）
我们自定义一个DecT_CPM类，方便实现在CPMBee上的decoder tuning，通过DecT_CPM中的run，run_zs函数分别实现16-shot下的decoder tuning和0-shot下的原本模型能力的测试。需要将预训练好的模型存储在`./ckpt/`文件夹下。具体而言，`./ckpt/`文件夹需要预先保存`./ckpt/config.json`、`./ckpt/pytorch_model.bin`和`./ckpt/vocab.txt`。

We customize a DecT_CPM class to facilitate the realization of decoder tuning on CPMBee. Through the run and run_zs functions in DecT_CPM, decoder tuning under 4-shot and the original model ability test under 0-shot can be realized respectively. The pre-trained model needs to be stored in the './ckpt/ 'folder. Specifically, the './ckpt/ 'folder needs to be pre-stored'./ckpt/config.json ', './ckpt/pytorch_model.bin 'and'./ckpt/vocab.txt '.

初始化一些超参数

Initialize some hyperparameters and verbalizer

In [8]:
# 初始化一些超参数以及verbalizer
lr = 4e-3
proto_dim = 128
model_logits_weight = 1
max_epochs = 100

构造decoder_tuning的trainer，decoder tuning的主体模型是一个线性层，当数据量较大时，可以增加层数，以提升模型的学习能力。

Construct the trainer of decoder_tuning. The main model of decoder tuning is a linear layer. When the amount of data is large, the number of layers can be increased to improve the learning ability of the model.

In [9]:
import os, shutil
import sys
sys.path.append(".")
import time
import torch
from torch import nn
import torch.nn.functional as F
from tqdm import tqdm
import dill
import warnings
from typing import Optional
from typing import Callable, Union, Dict, List
try:
    from typing import OrderedDict
except ImportError:
    from collections import OrderedDict
from sklearn.metrics import accuracy_score
from copy import deepcopy

class DecTCPM(object):
    r"""A runner for DecT
    This class is specially implemented for classification.
    Decoder Tuning: Efficient Language Understanding as Decoding : https://arxiv.org/pdf/2212.08408.pdf

    Args:
        model (:obj:`CPMBeeTorch`): One ``CPMBeeTorch`` object.
        test_dataloader (:obj:`FinetuneDataset`): The dataloader to bachify and process the test data.
        tokenizer (:obj:`CPMBeeTokenizer`): The tokenizer to process the word.
        verbalizer (:obj:`Verbalizer`): The verbalizer to map the label to the word.
        device (:obj:`torch.device`): The device to run the model.
        calibrate_dataloader (:obj:`FinetuneDataset`, optional): The dataloader that has empty input, to modify the output logits. Defaults to None.
        lr (:obj:`float`, optional): The learning rate. Defaults to 5e-3.
        hidden_size (:obj:`int`, optional): The hidden size of the model. Defaults to 4096.
        mid_dim (:obj:`int`, optional): The dimension of the proto vector. Defaults to 128.
        epochs (:obj:`int`, optional): The number of epochs to train. Defaults to 5.
        model_logits_weight (:obj:`float`, optional): The weight of the model logits. Defaults to 1.
    """
    def __init__(self, 
                 model,
                 test_dataloader,
                 val_dataloader,
                 tokenizer,
                 verbalizer,
                 device: Optional[Union[str, torch.device]] = "cuda:0",
                 calibrate_dataloader: Optional[List] = None,
                 lr: Optional[float] = 5e-3,
                 hidden_size: Optional[int] = 4096,
                 mid_dim: Optional[int] = 128,
                 epochs: Optional[int] = 5,
                 model_logits_weight: Optional[float] = 1,
                 ):
        self.model = model
        self.val_dataloader = val_dataloader
        self.test_dataloader = test_dataloader
        self.calibrate_dataloader = calibrate_dataloader
        self.loss_function = torch.nn.CrossEntropyLoss()
        self.device = device
        ids = []
        for idx in range(len(verbalizer.items())):
            ids.append(tokenizer.encode(verbalizer[str(idx)])[0][0])
        self.label_list = list(verbalizer.values())
        self.label_word_token_ids = []
        for label_word in self.label_list:
            self.label_word_token_ids.append(tokenizer.encode(label_word)[0][0])
        self.ids = ids #nn.Parameter(torch.tensor(ids), requires_grad=False)
        self.num_classes = len(self.ids)
        self.lr = lr
        self.mid_dim = mid_dim
        self.epochs = epochs
        self.model_logits_weight = model_logits_weight
        self.hidden_dims = hidden_size
        self.reset_parameter()
    
    # reset the parameters, useful when you want to test different random seeds
    # self.head is a linear layer, if you want to use other models, you can modify it (useful when there are more data)
    def reset_parameter(self):
        self.head = nn.Linear(self.hidden_dims, self.mid_dim, bias=False)
        w = torch.empty((self.num_classes, self.mid_dim)).to(self.device)
        nn.init.xavier_uniform_(w)
        self.proto = nn.Parameter(w, requires_grad=False)
        r = torch.ones(self.num_classes)
        self.proto_r = nn.Parameter(r, requires_grad=True)
        self.optimizer = torch.optim.Adam([p for n, p in self.head.named_parameters()] + [self.proto_r], lr=self.lr)


    # get the logits and hidden states of the model, specifically for cpmbee model, you can modify it for other models
    def get_logits_and_hidden(self,data):
        input_ids = torch.from_numpy(data["inputs"]).cuda().to(torch.int32)
        input_ids_sub = torch.from_numpy(data["inputs_sub"]).cuda().to(torch.int32)
        input_length = torch.from_numpy(data["length"]).cuda().to(torch.int32)
        input_context = torch.from_numpy(data["context"]).cuda().bool()
        input_sample_ids = torch.from_numpy(data["sample_ids"]).cuda().to(torch.int32)
        input_num_segments = torch.from_numpy(data["num_segments"]).cuda().to(torch.int32)
        input_segment_ids = torch.from_numpy(data["segment_ids"]).cuda().to(torch.int32)
        input_segment_rel_offset = (
            torch.from_numpy(data["segment_rel_offset"]).cuda().to(torch.int32)
        )
        input_segment_rel = torch.from_numpy(data["segment_rel"]).cuda().to(torch.int32)
        input_span = torch.from_numpy(data["spans"]).cuda().to(torch.int32)
        targets = torch.from_numpy(data["target"]).cuda().to(torch.int32)
        ext_table_ids = torch.from_numpy(data["ext_ids"]).cuda().to(torch.int32)
        ext_table_sub = torch.from_numpy(data["ext_sub"]).cuda().to(torch.int32)
        task_ids = torch.from_numpy(data["task_ids"]).cuda().to(torch.int32)
        task_names = data["task_names"]
        # to get the label from the targets
        mask = torch.logical_or(targets ==self.ids[0], targets==self.ids[1])
        labels = targets[mask]
        final_label = []
        for i in range(len(labels)):
            final_label.append(self.ids.index(labels[i]))
        with torch.no_grad():
            logits, hidden_states = self.model(
                    input_ids,
                    input_ids_sub,
                    input_length,
                    input_context,
                    input_sample_ids,
                    input_num_segments,
                    input_segment_ids,
                    input_segment_rel_offset,
                    input_segment_rel,
                    input_span,
                    ext_table_ids,
                    ext_table_sub,
                )
        # mask the targets where value is -100 or 7, to get the index of the valid position
        mask_matrix = deepcopy(targets)
        mask_matrix[targets == -100] = 0
        mask_matrix[targets == 7] = 0
        index_mask = mask_matrix.nonzero(as_tuple=False)
        # finally we get the logits and hidden states of the <ans> word position
        filtered_logits = logits[index_mask[:, 0], index_mask[:, 1], :]
        filtered_hiddens = hidden_states[index_mask[:, 0], index_mask[:, 1], :]
        label_logits = filtered_logits[:,self.label_word_token_ids] # F.softmax(filtered_logits)[:,self.label_word_token_ids]
        return label_logits, filtered_hiddens,final_label
    
    # test the model on the dev set, if zs is true, then test on the zero-shot setting, otherwise test on the decoder tuning setting
    def test(self, dataloader,zs):
        if zs:
            preds = []
            labels = []
            for iteration, data in enumerate(dataloader):
                if data is None:
                    if last_data is None:
                        raise RuntimeError(
                            "Dataset is too small, please use a smaller batch size or sequence length!"
                        )
                    data = last_data  # use last data
                    skip_this_batch = True
                else:
                    last_data = data
                logits,_,label = self.get_logits_and_hidden(data)
                preds.extend(torch.argmax(logits, dim=-1).cpu().tolist())
                labels.extend(label)
            res = sum([int(i==j) for i,j in zip(preds, labels)])/len(preds)
            return res
        else:
            preds = []
            labels = []
            for iteration, data in enumerate(dataloader):
                if data is None:
                    if last_data is None:
                        raise RuntimeError(
                            "Dataset is too small, please use a smaller batch size or sequence length!"
                        )
                    data = last_data  # use last data
                    skip_this_batch = True
                else:
                    last_data = data
                logits,hidden_states,label = self.get_logits_and_hidden(data)
                proto_logits = self.sim(self.head(hidden_states.float()), self.proto, self.proto_r, logits.float(), self.model_logits_weight).cpu()
                preds.extend(torch.argmax(proto_logits, dim=-1).cpu().tolist())
                labels.extend(label)
            res = sum([int(i==j) for i,j in zip(preds, labels)])/len(preds)
            return res

    @staticmethod
    def sim(x, y, r=0, model_logits=0, model_logits_weight=1):
        x = torch.unsqueeze(x, -2)
        x = F.normalize(x, dim=-1)
        d = torch.norm((x - y), dim=-1)
        dist = d - model_logits * model_logits_weight - r
        return -dist
    
    # conduct the loss function in the decoder tuning
    def loss_func(self, x, model_logits, labels):
        sim_mat = torch.exp(self.sim(x, self.proto, self.proto_r, model_logits, self.model_logits_weight))
        pos_score = torch.sum(sim_mat * F.one_hot(labels), -1)
        loss = -torch.mean(torch.log(pos_score / sim_mat.sum(-1)))
        return loss
    
    # run zero shot setting
    def run_zs(self):
        res = self.test(self.test_dataloader, zs = True)
        print("zero shot acc:",res)

    # train the model with decoder tuning, you need to provide the training dataloader (type:FinetuneDataset)
    def run(self, train_dataloader):
        logits_list = []
        hidden_states_list = []
        labels = []
        with torch.no_grad():
            for iteration, data in enumerate(train_dataloader):
                if data is None:
                    if last_data is None:
                        raise RuntimeError(
                            "Dataset is too small, please use a smaller batch size or sequence length!"
                        )
                    data = last_data  # use last data
                    skip_this_batch = True
                else:
                    last_data = data
                train_logits, train_embeds,label = self.get_logits_and_hidden(data)
                logits_list.append(train_logits)
                hidden_states_list.append(train_embeds)
                labels.extend(label)
        train_logits = torch.cat(logits_list,dim=0)
        train_embeds = torch.cat(hidden_states_list,dim=0)
        embeds = [[] for _ in range(self.num_classes)]
        train_labels = [[] for _ in range(self.num_classes)]
        model_logits = [[] for _ in range(self.num_classes)]
        total_num = 0
        start_time = time.time()

        for idx, label in enumerate(labels):
            label = torch.tensor(label)
            train_labels[label].append(label)
            embeds[label].append(torch.tensor(train_embeds[idx]))
            model_logits[label].append(torch.tensor(train_logits[idx]))
        embeds = list(map(torch.stack, embeds))
        labels = torch.cat(list(map(torch.stack, train_labels))).to(self.device)
        model_logits = torch.cat(list(map(torch.stack, model_logits))).float()

        self.head.to(self.device)
        self.proto.to(self.device)
        self.proto_r.to(self.device)
        dist = list(map(lambda x: torch.norm(self.head(x.float()) - self.head(x.float().mean(0)), dim=-1).mean(), embeds))
        self.proto_r.data = torch.stack(dist)
        
        loss = 0.
        best_eval_res = 0.
        
        for epoch in range(self.epochs):
            x = self.head(torch.cat(embeds).float())
            self.optimizer.zero_grad()
            loss = self.loss_func(x, model_logits, labels)
            loss.backward()
            self.optimizer.step()
            # use vaild dataset to evaluate the model, and test on best_eval_res
            if epoch % 20 == 0 and epoch > 0 :
                print("Total epoch: {}. DecT loss: {}".format(epoch, loss))
                eval_res = self.test(self.val_dataloader, zs = False)
                print("val acc:", eval_res)
                if eval_res > best_eval_res:
                    best_eval_res = eval_res
                    test_res = self.test(self.test_dataloader, zs = False)
                    print("test acc at best val:",test_res)


        end_time = time.time()
        print("Total time: {}".format(end_time - start_time))
        res= self.test(self.test_dataloader, zs = False)
        print("Final acc:",res)

加载模型的权重并部署

Load the model's weights and deploy

In [None]:
from cpm_live.tokenizers import CPMBeeTokenizer
from cpm_live.training_tasks.bee import FinetuneDataset
from cpm_live.models import CPMBeeConfig, CPMBeeTorch
import torch
import torch.nn.functional as F
import bmtrain as bmt
from copy import deepcopy
config = CPMBeeConfig.from_json_file("./ckpt/config.json")
ckpt_path = "./ckpt/pytorch_model.bin"
tokenizer = CPMBeeTokenizer()
model = CPMBeeTorch(config=config)
model.load_state_dict(torch.load(ckpt_path), strict=False)
device = torch.device("cuda:0")
model.to(device)

构建dataloader
build dataloader

In [11]:
train_dataloader = FinetuneDataset(
        dataset_path = "./decoder_tuning_data//bin_data/{}/train_data".format(dataset_name),
        batch_size=8,
        max_length=512,
        max_depth=8,
        tokenizer=tokenizer,
    )
val_dataloader = FinetuneDataset(
        dataset_path = "./decoder_tuning_data/bin_data/{}/valid_data".format(dataset_name),
        batch_size=8,
        max_length=512,
        max_depth=8,
        tokenizer=tokenizer,
    )
test_dataloader = FinetuneDataset(
        dataset_path = "./decoder_tuning_data/bin_data/{}/test_data".format(dataset_name),
        batch_size=8,
        max_length=512,
        max_depth=8,
        tokenizer=tokenizer,
    )

构建runner

Build runner

In [13]:
runner = DecTCPM(
    model = model,
    test_dataloader = test_dataloader,
    val_dataloader=val_dataloader,
    tokenizer = tokenizer,
    verbalizer = verbalizer,
    device = device,
    calibrate_dataloader = None,
    lr = lr,
    mid_dim = proto_dim,
    epochs = max_epochs,
    model_logits_weight = model_logits_weight,
)

开始训练，最后输出的是zero shot和decoder tuning的准确率

Start training, the final output is zero shot and decoder tuning accuracy

In [14]:
runner.run_zs()
runner.run(train_dataloader)

zero shot acc: 0.8791872597473915


  embeds[label].append(torch.tensor(train_embeds[idx]))
  model_logits[label].append(torch.tensor(train_logits[idx]))


Total epoch: 20. DecT loss: 0.0790218934416771
val acc: 0.9375
test acc at best val: 0.9192751235584844
Total epoch: 40. DecT loss: 0.05866580083966255
val acc: 0.9375
Total epoch: 60. DecT loss: 0.05533261597156525
val acc: 0.9375
Total epoch: 80. DecT loss: 0.0544731467962265
val acc: 0.9375
Total time: 36.721224308013916
Final acc: 0.9203734211971444
