# BERT Key Phrase Extractor

In this notebook we aim to realize the Bottom-Up Summarization Paper's extractor with BERT as the contextual embedding and see if we are able to extract phrases that maximizes the ROGUE scores. Our first goal in this project is to generate non-sensical summaries that maximizes the ROGUE score. Then, we aim to train an additional language model-like network to generate abstractive summaries. 

In [2]:
import nltk
from nltk.tokenize import sent_tokenize, word_tokenize

import os
import subprocess
import json
import pickle
from multiprocessing import Pool

import numpy as np
import torch
from pytorch_pretrained_bert import BertTokenizer, BertModel

from sklearn.model_selection import train_test_split

import re

from rouge import Rouge 

In [3]:
CNN_STORY_DIR = os.path.join('data', 'cnn', 'stories')
DM_STORY_DIR = os.path.join('data', 'dailymail', 'stories')

CNN_STORY_TOKENIZED = os.path.join('data', 'cnn', 'stories-tokenized')
DM_STORY_TOKENIZED = os.path.join('data', 'dailymail', 'stories-tokenized')

SRC_JSON = os.path.join('data', 'src.pk')
TGT_JSON = os.path.join('data', 'tgt.pk')

In [4]:
DOCS = os.path.join('data', 'docs.pk')
TAGS = os.path.join('data', 'tags.pk')
GOLD_SUMS = os.path.join('data', 'gold_sums.pk')
IDS = os.path.join('data', 'idx.pk')

## Preprocessing

We will first read in the files and process them into tokenized sentences and words, and separate out the source document and the abstract. Here, we heavily borrowed code from Pointer Generator code

In [None]:
dirs = [d for d in os.listdir(CNN_STORY_TOKENIZED)]

In [None]:
dm_single_close_quote = u'\u2019' # unicode
dm_double_close_quote = u'\u201d'
END_TOKENS = ['.', '!', '?', '...', "'", "`", '"', dm_single_close_quote, dm_double_close_quote, ")"] # acceptable ways to end a sentence
SENTENCE_START = '<s>'
SENTENCE_END = '</s>'
def process_json(filename):
    src, tgt = [], [] # a document is a list of list of words
    highlight = False # highlights are always at the end of the document 
    f = open(filename, 'r')
    parsed = json.load(f)
    for sent in parsed['sentences']:
        words = [word['word'] for word in sent['tokens']]
        if words[-1] not in END_TOKENS:
            words += ['.']
        if words[0] == '@highlight':
            highlight = True
        elif highlight:
            tgt += [words]
        else:
            src += [words]
    return src, tgt

src, tgt = process_json(os.path.join(CNN_STORY_TOKENIZED, dirs[0]))

In [None]:
def percentage_in_src_vocab(src, tgt):
    src_vocab = set()
    for sent in src:
        src_vocab |= set(sent)
    count = 0
    total_len = 0
    for sent in tgt:
        for word in sent:
            if word in src_vocab:
                count += 1
            total_len += 1
    return count / total_len

In [None]:
def process_all_json(file_dir):
    pool = Pool(processes=10)
    srcs, tgts = [], []
    percentages = []
    file_paths = [os.path.join(file_dir, file_name) for file_name in os.listdir(file_dir)]
    for tup in pool.imap_unordered(process_json, file_paths):
        src, tgt = tup
        srcs.append(src)
        tgts.append(tgt)
        percentages.append(percentage_in_src_vocab(src, tgt))
    print(np.mean(percentages))
    return srcs, tgts

In [None]:
srcs_cnn, tgts_cnn = process_all_json(CNN_STORY_TOKENIZED)

In [None]:
srcs_dm, tgts_dm = process_all_json(DM_STORY_TOKENIZED)

In [None]:
src, tgt = srcs_cnn + srcs_dm, tgts_cnn + tgts_dm

In [None]:
f = open(SRC_JSON, 'wb')
pickle.dump(src, f)
f.close()

f = open(TGT_JSON, 'wb')
pickle.dump(tgt, f)
f.close()

In [None]:
f = open(SRC_JSON, 'rb')
src = pickle.load(f)
f.close()

f = open(TGT_JSON, 'rb')
tgt = pickle.load(f)
f.close()

## Preprocess to BERT

