# 中文语法纠错 Chinese Grammatical Error Correction




#### 项目介绍：
* 本教程将演示：如何训练一个简单的中文语法纠错模型。<br>
* 基于 BERT 预训练模型。<br>
* 将会使用到 PyTorch 深度学习框架，以及 Hugging Face 提供的 Transformers 库。<br>
 
#### 任务目标：
纠正中文语句中的语法错误。

*示例：*  
> 原句：今天大阳很好，新情也很不错。所以我因该出门散不嘛？  
> 纠正：今天太阳很好，心情也很不错。所以我应该出门散步吗？ 

#### 解决步骤:

1. 准备工作  
2. 加载数据
3. 加载模型和优化器  
4. 训练模型 
5. 测试效果 


<!-- 注意：  
1.  -->

## 准备工作

In [1]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## 加载数据

In [2]:
# 设置使用数据集条数
num = None   # 使用全部数据集
# num = 50000  # 只使用前 num 条数据

In [3]:
import os
import json
from torch.utils.data import Dataset, DataLoader


class CscDataset(Dataset):
    def __init__(self, file_path):
        self.data = json.load(open(file_path, 'r', encoding='utf-8'))[:num]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        return self.data[index]['original_text'], self.data[index]['correct_text']


def make_loaders(train_path='', valid_path='', test_path='', batch_size=32):
    train_loader = None
    if train_path and os.path.exists(train_path):
        train_loader = DataLoader(CscDataset(train_path),
                                  batch_size=batch_size,
                                  shuffle=False,
                                 )
    valid_loader = None
    if valid_path and os.path.exists(valid_path):
        valid_loader = DataLoader(CscDataset(valid_path),
                                  batch_size=batch_size,
                                 )
    test_loader = None
    if test_path and os.path.exists(test_path):
        test_loader = DataLoader(CscDataset(test_path),
                                 batch_size=batch_size,
                                )
    return train_loader, valid_loader, test_loader

In [4]:
# 加载数据
train_loader, valid_loader, test_loader = make_loaders(train_path="output/train.json",
                                                       valid_path="output/dev.json", 
                                                       test_path="output/test.json",
                                                       batch_size=32, 
                                                      )
len(train_loader)

7870

## 加载模型和优化器

In [5]:
# 加载模型
from transformers import BertForMaskedLM
model = BertForMaskedLM.from_pretrained('bert-base-chinese')
model.to(device)
pass

Some weights of the model checkpoint at bert-base-chinese were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [6]:
# 加载优化器
from transformers import AdamW

optim = AdamW(model.parameters(), lr=5e-5)

In [7]:
# 加载分词器
from transformers import BertTokenizer

tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')

## 训练模型 (Fine-tune BERT)

In [8]:
# 训练
from tqdm import tqdm
import time

epoch = 3

model.train()
for epoch_i in range(epoch):
    print('Epoch %s/%s' % (epoch_i + 1, epoch))
    time.sleep(0.3)
    
    pbar = tqdm(train_loader)
    for batch in pbar:
        optim.zero_grad()
        ori_text, cor_text = batch
        encoded_text = tokenizer(ori_text, padding=True, return_tensors='pt').to(device)
        text_labels = tokenizer(cor_text, padding=True, return_tensors='pt')['input_ids'].to(device)
        outputs = model(**encoded_text, labels=text_labels)
        loss = outputs['loss']
        loss.backward()
        optim.step()
        
        # 显示进度条中的指标
        pbar.set_postfix({
            'Loss': '{:.3f}'.format(loss.item()),
        })
        
    pbar.close()

Epoch 1/3


100%|██████████| 7870/7870 [35:51<00:00,  3.66it/s, Loss=0.016]


Epoch 2/3


100%|██████████| 7870/7870 [36:33<00:00,  3.59it/s, Loss=0.005]  


Epoch 3/3


100%|██████████| 7870/7870 [36:10<00:00,  3.63it/s, Loss=0.007]  


## 测试效果

In [13]:
input = "今天大阳很好，新情也很不错。所以我因该出门散不嘛？"
model.eval()
for i in range(4):
    print("第 %s 轮纠错：%s" % (str(i),input))
    input = tokenizer(input, padding=True, return_tensors='pt').to(device)
    out = model(**input)
    input = tokenizer.decode(torch.argmax(out.logits[0], dim=-1), skip_special_tokens=True).replace(' ', '')

第 0 轮纠错：今天大阳很好，新情也很不错。所以我因该出门散不嘛？
第 1 轮纠错：今天太阳很好，心情也很不错。所以我应该出门散步嘛？
第 2 轮纠错：今天太阳很好，心情也很不错。所以我应该出门散步嘛？
第 3 轮纠错：今天太阳很好，心情也很不错。所以我应该出门散步嘛？
