# Dataset Creator -- English and Chinese  
## Import libraries

In [3]:
import matplotlib.pyplot as plt
%matplotlib inline
import os
import json
import random
import pickle
import torch
import numpy as np
import re
from tqdm.notebook import tqdm
from sklearn.utils import shuffle
from transformers import AutoTokenizer
import warnings
warnings.filterwarnings("ignore")

## Open IWSLT Files

In [4]:
model_type = 'xlm-roberta-base' #albert-base-v1, bert-base-cased, bert-base-uncased
data_path_zh = "../dataset/zh-en/"

with open(data_path_zh + 'train_texts_zh.txt', 'r', encoding='utf-8') as f:
    train_text_zh = f.readlines()
with open(data_path_zh + 'dev_texts_zh.txt', 'r', encoding='utf-8') as f:
    valid_text_zh = f.readlines()
with open(data_path_zh + 'test_texts_zh.txt', 'r', encoding='utf-8') as f:
    test_text_zh = f.readlines()

data_path_en = "../dataset/en-fr/"

with open(data_path_en + 'train_texts.txt', 'r', encoding='utf-8') as f:
    train_text_en = f.readlines()
with open(data_path_en + 'dev_texts.txt', 'r', encoding='utf-8') as f:
    valid_text_en = f.readlines()
with open(data_path_en + 'test_texts_2012.txt', 'r', encoding='utf-8') as f:
    test_text_en = f.readlines()

'''
train_text = random.shuffle(train_text_en+train_text_zh)
valid_text = random.shuffle(valid_text_en+valid_text_zh)
test_text  = random.shuffle(test_text_en+test_text_zh)
'''

'\ntrain_text = random.shuffle(train_text_en+train_text_zh)\nvalid_text = random.shuffle(valid_text_en+valid_text_zh)\ntest_text  = random.shuffle(test_text_en+test_text_zh)\n'

In [5]:
datasets = train_text_en, valid_text_en, test_text_en
datasets_zh = train_text_zh, valid_text_zh, test_text_zh

In [6]:
def clean_text_zh(text):
    text = text.replace('！', '。')
    text = text.replace('：', '，')
    text = text.replace('——', '，')
    
    #reg = "(?<=[a-zA-Z])-(?=[a-zA-Z]{2,})"
    #r = re.compile(reg, re.DOTALL)
    #text = r.sub(' ', text)
    
    text = re.sub(r'\s—\s', ' ， ', text)
    
#     text = text.replace('-', ',')
    text = text.replace(';', '。')    # replace symbols with the most relevant counterparts
    text = text.replace('、', '，')
    text = text.replace('♫', '')
    text = text.replace('……', '')
    text = text.replace('。”', '')
    text = text.replace('”', '，')
    text = text.replace('“','，')
    text = text.replace(',','，')
    

    text = re.sub(r'——\s?——', '', text) # replace --   -- to ''
    text = re.sub(r'\s+', ' ', text)    # strip all whitespaces
    
    text = re.sub(r'，\s?，', '，', text)  # merge commas separating only whitespace
    text = re.sub(r'，\s?。', '。', text) # , . -> ,
    text = re.sub(r'？\s?。', '？', text)# ? . -> ?
    text = re.sub(r'\s+', ' ', text)    # strip all redundant whitespace that could have been caused by preprocessing
    
    text = re.sub(r'\s+？', '？', text)
    text = re.sub(r'\s+，', '，', text)
    text = re.sub(r'。[\s+。]+', '。 ', text)
    text = re.sub(r'\s+。', '。 ', text)
    
    return text.strip().lower()

def clean_text_en(text):
    text = text.replace('!', '.')
    text = text.replace(':', ',')
    text = text.replace('--', ',')
    
    reg = "(?<=[a-zA-Z])-(?=[a-zA-Z]{2,})"
    r = re.compile(reg, re.DOTALL)
    text = r.sub(' ', text)
    
    text = re.sub(r'\s-\s', ' , ', text)
    
#     text = text.replace('-', ',')
    text = text.replace(';', '.')    # replace symbols with the most relevant counterparts
    text = text.replace(' ,', ',')
    text = text.replace('♫', '')
    text = text.replace('...', '')
    text = text.replace('.\"', ',')
    text = text.replace('"', ',')

    text = re.sub(r'--\s?--', '', text) # replace --   -- to ''
    text = re.sub(r'\s+', ' ', text)    # strip all whitespaces
    
    text = re.sub(r',\s?,', ',', text)  # merge commas separating only whitespace
    text = re.sub(r',\s?\.', '.', text) # , . -> ,
    text = re.sub(r'(?<=[a-zA-Z0-9]),(?=[a-zA-Z0-9])',', ',text) # say,you -> say, you
    text = re.sub(r'\?\s?\.', '?', text)# ? . -> ?
    text = re.sub(r'\s+', ' ', text)    # strip all redundant whitespace that could have been caused by preprocessing
    
    text = re.sub(r'\s+\?', '?', text)
    text = re.sub(r'\s+,', ',', text)
    text = re.sub(r'\.[\s+\.]+', '. ', text)
    text = re.sub(r'\s+\.', '.', text)
    
    return text.strip().lower()

