In [1]:
!nvidia-smi

Sun Jul 25 07:04:01 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 465.27       Driver Version: 465.27       CUDA Version: 11.3     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA Tesla V1...  Off  | 00000000:3B:00.0 Off |                    0 |
| N/A   48C    P0    41W / 250W |   6019MiB / 32510MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [2]:
!pip install -U jieba pypinyin transformers --proxy http://10.8.84.123:7890



In [None]:
import json
import re
import random
from collections import defaultdict
from typing import List
import jieba
import logging
from pypinyin import lazy_pinyin
import torch.nn as nn
from transformers import BertModel
from torch.utils.data import Dataset
from transformers import BertTokenizer
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import time
from collections import deque
import numpy as np
from tqdm import tqdm

batch_size = 32
loss_detector_weight = 0.4
num_workers = 12
transformers_path = "../hfl/chinese-roberta-wwm-ext"
chinese_pinyin_dict_path = "./resource/chinese_words_frequency.json"
train_dataset_file_path = "./data/train.json"
test_dataset_file_path = "./data/test.json"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
jieba.setLogLevel(logging.INFO)


class ConfusingSentenceGenerator:
    def __init__(self, prob_power_factor=0.7, prob_confuse_phrase=0.1, prob_confuse_word=0.03, prob_random_word=0.02):
        self.chinese_pinyin_dict = dict()
        self.chinese_single_word_pinyin_list = []
        self.chinese_char_regex = re.compile("^[\u4e00-\u9fff]+$")
        self.prob_power_factor = prob_power_factor
        self.prob_confuse_phrase = prob_confuse_phrase
        self.prob_confuse_word = prob_confuse_word
        self.prob_random_word = prob_random_word

    def build_dictionary(self, corpus: List[str], chinese_pinyin_dict_path: str):
        p_bar = tqdm(corpus, position=0, leave=True)
        p_bar.set_description("build_chinese_pinyin_dict")
        chinese_words_frequency = defaultdict(int)
        for corpus_sentence in p_bar:
            for word in jieba.lcut(corpus_sentence):
                if self.chinese_char_regex.match(word):
                    chinese_words_frequency[word] += 1

        chinese_pinyin_dict = defaultdict(list)
        for word in chinese_words_frequency:
            pinyin_ = "_".join(lazy_pinyin(word))
            chinese_pinyin_dict[pinyin_].append([word, chinese_words_frequency[word]])

        chinese_pinyin_dict = {k: sorted(v, key=lambda x: x[1], reverse=True) for k, v in chinese_pinyin_dict.items()}
        chinese_pinyin_dict = {k: [[x[0] for x in v], [x[1] for x in v]] for k, v in chinese_pinyin_dict.items()}

        with open(chinese_pinyin_dict_path, "w", encoding="utf-8") as f:
            json.dump(chinese_pinyin_dict, f, ensure_ascii=False)

    def load_dictionary(self, chinese_pinyin_dict_path: str):
        with open(chinese_pinyin_dict_path, "r", encoding="utf-8") as f:
            self.chinese_pinyin_dict = json.load(f)
            self.chinese_single_word_pinyin_list = [x for x in self.chinese_pinyin_dict if "_" not in x]

    def generate_sample(self, original_sentence: str):
        confusing_sentence = ""
        # 一定概率替换相同拼音的词组
        for word in jieba.lcut(original_sentence):
            if random.random() < self.prob_confuse_phrase:
                if self.chinese_char_regex.match(word):
                    pinyin_ = "_".join(lazy_pinyin(word))
                    if pinyin_ in self.chinese_pinyin_dict:
                        words_, freq_ = self.chinese_pinyin_dict[pinyin_]
                        if word in words_:
                            freq_[words_.index(word)] = 1e-4
                        freq_ = np.asarray(freq_)
                        freq_ = np.power(freq_, self.prob_power_factor)
                        freq_ = freq_ / np.sum(freq_)
                        confusing_sentence += np.random.choice(words_, p=freq_)
                        continue
            confusing_sentence += word
        confusing_sentence = list(confusing_sentence)
        # 一定概率替换相同拼音的字
        for i, word in enumerate(confusing_sentence):
            if random.random() < self.prob_confuse_word:
                if self.chinese_char_regex.match(word):
                    pinyin_ = "_".join(lazy_pinyin(word))
                    if pinyin_ in self.chinese_pinyin_dict:
                        words_, freq_ = self.chinese_pinyin_dict[pinyin_]
                        if word in words_:
                            freq_[words_.index(word)] = 1e-4
                        freq_ = np.asarray(freq_)
                        freq_ = np.power(freq_, self.prob_power_factor)
                        freq_ = freq_ / np.sum(freq_)
                        confusing_sentence[i] = np.random.choice(words_, p=freq_)
        # 一定概率随机替换字
        for i, word in enumerate(confusing_sentence):
            if random.random() < self.prob_confuse_word:
                if self.chinese_char_regex.match(word):
                    pinyin_ = random.choice(self.chinese_single_word_pinyin_list)

                    words_, freq_ = self.chinese_pinyin_dict[pinyin_]
                    if word in words_:
                        freq_[words_.index(word)] = 1e-4
                    freq_ = np.asarray(freq_)
                    freq_ = np.power(freq_, self.prob_power_factor)
                    freq_ = freq_ / np.sum(freq_)
                    confusing_sentence[i] = np.random.choice(words_, p=freq_)

        confusing_sentence = "".join(confusing_sentence)
        assert len(original_sentence) == len(confusing_sentence)
        label = [int(c1 != c2) for c1, c2 in zip(original_sentence, confusing_sentence)]
        return {
            "original_sentence": original_sentence,
            "confusing_sentence": confusing_sentence,
            "label": label
        }