To use BERT, we must format our data into one that BERT is able to use. We also have to redefine the problem as a sequence tagging problem presented in the Bottom-Up paper.

In [None]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', max_len=510)

In [None]:
def range_dist(tup1, tup2):
    """
    This function calculates the distance between 2 ranges.
    """
    start1, end1 = min(tup1), max(tup1)
    start2, end2 = min(tup2), max(tup2)
    if start1 < start2 < end1 or start1 < end2 < end1: # overlap
        return 0
    return min(abs(end1 - start2), abs(end2 - start1))

In [None]:
def tag(doc, tgt):
    """
    doc: a list of src tokens
    tgt: a list of tgt tokens that we will look for in the doc
    
    returns:
    decode_label: a list of size tgt (or less) denoting the positions of the tokens at each summ step
    """
    if len(tgt) == 0:
        print('zero sized tgt')
        return None

    decode_label = []
    l, r, last_range = 0, 0, (0, 0) # last step is the index into the src where we chose last
    while r < len(tgt):
        old_idxs = []
        idxs = [(i,i+1) for i, token in enumerate(doc) if token == tgt[r]]
        while len(idxs) > 0: # found a match
            r += 1
            old_idxs, idxs = idxs, []
            for start, end in old_idxs:
                if end < len(doc) and r < len(tgt) and doc[end] == tgt[r]:
                    idxs.append((start, end + 1))
        idx_to_look = old_idxs if len(idxs) == 0 else idxs
        if len(idx_to_look) == 0:
            r += 1
        else:
            best_i = min(range(len(idx_to_look)), key=lambda i: range_dist(last_range, idx_to_look[i]))
            last_range = idx_to_look[best_i]
            decode_label.extend(list(range(last_range[0], last_range[1])))
    return decode_label

In [35]:
def remove_bert_tokens(sent):
    return re.sub(r'( ##)|(\[CLS\] )|(\s*\[SEP\])','', sent)

In [None]:
def process_src_tgt(srcs, tgts, start_idx=0, end_idx=-1):
    assert len(srcs) == len(tgts)
    docs, tags, gold_sums_bert, ranges = [], [], [], []
    rn = range(len(srcs)) if end_idx == -1 else range(start_idx, end_idx)
    for i in rn:
        ## process src
        doc = '[CLS] ' + ' '.join(' '.join(sent) + ' [SEP]' for sent in srcs[i])
        doc = tokenizer.tokenize(doc)[:510]

        ## process tgt
        tgt = '[CLS] ' + ' '.join(' '.join(sent) + ' [SEP]' for sent in tgts[i])
#         tgt = ' '.join([' '.join(sent) for sent in tgts[i]]) ## Tried adding SEP and CLS, doesn't really work..
        tgt = tokenizer.tokenize(tgt)[:110]
        label = tag(doc, tgt)

        ## Add both to list
        docs.append(tokenizer.convert_tokens_to_ids(doc))
        tags.append(label)
        gold_sums_bert.append(remove_bert_tokens(' '.join(tgt)))
        ranges.append(i)
    return docs, tags, gold_sums_bert, ranges
docs, tags, gold_sums_bert, ranges = process_src_tgt(src, tgt, 9, 10)

In [None]:
gold_sums_bert[0]

In [None]:
' '.join(tokenizer.convert_ids_to_tokens([docs[0][i] for i in tags[0]]))

In [None]:
def process_ranges(args):
    return process_src_tgt(src, tgt, args[0], args[1])

In [None]:
n = 35
pool = Pool(n)
k = len(src)//n
result = pool.map(process_ranges, [(start * k, (start+1) * k) for start in range(n)])

In [None]:
def check_strictly_increasing(nested_sequence):
    counter = 0
    for sequence in nested_sequence:
        for i in sequence:
            if i != counter:
                return False
            counter += 1
    return True
check_strictly_increasing([tup[-1] for tup in result])

In [None]:
src, tgt = None, None

In [None]:
def clean(lst, valid_ids):
    return [lst[i] for i in valid_ids]

In [None]:
docs, tags, gold_sums_bert, ids = [], [], [], []
for a, b, c, d in result:
    valid_ids = [i for i in range(len(a)) if len(b[i]) > 0 and len(c[i]) > 0]
    docs.extend(clean(a, valid_ids))
    tags.extend(clean(b, valid_ids))
    gold_sums_bert.extend(clean(c, valid_ids))
    ids.extend(clean(d, valid_ids))

