CGED

In [1]:
import logging
import os
import codecs
from tqdm import tqdm
from bs4 import BeautifulSoup 
import pandas as pd
import zhconv

In [2]:
def read_langs_cged2014(file_name):
    logging.info(("Reading lines from {}".format(file_name)))
    total_correct=[]
    total_mistakes=[]
    with codecs.open(file_name, "r", "utf-8") as file:
 
        data = file.read()
        soup = BeautifulSoup(data, 'html.parser')
        results = soup.find_all('essay')
        for item in tqdm(results):
            text = item.find("text").text.strip()
            mistakes=text.split('\n')
            correct_text = [x.text.strip() for x in item.find_all("correction")]
            total_correct=total_correct+correct_text
            total_mistakes=total_mistakes+mistakes
    assert len(total_correct)==len(total_mistakes)
    return total_correct,total_mistakes
def rewrite_cged2014(file_name):
    total_correct,total_mistakes=read_langs_cged2014(file_name+".sgml")
    with open(file_name+".src",'w', encoding="utf-8") as fp:
        [fp.write(str(item)+'\n') for  item in total_mistakes]
        fp.close()
    with open(file_name+".trg",'w', encoding="utf-8") as fp:
        [fp.write(str(item)+'\n') for  item in total_correct]
        fp.close()

In [3]:
file_name='./CGED/cged2014/nlptea14cfl_release1.1/Training/A2_CFL_training'
rewrite_cged2014(file_name)
# file_name='./CGED/cged2014/nlptea14cfl_release1.1/Training/B1_CFL_training'
# rewrite_cged2014(file_name)
# 存在编码问题
file_name='./CGED/cged2014/nlptea14cfl_release1.1/Training/B2_CFL_training'
rewrite_cged2014(file_name)
file_name='./CGED/cged2014/nlptea14cfl_release1.1/Training/C1_CFL_training'
rewrite_cged2014(file_name)