class SampleEncoder:
    def __init__(self, transformers_path, max_length=128):
        self.max_length = max_length
        self.tokenizer = BertTokenizer.from_pretrained(transformers_path)

    def encode_sentence(self, sentence):
        tokens = [self.tokenizer.cls_token] + list(sentence)
        tokens = tokens[:self.max_length - 1] + [self.tokenizer.sep_token]
        tokens = tokens + [self.tokenizer.pad_token for _ in range(self.max_length - len(tokens))]

        input_ids = self.tokenizer.convert_tokens_to_ids(tokens)
        attention_mask = [1 if token != self.tokenizer.pad_token else 0 for token in tokens]
        token_type_ids = [0 for token in tokens]

        _input_encoding = {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "token_type_ids": token_type_ids
        }
        _input_encoding = {k: torch.tensor(v, dtype=torch.long) for k, v in _input_encoding.items()}
        return _input_encoding

    def encode_label(self, label):
        label = [0] + label
        label = label[:self.max_length - 1] + [0]
        label = label + [0 for _ in range(self.max_length - len(label))]
        label = np.asarray(label)
        return label


class TrainDataset(Dataset):
    def __init__(self, dataset_file_path):
        with open(dataset_file_path, "r", encoding="utf-8") as f:
            self.dataset = json.load(f)
        self.dataset_length = len(self.dataset)
        print(f"load train dataset size: {self.dataset_length}")

    def __getitem__(self, idx):
        sentence = self.dataset[idx]
        sample = confusing_sentence_generator.generate_sample(sentence)
        return (
            sample_encoder.encode_sentence(sample["original_sentence"]),
            sample_encoder.encode_sentence(sample["confusing_sentence"]),
            sample_encoder.encode_label(sample["label"])
        )

    def __len__(self):
        return self.dataset_length


class TestDataset(Dataset):
    def __init__(self, dataset_file_path):
        with open(dataset_file_path, "r", encoding="utf-8") as f:
            self.dataset = json.load(f)
        self.dataset_length = len(self.dataset)
        print(f"load test dataset size: {self.dataset_length}")

    def __getitem__(self, idx):
        sample = self.dataset[idx]
        return (
            sample_encoder.encode_sentence(sample["original_sentence"]),
            sample_encoder.encode_sentence(sample["confusing_sentence"]),
            sample_encoder.encode_label(sample["label"])
        )

    def __len__(self):
        return self.dataset_length


