In [1]:
import torch
from transformers import BertTokenizer, BertForTokenClassification, BertForMaskedLM
from model import PinyinBertForMaskedLM

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

tokenizer = BertTokenizer.from_pretrained("../pretrained_models/bert-base-chinese")
detect_model = BertForTokenClassification.from_pretrained("../pretrained_models/bert-base-chinese", use_safetensors=True).to(device)
correct_model = PinyinBertForMaskedLM.from_pretrained("../pretrained_models/bert-base-chinese",  use_safetensors=True).to(device)

Some weights of BertForTokenClassification were not initialized from the model checkpoint at ../pretrained_models/bert-base-chinese and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of PinyinBertForMaskedLM were not initialized from the model checkpoint at ../pretrained_models/bert-base-chinese and are newly initialized: ['bert.embeddings.pinyin_embeddings.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [13]:
detect_state_dict = torch.load("../no_pretrain/detect_model_epoch_1.ckpt")
correct_state_dict = torch.load("../no_pretrain/correct_model_epoch_1.ckpt")
# detect_state_dict = torch.load("../check_point/2kk_detect_1.ckpt")
# correct_state_dict = torch.load("../check_point/2kk_correct_1.ckpt")
detect_model.load_state_dict(detect_state_dict)
correct_model.load_state_dict(correct_state_dict)

<All keys matched successfully>

In [14]:
import sys
from pypinyin import lazy_pinyin
sys.path.append("..")

from build.utils import text2token, token2ids, pinyin_similarity, get_chinese_part

detect_model.eval()
correct_model.eval()
pinyin_vocab = {}
with open("../build/pinyin_vocab.txt", "r", encoding="utf-8") as f:
    lines = f.readlines()
    for i, line in enumerate(lines):
        pinyin_vocab[line.strip("\n")] = i

In [15]:
from detection import detect

def predict(text):
    inputs = tokenizer(text, return_tensors='pt').to(device)
    error_label = detect(detect_model, tokenizer, [text], 0.67, show_error=False)[0].unsqueeze(-1).to(device)
    text_token, pinyin_token = text2token(text, tokenizer)
    pinyin_ids = torch.tensor([0] + token2ids(pinyin_token, pinyin_vocab) + [0], device=device)
    output = correct_model(**inputs, pinyin_ids=pinyin_ids, error_prob=error_label)

    predict_ids_for_output = output[0].argmax(-1)[0]
    result = "".join(tokenizer.convert_ids_to_tokens(predict_ids_for_output)[1:-1])
    return result

In [6]:
import json
from tqdm import tqdm
with open("../SIGHAN2015/test.json", 'r', encoding='utf-8') as f:
    test_data = json.load(f)

cd_tp, cd_fp, cd_tn, cd_fn = 0, 0, 0, 0
cc_tp, cc_fp, cc_tn, cc_fn = 0, 0, 0, 0
sd_tp, sd_fp, sd_tn, sd_fn = 0, 0, 0, 0
sc_tp, sc_fp, sc_tn, sc_fn = 0, 0, 0, 0

correct_num = 0
for d in test_data:
    correct_text = ''.join(tokenizer.tokenize(d['correct_text']))
    original_text = ''.join(tokenizer.tokenize(d['original_text']))
    predict_text = predict(d['original_text'])
    correct_array = [char for char in correct_text]
    original_array = [char for char in original_text]
    predict_array = [char for char in predict_text]
    if "UNK" not in correct_text and "UNK" not in original_text and "UNK" not in predict_text:
        # print(f"original_text:{original_text}, correct_text:{correct_text}, predict_text:{predict_text}")
        # Character-level Metrics
        for i in range(len(correct_array)):
            correct_char = correct_array[i]
            original_char = original_array[i]
            predict_char = predict_array[i]
            # should correct
            if original_char != correct_char:
                if original_char != predict_char:
                    cd_tp += 1
                    if predict_char == correct_char:
                        cc_tp += 1
                    else:
                        cc_fn += 1
                else: 
                    cd_fn += 1
                    cc_fn += 1
            # should not correct
            if original_char == correct_char:
                if original_char != predict_char:
                    cd_fp += 1
                    cc_fp += 1
                else: 
                    cd_tn += 1
                    cc_tn += 1
    # Sentence-level Metrics
    flag_tp = 0
    # sentence should correct
    if d['correct_text'] != d['original_text']:
        # sentence no correction
        if predict_text == original_text:
            sd_fn += 1
            flag_tp = 1
        else:
            for i in range(len(correct_array)):
                correct_char = correct_array[i]
                original_char = original_array[i]
                predict_char = predict_array[i]
                # should correct
                if original_char != correct_char:
                    # no correction
                    if original_char == predict_char:
                        sd_fn += 1
                        flag_tp = 1
                        break
                # should not correct
                if original_char == correct_char:
                    if original_char != predict_char:
                        sd_fn += 1
                        flag_tp = 1
                        break        
        if flag_tp == 0:
            sd_tp += 1

    if d['correct_text'] != d['original_text'] and predict_text == correct_text:
        sc_tp += 1
    elif d['correct_text'] == d['original_text'] and predict_text != correct_text:
        sd_fp += 1
        sc_fp += 1
    elif d['correct_text'] == d['original_text'] and predict_text == correct_text:
        sd_tn += 1
        sc_tn += 1
    elif d['correct_text'] != d['original_text'] and predict_text != correct_text:
        sc_fn += 1

In [7]:
print("Character-level Detection Metrics")
accuracy = (cd_tp + cd_tn) / (cd_tp + cd_fn + cd_tn + cd_fp)
precision = cd_tp / (cd_tp + cd_fp)
recall = cd_tp / (cd_tp + cd_fn)
f1 =  2 * precision * recall / (precision + recall)
print(f"tp:{cd_tp}, fp:{cd_fp}, tn:{cd_tn}, fn:{cd_fn}\naccuracy:{accuracy} \nprecision: {precision} \nrecall: {recall} \nf1: {f1}")

print("Character-level Detection Metrics")
accuracy = (cc_tp + cc_tn) / (cc_tp + cc_fn + cc_tn + cc_fp)
precision = cc_tp / (cc_tp + cc_fp)
recall = cc_tp / (cc_tp + cc_fn)
f1 =  2 * precision * recall / (precision + recall)
print(f"tp:{cc_tp}, fp:{cc_fp}, tn:{cc_tn}, fn:{cc_fn}\naccuracy:{accuracy} \nprecision: {precision} \nrecall: {recall} \nf1: {f1}")

print("Sentence-level Detection Metrics")
accuracy = (sd_tp - 3 + sd_tn) / (sd_tp + sd_fn + sd_tn + sd_fp)
precision = (sd_tp - 3) / (sd_tp - 3 + sd_fp)
recall = (sd_tp - 3) / (sd_tp + sd_fn)
f1 =  2 * precision * recall / (precision + recall)
print(f"tp:{sd_tp}, fp:{sd_fp}, tn:{sd_tn}, fn:{sd_fn}\naccuracy:{accuracy} \nprecision: {precision} \nrecall: {recall} \nf1: {f1}")

print("Sentence-level Correction Metrics")
accuracy = (sc_tp + sc_tn) / (sc_tp + sc_fn + sc_tn + sc_fp)
precision = sc_tp / (sc_tp + sc_fp)
recall = sc_tp / (sc_tp + sc_fn)
f1 =  2 * precision * recall / (precision + recall)
print(f"tp:{sc_tp}, fp:{sc_fp}, tn:{sc_tn}, fn:{sc_fn}\naccuracy:{accuracy} \nprecision: {precision} \nrecall: {recall} \nf1: {f1}")

Character-level Detection Metrics
tp:547, fp:136, tn:32574, fn:138
accuracy:0.9917951789189998 
precision: 0.8008784773060029 
recall: 0.7985401459854015 
f1: 0.7997076023391813
Character-level Detection Metrics
tp:382, fp:136, tn:32574, fn:303
accuracy:0.9868543195089086 
precision: 0.7374517374517374 
recall: 0.5576642335766423 
f1: 0.6350789692435578
Sentence-level Detection Metrics
tp:413, fp:98, tn:459, fn:130
accuracy:0.79 
precision: 0.8070866141732284 
recall: 0.7550644567219152 
f1: 0.780209324452902
Sentence-level Correction Metrics
tp:309, fp:98, tn:459, fn:234
accuracy:0.6981818181818182 
precision: 0.7592137592137592 
recall: 0.569060773480663 
f1: 0.6505263157894736


In [8]:
def predict(text):
    text = [text]
    inputs = tokenizer(text, return_tensors='pt').to(device)
    error_label = detect(detect_model, tokenizer, text, 0.2, show_error=False).unsqueeze(-1).to(device)
    text_token, pinyin_token = text2token(text[0], tokenizer)
    pinyin_ids = torch.tensor([0] + token2ids(pinyin_token, pinyin_vocab) + [0], device=device)
    output = correct_model(**inputs, pinyin_ids=pinyin_ids, error_prob=error_label)
    
    predict_ids_for_output = output[0].argmax(-1)[0]
    for i, l in enumerate(error_label.squeeze()):
        if l:
            predict_ids = predict_ids_for_output.clone()
            max_sim = 0.01
            min_sco = 1e6
            max_s = 0
            for ind in output[0][:, i, :].topk(k=5).indices[0]:
                predict_ids[i] = ind
                score = detect_model(predict_ids.unsqueeze(0)).logits
                score = torch.nn.functional.softmax(score, dim=-1)[:, :, 1]
                sim = pinyin_similarity(lazy_pinyin(tokenizer.convert_ids_to_tokens([ind]))[0], \
                                      lazy_pinyin(tokenizer.convert_ids_to_tokens([inputs['input_ids'][0][i]]))[0])
                print(f"{score[0][i].item():.5f} {sim:.2f}, {tokenizer.convert_ids_to_tokens([ind])}")
                if (min_sco - score[0][i] + 0.1) * sim / max_sim > max_s:
                    max_sim =  sim
                    min_sco = score[0][i]
                    max_s = (min_sco - score[0][i] + 0.1) * sim / max_sim
                    predict_ids_for_output[i] = ind
    text = ["".join(tokenizer.convert_ids_to_tokens(predict_ids_for_output)[1:-1])]
                    # print(predict_text)
    return text[0]

In [9]:
def predict_top5(text):
    text = [text]
    inputs = tokenizer(text, return_tensors='pt').to(device)
    error_label = detect(detect_model, tokenizer, text, 0.67, show_error=False)[0].unsqueeze(-1).to(device)
    pinyin_ids = torch.tensor([0] + token2ids(text2token(text[0], tokenizer)[1], pinyin_vocab) + [0], device=device)
    output = correct_model(**inputs, pinyin_ids=pinyin_ids, error_prob=error_label)
    
    predict_ids_for_output = output[0].argmax(-1)[0]
    for i, l in enumerate(error_label.squeeze()):
        if l:
            print(tokenizer.convert_ids_to_tokens(output[0][:, i, :].topk(k=5).indices[0]))
    
    return "".join(tokenizer.convert_ids_to_tokens(predict_ids_for_output)[1:-1])

In [10]:
def token_error_prob(text):
    inputs = tokenizer(text, return_tensors='pt').to(device)
    outputs = detect_model(**inputs).logits
    outputs = torch.nn.functional.softmax(outputs, dim=-1)[:, :, 1]
    for i, (a, b) in enumerate(zip(text, outputs[0][1:-1])):
        norm_prob = 0
        if i == 0:
            norm_prob = 0.5*b + 0.5*outputs[0][1:-1][i+1]
        elif i == len(text) - 1:
            norm_prob = 0.5*b + 0.5*outputs[0][1:-1][i-1]
        else:
            norm_prob = 0.4*b + 0.3*outputs[0][1:-1][i-1] + 0.3*outputs[0][1:-1][i+1]
        print(f"{a} : {b:.3f}")

In [11]:
text = "你吃早菜了吗？"
token_error_prob(text)
predict_top5(text)

你 : 0.003
吃 : 0.006
早 : 0.087
菜 : 0.827
了 : 0.002
吗 : 0.002
？ : 0.000
['餐', '饭', '点', '起', '盘']


'你吃早餐了吗？'

In [16]:
predict("我跟小明在客厅便喝啤酒便看电视，小明的流量很好，喝了很多可是没有醉。")

'我跟小明在客厅边喝啤酒跟看电视，小明的流量很好，喝了很多可是没有醉。'