In [2]:
!nvidia-smi

Thu Jul 29 00:42:33 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 465.27       Driver Version: 465.27       CUDA Version: 11.3     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA Tesla V1...  Off  | 00000000:3B:00.0 Off |                    0 |
| N/A   57C    P0    40W / 250W |  28496MiB / 32510MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [4]:
!pip install -U jieba pypinyin transformers --proxy http://10.8.84.123:7890



In [1]:
import json
import re
import random
import time
import jieba
import logging
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertModel
from collections import defaultdict, deque
from pypinyin import lazy_pinyin
import numpy as np
from typing import List
from tqdm import tqdm

batch_size = 32
num_workers = 12
transformers_path = "../hfl/chinese-roberta-wwm-ext"
chinese_pinyin_dict_path = "./resource/chinese_words_frequency.json"
train_dataset_file_path = "./data/train.json"
test_dataset_file_path = "./data/test.json"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
jieba.setLogLevel(logging.INFO)


class ConfusingSentenceGenerator:
    def __init__(self, prob_power_factor=0.7, prob_confuse_phrase=0.1, prob_confuse_word=0.03, prob_similar_sound_word=0.02, prob_random_word=0.02):
        self.chinese_pinyin_dict = dict()
        self.chinese_single_word_pinyin_list = []
        self.chinese_char_regex = re.compile("^[\u4e00-\u9fff]+$")
        self.prob_power_factor = prob_power_factor
        self.prob_confuse_phrase = prob_confuse_phrase
        self.prob_confuse_word = prob_confuse_word
        self.prob_similar_sound_word = prob_similar_sound_word
        self.prob_random_word = prob_random_word
        similar_sounds = {
            'z_zh': [['za', 'ze', 'zi', 'zu', 'zai', 'zui', 'zao', 'zou', 'zan', 'zen', 'zun', 'zuo', 'zuan', 'zang', 'zeng', 'zong'],
                     ['zha', 'zhe', 'zhi', 'zhu', 'zhai', 'zhui', 'zhao', 'zhou', 'zhan', 'zhen', 'zhun', 'zhuo', 'zhuan', 'zhang', 'zheng', 'zhong']],
            'c_ch': [['ca', 'ce', 'ci', 'cu', 'cai', 'cui', 'cao', 'cou', 'can', 'cen', 'cun', 'cuo', 'cuan', 'cang', 'ceng', 'cong'],
                     ['cha', 'che', 'chi', 'chu', 'chai', 'chui', 'chao', 'chou', 'chan', 'chen', 'chun', 'chuo', 'chuan', 'chang', 'cheng', 'chong']],
            's_sh': [['sa', 'se', 'si', 'su', 'sai', 'sui', 'sao', 'sou', 'san', 'sen', 'sun', 'suo', 'suan', 'sang', 'seng'],
                     ['sha', 'she', 'shi', 'shu', 'shai', 'shui', 'shao', 'shou', 'shan', 'shen', 'shun', 'shuo', 'shuan', 'shang', 'sheng']],
            'l_n': [
                ['la', 'le', 'li', 'lu', 'lv', 'lai', 'lei', 'lao', 'lou', 'liu', 'lie', 'liao', 'lian', 'lve', 'lan', 'lin', 'lun', 'luo', 'luan', 'lang', 'leng', 'ling', 'long',
                 'liang'],
                ['na', 'ne', 'ni', 'nu', 'nv', 'nai', 'nei', 'nao', 'nou', 'niu', 'nie', 'niao', 'nian', 'nve', 'nan', 'nin', 'nun', 'nuo', 'nuan', 'nang', 'neng', 'ning', 'nong',
                 'niang']],
            'f_h': [['fa', 'fu', 'fei', 'fou', 'fan', 'fen', 'fang', 'feng'],
                    ['ha', 'hu', 'hei', 'hou', 'han', 'hen', 'hang', 'heng']],
            'r_l': [['re', 'ri', 'ru', 'rao', 'rou', 'ran', 'ren', 'run', 'ruo', 'rang', 'reng', 'rong'],
                    ['le', 'li', 'lu', 'lao', 'lou', 'lan', 'len', 'lun', 'luo', 'lang', 'leng', 'long']],
            'an_ang': [['ban', 'pan', 'man', 'fan', 'dan', 'tan', 'nan', 'lan', 'gan', 'kan', 'han', 'zhan', 'chan', 'shan', 'ran', 'zan', 'can', 'san', 'yan', 'wan'],
                       ['bang', 'pang', 'mang', 'fang', 'dang', 'tang', 'nang', 'lang', 'gang', 'kang', 'hang', 'zhang', 'chang', 'shang', 'rang', 'zang', 'cang', 'sang', 'yang',
                        'wang']],
            'en_eng': [['ben', 'pen', 'men', 'fen', 'den', 'nen', 'gen', 'ken', 'hen', 'zhen', 'chen', 'shen', 'ren', 'zen', 'cen', 'sen', 'wen'],
                       ['beng', 'peng', 'meng', 'feng', 'deng', 'neng', 'geng', 'keng', 'heng', 'zheng', 'cheng', 'sheng', 'reng', 'zeng', 'ceng', 'seng', 'weng']],
            'in_ing': [['bin', 'pin', 'min', 'nin', 'lin', 'jin', 'qin', 'xin', 'yin'],
                       ['bing', 'ping', 'ming', 'ning', 'ling', 'jing', 'qing', 'xing', 'ying']],
            'ian_iang': [['nian', 'lian', 'jian', 'qian', 'xian'],
                         ['niang', 'liang', 'jiang', 'qiang', 'xiang']],
            'uan_uang': [['guan', 'kuan', 'huan', 'zhuan', 'chuan', 'shuan'],
                         ['guang', 'kuang', 'huang', 'zhuang', 'chuang', 'shuang']]
        }

        similar_sound_dict = defaultdict(set)
        for k, v in similar_sounds.items():
            for pair in zip(*v):
                similar_sound_dict[pair[0]].add(pair[1])
                similar_sound_dict[pair[1]].add(pair[0])
        self.similar_sound_dict = {k: list(v) for k, v in similar_sound_dict.items()}

    def build_dictionary(self, corpus: List[str], chinese_pinyin_dict_path: str):
        p_bar = tqdm(corpus, position=0, leave=True)
        p_bar.set_description("build_chinese_pinyin_dict")
        chinese_words_frequency = defaultdict(int)
        for corpus_sentence in p_bar:
            for word in jieba.lcut(corpus_sentence):
                if self.chinese_char_regex.match(word):
                    chinese_words_frequency[word] += 1

        chinese_pinyin_dict = defaultdict(list)
        for word in chinese_words_frequency:
            pinyin_ = "_".join(lazy_pinyin(word))
            chinese_pinyin_dict[pinyin_].append([word, chinese_words_frequency[word]])

        chinese_pinyin_dict = {k: sorted(v, key=lambda x: x[1], reverse=True) for k, v in chinese_pinyin_dict.items()}
        chinese_pinyin_dict = {k: [[x[0] for x in v], [x[1] for x in v]] for k, v in chinese_pinyin_dict.items()}

        with open(chinese_pinyin_dict_path, "w", encoding="utf-8") as f:
            json.dump(chinese_pinyin_dict, f, ensure_ascii=False)

    def load_dictionary(self, chinese_pinyin_dict_path: str):
        with open(chinese_pinyin_dict_path, "r", encoding="utf-8") as f:
            self.chinese_pinyin_dict = json.load(f)
            self.chinese_single_word_pinyin_list = [x for x in self.chinese_pinyin_dict if "_" not in x]

    def generate_sample(self, original_sentence: str):
        confusing_sentence = ""
        # 一定概率替换相同拼音的词组
        for word in jieba.lcut(original_sentence):
            if random.random() < self.prob_confuse_phrase:
                if self.chinese_char_regex.match(word):
                    pinyin_ = "_".join(lazy_pinyin(word))
                    if pinyin_ in self.chinese_pinyin_dict:
                        words_, freq_ = self.chinese_pinyin_dict[pinyin_]
                        if word in words_:
                            freq_[words_.index(word)] = 1e-4
                        freq_ = np.asarray(freq_)
                        freq_ = np.power(freq_, self.prob_power_factor)
                        freq_ = freq_ / np.sum(freq_)
                        confusing_sentence += np.random.choice(words_, p=freq_)
                        continue
            confusing_sentence += word
        confusing_sentence = list(confusing_sentence)
        # 一定概率替换相同拼音的字
        for i, word in enumerate(confusing_sentence):
            if random.random() < self.prob_confuse_word:
                if self.chinese_char_regex.match(word):
                    pinyin_ = "_".join(lazy_pinyin(word))
                    if pinyin_ in self.chinese_pinyin_dict:
                        words_, freq_ = self.chinese_pinyin_dict[pinyin_]
                        if word in words_:
                            freq_[words_.index(word)] = 1e-4
                        freq_ = np.asarray(freq_)
                        freq_ = np.power(freq_, self.prob_power_factor)
                        freq_ = freq_ / np.sum(freq_)
                        confusing_sentence[i] = np.random.choice(words_, p=freq_)
        # 一定概率替换模糊拼音的字
        for i, word in enumerate(confusing_sentence):
            if random.random() < self.prob_similar_sound_word:
                if self.chinese_char_regex.match(word):
                    pinyin_ = "_".join(lazy_pinyin(word))
                    if pinyin_ in self.similar_sound_dict:
                        pinyin_ = random.choice(self.similar_sound_dict[pinyin_])
                        if pinyin_ in self.chinese_pinyin_dict:
                            words_, freq_ = self.chinese_pinyin_dict[pinyin_]
                            if word in words_:
                                freq_[words_.index(word)] = 1e-4
                            freq_ = np.asarray(freq_)
                            freq_ = np.power(freq_, self.prob_power_factor)
                            freq_ = freq_ / np.sum(freq_)
                            confusing_sentence[i] = np.random.choice(words_, p=freq_)
        # 一定概率随机替换字
        for i, word in enumerate(confusing_sentence):
            if random.random() < self.prob_confuse_word:
                if self.chinese_char_regex.match(word):
                    pinyin_ = random.choice(self.chinese_single_word_pinyin_list)

                    words_, freq_ = self.chinese_pinyin_dict[pinyin_]
                    if word in words_:
                        freq_[words_.index(word)] = 1e-4
                    freq_ = np.asarray(freq_)
                    freq_ = np.power(freq_, self.prob_power_factor)
                    freq_ = freq_ / np.sum(freq_)
                    confusing_sentence[i] = np.random.choice(words_, p=freq_)

        confusing_sentence = "".join(confusing_sentence)
        assert len(original_sentence) == len(confusing_sentence)
        label = [int(c1 != c2) for c1, c2 in zip(original_sentence, confusing_sentence)]
        return {
            "original_sentence": original_sentence,
            "confusing_sentence": confusing_sentence,
            "label": label
        }