In [24]:
gold_sums_bert[2039]

'rafael nadal beats leonardo mayer in straight sets . andy murray locked in five set struggle when play halted . gael monfils wins epic five - setter against fabio fognini . sloane stephens to face simona halep in last 16 .'

In [25]:
' '.join(tokenizer.convert_ids_to_tokens([docs[2039][i] for i in tags[2039]]))

'[CLS] rafael nad ##al leonardo mayer in straight sets . [SEP] andy murray locked in five set when play halted . [SEP] gael mon ##fi ##ls five - set against fabio fog ##nin ##i . [SEP] sloane stephens to face simon ##a hale ##p in last . [SEP]'

In [28]:
for obj, fname in zip([docs, tags, gold_sums_bert, ids], 
                      [DOCS, TAGS, GOLD_SUMS,      IDS]):
    with open(fname, 'wb') as f:
        print("saving to %s" % fname)
        pickle.dump(obj, f)

saving to data/docs.pk
saving to data/tags.pk
saving to data/gold_sums.pk
saving to data/idx.pk


## Gold Rouge

The oracle rouge score can be calculated here

In [38]:
tagged_sums = [remove_bert_tokens(' '.join(tokenizer.convert_ids_to_tokens([docs[i][j] for j in tags[i]]))) \
               for i in range(len(docs))]

In [41]:
rouge = Rouge()
scores = rouge.get_scores(tagged_sums, gold_sums_bert, avg=True)
print(scores)

{'rouge-1': {'f': 0.881164808423798, 'p': 0.9928848305179249, 'r': 0.8004520423694316}, 'rouge-2': {'f': 0.7637302213937581, 'p': 0.8396110056078727, 'r': 0.7056012489332818}, 'rouge-l': {'f': 0.8552061121823076, 'p': 0.9928827635932947, 'r': 0.8004505631941881}}


{'rouge-1': {'f': 0.881164808423798, 'p': 0.9928848305179249, 'r': 0.8004520423694316}, 'rouge-2': {'f': 0.7637302213937581, 'p': 0.8396110056078727, 'r': 0.7056012489332818}, 'rouge-l': {'f': 0.8552061121823076, 'p': 0.9928827635932947, 'r': 0.8004505631941881}}

## Bert Model

We have calculated the "oracle" score above, and now we would like to fit a model that accurately predicts the tags defined above.

Later, we might change how the tags are defined and see if we can achieve better results than "first occurance tagging"

We will split 90/5/5 with a 5k tiny dataset selected from the train set for faster development

In [5]:
docs, tags, gold_sums_bert, ids = \
    [pickle.load(open(file_path, 'rb')) for file_path in [DOCS, TAGS, GOLD_SUMS, IDS]]

In [7]:
X_train, X_dev_test, y_train, y_dev_test, gold_train, gold_dev_test, ids_train, ids_dev_test = \
        train_test_split(docs, tags, gold_sums_bert, ids, test_size=0.1)
X_dev, X_test, y_dev, y_test, gold_dev, gold_test, ids_dev, ids_test = \
        train_test_split(X_dev_test, y_dev_test, gold_dev_test, ids_dev_test, test_size=0.5)
X_tiny, y_tiny, gold_tiny, ids_tiny = \
        X_train[:500], y_train[:500], gold_train[:500], ids_train[:500]
data = dict()
data['train'] = {'X': X_train, 'y': y_train, 'gold': gold_train, 'ids':ids_train}
data['dev'] = {'X': X_dev, 'y': y_dev, 'gold': gold_dev, 'ids':ids_dev}
data['test'] = {'X': X_test, 'y': y_test, 'gold': gold_test, 'ids':ids_test}
data['tiny'] = {'X': X_tiny, 'y': y_tiny, 'gold': gold_tiny, 'ids':ids_tiny}

In [8]:
PROCESSED_DATA = os.path.join('data', 'data.pk')

In [9]:
with open(PROCESSED_DATA, 'wb') as f:
    pickle.dump(data, f)

In [None]:
with open(PROCESSED_DATA, 'rb') as f:
    data = pickle.load(f)

