In [77]:
import torch
import numpy as np
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torch import nn, tensor, Tensor
from datasets import load_dataset
from transformers import AutoModelForMaskedLM, AutoTokenizer
from typing import Any, AnyStr
from collections import namedtuple

%matplotlib inline

In [78]:
modelCheckPoint = 'bert-large-cased'
preModel = AutoModelForMaskedLM.from_pretrained(modelCheckPoint)
preTokenizer = AutoTokenizer.from_pretrained(modelCheckPoint)

Some weights of the model checkpoint at bert-large-cased were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [79]:
AnswerTuple = namedtuple("AnswerTuple", ['word', 'index'])

In [80]:
def TokenText(tokenizer, text):
    return tokenizer(text, return_tensors = "pt")

def PadSingleWord(
    model,
    tokenizer,
    text,
    options
) -> AnyStr:
    """
    根据上下文填充单个词语
    Args:
        model:          语言模型
        tokenizer:      编码器
        text:           上下文
        options:        给予的选项
    Returns:
        返回模型从给予的选项中概率最大
    """
    
    text = text.replace('_', '[MASK]')
    # 获得编码
    textToken = TokenText(tokenizer, text)
    textLogits = model(**textToken).logits

    # 获得掩码位置的概率
    maskTokenIndex = torch.where(textToken['input_ids'] == tokenizer.mask_token_id)[1]
    maskLogits = textLogits[0, maskTokenIndex, :][0]

    # 获得选项编码 并且去除头尾的标签
    optionsToken = TokenText(tokenizer, options)['input_ids'][0]
    optionsToken = optionsToken[1 : len(optionsToken) - 1]

    optionsLogits = tensor([maskLogits[index] for index in optionsToken])
    
    # 获得 答案下标
    resIndex = optionsLogits.argmax()

    return tokenizer.decode(optionsToken[resIndex]), resIndex

In [81]:
resIndex = PadSingleWord(preModel, preTokenizer, "This is _ model", "an fantastic Machillka a")

resIndex

('a', tensor(5))

# 做一整篇完形填空

## 人的思维

首先按照顺序读取文本, 对遇到的每个 mask 先根据上下文做出预测 ( Prediction )；最后检查 ( Check )

## Model

先对全局进行文本读取, 得到对于全局 mask 的理解

In [97]:
def PadWholeText(
    model,
    tokenizer,
    text,
    options
) -> Any:
    """
    Args:
        model:
        tokenizer:
        text:           the text is filled with '[MASK]'
        options:        the options of text, shape: [QLength, OptionNum]
    Return:
        the text model completed and the tuple (word, index)
    """
    answer = []

    text = text.replace('_', '[MASK]')
    queryNum = len(options)

    # 进行全局编码
    textToken = TokenText(tokenizer, text)
    textLogits = model(**textToken).logits
    
    resTextToken = textToken['input_ids'][0]

    # 获得所有掩码位置
    maskTokenIndex = torch.where(textToken['input_ids'] == tokenizer.mask_token_id)[1]

    # 获得全局掩码理解
    maskLogits = textLogits[0, maskTokenIndex, :]

    # 如果 搜索到的 mask 和选项不匹配 则 断言error
    assert maskLogits.size(0) == queryNum, "mask shape is not euqal to options shape"

    # --- 根据全局理解进行填充
    for queryIdx in range(queryNum):
        # 当前查询的 掩码 概率
        logits = maskLogits[queryIdx]

        # 当前查询的 选项
        option = options[queryIdx]
        optionsToken = TokenText(tokenizer, option)['input_ids'][0]
        optionsToken = optionsToken[1 : len(optionsToken) - 1]
        
        optionsLogits = tensor([logits[index] for index in optionsToken])
    
        # 获得当前查询 答案下标
        resIndex = optionsLogits.argmax()
        word = tokenizer.decode(optionsToken[resIndex])

        answer.append(AnswerTuple(word, resIndex))

        resTextToken[maskTokenIndex[queryIdx]] = optionsToken[resIndex]

    # return answer, text
    return answer, tokenizer.decode(resTextToken[1 : len(resTextToken) - 1])

text = 'this is _, I _ your mom, Ahh _'
options =["it my mother", "love none test", "no name vision"]

answer, text = PadWholeText(preModel, preTokenizer, text, options)

print(*zip(*answer))
print(text)

('it', 'love', 'no') (tensor(0), tensor(0), tensor(0))
this is it, I love your mom, Ahh no
