In [1]:
import gc
import re

import subprocess
from collections import Counter
from os.path import join as pjoin

import torch
from multiprocess import Pool

In [2]:
from others.logging import logger
from others.tokenization import BertTokenizer

from prepro.utils import _get_word_ngrams

In [3]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)

In [4]:
sep_token = '[SEP]'
cls_token = '[CLS]'
pad_token = '[PAD]'
tgt_bos = '[unused0]'
tgt_eos = '[unused1]'
tgt_sent_split = '[unused2]'
sep_vid = tokenizer.vocab[sep_token]
cls_vid = tokenizer.vocab[cls_token]
pad_vid = tokenizer.vocab[pad_token]

In [5]:
min_src_ntokens_per_sent = 1
max_src_ntokens_per_sent = 600
min_src_nsents = 1
max_src_nsents = 200

min_tgt_ntokens = 1
max_tgt_ntokens = 30

In [6]:
n_cpus = 0

In [8]:
lower = True
use_bert_basic_tokenizer = True

toke_path = "random_tokened/token/"
save_bath = "random_tokened/bert/"

In [9]:
class BertData():
    def __init__(self):
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)

        self.sep_token = '[SEP]'
        self.cls_token = '[CLS]'
        self.pad_token = '[PAD]'
        self.tgt_bos = '[unused0]'
        self.tgt_eos = '[unused1]'
        self.tgt_sent_split = '[unused2]'
        self.sep_vid = self.tokenizer.vocab[self.sep_token]
        self.cls_vid = self.tokenizer.vocab[self.cls_token]
        self.pad_vid = self.tokenizer.vocab[self.pad_token]

    def preprocess(self, src, tgt, sent_labels, use_bert_basic_tokenizer=False, is_test=False):

        if ((not is_test) and len(src) == 0):
            return None

        original_src_txt = [' '.join(s) for s in src]

        idxs = [i for i, s in enumerate(src) if (len(s) > min_src_ntokens_per_sent)]

        _sent_labels = [0] * len(src)
        for l in sent_labels:
            _sent_labels[l] = 1

        src = [src[i][:max_src_ntokens_per_sent] for i in idxs]
        sent_labels = [_sent_labels[i] for i in idxs]
        src = src[:max_src_nsents]
        sent_labels = sent_labels[:max_src_nsents]

        if ((not is_test) and len(src) < min_src_nsents):
            return None

        src_txt = [' '.join(sent) for sent in src]
        text = ' {} {} '.format(self.sep_token, self.cls_token).join(src_txt)

        src_subtokens = self.tokenizer.tokenize(text)

        src_subtokens = [self.cls_token] + src_subtokens + [self.sep_token]
        src_subtoken_idxs = self.tokenizer.convert_tokens_to_ids(src_subtokens)
        _segs = [-1] + [i for i, t in enumerate(src_subtoken_idxs) if t == self.sep_vid]
        segs = [_segs[i] - _segs[i - 1] for i in range(1, len(_segs))]
        segments_ids = []
        for i, s in enumerate(segs):
            if (i % 2 == 0):
                segments_ids += s * [0]
            else:
                segments_ids += s * [1]
        cls_ids = [i for i, t in enumerate(src_subtoken_idxs) if t == self.cls_vid]
        sent_labels = sent_labels[:len(cls_ids)]

        tgt_subtokens_str = '[unused0] ' + ' [unused2] '.join(
            [' '.join(self.tokenizer.tokenize(' '.join(tt), use_bert_basic_tokenizer=use_bert_basic_tokenizer)) for tt in tgt]) + ' [unused1]'
        tgt_subtoken = tgt_subtokens_str.split()[:max_tgt_ntokens]
        if ((not is_test) and len(tgt_subtoken) < min_tgt_ntokens):
            return None

        tgt_subtoken_idxs = self.tokenizer.convert_tokens_to_ids(tgt_subtoken)

        tgt_txt = '<q>'.join([' '.join(tt) for tt in tgt])
        src_txt = [original_src_txt[i] for i in idxs]

        return src_subtoken_idxs, sent_labels, tgt_subtoken_idxs, segments_ids, cls_ids, src_txt, tgt_txt


