In [None]:
!pip install synonyms
!pip install jieba
!pip install jionlp
!pip install wget

In [1]:
import json
import os
import random
import jieba
from jionlp import homophone_substitution, swap_char_position
import wget

# jionlp - 微信公众号: JioNLP  Github: `https://github.com/dongrixinyu/JioNLP`.
# jiojio - `http://www.jionlp.com/jionlp_online/cws_pos` is available for online use.


In [2]:
dir_path = os.path.dirname(os.path.abspath('CAIL2019-SCM数据预处理.ipynb'))  # 获取当前目录位置
data_path = dir_path+"/CAIL2019-SCM"
if os.path.exists(data_path):
    print("Data already exists!")
else:
    os.makedirs(data_path)
    url = "https://cail.oss-cn-qingdao.aliyuncs.com/cail2019/CAIL2019-SCM.zip"
    wget.download(url, out=data_path)
    print("Download!")

Data already exists!


In [5]:
def data_augment(sentence, seed=random.randint(1,10), alpha=0.1, nums=5):
    """
    Args:
        sentence:用于语义增强的句子
        alpha:每条语句中将会被改变的单词数占比
        seed:随机种子，是一个1-10之间的整数，当其大于7的时候，句子将不被增强
        nums:一种方法下每条原始语句增强的语句数
    """
    if seed > 7:
        #print(seed)
        return sentence
    augmented_sentences = []
    
    ## 生成候选句子序列
    ## 随机近义字替换
    res1 = homophone_substitution(sentence,augmentation_num=nums)
    augmented_sentences.extend(res1)
    
    ## 邻近汉字换位
    res2 = swap_char_position(sentence,augmentation_num=nums)
    augmented_sentences.extend(res2)
    
    ## 随机打乱句子顺序
    random.shuffle(augmented_sentences)
    #print(augmented_sentences)
    
    for aug_sentence in augmented_sentences:
        if len(aug_sentence) == len(sentence): # 长度相等则直接输出
            if aug_sentence != sentence:
                return aug_sentence
    else:
        return sentence  # 否则直接输出句子

In [4]:
print(data_augment("经新沂市劳动争议仲裁委员会和贵院调动调解"))

经新沂市动劳争议仲裁委员会和贵院调动调解


In [6]:
raw_data = []

In [7]:
with open("CAIL2019-SCM/test.json","r",encoding='utf-8') as f:
    num = 0
    for line in f.readlines():
        sentence_lst = []  # 初始化句子列表
        content = json.loads(line) # 将字符串数据转为字典
        if 'A' in content.keys():
            content_A = content['A'].split("\n\n")[1].strip() # 获取语料
            content_length = len(content_A)
            # 向前切分句子
            # 否则，则直接找到最大限度开始向前切分
            for k in range(content_length-1,0,-1):
                if content_A[k] in ["。","；"]: 
                    sentence_lst.append(content_A[-1:k:-1][::-1])
                    content_A = content_A[:k+1]
            else:
                if content_length > 512:
                    content_A = content_A[:512]
            
            sentence_lst.append(content_A)  # 将最后的语料压入
            #print(sentence_lst)
            sentence_aug_lst = [data_augment(x) for x in sentence_lst]  # 进行数据增强
            #print(sentence_aug_lst)
            
            for a,b in zip(sentence_aug_lst,sentence_lst):
                #print(a) 
                if a == '':
                    continue
                else:
                    wrong_ids = []
                    for k in range(len(a)):
                        if a[k] != b[k]:
                            wrong_ids.append(k)  # 获得错误文字所在位置
                        
                    idx = "CAIL2019-SCM-A"+str(num)
                    raw_data.append({'id':idx, 'original_text':a, 'wrong_ids':wrong_ids, 'correct_text':b})  # 填入字典
                    num += 1
                
            sentence_lst = []  # 清空句子列表
        
        if 'B' in content.keys():
            content_B = content['B'].split("\n\n")[1].strip() # 获取语料
            content_length = len(content_B)
            # 向前切分句子
            # 否则，则直接找到最大限度开始向前切分
            for k in range(content_length-1,0,-1):
                if content_B[k] in ["。","；"]: 
                    sentence_lst.append(content_B[-1:k:-1][::-1])
                    content_B = content_B[:k+1]
            else:
                if content_length > 512:
                    content_B = content_B[:512]
            
            sentence_lst.append(content_B)  # 将最后的语料压入
        
            sentence_aug_lst = [data_augment(x) for x in sentence_lst]  # 进行数据增强
            
            for a,b in zip(sentence_aug_lst,sentence_lst):
                if a == '':
                    continue
                else:
                    wrong_ids = []
                    for k in range(len(a)):
                        if a[k] != b[k]:
                            wrong_ids.append(k)  # 获得错误文字所在位置
                        
                    idx = "CAIL2019-SCM-B"+str(num)
                    raw_data.append({'id':idx, 'original_text':a, 'wrong_ids':wrong_ids, 'correct_text':b})  # 填入字典
                    num += 1
                
            sentence_lst = []  # 清空句子列表
            
        if 'C' in content.keys():
            content_C = content['C'].split("\n\n")[1].strip() # 获取语料
            content_length = len(content_C)
            # 向前切分句子
            # 否则，则直接找到最大限度开始向前切分
            for k in range(content_length-1,0,-1):
                if content_C[k] in ["。","；"]: 
                    sentence_lst.append(content_C[-1:k:-1][::-1])
                    content_C = content_C[:k+1]
            else:
                if content_length > 512:
                    content_C = content_C[:512]
            
            sentence_lst.append(content_C)  # 将最后的语料压入
        
            sentence_aug_lst = [data_augment(x) for x in sentence_lst]  # 进行数据增强
            
            for a,b in zip(sentence_aug_lst,sentence_lst):
                if a == '':
                    continue
                else:
                    wrong_ids = []
                    for k in range(len(a)):
                        if a[k] != b[k]:
                            wrong_ids.append(k)  # 获得错误文字所在位置
                        
                    idx = "CAIL2019-SCM-C"+str(num)
                    raw_data.append({'id':idx, 'original_text':a, 'wrong_ids':wrong_ids, 'correct_text':b})  # 填入字典
                    num += 1
                
            sentence_lst = []  # 清空句子列表