class SampleEncoder:
    def __init__(self, transformers_path, max_length=128):
        self.max_length = max_length
        self.tokenizer = BertTokenizer.from_pretrained(transformers_path)

    def encode_sentence(self, sentence):
        tokens = [self.tokenizer.cls_token] + list(sentence)
        tokens = tokens[:self.max_length - 1] + [self.tokenizer.sep_token]
        tokens = tokens + [self.tokenizer.pad_token for _ in range(self.max_length - len(tokens))]

        input_ids = self.tokenizer.convert_tokens_to_ids(tokens)
        attention_mask = [1 if token != self.tokenizer.pad_token else 0 for token in tokens]
        token_type_ids = [0 for token in tokens]

        _input_encoding = {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "token_type_ids": token_type_ids
        }
        _input_encoding = {k: torch.tensor(v, dtype=torch.long) for k, v in _input_encoding.items()}
        return _input_encoding

    def encode_label(self, label):
        label = [0] + label
        label = label[:self.max_length - 1] + [0]
        label = label + [0 for _ in range(self.max_length - len(label))]
        label = np.asarray(label)
        return label


class TrainDataset(Dataset):
    def __init__(self, dataset_file_path):
        with open(dataset_file_path, "r", encoding="utf-8") as f:
            self.dataset = json.load(f)
        self.dataset_length = len(self.dataset)
        print(f"load train dataset size: {self.dataset_length}")

    def __getitem__(self, idx):
        sentence = self.dataset[idx]
        sample = confusing_sentence_generator.generate_sample(sentence)
        return (
            sample_encoder.encode_sentence(sample["original_sentence"]),
            sample_encoder.encode_sentence(sample["confusing_sentence"]),
            sample_encoder.encode_label(sample["label"])
        )

    def __len__(self):
        return self.dataset_length