100%|█████████████████████████████████████████████████████████████████████████████| 388/388 [00:00<00:00, 43108.53it/s]
100%|█████████████████████████████████████████████████████████████████████████████| 232/232 [00:00<00:00, 33140.74it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████| 35/35 [00:00<?, ?it/s]


In [4]:
def read_langs_cged2015(file_name):
    logging.info(("Reading lines from {}".format(file_name)))
    total_correct=[]
    total_mistakes=[]
    with codecs.open(file_name, "r", "utf-8") as file:
 
        data = file.read()
        soup = BeautifulSoup(data, 'html.parser')
        results = soup.find_all('doc')

        for item in tqdm(results):
            text = item.find("sentence").text.strip()
            mistakes=text.split('\n')
            correct_text = [x.text.strip() for x in item.find_all("correction")]
            total_correct=total_correct+correct_text
            total_mistakes=total_mistakes+mistakes
    assert len(total_correct)==len(total_mistakes)
    return total_correct,total_mistakes
def rewrite_cged2015(file_name):
    total_correct,total_mistakes=read_langs_cged2015(file_name+".sgml")
    with open(file_name+".src",'w', encoding="utf-8") as fp:
        [fp.write(str(item)+'\n') for  item in total_mistakes]
        fp.close()
    with open(file_name+".trg",'w', encoding="utf-8") as fp:
        [fp.write(str(item)+'\n') for  item in total_correct]
        fp.close()

In [5]:
file_name='./CGED/cged2015/nlptea15cged_release1.0/Training/NLPTEA15_CGED_Training'
rewrite_cged2015(file_name)

100%|███████████████████████████████████████████████████████████████████████████| 2205/2205 [00:00<00:00, 55755.82it/s]


In [6]:
def read_langs_cged2016(file_name):
    logging.info(("Reading lines from {}".format(file_name)))
    total_correct=[]
    total_mistakes=[]
    with codecs.open(file_name, "r", "utf-8") as file:
 
        data = file.read()
        soup = BeautifulSoup(data, 'html.parser')
        results = soup.find_all('doc')

        for item in tqdm(results):
            text = item.find("text").text.strip()
            mistakes=text.split('\n')
            correct_text = [x.text.strip() for x in item.find_all("correction")]
            total_correct=total_correct+correct_text
            total_mistakes=total_mistakes+mistakes
    assert len(total_correct)==len(total_mistakes)
    return total_correct,total_mistakes
def rewrite_cged2016(file_name):
    total_correct,total_mistakes=read_langs_cged2016(file_name+".txt")
    with open(file_name+".src",'w', encoding="utf-8") as fp:
        [fp.write(str(item)+'\n') for  item in total_mistakes]
        fp.close()
    with open(file_name+".trg",'w', encoding="utf-8") as fp:
        [fp.write(str(item)+'\n') for  item in total_correct]
        fp.close()

In [7]:
file_name='./CGED/cged2016/nlptea16cged_release1.0/Training/CGED16_HSK_TrainingSet'
rewrite_cged2016(file_name)
# file_name='./CGED/cged2016/nlptea16cged_release1.0/Training/CGED16_TOCFL_TrainingSet'
# rewrite_cged2016(file_name)
# 存在编码问题

100%|███████████████████████████████████████████████████████████████████████████| 9602/9602 [00:00<00:00, 19258.88it/s]


In [8]:
def read_langs_cged2017(file_name):
    logging.info(("Reading lines from {}".format(file_name)))
    total_correct=[]
    total_mistakes=[]
    with codecs.open(file_name, "r", "utf-8") as file:
 
        data = file.read()
        soup = BeautifulSoup(data, 'html.parser')
        results = soup.find_all('doc')

        for item in tqdm(results):
            if item.find("text") is None:
                continue
            text = item.find("text").text.strip()
            mistakes=text.split('\n')
            correct_text = [x.text.strip() for x in item.find_all("correction")]
            total_correct=total_correct+correct_text
            total_mistakes=total_mistakes+mistakes
    assert len(total_correct)==len(total_mistakes)
    return total_correct,total_mistakes
def rewrite_cged2017(file_name):
    total_correct,total_mistakes=read_langs_cged2017(file_name+".xml")
    with open(file_name+".src",'w', encoding="utf-8") as fp:
        [fp.write(str(item)+'\n') for  item in total_mistakes]
        fp.close()
    with open(file_name+".trg",'w', encoding="utf-8") as fp:
        [fp.write(str(item)+'\n') for  item in total_correct]
        fp.close()

In [9]:
file_name='./CGED/cged2017/train.release'
rewrite_cged2017(file_name)

100%|█████████████████████████████████████████████████████████████████████████| 10449/10449 [00:00<00:00, 15375.26it/s]


In [10]:
file_name='./CGED/cged2018/train_CGED2018'
rewrite_cged2017(file_name)

100%|█████████████████████████████████████████████████████████████████████████████| 402/402 [00:00<00:00, 50246.15it/s]


In [11]:
file_name='./CGED/cged2020/test_truth_2020mk'
rewrite_cged2017(file_name)

100%|███████████████████████████████████████████████████████████████████████████| 1156/1156 [00:00<00:00, 48165.85it/s]


In [12]:
file_name='./CGED/cged2021/test_2021'
rewrite_cged2017(file_name)

100%|███████████████████████████████████████████████████████████████████████████| 2282/2282 [00:00<00:00, 45181.37it/s]


CTC2021

In [13]:
import json
CTC2021 = []
file_name='./CTC2021/train/train_large_v2'
for line in open('./CTC2021/train/train_large_v2.json','r',encoding='utf8'): 
    CTC2021.append(json.loads(line))

In [14]:
total_mistakes=[]
total_correct=[]
for item in CTC2021:
    total_correct.append(item['target'])
    total_mistakes.append(item['source'])
assert len(total_correct)==len(total_mistakes)
with open(file_name+".src",'w', encoding="utf-8") as fp:
    [fp.write(str(item)+'\n') for  item in total_mistakes]
    fp.close()
with open(file_name+".trg",'w', encoding="utf-8") as fp:
    [fp.write(str(item)+'\n') for  item in total_correct]
    fp.close()

Lang8

In [15]:
total_mistakes=[]
total_correct=[]
file_name='./Lang8/data.train'
for line in open('./Lang8/data.train','r',encoding='utf8'): 
    data=line[:-1].split('\t')
    if len(data)!=int(data[1])+3:
        continue
    if int(data[1])==0:
        total_mistakes.append(data[2])
        total_correct.append(data[2])
    else:
        for i in range(int(data[1])):
            total_mistakes.append(data[2])
            total_correct.append(data[3+i])

assert len(total_correct)==len(total_mistakes)
with open(file_name+".src",'w', encoding="utf-8") as fp:
    [fp.write(str(item)+'\n') for  item in total_mistakes]
    fp.close()
with open(file_name+".trg",'w', encoding="utf-8") as fp:
    [fp.write(str(item)+'\n') for  item in total_correct]
    fp.close()

In [16]:
len(total_mistakes)

1220098

MuCGEC

In [17]:
total_mistakes=[]
total_correct=[]
file_name='./MuCGEC/MuCGEC_dev'
for line in open(file_name+".txt",'r',encoding='utf8'): 
    data=line[:-1].split('\t')
    if data[2]=="没有错误":
        total_mistakes.append(data[1])
        total_correct.append(data[1])
    else:
        for i in range(len(data)-2):
            total_mistakes.append(data[1])
            total_correct.append(data[2+i])

assert len(total_correct)==len(total_mistakes)
with open(file_name+".src",'w', encoding="utf-8") as fp:
    [fp.write(str(item)+'\n') for  item in total_mistakes]
    fp.close()
with open(file_name+".trg",'w', encoding="utf-8") as fp:
    [fp.write(str(item)+'\n') for  item in total_correct]
    fp.close()

MuCGEC_Exp

In [18]:
total_mistakes=[]
total_correct=[]
file_name='./MuCGEC_exp_data/train/train'
for line in open(file_name+".para",'r',encoding='utf8'): 
    data=line[:-1].split('\t')
    total_mistakes.append(data[0])
    total_correct.append(data[1])
    

assert len(total_correct)==len(total_mistakes)
with open(file_name+".src",'w', encoding="utf-8") as fp:
    [fp.write(str(item)+'\n') for  item in total_mistakes]
    fp.close()
with open(file_name+".trg",'w', encoding="utf-8") as fp:
    [fp.write(str(item)+'\n') for  item in total_correct]
    fp.close()

In [19]:
total_mistakes=[]
total_correct=[]
file_name='./MuCGEC_exp_data/valid/valid'
for line in open(file_name+".para",'r',encoding='utf8'): 
    data=line[:-1].split('\t')
    total_mistakes.append(data[0])
    total_correct.append(data[1])

assert len(total_correct)==len(total_mistakes)
with open(file_name+".src",'w', encoding="utf-8") as fp:
    [fp.write(str(item)+'\n') for  item in total_mistakes]
    fp.close()
with open(file_name+".trg",'w', encoding="utf-8") as fp:
    [fp.write(str(item)+'\n') for  item in total_correct]
    fp.close()

YACLC

In [20]:
YACLC = []
file_name='./YACLC/valid'
for line in open(file_name+'.jsonl','r',encoding='utf8'): 
    YACLC.append(json.loads(line))
total_mistakes=[]
total_correct=[]
for item in YACLC:
    for correct in item['sentence_annos']:
        total_mistakes.append(item['sentence_text'])
        total_correct.append(correct['correction'])
assert len(total_correct)==len(total_mistakes)
with open(file_name+".src",'w', encoding="utf-8") as fp:
    [fp.write(str(item)+'\n') for  item in total_mistakes]
    fp.close()
with open(file_name+".trg",'w', encoding="utf-8") as fp:
    [fp.write(str(item)+'\n') for  item in total_correct]
    fp.close()

In [21]:
total_mistakes=[]
total_correct=[]
file_name='./CSC_val/val'
for line in open("./CSC_val/data.txt",'r',encoding='utf8'): 
    data=line[:-1].split('\t')
    total_mistakes.append(data[0])
    total_correct.append(data[1])
assert len(total_correct)==len(total_mistakes)
with open(file_name+".src",'w', encoding="utf-8") as fp:
    [fp.write(str(item)+'\n') for  item in total_mistakes]
    fp.close()
with open(file_name+".trg",'w', encoding="utf-8") as fp:
    [fp.write(str(item)+'\n') for  item in total_correct]
    fp.close()

In [22]:
total_mistakes=[]
file_name='./CSC_test/test'
for line in open("./CSC_test/NLPCC_TASK8_TESTDATA.txt",'r',encoding='utf8'): 
    total_mistakes.append(line[:-1])
with open(file_name+".src",'w', encoding="utf-8") as fp:
    [fp.write(str(item)+'\n') for item in total_mistakes]
    fp.close()

In [23]:
def read_langs_csc(file_name):
    logging.info(("Reading lines from {}".format(file_name)))
    total_correct=[]
    total_mistakes=[]
    with codecs.open(file_name, "r", "utf-8") as file:
 
        data = file.read()
        soup = BeautifulSoup(data, 'html.parser')
        results = soup.find_all('essay')
        for item in tqdm(results):
            correct_sentences = []
            wrong_sentences = []
            for passage in item.find_all('passage'):
                text = passage.text
                mistake = item.find('mistake', {'id': passage['id']})
                if mistake:
                    wrong_text = mistake.wrong.text
                    correct_text = mistake.correction.text
                    text = text.replace(wrong_text, correct_text)
                    wrong_sentences.append(passage.text)
                    correct_sentences.append(text)
            total_correct=total_correct+correct_sentences
            total_mistakes=total_mistakes+wrong_sentences
    assert len(total_correct)==len(total_mistakes)
    return total_correct,total_mistakes
def rewrite_csc(file_name):
    total_correct,total_mistakes=read_langs_csc(file_name+".sgml")
    with open(file_name+".src",'w', encoding="utf-8") as fp:
        [fp.write(str(item)+'\n') for  item in total_mistakes]
        fp.close()
    with open(file_name+".trg",'w', encoding="utf-8") as fp:
        [fp.write(str(item)+'\n') for  item in total_correct]
        fp.close()

In [24]:
file_name='./clp14csc/Training/C1_training'
rewrite_csc(file_name)
# file_name='./clp14csc/Training/B1_training'
# rewrite_csc(file_name)
# 存在编码问题

100%|██████████████████████████████████████████████████████████████████████████████| 114/114 [00:00<00:00, 3166.62it/s]


In [25]:
file_name='./sighan8csc/Training/SIGHAN15_CSC_A2_Training'
rewrite_csc(file_name)
# file_name='./sighan8csc/Training/SIGHAN15_CSC_B2_Training'
# rewrite_csc(file_name)
# 存在编码问题

100%|██████████████████████████████████████████████████████████████████████████████| 442/442 [00:00<00:00, 5848.62it/s]


In [26]:
def read_langs_hybrid(file_name):
    total_mistakes=[]
    total_correct=[]
    for line in open(file_name+"_1.src",'r',encoding='utf8'): 
        total_mistakes.append(line[:-1].split('\t')[1])
    for line in open(file_name+"_1.trg",'r',encoding='utf8'): 
        total_correct.append(line[:-1].split('\t')[1])
    assert len(total_correct)==len(total_mistakes)
    return total_correct,total_mistakes
def rewrite_hybrid(file_name):
    total_correct,total_mistakes=read_langs_hybrid(file_name)
    with open(file_name+".src",'w', encoding="utf-8") as fp:
        [fp.write(str(item)+'\n') for  item in total_mistakes]
        fp.close()
    with open(file_name+".trg",'w', encoding="utf-8") as fp:
        [fp.write(str(item)+'\n') for  item in total_correct]
        fp.close()

In [27]:
file_name='./hybrid/hybird_train'
rewrite_hybrid(file_name)
file_name='./hybrid/hybird_test'
rewrite_hybrid(file_name)

In [28]:
total_mistakes=[]
total_correct=[]
file_name='./sighan15/sighan.train.ccl22'
for line in open(file_name+".para",'r',encoding='utf8'): 
    data=line[:-1].split(' ||| ')
    total_mistakes.append(data[0])
    total_correct.append(data[1])
    

assert len(total_correct)==len(total_mistakes)
with open(file_name+".src",'w', encoding="utf-8") as fp:
    [fp.write(str(item)+'\n') for  item in total_mistakes]
    fp.close()
with open(file_name+".trg",'w', encoding="utf-8") as fp:
    [fp.write(str(item)+'\n') for  item in total_correct]
    fp.close()

In [29]:
total_mistakes=[]
total_correct=[]
file_name='./WANG27/wang.train.ccl22'
for line in open(file_name+".para",'r',encoding='utf8'): 
    data=line[:-1].split(' ||| ')
    total_mistakes.append(data[0])
    total_correct.append(data[1])
    

assert len(total_correct)==len(total_mistakes)
with open(file_name+".src",'w', encoding="utf-8") as fp:
    [fp.write(str(item)+'\n') for  item in total_mistakes]
    fp.close()
with open(file_name+".trg",'w', encoding="utf-8") as fp:
    [fp.write(str(item)+'\n') for  item in total_correct]
    fp.close()

In [30]:
file_names=['./CGED/cged2014/nlptea14cfl_release1.1/Training/A2_CFL_training','./CGED/cged2014/nlptea14cfl_release1.1/Training/B2_CFL_training','./CGED/cged2014/nlptea14cfl_release1.1/Training/C1_CFL_training','./CGED/cged2015/nlptea15cged_release1.0/Training/NLPTEA15_CGED_Training','./CGED/cged2016/nlptea16cged_release1.0/Training/CGED16_HSK_TrainingSet','./CGED/cged2017/train.release','./CGED/cged2018/train_CGED2018','./CGED/cged2020/test_truth_2020mk','./CGED/cged2021/test_2021','./CTC2021/train/train_large_v2','./HSK/hsk','./Lang8/data.train','./MuCGEC/MuCGEC_dev','./MuCGEC_exp_data/train/train','./MuCGEC_exp_data/valid/valid','./YACLC/valid']
file_names.append('./clp14csc/Training/C1_training')
file_names.append('./sighan8csc/Training/SIGHAN15_CSC_A2_Training')
file_names.append('./hybrid/hybird_train')
file_names.append('./hybrid/hybird_test')
file_names.append('./sighan15/sighan.train.ccl22')
file_names.append('./WANG27/wang.train.ccl22')

In [31]:
def convert_1(file_name):
    srcs=[]
    trgs=[]
    for line in open(file_name+".src",'r',encoding='utf8'): 
        srcs.append(line[:-1])
    for line in open(file_name+".trg",'r',encoding='utf8'): 
        trgs.append(line[:-1])
    res={}
    for src,trg in zip(srcs,trgs):
        if src not in res:
            res[src]=[trg]
        else:
            res[src].append(trg)
    with open(file_name+"_1.txt", "w", encoding="utf-8") as file:
        i = 0
        for key, items in res.items():
            line = str(i) + "\t" + key + "\t" + "\t".join(items) + "\n"
            file.write(line)
            i=i+1

for file_name in file_names:
    convert_1(file_name)

In [32]:
def convert_2(file_name):
    srcs=[]
    trgs=[]
    task=file_name.split('/')[1]
    name=file_name.split('/')[-1]
    for line in open(file_name+".src",'r',encoding='utf8'): 
        srcs.append(line[:-1])
    for line in open(file_name+".trg",'r',encoding='utf8'): 
        trgs.append(line[:-1])
    
    with open(file_name+"_1.src", "w", encoding="utf-8") as file:
        i = 0
        for item in srcs:
            line = task+"_"+name+"_"+str(i) + "\t" + item + "\n"
            file.write(line)
            i=i+1
    with open(file_name+"_1.trg", "w", encoding="utf-8") as file:
        i = 0
        for item in trgs:
            line = task+"_"+name+"_"+str(i) + "\t" + item + "\n"
            file.write(line)
            i=i+1

for file_name in file_names:
    convert_2(file_name)

In [33]:
file_name='./CSC_val/val'
convert_1(file_name)
convert_2(file_name)
file_name='./CSC_test/test'
srcs=[]
for line in open(file_name+".src",'r',encoding='utf8'): 
    srcs.append(line[:-1])
with open(file_name+"_1.src", "w", encoding="utf-8") as file:
    i = 0
    for item in srcs:
        line = "CSC_test_"+str(i) + "\t" + item + "\n"
        file.write(line)
        i=i+1
with open(file_name+".txt", "w", encoding="utf-8") as file:
    for item in srcs:
        line = item + "\n"
        file.write(line)

In [34]:
def convert_all_1(file_names):
    srcs=[]
    trgs=[]
    for file_name in file_names:
        for line in open(file_name+".src",'r',encoding='utf8'): 
            srcs.append(line[:-1])
        for line in open(file_name+".trg",'r',encoding='utf8'): 
            trgs.append(line[:-1])
    res={}
    for src,trg in zip(srcs,trgs):
        if src not in res:
            res[src]=[trg]
        else:
            res[src].append(trg)
    with open("./CSC_train/train_1.txt", "w", encoding="utf-8") as file:
        i = 0
        for key, items in res.items():
            line = str(i) + "\t" + key + "\t" + "\t".join(items) + "\n"
            file.write(line)
            i=i+1

convert_all_1(file_names)

In [35]:
def convert_all_2(file_names):
    srcs=[]
    trgs=[]
    for file_name in file_names:
        for line in open(file_name+"_1.src",'r',encoding='utf8'): 
            srcs.append(line[:-1])
        for line in open(file_name+"_1.trg",'r',encoding='utf8'): 
            trgs.append(line[:-1])
        if len(srcs)!=len(trgs):
            print(file_name)
    with open("./CSC_train/train_1.src",'w', encoding="utf-8") as fp:
        [fp.write(str(item)+'\n') for  item in srcs]
        fp.close()
    with open("./CSC_train/train_1.trg",'w', encoding="utf-8") as fp:
        [fp.write(str(item)+'\n') for  item in trgs]
        fp.close()

convert_all_2(file_names)

In [36]:
def convert_all_3(file_names):
    srcs=[]
    trgs=[]
    for file_name in file_names:
        for line in open(file_name+".src",'r',encoding='utf8'): 
            srcs.append(line[:-1])
        for line in open(file_name+".trg",'r',encoding='utf8'): 
            trgs.append(line[:-1])
    with open("./CSC_train/train.src",'w', encoding="utf-8") as fp:
        [fp.write(str(item)+'\n') for  item in srcs]
        fp.close()
    with open("./CSC_train/train.trg",'w', encoding="utf-8") as fp:
        [fp.write(str(item)+'\n') for  item in trgs]
        fp.close()

convert_all_3(file_names)

In [37]:
from zhconv import convert

def E_trans_to_C(string):
    E_pun = u',!?[]()<>"\';:'
    C_pun = u'，！？【】（）《》“‘；：'
    table= {ord(f):ord(t) for f,t in zip(E_pun,C_pun)}
    return string.translate(table)

def normalize_text(text):
    # Convert text from traditional to simplified Chinese
    text = convert(text, 'zh-cn')
    
    # Normalize punctuation
    text = unicodedata.normalize('NFKC', text)
    text = E_trans_to_C(text)
    # Remove special characters
    text = re.sub(r"[^a-zA-Z0-9\u4e00-\u9fa5，。？！；：“”‘’《》【】（）〔〕…—-]", '', text)

    return text

In [38]:
import re
import unicodedata
pairs6=[]
total_mistakes=[]
total_correct=[]
for line in open("CSC_train/train_1.src",'r',encoding='utf8'): 
    total_mistakes.append(normalize_text(line[:-1].split("\t")[1]))
for line in open("CSC_train/train_1.trg",'r',encoding='utf8'): 
    total_correct.append(normalize_text(line[:-1].split("\t")[1]))
pairs=[]
for srcs, trgs in zip(total_mistakes,total_correct):
    srcs2=re.sub(r'[^\u4e00-\u9fa5]', '', srcs)
    trgs2=re.sub(r'[^\u4e00-\u9fa5]', '', trgs)
    if len(srcs)==len(trgs) and len(trgs)!=0 and len(srcs2)==len(trgs2) and len(trgs2)!=0:
        pairs.append((srcs, trgs))
    else:
        pairs6.append((srcs, trgs))

In [39]:
from collections import Counter

def is_same_elements(list1, list2):
    return Counter(list1) == Counter(list2)


In [40]:
total_mistakes=[]
total_correct=[]
for item in pairs:
    total_mistakes.append(item[0])
    total_correct.append(item[1])
pairs0=[]
pairs1=[]
pairs2=[]
pairs3=[]
pairs4=[]
pairs5=[]
#pairs6=[]
for srcs, trgs in zip(total_mistakes,total_correct):
    i=0
    for index, (src, trg) in enumerate(zip(srcs, trgs)):
        if src != trg:
            i=i+1
    if is_same_elements(srcs, trgs) and srcs!=trgs:
        pairs6.append((srcs, trgs))
        continue
    if i>int(len(srcs)/5):
        pairs6.append((srcs, trgs))
        continue
    if i==0:
        pairs0.append((srcs, trgs))
    if i==1:
        pairs1.append((srcs, trgs))
    if i==2:
        pairs2.append((srcs, trgs))
    if i==3:
        pairs3.append((srcs, trgs))
    if i==4:
        pairs4.append((srcs, trgs))
    if i==5:
        pairs5.append((srcs, trgs))
    if i>5:
        pairs6.append((srcs, trgs))
print(len(pairs0),len(pairs1),len(pairs2),len(pairs3),len(pairs4),len(pairs5),len(pairs6))

223757 481665 204466 61297 20812 8238 2491930


In [41]:
pairs=pairs0+pairs1+pairs2+pairs3+pairs4+pairs5
total_mistakes=[]
total_correct=[]
for item in pairs:
    total_mistakes.append(item[0])
    total_correct.append(item[1])

In [42]:
len(total_mistakes)

1000235

In [43]:
with open("CSC_train/train_no_Aug.src", "w", encoding="utf-8") as file:
    i = 0
    for item in total_mistakes:
        line = "train_CSC_"+str(i) + "\t" + item + "\n"
        file.write(line)
        i=i+1
with open("CSC_train/train_no_Aug.trg", "w", encoding="utf-8") as file:
    i = 0
    for item in total_correct:
        line = "train_CSC_"+str(i) + "\t" + item + "\n"
        file.write(line)
        i=i+1

In [44]:
tgt=set()
for item in pairs6:
    tgt.add(item[1])
tgt.remove("")
with open("CSC_train/train_Aug.txt", "w", encoding="utf-8") as file:
    i = 0
    for item in tgt:
        line = "train_CSC_Aug_"+str(i) + "\t" + item + "\n"
        file.write(line)
        i=i+1

In [45]:
len(tgt)

1192788

In [46]:
total_mistakes=[]
total_correct=[]
for line in open("./CSC_train/augmentation.csv",'r',encoding='utf8'): 
    lines=line[:-1].split("\t")
    if len(lines)!=3:
        continue
    total_mistakes.append(lines[2])
    total_correct.append(lines[1])

In [47]:
len(total_mistakes)

490685

In [48]:
total_mistakes=total_mistakes[1:]
total_correct=total_correct[1:]
for line in open("./CSC_train/train_no_Aug.src",'r',encoding='utf8'): 
    total_mistakes.append(line[:-1].split("\t")[1])
for line in open("./CSC_train/train_no_Aug.trg",'r',encoding='utf8'): 
    total_correct.append(line[:-1].split("\t")[1])

In [49]:
len(total_mistakes)

1490919

In [50]:
pairs=list(zip(total_mistakes,total_correct))
import random
random.shuffle(pairs)
srcs1,trgs1=zip(*pairs)
from collections import Counter

def is_same_elements(list1, list2):
    return Counter(list1) == Counter(list2)

pairs0=[]
pairs1=[]
pairs2=[]
pairs3=[]
pairs4=[]
pairs5=[]
pairs6=[]
for srcs, trgs in zip(srcs1,trgs1):
    i=0
    for index, (src, trg) in enumerate(zip(srcs, trgs)):
        if src != trg:
            i=i+1
    if is_same_elements(srcs, trgs) and srcs!=trgs:
        pairs6.append((srcs, trgs))
        continue
    if i>int(len(srcs)/5):
        pairs6.append((srcs, trgs))
        continue
    if i==0:
        pairs0.append((srcs, trgs))
    if i==1:
        pairs1.append((srcs, trgs))
    if i==2:
        pairs2.append((srcs, trgs))
    if i==3:
        pairs3.append((srcs, trgs))
    if i==4:
        pairs4.append((srcs, trgs))
    if i==5:
        pairs5.append((srcs, trgs))
    if i>5:
        pairs6.append((srcs, trgs))
print(len(pairs0),len(pairs1),len(pairs2),len(pairs3),len(pairs4),len(pairs5),len(pairs6))

224018 721007 389507 101539 36082 13735 5031


In [51]:
with open("./CSC_train/train_Aug.src", "w", encoding="utf-8") as file:
    i = 0
    for item in srcs1:
        line = "train_Aug_"+str(i) + "\t" + item + "\n"
        file.write(line)
        i=i+1
with open("./CSC_train/train_Aug.trg", "w", encoding="utf-8") as file:
    i = 0
    for item in trgs1:
        line = "train_Aug_"+str(i) + "\t" + item + "\n"
        file.write(line)
        i=i+1