In [8]:
def save_json(data, json_path, mode='w', encoding='utf-8'):
    dir = os.path.dirname(os.path.abspath(json_path))
    if not os.path.exists(dir):
        print(dir)
        os.makedirs(dir)
    with open(json_path, mode=mode, encoding=encoding) as f:
        f.write(json.dumps(data, ensure_ascii=False, indent=4))

In [9]:
save_json(raw_data,"CAIL2019-SCM/output/test.json") #将其存储起来

In [10]:
raw_data = []

In [11]:
with open("CAIL2019-SCM/train.json","r",encoding='utf-8') as f:
    for line in f.readlines():
        sentence_lst = []  # 初始化句子列表
        content = json.loads(line) # 将字符串数据转为字典
        if 'A' in content.keys():
            content_A = content['A'].split("\n\n")[1].strip() # 获取语料
            content_length = len(content_A)
            # 向前切分句子
            # 否则，则直接找到最大限度开始向前切分
            for k in range(content_length-1,0,-1):
                if content_A[k] in ["。","；"]: 
                    sentence_lst.append(content_A[-1:k:-1][::-1])
                    content_A = content_A[:k+1]
            else:
                if content_length > 512:
                    content_A = content_A[:512]
            
            sentence_lst.append(content_A)  # 将最后的语料压入
            #print(sentence_lst)
            sentence_aug_lst = [data_augment(x) for x in sentence_lst]  # 进行数据增强
            #print(sentence_aug_lst)
            
            for a,b in zip(sentence_aug_lst,sentence_lst):
                #print(a) 
                if a == '':
                    continue
                else:
                    wrong_ids = []
                    for k in range(len(a)):
                        if a[k] != b[k]:
                            wrong_ids.append(k)  # 获得错误文字所在位置
                        
                    idx = "CAIL2019-SCM-A"+str(num)
                    raw_data.append({'id':idx, 'original_text':a, 'wrong_ids':wrong_ids, 'correct_text':b})  # 填入字典
                    num += 1
                
            sentence_lst = []  # 清空句子列表
        
        if 'B' in content.keys():
            content_B = content['B'].split("\n\n")[1].strip() # 获取语料
            content_length = len(content_B)
            # 向前切分句子
            # 否则，则直接找到最大限度开始向前切分
            for k in range(content_length-1,0,-1):
                if content_B[k] in ["。","；"]: 
                    sentence_lst.append(content_B[-1:k:-1][::-1])
                    content_B = content_B[:k+1]
            else:
                if content_length > 512:
                    content_B = content_B[:512]
            
            sentence_lst.append(content_B)  # 将最后的语料压入
        
            sentence_aug_lst = [data_augment(x) for x in sentence_lst]  # 进行数据增强
            
            for a,b in zip(sentence_aug_lst,sentence_lst):
                if a == '':
                    continue
                else:
                    wrong_ids = []
                    for k in range(len(a)):
                        if a[k] != b[k]:
                            wrong_ids.append(k)  # 获得错误文字所在位置
                        
                    idx = "CAIL2019-SCM-B"+str(num)
                    raw_data.append({'id':idx, 'original_text':a, 'wrong_ids':wrong_ids, 'correct_text':b})  # 填入字典
                    num += 1
                
            sentence_lst = []  # 清空句子列表
            
        if 'C' in content.keys():
            content_C = content['C'].split("\n\n")[1].strip() # 获取语料
            content_length = len(content_C)
            # 向前切分句子
            # 否则，则直接找到最大限度开始向前切分
            for k in range(content_length-1,0,-1):
                if content_C[k] in ["。","；"]: 
                    sentence_lst.append(content_C[-1:k:-1][::-1])
                    content_C = content_C[:k+1]
            else:
                if content_length > 512:
                    content_C = content_C[:512]
            
            sentence_lst.append(content_C)  # 将最后的语料压入
        
            sentence_aug_lst = [data_augment(x) for x in sentence_lst]  # 进行数据增强
            
            for a,b in zip(sentence_aug_lst,sentence_lst):
                if a == '':
                    continue
                else:
                    wrong_ids = []
                    for k in range(len(a)):
                        if a[k] != b[k]:
                            wrong_ids.append(k)  # 获得错误文字所在位置
                        
                    idx = "CAIL2019-SCM-C"+str(num)
                    raw_data.append({'id':idx, 'original_text':a, 'wrong_ids':wrong_ids, 'correct_text':b})  # 填入字典
                    num += 1
                
            sentence_lst = []  # 清空句子列表