class TestDataset(Dataset):
    def __init__(self, dataset_file_path):
        with open(dataset_file_path, "r", encoding="utf-8") as f:
            self.dataset = json.load(f)
        self.dataset_length = len(self.dataset)
        print(f"load test dataset size: {self.dataset_length}")

    def __getitem__(self, idx):
        sample = self.dataset[idx]
        return (
            sample_encoder.encode_sentence(sample["original_sentence"]),
            sample_encoder.encode_sentence(sample["confusing_sentence"]),
            sample_encoder.encode_label(sample["label"])
        )

    def __len__(self):
        return self.dataset_length


confusing_sentence_generator = ConfusingSentenceGenerator()
confusing_sentence_generator.load_dictionary(chinese_pinyin_dict_path)
sample_encoder = SampleEncoder(transformers_path)


class BertDetector(nn.Module):
    def __init__(self, transformers_path):
        super(BertDetector, self).__init__()
        self.bert = BertModel.from_pretrained(transformers_path)
        self.config = self.bert.config
        self.linear_detector = nn.Linear(self.config.hidden_size, 1)

    def forward(self, input_ids, attention_mask, token_type_ids):
        bert_output = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)[0]
        err_prob = self.linear_detector(bert_output)
        return err_prob.squeeze(dim=-1)


