In [108]:
import matplotlib.pyplot as plt
%matplotlib inline
import os
import json
import pickle
import torch
import numpy as np
import re
from tqdm.notebook import tqdm
from sklearn.utils import shuffle
from transformers import AutoTokenizer
import warnings
warnings.filterwarnings("ignore")

In [110]:
model_type = 'xlm-roberta-base' #albert-base-v1, bert-base-cased, bert-base-uncased
data_path = "../dataset/zh-en/"

with open(data_path + 'train_texts_zh.txt', 'r', encoding='utf-8') as f:
    train_text = f.readlines()
with open(data_path + 'dev_texts_zh.txt', 'r', encoding='utf-8') as f:
    valid_text = f.readlines()
with open(data_path + 'test_texts_zh.txt', 'r', encoding='utf-8') as f:
    test_text = f.readlines()

In [111]:
datasets = train_text, valid_text, test_text

In [112]:
[len(ds) for ds in datasets]

[1017, 8, 11]

In [113]:
def clean_text(text):
    text = text.replace('！', '。')
    text = text.replace('：', '，')
    text = text.replace('——', '，')
    
    #reg = "(?<=[a-zA-Z])-(?=[a-zA-Z]{2,})"
    #r = re.compile(reg, re.DOTALL)
    #text = r.sub(' ', text)
    
    text = re.sub(r'\s—\s', ' ， ', text)
    
#     text = text.replace('-', ',')
    text = text.replace(';', '。')    # replace symbols with the most relevant counterparts
    text = text.replace('、', '，')
    text = text.replace('♫', '')
    text = text.replace('……', '')
    text = text.replace('。”', '')
    text = text.replace('”', '，')
    text = text.replace('“','，')
    text = text.replace(',','，')
    

    text = re.sub(r'——\s?——', '', text) # replace --   -- to ''
    text = re.sub(r'\s+', ' ', text)    # strip all whitespaces
    
    text = re.sub(r'，\s?，', '，', text)  # merge commas separating only whitespace
    text = re.sub(r'，\s?。', '。', text) # , . -> ,
    text = re.sub(r'？\s?。', '？', text)# ? . -> ?
    text = re.sub(r'\s+', ' ', text)    # strip all redundant whitespace that could have been caused by preprocessing
    
    text = re.sub(r'\s+？', '？', text)
    text = re.sub(r'\s+，', '，', text)
    text = re.sub(r'。[\s+。]+', '。 ', text)
    text = re.sub(r'\s+。', '。 ', text)
    
    #text = re.sub(r'？\s+', '？', text)
    #text = re.sub(r'，\s+', '，', text)
    #text = re.sub(r'。\s+', '。 ', text)
    
    return text.strip().lower()

In [114]:
datasets = [[clean_text(text) for text in ds] for ds in datasets]

In [115]:
[len([t for t in ds if len(t)>0]) for ds in datasets] # remove all 0 word datasets

[1017, 8, 11]

In [116]:
[len(' '.join(ds).split(' ')) for ds in datasets] # make them sentences separated by a space for tokenizing

[316676, 2307, 3608]

In [117]:
tokenizer = AutoTokenizer.from_pretrained(model_type)

In [118]:
target_ids = tokenizer.encode("。？，")[1:-1]
tokenizer.convert_ids_to_tokens(target_ids)

['▁', '。', '?', ',']

In [119]:
target_token2id = {t: tokenizer.encode(t)[-2] for t in "。？，"}
target_token2id

{'。': 30, '？': 705, '，': 4}

In [120]:
target_ids = list(target_token2id.values())
target_token2id.items()
#target_ids

dict_items([('。', 30), ('？', 705), ('，', 4)])

In [121]:
import jieba
id2target = {
    0: 0,
    -1: -1,
}
for i, ti in enumerate(target_ids):
    id2target[ti] = i+1
target2id = {value: key for key, value in id2target.items()}
# print(id2target, target2id)