confusing_sentence_generator = ConfusingSentenceGenerator()
confusing_sentence_generator.load_dictionary(chinese_pinyin_dict_path)
sample_encoder = SampleEncoder(transformers_path)


class BertCorrector(nn.Module):
    def __init__(self, transformers_path):
        super(BertCorrector, self).__init__()
        self.bert = BertModel.from_pretrained(transformers_path)
        self.config = self.bert.config
        self.linear_detector = nn.Linear(self.config.hidden_size, 1)
        self.linear_char_predict = nn.Linear(self.config.hidden_size, self.config.vocab_size)

    def forward(self, input_ids, attention_mask, token_type_ids):
        bert_output = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)[0]
        err_prob = self.linear_detector(bert_output)
        char_predict = self.linear_char_predict(bert_output)
        return err_prob.squeeze(dim=-1), char_predict


def train(model, dataloader, epoch, optimizer, scaler):
    time.sleep(0.2)
    model.train()
    loss_count = deque([], maxlen=100)
    detector_tp_count = deque([], maxlen=100)
    detector_fp_count = deque([], maxlen=100)
    detector_fn_count = deque([], maxlen=100)
    detector_tn_count = deque([], maxlen=100)
    corrector_accuracy_count = deque([], maxlen=100)
    pbar = tqdm(dataloader, position=0, leave=True)
    pbar.set_description("train epoch {}".format(epoch))
    for input_encodings_original, input_encodings_confusing, y_target in pbar:
        optimizer.zero_grad()
        input_encodings_original = {k: v.to(device) for k, v in input_encodings_original.items()}
        input_encodings_confusing = {k: v.to(device) for k, v in input_encodings_confusing.items()}
        y_target = y_target.to(device)
        with torch.cuda.amp.autocast():
            err_prob, char_predict = model(**input_encodings_confusing)
            loss_detector = F.binary_cross_entropy_with_logits(err_prob, y_target.float())
            loss_corrector = F.cross_entropy(char_predict.transpose(-1, -2), input_encodings_original["input_ids"])
            loss = loss_detector_weight * loss_detector + (1 - loss_detector_weight) * loss_corrector
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        loss_count.append(loss.item())

        y_detector_predict = torch.gt(err_prob, 0)
        detector_tp_count.append(torch.logical_and(y_detector_predict, y_target).sum().item())
        detector_fp_count.append(torch.logical_and(y_detector_predict, torch.logical_not(y_target)).sum().item())
        detector_fn_count.append(torch.logical_and(torch.logical_not(y_detector_predict), y_target).sum().item())
        detector_tn_count.append(torch.logical_and(torch.logical_not(y_detector_predict), torch.logical_not(y_target)).sum().item())

        y_corrector_predict = torch.eq(torch.argmax(char_predict, dim=-1), input_encodings_original["input_ids"])
        corrector_accuracy_count.append(y_corrector_predict.sum().item() / torch.ones_like(y_corrector_predict).sum().item())

        cur_loss = np.mean(loss_count)
        cur_precision = np.sum(detector_tp_count) / (np.sum(detector_tp_count) + np.sum(detector_fp_count) + 1e-5)
        cur_recall = np.sum(detector_tp_count) / (np.sum(detector_tp_count) + np.sum(detector_fn_count) + 1e-5)
        cur_acc = np.mean(corrector_accuracy_count)

        log_str = f"loss={cur_loss:>6.5} d_precision:{cur_precision:>8.5} d_recall:{cur_recall:>8.5}  c_acc:{cur_acc:>8.5} "
        pbar.set_postfix_str(log_str)