In [7]:
datasets_en = train_text_en, valid_text_en, test_text_en
datasets_zh = train_text_zh, valid_text_zh, test_text_zh

datasets_zh = [[clean_text_zh(text) for text in ds] for ds in datasets_zh]
datasets_en = [[clean_text_en(text) for text in ds] for ds in datasets_en]

In [8]:
tokenizer = AutoTokenizer.from_pretrained(model_type)

In [15]:

target_token2id_en = {t: tokenizer.encode(t)[-2] for t in ".?,"}
target_token2id_zh = {t: tokenizer.encode(t)[-2] for t in "。？，"}
target_ids_en = list(target_token2id_en.values())
target_ids_zh = list(target_token2id_zh.values())
target_ids_en, target_ids_zh

([5, 705, 4], [30, 705, 4])

In [16]:
id2target_zh = {
    0: 0,
    -1: -1,
}
for i, ti in enumerate(target_ids_zh):
    id2target_zh[ti] = i+1
target2id_zh = {value: key for key, value in id2target_zh.items()}
print(id2target_zh, target2id_zh)


{0: 0, -1: -1, 30: 1, 705: 2, 4: 3} {0: 0, -1: -1, 1: 30, 2: 705, 3: 4}


In [17]:
import jieba
id2target_en = {
    0: 0,
    -1: -1,
}
for i, ti in enumerate(target_ids_en):
    id2target_en[ti] = i+1
target2id_en = {value: key for key, value in id2target_en.items()}

def create_target_zh(text):
    encoded_words, targets = [], []
    
    words = list(jieba.cut(text,HMM=True)) ## ignore the first space
    words2 = []
    for i in range(len(words)):
        encoded_word = tokenizer.encode(words[i])
        #print(words[i],encoded_word)
        if (len(encoded_word[1:-1]) > 1 and encoded_word[1] != 6) or (len(encoded_word[1:-1]) > 2 and encoded_word[1] == 6):
            for word in encoded_word[1:-1]:
                if word != 6:
                    encoded_words.append(word)
                    targets.append(-1)
            targets = targets[:-1]   
        elif len(encoded_word[1:-1]) == 0:
            continue
        else:
            #print("Here! ",encoded_word)
            s = 2 if encoded_word[1] == 6 else 1
            encoded_words.append(encoded_word[s])
                
            
            
        if words[i] not in ["。","？","，"," ","▁"]:
            if i < len(words) -1 and words[i+1] in ["。","？","，"," ","▁"]:
                ##words2.append(words[i])
                targets.append(0)
                pass
            else:
                targets.append(0)
                encoded_words.append(6)
                targets.append(0)
        else:
            if words[i] in ["▁"," "]:
                if i > 0 and words[i-1] not in ["。","？","，"," ","▁"]:
                    #encoded_words.append(" ")
                    #targets.append(0)
                    pass
            else:
                #print("YES",words[i])
                targets.append(id2target_zh[target_token2id_zh[words[i]]])
                # words2.append(words[i])
    
    encoded_words = [tokenizer.cls_token_id or tokenizer.bos_token_id] +\
                    encoded_words +\
                    [tokenizer.sep_token_id or tokenizer.eos_token_id]
    
    targets = [-1]+ targets + [-1]    
    
    return encoded_words, targets

def create_target_en(text):
    encoded_words, targets = [], []
    
    words = text.split(' ')

    for word in words:
        target = 0
        for target_token, target_id in target_token2id_en.items():
            if word.endswith(target_token):
                word = word.rstrip(target_token)
                target = id2target_en[target_id]

        encoded_word = tokenizer.encode(word, add_special_tokens=False)
        
        for w in encoded_word:
            encoded_words.append(w)
        for _ in range(len(encoded_word)-1):
            targets.append(-1)
        targets.append(0)
        
        if target != 0:
            encoded_words.append(target2id_en[target])
        else:
            encoded_words.append(6)
        targets.append(target)
        
        