In [10]:
def format_to_bert(args):
    datasets = ['train', 'test', 'dev']
    for corpus_type in datasets:
        a_lst = []
        for json_f in glob.glob(pjoin(args.raw_path, '*' + corpus_type + '.*.json')):
            real_name = json_f.split('/')[-1]
            a_lst.append((corpus_type, json_f, args, pjoin(args.save_path, real_name.replace('json', 'bert.pt'))))
        pool = Pool(args.n_cpus)
        for d in pool.imap(_format_to_bert, a_lst):
            pass

        pool.close()
        pool.join()

In [11]:
def _format_to_bert(params):
    corpus_type, json_file, args, save_file = params
    is_test = corpus_type == 'test'
    if (os.path.exists(save_file)):
        logger.info('Ignore %s' % save_file)
        return

    bert = BertData(args)

    logger.info('Processing %s' % json_file)
    jobs = json.load(open(json_file))
    datasets = []
    for d in jobs:
        source, tgt = d['src'], d['tgt']

        sent_labels = greedy_selection(source[:args.max_src_nsents], tgt, 3)
        if (args.lower):
            source = [' '.join(s).lower().split() for s in source]
            tgt = [' '.join(s).lower().split() for s in tgt]
        b_data = bert.preprocess(source, tgt, sent_labels, use_bert_basic_tokenizer=args.use_bert_basic_tokenizer,
                                 is_test=is_test)
        # b_data = bert.preprocess(source, tgt, sent_labels, use_bert_basic_tokenizer=args.use_bert_basic_tokenizer)

        if (b_data is None):
            continue
        src_subtoken_idxs, sent_labels, tgt_subtoken_idxs, segments_ids, cls_ids, src_txt, tgt_txt = b_data
        b_data_dict = {"src": src_subtoken_idxs, "tgt": tgt_subtoken_idxs,
                       "src_sent_labels": sent_labels, "segs": segments_ids, 'clss': cls_ids,
                       'src_txt': src_txt, "tgt_txt": tgt_txt}
        datasets.append(b_data_dict)
    logger.info('Processed instances %d' % len(datasets))
    logger.info('Saving to %s' % save_file)
    torch.save(datasets, save_file)
    datasets = []
    gc.collect()

In [12]:
def cal_rouge(evaluated_ngrams, reference_ngrams):
    reference_count = len(reference_ngrams)
    evaluated_count = len(evaluated_ngrams)

    overlapping_ngrams = evaluated_ngrams.intersection(reference_ngrams)
    overlapping_count = len(overlapping_ngrams)

    if evaluated_count == 0:
        precision = 0.0
    else:
        precision = overlapping_count / evaluated_count

    if reference_count == 0:
        recall = 0.0
    else:
        recall = overlapping_count / reference_count

    f1_score = 2.0 * ((precision * recall) / (precision + recall + 1e-8))
    return {"f": f1_score, "p": precision, "r": recall}

In [13]:
def greedyselect(doc_sent_list, sub_sent, summary_size):
    def _rouge_clean(s):
        return re.sub(r'[^a-zA-Z0-9 ]', '', s)

    max_rouge = 0.0
    subject = _rouge_clean(' '.join(sub_sent)).lower().split()
    sents = [_rouge_clean(' '.join(s)).lower().split() for s in doc_sent_list]
    evaluated_1grams = [_get_word_ngrams(1, [sent]) for sent in sents]
    reference_1grams = _get_word_ngrams(1, [subject])
    evaluated_2grams = [_get_word_ngrams(2, [sent]) for sent in sents]
    reference_2grams = _get_word_ngrams(2, [subject])

    selected = []
    for s in range(summary_size):
        cur_max_rouge = max_rouge
        cur_id = -1
        for i in range(len(sents)):
            if (i in selected):
                continue
            c = selected + [i]
            candidates_1 = [evaluated_1grams[idx] for idx in c]
            candidates_1 = set.union(*map(set, candidates_1))
            candidates_2 = [evaluated_2grams[idx] for idx in c]
            candidates_2 = set.union(*map(set, candidates_2))

            rouge_1 = cal_rouge(candidates_1, reference_1grams)['f']
            rouge_2 = cal_rouge(candidates_2, reference_2grams)['f']
            rouge_score = rouge_1 + rouge_2

            if rouge_score > cur_max_rouge:
                cur_max_rouge = rouge_score
                cur_id = i
        if (cur_id == -1):
            return selected
        selected.append(cur_id)
        max_rouge = cur_max_rouge
    return sorted(selected)

