In [1]:
from utils.saver import tokenizer_loader, model_loader
from utils.constants import DEVICE
from operators.CONLLReader import CONLLReader
from operators.NERDataset import NERDataset
from transformers import BertTokenizer, BartForConditionalGeneration

DATASET_NAME = "msra.min"
dataset = NERDataset(
    reader=CONLLReader(f"./data/{DATASET_NAME}.test"),
    tokenizer=tokenizer_loader(BertTokenizer, "fnlp/bart-base-chinese")
)

model = model_loader(BartForConditionalGeneration, "fine-tune/prompt-bart")
model.to(DEVICE)

BartForConditionalGeneration(
  (model): BartModel(
    (shared): Embedding(51271, 768, padding_idx=0)
    (encoder): BartEncoder(
      (embed_tokens): Embedding(51271, 768, padding_idx=0)
      (embed_positions): BartLearnedPositionalEmbedding(1026, 768)
      (layers): ModuleList(
        (0): BartEncoderLayer(
          (self_attn): BartAttention(
            (k_proj): Linear(in_features=768, out_features=768, bias=True)
            (v_proj): Linear(in_features=768, out_features=768, bias=True)
            (q_proj): Linear(in_features=768, out_features=768, bias=True)
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (activation_fn): GELUActivation()
          (fc1): Linear(in_features=768, out_features=3072, bias=True)
          (fc2): Linear(in_features=3072, out_features=768, bias=True)
          (final_layer_norm): LayerNorm((768,), eps=1e-05,

In [3]:
from PromptWeaver import BartPromptOperator

def calc_labels_entity(dataset):
    labels = list(set(
        map(lambda item: item[2:],
        filter(
            lambda item: item != "O",
            dataset.id_label
        ))
    ))

    return { item: BartPromptOperator.LABEL_ENTITY[item] for item in labels }

part_labels_entity = calc_labels_entity(dataset)
part_labels_entity

{'LOC': '地点', 'PER': '人名'}

In [31]:
GRAM = 4

def generate_template(sentence_str, start_point):
    result = []
    for span_size in range(1, GRAM + 1):
        span = sentence_str[start_point:start_point + span_size]
        result.append(BartPromptOperator.NEGATIVE_TEMPLATE.format(candidate_span=span))
        for entity in part_labels_entity:
            result.append(BartPromptOperator.POSITIVE_TEMPLATE.format(
                candidate_span=span,
                entity_type=part_labels_entity[entity]
            ))
    return result

def calc_max_possible(model, tokenizer, sentence_str, templates):
    inputs = tokenizer([sentence_str] * len(templates), return_tensors="pt")["input_ids"]
    output = tokenizer(templates, return_tensors="pt", padding=True)["input_ids"]
    print(output.shape)

def predict_label(model, tokenizer, sentence_str):
    right = len(sentence_str) - GRAM + 1
    # for start_point in range(0, right):
    for start_point in range(0, 1):
        templates = generate_template(sentence_str, start_point)
        calc_max_possible(model, tokenizer, sentence_str, templates)

predict_label(model, dataset.tokenizer, "".join(dataset.reader.sentences[0]))

tensor([[  101,  2483, 17516,  2484,  4909, 11009,  4896,  4938,  6544,  6432,
          8339,  5232,   102,     0,     0,     0],
        [  101,  2483, 17516,  2484, 11009,  4896,  4938,  7229, 13670,  8339,
          5232,   102,     0,     0,     0,     0],
        [  101,  2483, 17516,  2484, 11009,  4896,  4938,  5080,  6432,  8339,
          5232,   102,     0,     0,     0,     0],
        [  101,  2483, 17516,  5080,  2484,  4909, 11009,  4896,  4938,  6544,
          6432,  8339,  5232,   102,     0,     0],
        [  101,  2483, 17516,  5080,  2484, 11009,  4896,  4938,  7229, 13670,
          8339,  5232,   102,     0,     0,     0],
        [  101,  2483, 17516,  5080,  2484, 11009,  4896,  4938,  5080,  6432,
          8339,  5232,   102,     0,     0,     0],
        [  101,  2483, 17516,  5080, 15134,  2484,  4909, 11009,  4896,  4938,
          6544,  6432,  8339,  5232,   102,     0],
        [  101,  2483, 17516,  5080, 15134,  2484, 11009,  4896,  4938,  7229,
    