In [12]:
X_tiny, y_tags_tiny, gold_tiny, ids_tiny = \
    data['tiny']['X'], data['tiny']['y'], data['tiny']['gold'], data['tiny']['ids'],
super_tiny = {'tiny':{'X':X_tiny[:10], 'y':y_tags_tiny[:10], 'gold':gold_tiny[:10],
        'ids':ids_tiny[:10]}}

In [13]:
SUPER_TINY = os.path.join('data', 'super_tiny.pk')
with open(SUPER_TINY, 'wb') as f:
    pickle.dump(super_tiny, f)

## GloVe (NOT USED)

Turns out BERT is too heavy weight and instead we would try to use GloVe + LSTM instead. We will first process the glove embeddings

In [None]:
# def tag(doc, tgt):
#     """
#     doc: a list of src tokens
#     tgt: a list of tgt tokens that we will look for in the doc
#     """
#     if len(tgt) == 0:
#         print('zero sized tgt')
#         return None
#     vocab = set(tgt)
#     doc = np.array(doc)
#     tgt = np.array(tgt)

#     label = np.zeros(len(doc), dtype=bool)
    
#     ## The following tags all tokens present in both the source and target
# #     for i in range(len(doc)):
# #         if doc[i] in vocab:
# #             label[i] = 1
#     ## The following does the max tagging thingy the original paper did
#     l, r = 0, 0
#     while r < len(tgt):
#         old_idxs = []
#         idxs = [(i,i+1) for i, token in enumerate(doc) if token == tgt[r]]
#         while len(idxs) > 0 and r + 1 < len(tgt):
#             r += 1
#             old_idxs, idxs = idxs, []
#             for idx in old_idxs:
#                 if idx[-1] < len(doc) and doc[idx[-1]] == tgt[r]:
#                     idxs.append((idx[0], idx[-1] + 1))
#         if len(idxs) > 0: ## we ran out of tgt
#             label[idxs[0][0]:idxs[0][-1]] = 1
#             break
#         elif len(old_idxs) > 0: ## we found longest seq
#             label[old_idxs[0][0]:old_idxs[0][-1]] = 1
#         else: ## this token does not exist
#             r += 1
#     idxs = []
    
#     return label, idxs

In [None]:
src_vocab = {word for doc in src for sent in doc for word in sent}

In [None]:
len(src_vocab)

In [None]:
DATA_BASE = 'data-decode'
GLOVE_HOME = os.path.join(DATA_BASE, 'glove.840B.300d.txt')
def glove2dict(src_filename, model_vocab):
    """GloVe Reader.
    Parameters
    ----------
    src_filename : str
        Full path to the GloVe file to be processed.
    Returns
    -------
    dict
        Mapping words to their GloVe vectors.
    """
    data = {}
    original_len, src_only_len = 0, 0
    with open(src_filename, 'r', newline="") as f:
        while True:
            try:
                line = next(f)
                line = line.strip().split()
                line[0] = line[0].lower()
                if line[0] in src_vocab:
                    data[line[0]] = np.array(line[1: ], dtype=np.float)
                    src_only_len += 1
                original_len += 1
            except StopIteration:
                break
            except UnicodeDecodeError:
                pass
            except:
                pass
    return data
glove = glove2dict(GLOVE_HOME, src_vocab)

In [None]:
len(glove)

In [None]:
SENTENCE_START = '<s>'
SENTENCE_END = '</s>'
def build_vocab(srcs):
    vocab = {SENTENCE_START, SENTENCE_END}
    for src in srcs:
        # src is a list of list of words
        for sent in src:
            vocab |= set([word.lower() for word in sent])
    return {word:i for i, word in enumerate(vocab)}
vocab = build_vocab(src)

In [None]:
len(vocab)

In [None]:
len(glove)

In [None]:
glove_vocab = {word:i for i, word in enumerate(glove.keys())}

In [None]:
len(vocab) - len(glove_vocab)

In [None]:
not_in_vocab = set()
for t in tgt:
    for sent in t:
        for word in sent:
            word = word.lower()
            if word not in vocab:
                not_in_vocab.add(word)

In [None]:
reverse_vocab = {i:word for word, i in vocab.items()}

In [None]:
' '.join([reverse_vocab[i] for i in gold_sum_idxs[0]])