In [12]:
save_json(raw_data,"CAIL2019-SCM/output/train.json") #将其存储起来

In [13]:
with open("CAIL2019-SCM/valid.json","r",encoding='utf-8') as f:
    for line in f.readlines():
        sentence_lst = []  # 初始化句子列表
        content = json.loads(line) # 将字符串数据转为字典
        if 'A' in content.keys():
            content_A = content['A'].split("\n\n")[1].strip() # 获取语料
            content_length = len(content_A)
            # 向前切分句子
            # 否则，则直接找到最大限度开始向前切分
            for k in range(content_length-1,0,-1):
                if content_A[k] in ["。","；"]: 
                    sentence_lst.append(content_A[-1:k:-1][::-1])
                    content_A = content_A[:k+1]
            else:
                if content_length > 512:
                    content_A = content_A[:512]
            
            sentence_lst.append(content_A)  # 将最后的语料压入
            #print(sentence_lst)
            sentence_aug_lst = [data_augment(x) for x in sentence_lst]  # 进行数据增强
            #print(sentence_aug_lst)
            
            for a,b in zip(sentence_aug_lst,sentence_lst):
                #print(a) 
                if a == '':
                    continue
                else:
                    wrong_ids = []
                    for k in range(len(a)):
                        if a[k] != b[k]:
                            wrong_ids.append(k)  # 获得错误文字所在位置
                        
                    idx = "CAIL2019-SCM-A"+str(num)
                    raw_data.append({'id':idx, 'original_text':a, 'wrong_ids':wrong_ids, 'correct_text':b})  # 填入字典
                    num += 1
                
            sentence_lst = []  # 清空句子列表
        
        if 'B' in content.keys():
            content_B = content['B'].split("\n\n")[1].strip() # 获取语料
            content_length = len(content_B)
            # 向前切分句子
            # 否则，则直接找到最大限度开始向前切分
            for k in range(content_length-1,0,-1):
                if content_B[k] in ["。","；"]: 
                    sentence_lst.append(content_B[-1:k:-1][::-1])
                    content_B = content_B[:k+1]
            else:
                if content_length > 512:
                    content_B = content_B[:512]
            
            sentence_lst.append(content_B)  # 将最后的语料压入
        
            sentence_aug_lst = [data_augment(x) for x in sentence_lst]  # 进行数据增强
            
            for a,b in zip(sentence_aug_lst,sentence_lst):
                if a == '':
                    continue
                else:
                    wrong_ids = []
                    for k in range(len(a)):
                        if a[k] != b[k]:
                            wrong_ids.append(k)  # 获得错误文字所在位置
                        
                    idx = "CAIL2019-SCM-B"+str(num)
                    raw_data.append({'id':idx, 'original_text':a, 'wrong_ids':wrong_ids, 'correct_text':b})  # 填入字典
                    num += 1
                
            sentence_lst = []  # 清空句子列表
            
        if 'C' in content.keys():
            content_C = content['C'].split("\n\n")[1].strip() # 获取语料
            content_length = len(content_C)
            # 向前切分句子
            # 否则，则直接找到最大限度开始向前切分
            for k in range(content_length-1,0,-1):
                if content_C[k] in ["。","；"]: 
                    sentence_lst.append(content_C[-1:k:-1][::-1])
                    content_C = content_C[:k+1]
            else:
                if content_length > 512:
                    content_C = content_C[:512]
            
            sentence_lst.append(content_C)  # 将最后的语料压入
        
            sentence_aug_lst = [data_augment(x) for x in sentence_lst]  # 进行数据增强
            
            for a,b in zip(sentence_aug_lst,sentence_lst):
                if a == '':
                    continue
                else:
                    wrong_ids = []
                    for k in range(len(a)):
                        if a[k] != b[k]:
                            wrong_ids.append(k)  # 获得错误文字所在位置
                        
                    idx = "CAIL2019-SCM-C"+str(num)
                    raw_data.append({'id':idx, 'original_text':a, 'wrong_ids':wrong_ids, 'correct_text':b})  # 填入字典
                    num += 1
                
            sentence_lst = []  # 清空句子列表

In [14]:
save_json(raw_data,"CAIL2019-SCM/output/dev.json") #将其存储起来