In [1]:
import torch
import json
import numpy as np

加载一个模型

In [3]:
# Load model directly
from transformers import AutoTokenizer, AutoModelForCausalLM
model_path = "../../model/llama_2.7b"
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(model_path).half().to(torch.device('cuda'))

  return self.fget.__get__(instance, owner)()
Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.14it/s]


读取数据集，用get_data_reader方法从一个json文件中读取，json文件格式要求如下：
{ \
   'data_info':{\
    'data_name': 'rte_train',\
    'label_space': ['entailment', 'not_entailment'],\
    'columns': ['premise', 'hypothesis', 'label']\
   },\
   'data':[\
    sample1,\
    sample2,\
    ...\
   ]\
}\
\
读取数据时传入label_map参数，将标签从默认标签转换为想要的标签

In [4]:
from src.data_reader import get_data_reader, DataReader

trainset = get_data_reader(file_name='rte_train.json',  label_map={'entailment': 'yes','not_entailment':'no'})
testset = get_data_reader(file_name='scitail.json', label_map={0: 'yes',1:'no'})
trainset.data_info

{'data_name': 'rte_train',
 'label_space': ['entailment', 'not_entailment'],
 'columns': ['premise', 'hypothesis', 'label']}

其中每个sample是一个字典，由若干关键词（必须包括label）构成，例如：

In [5]:
trainset[0]

{'premise': 'No Weapons of Mass Destruction Found in Iraq Yet.',
 'hypothesis': 'Weapons of Mass Destruction Found in Iraq.',
 'label': 'no'}

初始化一个prompter用于生成上下文，需要给定一个模板（必须），一个提示头（可选），模板中将需要替换的sample关键词用[]标出

In [6]:
from src.prompter import Prompter
template = "Premise:[premise]\nHypothesis:[hypothesis]\nAnswer:[label]"

prompter = Prompter(template=template, head="This is a head", sep='\n')

prompter的generate_context方法将demos和sample处理成prompt

In [7]:
print(prompter.generate_context(trainset[:3], testset[0]))

This is a head
Premise:No Weapons of Mass Destruction Found in Iraq Yet.
Hypothesis:Weapons of Mass Destruction Found in Iraq.
Answer:no
Premise:A place of sorrow, after Pope John Paul II died, became a place of celebration, as Roman Catholic faithful gathered in downtown Chicago to mark the installation of new Pope Benedict XVI.
Hypothesis:Pope Benedict XVI is the new leader of the Roman Catholic Church.
Answer:yes
Premise:Herceptin was already approved to treat the sickest breast cancer patients, and the company said, Monday, it will discuss with federal regulators the possibility of prescribing the drug for more breast cancer patients.
Hypothesis:Herceptin can be used to treat breast cancer.
Answer:yes
Premise:Pluto rotates once on its axis every 6.39 Earth days;
Hypothesis:Earth rotates on its axis once times in one day.
Answer:


挑选示例的方法主要由下面三个类实现，其中大部分算法由前两种实现。
1. selector: 从一个数据集中挑选出一组示例
2. retriever: 为一个sample，从数据集中检索出一组示例
3. ranker: 根据一个sample，为一组示例进行排序

In [10]:
from src.selector import RandomSelector

random_selector = RandomSelector()
demos_l = [random_selector.select(trainset,num=8) for _ in range(len(testset))]

初始化一个inferencer进行推理，目前支持direct inferencer（利用推测解码直接获取label上的概率）和generation inferencer（常规的自回归生成）

In [11]:
from src.inferencer import DirectInferencer
labels = ['yes','no','maybe']
direct_inferencer = DirectInferencer(model, tokenizer, prompter, labels)

进行推理

In [13]:
y = direct_inferencer.infer(demos=demos_l[0], sample=testset[0])
y

'no'

对整个测试集推理

In [14]:
y_p = direct_inferencer.batch_infer(demos_l, testset[:100])

100%|██████████| 100/100 [00:14<00:00,  6.94it/s]


可以调用Evaluator类对结果进行评估（暂时只支持accuracy和f1-score，可以自己拓展）

In [15]:
from src.evaluator import Evaluator
y_t = [testset[_]['label'] for _ in range(len(testset[:100]))]
evaluator = Evaluator()

In [16]:
evaluator.acc_evaluate(y_p, y_t)

0.57