In [1]:
import json, os, re, random, string
from tqdm import tqdm
from ckiptagger import data_utils, construct_dictionary, WS, POS, NER
from rank_bm25 import BM25Okapi
from opencc import OpenCC
cc_t2s = OpenCC('t2s')
cc_s2t = OpenCC('s2t')



In [2]:
# path = "/home/zchen/encyclopedia-text-style-transfer/tools/ckip/"
# data_utils.download_data_url(path)
# os.environ["CUDA_VISIBLE_DEVICES"] = "0"

# path = path + "data"
# ws = WS(path, disable_cuda=False)
# pos = POS(path, disable_cuda=False)
# ner = NER(path, disable_cuda=False)

In [3]:
def load(file):
    with open(file, 'r', encoding="utf-8") as f:
        return f.read()
        
def save(file, string):
    with open(file, 'w', encoding="utf-8") as f:
        f.write(string)
        
def load_json(file):
    with open(file, 'r', encoding="utf-8") as f:
        return json.load(f)

def save_json(file, obj):
    with open(file, 'w', encoding="utf-8") as f:
        json.dump(obj, f, ensure_ascii=False)

In [4]:
def clean_ltn_news(news):
    news = news.strip()
    
    # Remove 〔...〕 at the beginning.
    m = re.match(r"〔.*〕", news)
    if m:
        end = m.end()
        news = news[end:].strip()
    
    news = news.split('\n')
    
    # Remove (...) at the end.
    news[-1] = re.sub(r"[(（].*[)）]", "", news[-1]).strip()
    
    return '\n'.join(news)

def clean_cn_news(news):
    news = news.strip()
    
    # Remove ...日电 at the beginning.
    m = re.match(r".*日电", news)
    if m:
        end = m.end()
        news = news[end:].strip()
    
    news = news.split('\n')
    
    # Remove (...) at the beginning.
    m = re.search(r"[(（].*[)）]", news[0])
    if m:
        end = m.end()
        news[0] = news[0][end:].strip()
    
    # Remove (...) at the end.
    news[-1] = re.sub(r"[(（].*[)）]", "", news[-1]).strip()
        
    return '\n'.join(news)

In [5]:
file = "/home/zchen/encyclopedia-text-style-transfer/data/political_news/raw/Chinanews_articles.json"
cn_arts = [{"headline": news["headline"], "text": clean_cn_news(news["text"])} for news in load_json(file)]
print(len(cn_arts))

files = [
    "/home/zchen/encyclopedia-text-style-transfer/data/political_news/raw/articles_3800k-3825k.json",
    "/home/zchen/encyclopedia-text-style-transfer/data/political_news/raw/articles_3825k-3850k.json",
    "/home/zchen/encyclopedia-text-style-transfer/data/political_news/raw/articles_3850k-3900k.json"
]
tw_arts = []
for file in files:
    tw_arts += [{"headline": news["headline"], "text": clean_ltn_news(news["text"])} for news in load_json(file)]
print(len(tw_arts))

3849
7496


In [6]:
def remove_punctuation(line):
    rule = re.compile(r"[^a-zA-Z0-9\u4e00-\u9fa5]")
    line = rule.sub(' ',line)
    return line

def tokenize_arts(arts, file):
    if os.path.exists(file):
        return load(file).strip().split('\n')
    else:
        tok_arts = []
        for i, art in enumerate(tqdm(arts)):
            art = cc_s2t.convert(art["headline"] + '\n' + art["text"])
            art = [sent for sent in art.split('\n') if sent]
            art_ = []
            for sent in ws(art):
                art_ += sent
            tok_arts.append(remove_punctuation(' '.join(art_)))    # Saving space.
        save(file, '\n'.join(tok_arts))
        return tok_arts

def split_sent(arts):
    """ For space saving purpose. """
    return [art.split() for art in arts]

In [7]:
# 1hr
file = "/home/zchen/encyclopedia-text-style-transfer/data/political_news/tok_cn_arts"
tok_cn_arts = tokenize_arts(cn_arts, file)
assert len(tok_cn_arts) == len(cn_arts), (cn_arts, tok_cn_arts)

file = "/home/zchen/encyclopedia-text-style-transfer/data/political_news/tok_tw_arts"
tok_tw_arts = tokenize_arts(tw_arts, file)
assert len(tok_tw_arts) == len(tw_arts), (tw_arts, tok_tw_arts)