class BertCorrector(nn.Module):
    def __init__(self, transformers_path):
        super(BertCorrector, self).__init__()
        self.bert = BertModel.from_pretrained(transformers_path)
        self.config = self.bert.config
        self.linear_char_predict = nn.Linear(self.config.hidden_size, self.config.vocab_size)

    def forward(self, input_ids, attention_mask, token_type_ids):
        bert_output = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)[0]
        char_predict = self.linear_char_predict(bert_output)
        return char_predict


def train(detector, corrector, dataloader, epoch, optimizer, scaler):
    time.sleep(0.2)
    detector.train()
    corrector.train()
    loss_count = deque([], maxlen=100)
    detector_tp_count = deque([], maxlen=100)
    detector_fp_count = deque([], maxlen=100)
    detector_fn_count = deque([], maxlen=100)
    detector_tn_count = deque([], maxlen=100)
    corrector_accuracy_count = deque([], maxlen=100)
    pbar = tqdm(dataloader, position=0, leave=True)
    pbar.set_description("train epoch {}".format(epoch))
    for input_encodings_original, input_encodings_confusing, y_target in pbar:
        optimizer.zero_grad()
        input_encodings_original = {k: v.to(device) for k, v in input_encodings_original.items()}
        input_encodings_confusing = {k: v.to(device) for k, v in input_encodings_confusing.items()}
        y_target = y_target.to(device)
        with torch.cuda.amp.autocast():
            err_prob = detector(**input_encodings_confusing)
            char_predict = corrector(**input_encodings_confusing)
            loss_detector = F.binary_cross_entropy_with_logits(err_prob, y_target.float())
            loss_corrector = F.cross_entropy(char_predict.transpose(-1, -2), input_encodings_original["input_ids"])
            loss = loss_detector + loss_corrector
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        loss_count.append(loss.item())

        y_detector_predict = torch.gt(err_prob, 0)
        detector_tp_count.append(torch.logical_and(y_detector_predict, y_target).sum().item())
        detector_fp_count.append(torch.logical_and(y_detector_predict, torch.logical_not(y_target)).sum().item())
        detector_fn_count.append(torch.logical_and(torch.logical_not(y_detector_predict), y_target).sum().item())
        detector_tn_count.append(torch.logical_and(torch.logical_not(y_detector_predict), torch.logical_not(y_target)).sum().item())

        y_corrector_predict = torch.eq(torch.argmax(char_predict, dim=-1), input_encodings_original["input_ids"])
        corrector_accuracy_count.append(y_corrector_predict.sum().item() / torch.ones_like(y_corrector_predict).sum().item())

        cur_loss = np.mean(loss_count)
        cur_precision = np.sum(detector_tp_count) / (np.sum(detector_tp_count) + np.sum(detector_fp_count) + 1e-5)
        cur_recall = np.sum(detector_tp_count) / (np.sum(detector_tp_count) + np.sum(detector_fn_count) + 1e-5)
        cur_acc = np.mean(corrector_accuracy_count)

        log_str = f"loss={cur_loss:>6.5} d_precision:{cur_precision:>8.5} d_recall:{cur_recall:>8.5}  c_acc:{cur_acc:>8.5} "
        pbar.set_postfix_str(log_str)