def create_target(text):
    encoded_words, targets = [], []
    
    words = list(jieba.cut(text,HMM=True)) ## ignore the first space
    words2 = []
    for i in range(len(words)):
        encoded_word = tokenizer.encode(words[i])
        #print(words[i],encoded_word)
        if (len(encoded_word[1:-1]) > 1 and encoded_word[1] != 6) or (len(encoded_word[1:-1]) > 2 and encoded_word[1] == 6):
            for word in encoded_word[1:-1]:
                if word != 6:
                    encoded_words.append(word)
                    targets.append(-1)
            targets = targets[:-1]   
        elif len(encoded_word[1:-1]) == 0:
            continue
        else:
            #print("Here! ",encoded_word)
            s = 2 if encoded_word[1] == 6 else 1
            encoded_words.append(encoded_word[s])
                
            
            
        if words[i] not in ["。","？","，"," ","▁"]:
            if i < len(words) -1 and words[i+1] in ["。","？","，"," ","▁"]:
                ##words2.append(words[i])
                targets.append(0)
                pass
            else:
                targets.append(0)
                encoded_words.append(6)
                targets.append(0)
        else:
            if words[i] in ["▁"," "]:
                if i > 0 and words[i-1] not in ["。","？","，"," ","▁"]:
                    #encoded_words.append(" ")
                    #targets.append(0)
                    pass
            else:
                #print("YES",words[i])
                targets.append(id2target[target_token2id[words[i]]])
                # words2.append(words[i])
    
    encoded_words = [tokenizer.cls_token_id or tokenizer.bos_token_id] +\
                    encoded_words +\
                    [tokenizer.sep_token_id or tokenizer.eos_token_id]
    
    targets = [-1]+ targets + [-1]    
    
    return encoded_words, targets

    '''    
    words = words2
    
    for word in words:
        target = 0
        target_appended = False
        for target_token, target_id in target_token2id.items():
            if word == target_token:
                #word = word.rstrip(target_token)
                encoded_words.append(target_token2id[word])
                targets.append(id2target[target_id])
                target_appended = True
        if not target_appended:
            if word == ' ':
                encoded_words.append(6)
                targets.append(0)
            else:    
                encoded_word = tokenizer.encode(word, add_special_tokens=False)

                if len(encoded_word) == 2:
                    encoded_word = encoded_word[1:]

                for w in encoded_word:
                    encoded_words.append(w)

                if len(encoded_word)>1:
                    for _ in range(len(encoded_word)-1):
                        if encoded_word[_] == 6:
                            targets.append(0)
                        else:
                            targets.append(-1)
                    targets.append(0)
                else:
                    targets.append(0)    

#             print([tokenizer._convert_id_to_token(ew) for ew in encoded_word], target)
            assert(len(encoded_word)>0)
    
    encoded_words = [tokenizer.cls_token_id or tokenizer.bos_token_id] +\
                    encoded_words +\
                    [tokenizer.sep_token_id or tokenizer.eos_token_id]
    
    targets = [-1]+ targets + [-1]
    
    return encoded_words, targets
    '''

In [122]:
print(id2target)
s = "谁能猜一猜：你大脑里神经元的总长有多少？ ”西班牙厨师被控告……“ 非常坚硬的土地。西班牙厨师被控告"
#s = "小明硕士毕业于中国科学院计算所，后在日本京都大学深造"
#s = "算所， 日本"
print(s)
s = clean_text(s)
print(s)
data, targets = create_target(s)
#print(data)
#print(targets)
[(tokenizer._convert_id_to_token(d), ta) for d,ta in zip(data[2:-1], targets[2:-1])]

{0: 0, -1: -1, 30: 1, 705: 2, 4: 3}
谁能猜一猜：你大脑里神经元的总长有多少？ ”西班牙厨师被控告……“ 非常坚硬的土地。西班牙厨师被控告
谁能猜一猜，你大脑里神经元的总长有多少？，西班牙厨师被控告， 非常坚硬的土地。西班牙厨师被控告


[('▁', 0),
 ('能', 0),
 ('▁', 0),
 ('猜', 0),
 ('▁', 0),
 ('▁一', 0),
 ('▁', 0),
 ('猜', 0),
 (',', 3),
 ('▁你', 0),
 ('▁', 0),
 ('大脑', 0),
 ('▁', 0),
 ('里', 0),
 ('▁', 0),
 ('神经', -1),
 ('元', 0),
 ('▁', 0),
 ('的', 0),
 ('▁', 0),
 ('总', -1),
 ('长', 0),
 ('▁', 0),
 ('有', 0),
 ('▁', 0),
 ('多少', 0),
 ('▁?', 2),
 (',', 3),
 ('西班牙', 0),
 ('▁', 0),
 ('厨', -1),
 ('师', 0),
 ('▁', 0),
 ('被', 0),
 ('▁', 0),
 ('控', -1),
 ('告', 0),
 (',', 3),
 ('非常', 0),
 ('▁', 0),
 ('坚', -1),
 ('硬', 0),
 ('▁', 0),
 ('的', 0),
 ('▁', 0),
 ('土地', 0),
 ('。', 1),
 ('西班牙', 0),
 ('▁', 0),
 ('厨', -1),
 ('师', 0),
 ('▁', 0),
 ('被', 0),
 ('▁', 0),
 ('控', -1),
 ('告', 0),
 ('▁', 0)]

In [97]:
# sentence endings split
encoded_texts, targets = [], []