#         print([tokenizer._convert_id_to_token(ew) for ew in encoded_word], target)
        assert(len(encoded_word)>0)

    encoded_words = [tokenizer.cls_token_id or tokenizer.bos_token_id] +\
                    encoded_words +\
                    [tokenizer.sep_token_id or tokenizer.eos_token_id]
    targets = [-1] + targets + [-1]
    
    return encoded_words, targets

In [18]:
# sentence endings split
encoded_texts_zh, targets_zh = [], []

for ds in datasets_zh:
    trgts = []
    '''
    for ts in ds:
        prev = 0
        init = 0
        #print("Length of sequence: ",len(ts))
        for i in range(len(ts)):
            if  ts[i] in ["。",".","？","?"]:
                if i > init+511:
                    if prev == 0:
                        #print("truncating first sentence")
                        trgts.append(ts[0:512])
                        prev = 511
                        init = 511
                    else:
                        if prev == init: 
                            prev = i
                        #print("appending from ",init," to ",prev)
                        if prev - init > 511:
                            #print("CHUNKing sentence")
                            ls = ts[init+1:prev+1]
                            trgts.extend([ls[i:i+512] for i in range(0,prev-init,512)])
                        else:
                            trgts.append(ts[init+1:prev+1])
                        init = prev
                        prev = init
                else:
                    prev = i
        
        if prev < len(ts)-1:
            #print("appending last sentence from ",prev," to ",len(ts)-1)
            #if(len(ts)-1 - prev > 511):
            #    #print("chunking last sentence")
            trgts.extend([ts[i:i+512] for i in range(prev,len(ts)-1,512)])
            #trgts.append(ts[prev:len(ts)])
    '''
    x = list(zip(*(create_target_zh(trgt) for trgt in tqdm(ds)))) # use "trgts" instead of "ds" if you want 512, the warning can be ignored
    encoded_texts_zh.append(x[0])
    targets_zh.append(x[1])
    
encoded_texts_en, targets_en = [], []

for ds in datasets_en:
    x = list(zip(*(create_target_en(ts) for ts in tqdm(ds))))
    encoded_texts_en.append(x[0])
    targets_en.append(x[1])

  0%|          | 0/1017 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/1029 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

In [19]:
print(id2target_zh)
s = "谁能猜一猜：你大脑里神经元的总长有多少？ ”西班牙厨师被控告……“ 非常坚硬的土地。西班牙厨师被控告"
print(s)
s = clean_text_zh(s)
print(s)
data, tgts = create_target_zh(s)
#print(data)
#print(targets)
print([(tokenizer._convert_id_to_token(d), ta) for d,ta in zip(data[2:-1], tgts[2:-1])])

print(id2target_en)
# s = "Tyranosaurus: kill me? Not enough, rumplestilskin -- said the co-pilot -- ..."
s = "it  can  be  a  very  complicated  thing, the  ocean. and  it  can  be  a  very  complicated  thing, what  human  health  is."
print(s)
s = clean_text_en(s)
print(s)
data, tgts = create_target_en(s)
print(data)
print(tgts)
print([(tokenizer._convert_id_to_token(d), ta) for d,ta in zip(data[1:-1], tgts[1:-1])])

