In [12]:
import torch
import json
import os
import numpy as np
from src.utils import set_seed

In [2]:
# os.environ["CUDA_VISIBLE_DEVICES"] = '2'

加载一个模型

In [3]:
# Load model directly
from transformers import AutoTokenizer, AutoModelForCausalLM
model_path = "../model/llama_1.3b"
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(model_path).half().to(torch.device('cuda'))

  from .autonotebook import tqdm as notebook_tqdm
You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama_fast.LlamaTokenizerFast'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565 - if you loaded a llama tokenizer from a GGUF file you can ignore this message.


读取数据集，用get_data_reader方法从一个json文件中读取，json文件格式要求如下：
{ \
   'data_info':{\
    'data_name': 'rte_train',\
    'label_space': ['entailment', 'not_entailment'],\
    'columns': ['premise', 'hypothesis', 'label']\
   },\
   'data':[\
    sample1,\
    sample2,\
    ...\
   ]\
}\
\
读取数据时传入label_map参数，将标签从默认标签转换为想要的标签

In [13]:
from src.data_reader import get_data_reader, DataReader

trainset = get_data_reader(file_name='agnews/train.json',)
set_seed(0)
testset = get_data_reader(file_name='agnews/test.json', ).get_subset(256)
trainset.data_info

{'data_name': 'agnews_train',
 'label_space': ['Business', 'Technology', 'World', 'Sports'],
 'columns': ['sentence', 'label']}

其中每个sample是一个字典，由若干关键词（必须包括label）构成，例如：

In [14]:
trainset[0]

{'sentence': "Reuters - Short-sellers, Wall Street's dwindling\\band of ultra-cynics, are seeing green again.",
 'label': 'Business'}

初始化一个prompter用于生成上下文，需要给定一个模板（必须），一个提示头（可选），模板中将需要替换的sample关键词用[]标出

In [15]:
from src.prompter import Prompter
template = "Article:[sentence]\nAnswer:[label]"

prompter = Prompter(template=template, head="Classify the news based on whether their type is Sports, Business, Technology or World.", sep='\n')

prompter的generate_context方法将demos和sample处理成prompt

In [16]:
print(prompter.generate_context(trainset[4:7], testset[0]))

Classify the news based on whether their type is Sports, Business, Technology or World.
Article:AFP - Tearaway world oil prices, toppling records and straining wallets, present a new economic menace barely three months before the US presidential elections.
Answer:Business
Article:Reuters - Stocks ended slightly higher on Friday\but stayed near lows for the year as oil prices surged past  #36;46\a barrel, offsetting a positive outlook from computer maker\Dell Inc. (DELL.O)
Answer:Business
Article:AP - Assets of the nation's retail money market mutual funds fell by  #36;1.17 billion in the latest week to  #36;849.98 trillion, the Investment Company Institute said Thursday.
Answer:Business
Article:Jermain Defoe underlined his claims for an improved contract as he inspired Tottenham to a 2-0 win against 10-man Middlesbrough. New coach Martin Jol, who secured his first win in charge, may have been helped 
Answer:


挑选示例的方法主要由下面三个类实现，其中大部分算法由前两种实现。
1. selector: 从一个数据集中挑选出一组示例
2. retriever: 为一个sample，从数据集中检索出一组示例
3. ranker: 根据一个sample，为一组示例进行排序

In [19]:
from src.selector import RandomSelector

random_selector = RandomSelector()
demos_l = [random_selector.select(trainset,num=8) for _ in range(len(testset))]

初始化一个inferencer进行推理，目前支持direct inferencer（利用推测解码直接获取label上的概率）和generation inferencer（常规的自回归生成）

In [20]:
from src.inferencer import DirectInferencer
labels = ['yes','no','maybe']
direct_inferencer = DirectInferencer(model, tokenizer, prompter, labels)

进行推理

In [21]:
y = direct_inferencer.infer(demos=demos_l[0], sample=testset[0])
y

We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)


'no'

对整个测试集推理

In [22]:
y_p = direct_inferencer.batch_infer(demos_l, testset[:100])

100%|██████████| 100/100 [00:07<00:00, 13.83it/s]


可以调用Evaluator类对结果进行评估（暂时只支持accuracy和f1-score，可以自己拓展）

In [24]:
from src.evaluator import Evaluator
y_t = [testset[_]['label'] for _ in range(len(testset[:100]))]
evaluator = Evaluator()

In [25]:
evaluator.acc_evaluate(y_p, y_t)

0.41