In [16]:
for corpus_type in ["dev","test","train"]:
    toke_file = "{}AESLC_{}_token_{}.txt".format(toke_path, corpus_type,"*")
    save_file = "{}AESLC.bert.{}.pt".format(save_bath, corpus_type)

    reader_doc = open("{}AESLC_{}_token_{}.txt".format(toke_path, corpus_type, "doc"))
    reader_sub = open("{}AESLC_{}_token_{}.txt".format(toke_path, corpus_type, "sub"))

    is_test = corpus_type == 'test'

    bert = BertData()
    datasets = []

    while True:
        doc_token = reader_doc.readline().replace("\n","")
        sub_token = reader_sub.readline().replace("\n","")
        if not doc_token:
            reader_doc.close()
            reader_sub.close()
            break

        doc = [tokens.split("<tokesep>") for tokens in doc_token.split("<sentsep>")]
        sub = sub_token.replace("<sentsep>","<tokesep>").split("<tokesep>")

        sent_labels = greedyselect(doc[:max_src_nsents], sub, 1)
        if (lower):
            doc = [' '.join(s).lower().split() for s in doc]
            sub = [' '.join(sub).lower().split()]

        b_data = bert.preprocess(doc, sub, sent_labels, use_bert_basic_tokenizer=use_bert_basic_tokenizer, is_test=is_test)

        if (b_data is None):
            continue
        src_subtoken_idxs, sent_labels, tgt_subtoken_idxs, segments_ids, cls_ids, src_txt, tgt_txt = b_data
        b_data_dict = {"src": src_subtoken_idxs, "tgt": tgt_subtoken_idxs,
                       "src_sent_labels": sent_labels, "segs": segments_ids, 'clss': cls_ids,
                       'src_txt': src_txt, "tgt_txt": tgt_txt}
        datasets.append(b_data_dict)

    logger.info('Processed instances %d' % len(datasets))
    logger.info('Saving to %s' % save_file)
    torch.save(datasets, save_file)
    datasets = []
    gc.collect()

    reader_doc.close()
    reader_sub.close()   

In [8]:
from pytorch_transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True, cache_dir='../temp')
symbols = {'BOS': tokenizer.vocab['[unused0]'], 'EOS': tokenizer.vocab['[unused1]'],
           'PAD': tokenizer.vocab['[PAD]'], 'EOQ': tokenizer.vocab['[unused2]']}

100%|███████████████████████████████████████████████████████████████████████| 231508/231508 [00:00<00:00, 332968.00B/s]


In [17]:
def decode(line):
    print([tokenizer.ids_to_tokens[int(n)] for n in line])

In [19]:
test = torch.load("random_tokened/bert/AESLC.bert.test.pt")
decode(test[0]['tgt'])
print(len(test))

FileNotFoundError: [Errno 2] No such file or directory: 'randim_tokened/bert/AESLC.bert.test.pt'

In [18]:
test = torch.load("../bertAeslc/bert_data/bert.test.pt")
decode(test[0]['tgt'])
print(len(test))

['[unused0]', 'h', 'u', 'n', 't', 'l', 'e', 'y', '[unused2]', '/', '[unused2]', 'q', 'u', 'e', 's', 't', 'i', 'o', 'n', '[unused1]']
1906