for ds in datasets:
    trgts = []
    for ts in ds:
        prev = 0
        init = 0
        #print("Length of sequence: ",len(ts))
        for i in range(len(ts)):
            if  ts[i] in ["。",".","？","?"]:
                if i > init+511:
                    if prev == 0:
                        #print("truncating first sentence")
                        trgts.append(ts[0:512])
                        prev = 511
                        init = 511
                    else:
                        if prev == init: 
                            prev = i
                        #print("appending from ",init," to ",prev)
                        if prev - init > 511:
                            #print("CHUNKing sentence")
                            ls = ts[init+1:prev+1]
                            trgts.extend([ls[i:i+512] for i in range(0,prev-init,512)])
                        else:
                            trgts.append(ts[init+1:prev+1])
                        init = prev
                        prev = init
                else:
                    prev = i
        
        if prev < len(ts)-1:
            #print("appending last sentence from ",prev," to ",len(ts)-1)
            #if(len(ts)-1 - prev > 511):
            #    #print("chunking last sentence")
            trgts.extend([ts[i:i+512] for i in range(prev,len(ts)-1,512)])
            #trgts.append(ts[prev:len(ts)])
    x = list(zip(*(create_target(trgt) for trgt in tqdm(trgts))))
    encoded_texts.append(x[0])
    targets.append(x[1])

  0%|          | 0/9357 [00:00<?, ?it/s]

  0%|          | 0/76 [00:00<?, ?it/s]

  0%|          | 0/125 [00:00<?, ?it/s]

In [None]:
## basic split
encoded_texts, targets = [], []

for ds in datasets:
    trgts = []
    for ts in ds:
        trgts.extend([ts[i:i+512] for i in range(0,len(ts),512)])
    x = list(zip(*(create_target(trgt) for trgt in tqdm(trgts))))
    encoded_texts.append(x[0])
    targets.append(x[1])

In [123]:
#no split
encoded_texts, targets = [], []

for ds in datasets:
    x = list(zip(*(create_target(ts) for ts in tqdm(ds))))
    encoded_texts.append(x[0])
    targets.append(x[1])

  0%|          | 0/1017 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

In [131]:
# encoded_words, targets
comma_count = 0
word_count = 0
q_count = 0
p_count = 0

for target in targets[:]:
    for tar in target:
        for ta in tar:
            comma_count += 1 if (ta == 3) else 0
            q_count += 1 if (ta == 2) else 0
            p_count += 1 if (ta == 1) else 0

sc = 0
mwc = 0
for text in encoded_texts[:]:
    for tex in text:
        en = 0
        for t in tex:
            if t not in [6,30,0,-1,1,2,4,705]:
                word_count+=1
                en+=1
            elif t in [705, 30, 4]:
                mwc*=sc
                sc += 1
                mwc += en
                mwc /= sc
                en = 0
                
print(mwc)
                
        
print(comma_count, word_count, q_count, p_count)

           

'''
for te, ta in zip(encoded_texts[0][0], targets[0][0]):
    print(f"{tokenizer._convert_id_to_token(te):15}\t{ta}")
'''

10.086594393462324
157004 2591643 10256 86479


'\nfor te, ta in zip(encoded_texts[0][0], targets[0][0]):\n    print(f"{tokenizer._convert_id_to_token(te):15}\t{ta}")\n'

In [128]:
def return_counts(encoded_texts, targets):
    # encoded_words, targets
    comma_count = 0
    word_count = 0
    q_count = 0
    p_count = 0
    space_count = 0
    for target in targets:
        for tar in target:
            for ta in tar:
                comma_count += 1 if (ta == 3) else 0
                q_count += 1 if (ta == 2) else 0
                p_count += 1 if (ta == 1) else 0
    sc = 0
    mwc = 0
    for text,target in zip(encoded_texts, targets):
        for tex,tar  in zip(text,target):
            en = 0
            for t,ta in zip(tex,tar):
                if t not in [6,5,0,-1,1,2,4,705] and ta != -1:
                    word_count+=1
                    en+=1
                elif t == 6 and ta != -1: # space
                    space_count+=1
                elif t in [705, 5]:
                    mwc*=sc
                    sc += 1
                    mwc += en
                    mwc /= sc
                    en = 0
    return space_count, p_count, q_count, comma_count

In [126]:
os.makedirs(data_path + model_type, exist_ok=True)
space_count, p_count, q_count, comma_count = return_counts(encoded_texts,targets)

In [127]:
for i, name in enumerate(('train', 'valid', 'test')):
    with open(data_path + f'{model_type}/{name}_data.pkl', 'wb') as f:
        pickle.dump((encoded_texts[i], targets[i], space_count, p_count, q_count, comma_count), f)

9004


In [102]:
from collections import Counter

for ds_targets in targets:
    c = Counter((target for t in ds_targets for target in t))
    print('\t'.join([str(c[i]) for i in (1,2,3,0,-1)]))

80006	9499	144329	3701035	373362
395	88	1151	32002	3611
848	144	2067	49241	4530


In [None]:
e = []
i = 0