In [8]:
def get_art_pairs(src_arts, tok_src_arts, targ_arts, tok_targ_arts, file):
    if os.path.exists(file):
        return load_json(file)
    else:
        bm25 = BM25Okapi(split_sent(tok_targ_arts))
        
        # 30min
        iterable = zip(src_arts, split_sent(tok_src_arts))
        art_pairs = [
            (art, bm25.get_top_n(tok_art, list(range(len(targ_arts))), n=5))
            for art, tok_art in tqdm(list(iterable))
        ]
        save_json(file, art_pairs)
        return art_pairs
    
file = "/home/zchen/encyclopedia-text-style-transfer/data/political_news/art_pairs.json"
art_pairs = get_art_pairs(cn_arts, tok_cn_arts, tw_arts, tok_tw_arts, file)

In [9]:
print(len(art_pairs))

3849


In [10]:
def art2sents_str(art):
    return art["headline"] + '\n' + art["text"]
    
def get_s_t(s):
    s_s, s_t = cc_t2s.convert(s), cc_s2t.convert(s)
    s_s = [sent for sent in s_s.split('\n') if len(sent) > 5]
    s_t = [sent for sent in s_t.split('\n') if len(sent) > 5]
    assert len(s_s) == len(s_t), (len(s_s), len(s_t))
    
    zipped = list(zip(s_s, s_t))
    zipped = list(set(zipped))
    s_s, s_t = list(zip(*zipped))
    assert len(s_s) == len(s_t), (len(s_s), len(s_t))
    return s_s, s_t
    
def sents2ents(sentence_list):
    word_sentence_list = ws(
        sentence_list,
        # sentence_segmentation = True, # To consider delimiters
        # segment_delimiter_set = {",", "。", ":", "?", "!", ";"}), # This is the defualt set of delimiters
        # recommend_dictionary = dictionary1, # words in this dictionary are encouraged
        # coerce_dictionary = dictionary2, # words in this dictionary are forced
    )
    pos_sentence_list = pos(word_sentence_list)
    entity_sentence_list = ner(word_sentence_list, pos_sentence_list)
    
    ent_sents = []
    for sent, word_sent, pos_sent, ent_sent in zip(sentence_list, word_sentence_list, pos_sentence_list, entity_sentence_list):
        ents = [word for word, pos in zip(word_sent, pos_sent) if re.match(r"N[a-d]", pos)]
        # re.match() determines whether the pattern matches the beginning of the string,
        # returns the matching object if it matches, and returns None if it does not match.
        
        ents += [ent[3] for ent in ent_sent]
        ent_sents.append(set(ents))
    return ent_sents
    
def get_score(src_ent_sent, targ_ent_sent):
    intersection_len = len(src_ent_sent & targ_ent_sent)
    if intersection_len == 0:
        return 0
    precision = intersection_len / len(targ_ent_sent)
    recall = intersection_len / len(src_ent_sent)
    targ_len = len(targ_ent_sent)
    return (2*precision*recall) / (precision+recall)
    
def filter_sent(zipped, p):
    """ filtering strategy, e.g., highest k sents or a threshold score p """
#     return zipped[:k]
    
    sents = []
    for sent, score in zipped:
        if score >= p:
            sents.append((sent, score))
        else:
            break
    return sents

def get_pair(src_sent, targ_sents, src_ent_sent, targ_ent_sents):
    scores = [get_score(src_ent_sent, targ_ent_sent) for targ_ent_sent in targ_ent_sents]
    zipped = list(zip(targ_sents, scores))
    zipped.sort(key=lambda x: x[1], reverse=True)
    matched = filter_sent(zipped, p=0.3)
    return (src_sent, matched)

def get_pairs(src_sents_str, targ_sents_str):
    src_sents_s, src_sents_t = get_s_t(src_sents_str)
    targ_sents_s, targ_sents_t = get_s_t(targ_sents_str)
    src_ent_sents, targ_ent_sents = sents2ents(src_sents_t), sents2ents(targ_sents_t)
    return [get_pair(src_sent, targ_sents_s, src_ent_sent, targ_ent_sents) for src_sent, src_ent_sent in zip(src_sents_s, src_ent_sents)]