def test(detector, corrector, dataloader, epoch):
    time.sleep(0.2)
    detector.eval()
    corrector.eval()
    loss_count = []
    detector_tp_count = []
    detector_fp_count = []
    detector_fn_count = []
    detector_tn_count = []
    corrector_accuracy_count = []
    pbar = tqdm(dataloader, position=0, leave=True)
    pbar.set_description("test epoch {}".format(epoch))
    for input_encodings_original, input_encodings_confusing, y_target in pbar:
        input_encodings_original = {k: v.to(device) for k, v in input_encodings_original.items()}
        input_encodings_confusing = {k: v.to(device) for k, v in input_encodings_confusing.items()}
        y_target = y_target.to(device)
        with torch.cuda.amp.autocast():
            err_prob = detector(**input_encodings_confusing)
            char_predict = corrector(**input_encodings_confusing)
            loss_detector = F.binary_cross_entropy_with_logits(err_prob, y_target.float())
            loss_corrector = F.cross_entropy(char_predict.transpose(-1, -2), input_encodings_original["input_ids"])
            loss = loss_detector + loss_corrector

        loss_count.append(loss.item())

        y_detector_predict = torch.gt(err_prob, 0)
        detector_tp_count.append(torch.logical_and(y_detector_predict, y_target).sum().item())
        detector_fp_count.append(torch.logical_and(y_detector_predict, torch.logical_not(y_target)).sum().item())
        detector_fn_count.append(torch.logical_and(torch.logical_not(y_detector_predict), y_target).sum().item())
        detector_tn_count.append(torch.logical_and(torch.logical_not(y_detector_predict), torch.logical_not(y_target)).sum().item())

        y_corrector_predict = torch.eq(torch.argmax(char_predict, dim=-1), input_encodings_original["input_ids"])
        corrector_accuracy_count.append(y_corrector_predict.sum().item() / torch.ones_like(y_corrector_predict).sum().item())

        cur_loss = np.mean(loss_count)
        cur_precision = np.sum(detector_tp_count) / (np.sum(detector_tp_count) + np.sum(detector_fp_count) + 1e-5)
        cur_recall = np.sum(detector_tp_count) / (np.sum(detector_tp_count) + np.sum(detector_fn_count) + 1e-5)
        cur_acc = np.mean(corrector_accuracy_count)

        log_str = f"loss={cur_loss:>6.5} d_precision:{cur_precision:>8.5} d_recall:{cur_recall:>8.5}  c_acc:{cur_acc:>8.5} "
        pbar.set_postfix_str(log_str)


if __name__ == '__main__':
    dataset_train = TrainDataset(train_dataset_file_path)
    dataset_test = TestDataset(test_dataset_file_path)

    dataloader_train = DataLoader(dataset_train, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
    dataloader_test = DataLoader(dataset_test, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)

    detector = BertDetector(transformers_path)
    corrector = BertCorrector(transformers_path)
    detector.to(device)
    corrector.to(device)

    optimizer = torch.optim.Adam(list(detector.parameters()) + list(corrector.parameters()), lr=1e-5)
    scaler = torch.cuda.amp.GradScaler()

    for epoch in range(100):
        train(detector, corrector, dataloader_train, epoch, optimizer, scaler)
        test(detector, corrector, dataloader_test, epoch)
        torch.save(detector.state_dict(), f"./model_4/detector_{epoch}.pth")
        torch.save(corrector.state_dict(), f"./model_4/corrector_{epoch}.pth")


load train dataset size: 731254
load test dataset size: 22617


train epoch 0: 100%|██████████| 22852/22852 [1:43:33<00:00,  3.68it/s, loss=0.16045 d_precision: 0.96587 d_recall: 0.95037  c_acc: 0.96735 ]
test epoch 0: 100%|██████████| 707/707 [01:07<00:00, 10.51it/s, loss=0.14216 d_precision: 0.96874 d_recall: 0.94854  c_acc: 0.97066 ]
train epoch 1: 100%|██████████| 22852/22852 [1:43:47<00:00,  3.67it/s, loss=0.11218 d_precision: 0.97244 d_recall: 0.95821  c_acc: 0.97474 ]
test epoch 1: 100%|██████████| 707/707 [01:07<00:00, 10.52it/s, loss=0.098858 d_precision: 0.97187 d_recall: 0.95723  c_acc:  0.9772 ]
train epoch 2: 100%|██████████| 22852/22852 [1:44:06<00:00,  3.66it/s, loss=0.093734 d_precision: 0.97438 d_recall: 0.96136  c_acc: 0.97742 ]
test epoch 2: 100%|██████████| 707/707 [01:07<00:00, 10.52it/s, loss=0.082903 d_precision: 0.97502 d_recall:  0.9605  c_acc: 0.97973 ]
train epoch 3: 100%|██████████| 22852/22852 [1:43:23<00:00,  3.68it/s, loss=0.085192 d_precision: 0.97589 d_recall: 0.96353  c_acc: 0.97876 ]
test epoch 3: 100%|██████████|

KeyboardInterrupt: 