raw_words = datasets[1][2].split(' ')

for te, ta in zip(encoded_texts[1][2], targets[1][2]):
    if ta == -1:
        e.append(te)
    else:
        e.append(te)
        print(f"{tokenizer.decode(e):15}\t{tokenizer.decode(target2id[ta]):10}\t{raw_words[i]}")
        e = []
        i += 1
print(f"{tokenizer.decode(e):15}\t{tokenizer.decode(target2id[ta]):10}\t")

In [107]:
#print(tokenizer.decode([30,4,5 ,
#                        30]))
#print([np.unique(i) for i in encoded_texts[1]])
print(tokenizer.decode(encoded_texts[1][2]), targets[1][2])
[(tokenizer._convert_id_to_token(d), ta) for d,ta in zip(encoded_texts[1][2], targets[1][2])]

<s>这个 项目 叫  photosynth,它 实际上 融合 了 两个 不同 的 技术,一个 是  seadragon, 而 另 一个 则 是 源自 华盛顿大学 的 研究生  noah snavely所 进行 的 计算机 视觉 研究 的 成果。这项 研究 还 得到 了 华盛顿大学  steve seitz和 微软 研究院 rickszeliski 的 协助。这是 一个 非常 漂亮 的 合作 成果。这个 项目  在 互联网 上 已经 得到 应用 了,它 是 基于  seadragon 技术 构建 的。 你 可以 看到, 我们 轻松 地 对 图片 进行 多种 方式 的 查看,从而 能够 对 图片 进行 细致 的 剖析并且 拥有 多 分辨率 的 浏览 体验。 不过,这些 图片  在  三维空间 的 排列 事实上 是 非常 有 意义 的。计算机 视觉 算法 将 这些 图片 联系 到 一起,那么 这些 图片 就 能够 将 真实 空间 呈现 出来 了, 而  我们 正是  在 这个 空间 里 拍 下 了 上述 的 照片,这些 照片 都 是  在加拿大 落基 山脉 的 格拉西 湖  (  grassi lakes  ) 附近 拍下 的, ( 所有 照片  ) 都 是  在 这里 拍下 的。因此  你 可以 看到 这里 的 元素 是 稳定 的 幻灯 放映 或者 全景 成像, 而 这些 内容  在 空间 上 都 是 关联 的。 我 不 确定  我们 是否 有 时间 来 展示 更 多 的 环境 全景。有 很多 例子 比 这个 的 空间感 还要 强。</s> [-1, 0, 0, 0, 0, 0, 0, -1, -1, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, -1, -1, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0, 0, -1, 0, 0, 0, 0, 0, 0, -1, 0, -1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0, 0, -1, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0

[('<s>', -1),
 ('这个', 0),
 ('▁', 0),
 ('项目', 0),
 ('▁', 0),
 ('叫', 0),
 ('▁', 0),
 ('▁photos', -1),
 ('yn', -1),
 ('th', 0),
 (',', 3),
 ('它', 0),
 ('▁', 0),
 ('实际上', 0),
 ('▁', 0),
 ('融合', 0),
 ('▁', 0),
 ('了', 0),
 ('▁', 0),
 ('两个', 0),
 ('▁', 0),
 ('不同', 0),
 ('▁', 0),
 ('的', 0),
 ('▁', 0),
 ('技术', 0),
 (',', 3),
 ('一个', 0),
 ('▁', 0),
 ('是', 0),
 ('▁', 0),
 ('▁se', -1),
 ('ad', -1),
 ('ragon', 0),
 (',', 3),
 ('▁而', 0),
 ('▁', 0),
 ('另', 0),
 ('▁', 0),
 ('一个', 0),
 ('▁', 0),
 ('则', 0),
 ('▁', 0),
 ('是', 0),
 ('▁', 0),
 ('源', -1),
 ('自', 0),
 ('▁', 0),
 ('华盛顿', -1),
 ('大学', 0),
 ('▁', 0),
 ('的', 0),
 ('▁', 0),
 ('研究生', 0),
 ('▁', 0),
 ('▁no', -1),
 ('ah', 0),
 ('▁sna', -1),
 ('ve', -1),
 ('ly', 0),
 ('所', 0),
 ('▁', 0),
 ('进行', 0),
 ('▁', 0),
 ('的', 0),
 ('▁', 0),
 ('计算机', 0),
 ('▁', 0),
 ('视觉', 0),
 ('▁', 0),
 ('研究', 0),
 ('▁', 0),
 ('的', 0),
 ('▁', 0),
 ('成果', 0),
 ('。', 1),
 ('这项', 0),
 ('▁', 0),
 ('研究', 0),
 ('▁', 0),
 ('还', 0),
 ('▁', 0),
 ('得到', 0),
 ('▁', 0),
 ('了', 0),
 ('▁'