def get_sent_pairs(art_pairs, targ_arts, file):
    if os.path.exists(file):
        return load_json(file)
    else:
        sent_pairs = []
        """
        [
            (src_sent_1, [(targ_sent_n, score), ...]),
            ...
            ]
        """

        for i, (src_art, targ_art_ids) in enumerate(tqdm(art_pairs)):
            src_sents_str = art2sents_str(src_art)
            targ_sents_str = '\n'.join([art2sents_str(targ_arts[i]) for i in targ_art_ids])
            sent_pairs += get_pairs(src_sents_str, targ_sents_str)
        save_json(file, sent_pairs)
        return sent_pairs

# 2hr
file = "/home/zchen/encyclopedia-text-style-transfer/data/political_news/sent_pairs.json"
sent_pairs = get_sent_pairs(art_pairs, tw_arts, file)

In [20]:
n_src = sum([len(targs) > 0 for src, targs in sent_pairs])
n_pairs = sum([len(targs) for src, targs in sent_pairs])
avg_pairs = n_pairs / n_src
print(len(sent_pairs), n_src, n_pairs, avg_pairs)
print(sent_pairs[:10])

31744 8819 22782 2.5832860868579206
[['俄乌代表在谈判开始前握手', [['首轮停战谈判无果！ 乌俄第二回合会谈2日举行', 0.4], ['乌俄谈判第一回合结束！下次会谈选在波兰、白俄交界', 0.3333333333333333], ['乌俄第三轮谈判今晚10点开始 双方外长周四将会面', 0.30769230769230765]]], ['当地时间3月3日，俄罗斯与乌克兰的第二轮谈判正式开始。据俄罗斯卫星通讯社消息，乌克兰总统办公室主任顾问波多利亚克表示，列入与俄谈判议程的有停火、休战、开放疏散人员的人道主义走廊。俄罗斯代表团和乌克兰代表团的成员在会谈时相互握手。', [['《CNN》指出，稍早网路上一段影片显示，乌俄双方代表团已碰面，在坐下会谈前先相互握手。乌克兰代表团成员之一、乌克兰总统办公室主任顾问波多利亚克在推特分享一张乌俄代表已就定位准备谈判的照片指出，「我们已开始与俄罗斯代表会谈，此次主要议程有三，『立即停火』、『停战撤军』、『设人道走廊将无辜平民从被毁或不断遭受砲火的城镇救出』。」', 0.4827586206896552], ['乌克兰总统泽伦斯基（Volodymyr Zelenskiy）的助理在台湾时间3日晚间透露，乌克兰代表团已搭乘直升机赴与俄谈判地点。稍早，乌克兰代表团成员之一的波多利亚克（Mykhailo Podoliak）在推特PO出照片，显示乌俄第二轮谈判已经展开！', 0.4230769230769231], ['乌克兰总统顾问波多利亚克（Mikhailo Podolyak）今日稍早在推特上表示，「与俄罗斯联邦的第三轮谈判，将在基辅时间16点开始。代表人员没有变动。」不过，波多利亚克没有透露谈判地点。', 0.35555555555555557], ['乌克兰与俄罗斯在2月28进行首轮停战谈判无果，俄官媒昨（3/1）声称2日会举行第2轮会谈，克里姆林宫也于周三表示希望乌克兰代表团能依约前去谈判，稍早乌克兰总统助理证实，双方将于当地时间2日开展第二回合谈判，但确切时间没透露。', 0.35294117647058826], ['乌俄首轮谈判约在2月28日当地时间下午1时（台湾时间晚间7时）在白俄罗斯邻近乌克兰边境的戈梅利州（Gomel region）展开，乌克兰总统办公室主任顾问波多利亚克（My

In [28]:
def split_data(pairs, maxlen=254, n_test=1000, seed="TST"):
    # remove long sentences
    _pairs = pairs
    pairs = []
    for src, targs in _pairs:
        if len(src) <= maxlen:
            targs = [(targ, score) for targ, score in targs if len(targ) <= maxlen]
            if targs:
                pairs.append((src, targs))
    
    random.seed(seed)
    random.shuffle(pairs)
    random.seed()
            
    sep1 = -2 * n_test
    sep2 = -n_test
    return pairs[:sep1], pairs[sep1:sep2], pairs[sep2:]

def pairs2datasets(pairs):
    train_pairs, valid_pairs, test_pairs = split_data(pairs)
    
    train_ds = []
    for src, targs in train_pairs:
        train_ds += [(src, targ) for targ, score in targs]
    
    valid_ds = [(src, targs[0][0]) for src, targs in valid_pairs if targs]
    test_ds = [(src, targs[0][0]) for src, targs in test_pairs if targs]
    return train_ds, valid_ds, test_ds

In [29]:
train_ds, valid_ds, test_ds = pairs2datasets(sent_pairs)
print(len(train_ds))
print(train_ds[:10])
print()
print(len(valid_ds))
print(valid_ds[:10])
print()
print(len(test_ds))
print(test_ds[:10])

17384
[('中国国务委员兼外长王毅29日同欧盟外交与安全政策高级代表博雷利举行视频会见。王毅强调，处理复杂的安全问题，不应采取非敌即友、非黑即白的简单化做法。各国都有权独立自主地决定自己的对外政策。事实证明，冷战思维、阵营对抗的老路在欧洲已经行不通了，选边站队、分裂世界的做法更不可取。极限制裁只会导致互相伤害，使形势更加复杂，矛盾更加激化。让非当事方的国家和人民为冲突埋单，既不公正，也不合法。中方愿同各方一道，照顾各方正当合理关切，朝着争取俄乌冲突尽快解决、欧洲尽快恢复和平的大方向做出努力。', '中国与俄罗斯同样反对北约扩张，但并不等同于支持俄乌开战。中方多次在公开声明中促进各方缓和乌克兰紧张局势，支持俄乌开启谈判对话。欧盟外交与安全政策高级代表波瑞尔（Josep Borrell）日前表示，只有中国能充当俄乌调停人，中国驻欧盟使团当地时间5日回应，中方鼓励俄乌直接谈判。中国驻欧盟使团的回应，维持了不介入姿态。'), ('东西问·中外对话 | 斯蒂芬·佩里：为什么说美国“错过”了中国的崛起？', '与此同时，中国敦促美国尊重并解决俄罗斯对于安全保证的要求。'), ('3月4日，十三届全国人大五次会议新闻发布会在人民大会堂举行，大会发言人张业遂说——', '吴谦今以全国人大五次会议解放军和武警部队代表团新闻发言人身分，讲解军费安排，他称，中国国防支出预算每年都纳入政府预算草案，并由人民代表大会审查后依法使用，也对外公开预算总额。'), ('俄媒报道称，乌克兰总统办公室顾问证实，乌方代表团尚未抵达谈判地点，并敦促等待对话的真正开始。', '乌俄第2轮谈判开始！乌克兰总统顾问曝议程3重点'), ('俄媒报道称，乌克兰总统办公室顾问证实，乌方代表团尚未抵达谈判地点，并敦促等待对话的真正开始。', '乌克兰总统泽伦斯基（Volodymyr Zelenskiy）的助理在台湾时间3日晚间透露，乌克兰代表团已搭乘直升机赴与俄谈判地点。稍早，乌克兰代表团成员之一的波多利亚克（Mykhailo Podoliak）在推特PO出照片，显示乌俄第二轮谈判已经展开！'), ('俄媒报道称，乌克兰总统办公室顾问证实，乌方代表团尚未抵达谈判地点，并敦促等待对话的真正开始。', '地点曝光！ 乌克兰证实2日与俄罗斯续谈 谈判代表与首轮相同'), ('俄媒报道称，乌克兰总统办公室顾问证实，乌方代

In [30]:
def save_dataset(dataset, src_file, targ_file):
    src, targ = zip(*dataset)
    save_corpus(src, src_file)
    save_corpus(targ, targ_file)
    
def save_corpus(sent_list, file):
    corpus = '\n'.join(sent_list) + '\n'
    with open(file, 'w', encoding="utf-8") as f:
        f.write(corpus)
    
for dataset, split in [(train_ds, "train"), (valid_ds, "valid"), (test_ds, "test")]:
    src_file = f"/home/zchen/XLM_ETST/data/cn-tw_1k/txt/cn-tw.cn.{split}_raw"
    targ_file = f"/home/zchen/XLM_ETST/data/cn-tw_1k/txt/cn-tw.tw.{split}_raw"
    save_dataset(dataset, src_file, targ_file)

In [6]:
def print_word_pos_sentence(word_sentence, pos_sentence):
    assert len(word_sentence) == len(pos_sentence)
    for word, pos in zip(word_sentence, pos_sentence):
        print(f"{word}({pos})", end="\u3000")
    print()
    return
    
for i, sentence in enumerate(sentence_list):
    print()
    print(f"'{sentence}'")
    print_word_pos_sentence(word_sentence_list[i],  pos_sentence_list[i])
    for entity in sorted(entity_sentence_list[i]):
        print(entity)

NameError: name 'sentence_list' is not defined