def test(model, dataloader, epoch):
    time.sleep(0.2)
    model.eval()
    loss_count = []
    detector_tp_count = []
    detector_fp_count = []
    detector_fn_count = []
    detector_tn_count = []
    corrector_accuracy_count = []
    pbar = tqdm(dataloader, position=0, leave=True)
    pbar.set_description("test epoch {}".format(epoch))
    for input_encodings_original, input_encodings_confusing, y_target in pbar:
        input_encodings_original = {k: v.to(device) for k, v in input_encodings_original.items()}
        input_encodings_confusing = {k: v.to(device) for k, v in input_encodings_confusing.items()}
        y_target = y_target.to(device)
        with torch.cuda.amp.autocast():
            err_prob, char_predict = model(**input_encodings_confusing)
            loss_detector = F.binary_cross_entropy_with_logits(err_prob, y_target.float())
            loss_corrector = F.cross_entropy(char_predict.transpose(-1, -2), input_encodings_original["input_ids"])
            loss = loss_detector_weight * loss_detector + (1 - loss_detector_weight) * loss_corrector

        loss_count.append(loss.item())

        y_detector_predict = torch.gt(err_prob, 0)
        detector_tp_count.append(torch.logical_and(y_detector_predict, y_target).sum().item())
        detector_fp_count.append(torch.logical_and(y_detector_predict, torch.logical_not(y_target)).sum().item())
        detector_fn_count.append(torch.logical_and(torch.logical_not(y_detector_predict), y_target).sum().item())
        detector_tn_count.append(torch.logical_and(torch.logical_not(y_detector_predict), torch.logical_not(y_target)).sum().item())

        y_corrector_predict = torch.eq(torch.argmax(char_predict, dim=-1), input_encodings_original["input_ids"])
        corrector_accuracy_count.append(y_corrector_predict.sum().item() / torch.ones_like(y_corrector_predict).sum().item())

        cur_loss = np.mean(loss_count)
        cur_precision = np.sum(detector_tp_count) / (np.sum(detector_tp_count) + np.sum(detector_fp_count) + 1e-5)
        cur_recall = np.sum(detector_tp_count) / (np.sum(detector_tp_count) + np.sum(detector_fn_count) + 1e-5)
        cur_acc = np.mean(corrector_accuracy_count)

        log_str = f"loss={cur_loss:>6.5} d_precision:{cur_precision:>8.5} d_recall:{cur_recall:>8.5}  c_acc:{cur_acc:>8.5} "
        pbar.set_postfix_str(log_str)


if __name__ == '__main__':
    dataset_train = TrainDataset(train_dataset_file_path)
    dataset_test = TestDataset(test_dataset_file_path)

    dataloader_train = DataLoader(dataset_train, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
    dataloader_test = DataLoader(dataset_test, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)

    model = BertCorrector(transformers_path)
    model.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
    scaler = torch.cuda.amp.GradScaler()

    for epoch in range(100):
        train(model, dataloader_train, epoch, optimizer, scaler)
        test(model, dataloader_test, epoch)
        torch.save(model.state_dict(), f"./model_1/model_{epoch}.pth")


load train dataset size: 731254
load test dataset size: 22617


train epoch 0: 100%|██████████| 22852/22852 [1:04:47<00:00,  5.88it/s, loss=0.09164 d_precision: 0.94493 d_recall: 0.93055  c_acc: 0.96957 ] 
test epoch 0: 100%|██████████| 707/707 [00:45<00:00, 15.58it/s, loss=0.084392 d_precision: 0.95958 d_recall: 0.92572  c_acc: 0.97108 ]
train epoch 1: 100%|██████████| 22852/22852 [1:04:55<00:00,  5.87it/s, loss=0.062023 d_precision: 0.95906 d_recall: 0.94778  c_acc: 0.97616 ]
test epoch 1: 100%|██████████| 707/707 [00:44<00:00, 15.72it/s, loss=0.058077 d_precision: 0.96756 d_recall: 0.94324  c_acc: 0.97722 ]
train epoch 2: 100%|██████████| 22852/22852 [1:04:47<00:00,  5.88it/s, loss=0.052886 d_precision: 0.96546 d_recall:   0.952  c_acc: 0.97837 ]
test epoch 2: 100%|██████████| 707/707 [00:45<00:00, 15.64it/s, loss=0.049418 d_precision: 0.96817 d_recall: 0.95166  c_acc: 0.97956 ]
train epoch 3: 100%|██████████| 22852/22852 [1:05:10<00:00,  5.84it/s, loss=0.046217 d_precision:  0.9689 d_recall: 0.95628  c_acc: 0.97971 ]
test epoch 3: 100%|████████