{0: 0, -1: -1, 30: 1, 705: 2, 4: 3}
谁能猜一猜：你大脑里神经元的总长有多少？ ”西班牙厨师被控告……“ 非常坚硬的土地。西班牙厨师被控告
谁能猜一猜，你大脑里神经元的总长有多少？，西班牙厨师被控告， 非常坚硬的土地。西班牙厨师被控告
[('▁', 0), ('能', 0), ('▁', 0), ('猜', 0), ('▁', 0), ('▁一', 0), ('▁', 0), ('猜', 0), (',', 3), ('▁你', 0), ('▁', 0), ('大脑', 0), ('▁', 0), ('里', 0), ('▁', 0), ('神经', -1), ('元', 0), ('▁', 0), ('的', 0), ('▁', 0), ('总', -1), ('长', 0), ('▁', 0), ('有', 0), ('▁', 0), ('多少', 0), ('▁?', 2), (',', 3), ('西班牙', 0), ('▁', 0), ('厨', -1), ('师', 0), ('▁', 0), ('被', 0), ('▁', 0), ('控', -1), ('告', 0), (',', 3), ('非常', 0), ('▁', 0), ('坚', -1), ('硬', 0), ('▁', 0), ('的', 0), ('▁', 0), ('土地', 0), ('。', 1), ('西班牙', 0), ('▁', 0), ('厨', -1), ('师', 0), ('▁', 0), ('被', 0), ('▁', 0), ('控', -1), ('告', 0), ('▁', 0)]
{0: 0, -1: -1, 5: 1, 705: 2, 4: 3}
it  can  be  a  very  complicated  thing, the  ocean. and  it  can  be  a  very  complicated  thing, what  human  health  is.
it can be a very complicated thing, the ocean. and it can be a very complicated thing, what human health is.
[0, 4

In [20]:
def return_counts(encoded_texts, targets):
    # encoded_words, targets
    comma_count = 0
    word_count = 0
    q_count = 0
    p_count = 0
    space_count = 0
    for target in targets:
        for tar in target:
            for ta in tar:
                comma_count += 1 if (ta == 3) else 0
                q_count += 1 if (ta == 2) else 0
                p_count += 1 if (ta == 1) else 0
    sc = 0
    mwc = 0
    for text,target in zip(encoded_texts, targets):
        for tex,tar  in zip(text,target):
            en = 0
            for t,ta in zip(tex,tar):
                if t not in [6,5,0,-1,1,2,4,705] and ta != -1:
                    word_count+=1
                    en+=1
                elif t == 6 and ta != -1: # space
                    space_count+=1
                elif t in [705, 5]:
                    mwc*=sc
                    sc += 1
                    mwc += en
                    mwc /= sc
                    en = 0
    return space_count, p_count, q_count, comma_count

In [21]:
data_path_dual = "../dataset/en-zh-dual/"
os.makedirs(data_path_dual + model_type, exist_ok=True)
space_count, p_count, q_count, comma_count = map(sum, list(zip(*(return_counts(encoded_texts_zh,targets_zh),return_counts(encoded_texts_en,targets_en)))))
space_count, p_count, q_count, comma_count

(3816779, 228107, 20588, 355804)

In [22]:
encoded_texts = [encoded_texts_en[i]+encoded_texts_zh[i] for i in range(len(encoded_texts_en))]
targets = [targets_en[i]+targets_zh[i] for i in range(len(encoded_texts_en))]
enc = [list(zip(encoded_texts[i],targets[i])) for i in range(len(encoded_texts))]
temp = [random.sample(enc[i], len(enc[i])) for i in range(len(encoded_texts))]
encoded_texts = []
targets = []
for i in temp:
    a,b = list(zip(*i))
    encoded_texts.append(a)
    targets.append(b)

In [23]:
for i in range(0,100):
    print(len(encoded_texts[0][i]),len(targets[0][i]))

6552 6552
6125 6125
7686 7686
4376 4376
7896 7896
724 724
5651 5651
3404 3404
9993 9993
10220 10220
5471 5471
4050 4050
4825 4825
4539 4539
7217 7217
5693 5693
1280 1280
6319 6319
1458 1458
5048 5048
5652 5652
1314 1314
7198 7198
5268 5268
6420 6420
5861 5861
6543 6543
6680 6680
7071 7071
5676 5676
6281 6281
7599 7599
6721 6721
6895 6895
6978 6978
767 767
3594 3594
4955 4955
7110 7110
6703 6703
1238 1238
1716 1716
1149 1149
2477 2477
4825 4825
1227 1227
7223 7223
5428 5428
2627 2627
6713 6713
3954 3954
6011 6011
6039 6039
3757 3757
5739 5739
4217 4217
6762 6762
4703 4703
2699 2699
6641 6641
7470 7470
3385 3385
5544 5544
6143 6143
2133 2133
2731 2731
3083 3083
6739 6739
3991 3991
807 807
6388 6388
6555 6555
565 565
2120 2120
2374 2374
5060 5060
4371 4371
5314 5314
6263 6263
7172 7172
4360 4360
7054 7054
810 810
3295 3295
7335 7335
4130 4130
4123 4123
4643 4643
4185 4185
5472 5472
6845 6845
4747 4747
5866 5866
3494 3494
6565 6565
7718 7718
3390 3390
6592 6592
340 340
5899 5899


In [24]:
for i, name in enumerate(('train', 'valid', 'test')):
    with open(data_path_dual + f'{model_type}/{name}_data.pkl', 'wb') as f:
        pickle.dump((encoded_texts[i], targets[i], space_count, p_count, q_count, comma_count), f)

In [25]:
print(list(zip([tokenizer.convert_ids_to_tokens(token) for token in encoded_texts[0][9][0:20]],targets[0][9][0:20])))

[('<s>', -1), ('▁thank', 0), ('▁', 0), ('▁you', 0), ('▁', 0), ('▁so', 0), ('▁', 0), ('▁much', 0), ('▁', 0), ('▁everyone', 0), ('▁', 0), ('▁from', 0), ('▁', 0), ('▁', -1), ('ted', 0), (',', 3), ('▁and', 0), ('▁', 0), ('▁chr', -1), ('is', 0)]
