# BERT Key Phrase Extractor

In this notebook we aim to realize the Bottom-Up Summarization Paper's extractor with BERT as the contextual embedding and see if we are able to extract phrases that maximizes the ROGUE scores. Our first goal in this project is to generate non-sensical summaries that maximizes the ROGUE score. Then, we aim to train an additional language model-like network to generate abstractive summaries. 

In [1]:
import nltk
from nltk.tokenize import sent_tokenize, word_tokenize

import os
import subprocess
import json
import pickle
from multiprocessing import Pool

import numpy as np
import torch
from pytorch_pretrained_bert import BertTokenizer, BertModel

from sklearn.model_selection import train_test_split

from rouge import Rouge 

In [4]:
CNN_STORY_DIR = os.path.join('data', 'cnn', 'stories')
DM_STORY_DIR = os.path.join('data', 'dailymail', 'stories')

CNN_STORY_TOKENIZED = os.path.join('data', 'cnn', 'stories-tokenized')
DM_STORY_TOKENIZED = os.path.join('data', 'dailymail', 'stories-tokenized')

SRC_JSON = os.path.join('data', 'src.pk')
TGT_JSON = os.path.join('data', 'tgt.pk')

In [5]:
DOCS = os.path.join('data', 'docs.pk')
TAGS = os.path.join('data', 'tags.pk')
TAGGED_SUMS = os.path.join('data', 'tagged_sums.pk')
GOLD_SUMS = os.path.join('data', 'gold_sums.pk')
IDX_TAGS = os.path.join('data', 'idx_tags.pk')
IDS = os.path.join('data', 'idx.pk')

## Preprocessing

We will first read in the files and process them into tokenized sentences and words, and separate out the source document and the abstract. Here, we heavily borrowed code from Pointer Generator code

In [None]:
dirs = [d for d in os.listdir(CNN_STORY_TOKENIZED)]

In [None]:
dm_single_close_quote = u'\u2019' # unicode
dm_double_close_quote = u'\u201d'
END_TOKENS = ['.', '!', '?', '...', "'", "`", '"', dm_single_close_quote, dm_double_close_quote, ")"] # acceptable ways to end a sentence
SENTENCE_START = '<s>'
SENTENCE_END = '</s>'
def process_json(filename):
    src, tgt = [], [] # a document is a list of list of words
    highlight = False # highlights are always at the end of the document 
    f = open(filename, 'r')
    parsed = json.load(f)
    for sent in parsed['sentences']:
        words = [word['word'] for word in sent['tokens']]
        if words[-1] not in END_TOKENS:
            words += ['.']
        if words[0] == '@highlight':
            highlight = True
        elif highlight:
            tgt += [words]
        else:
            src += [words]
    return src, tgt

src, tgt = process_json(os.path.join(CNN_STORY_TOKENIZED, dirs[0]))

In [None]:
def percentage_in_src_vocab(src, tgt):
    src_vocab = set()
    for sent in src:
        src_vocab |= set(sent)
    count = 0
    total_len = 0
    for sent in tgt:
        for word in sent:
            if word in src_vocab:
                count += 1
            total_len += 1
    return count / total_len

In [None]:
def process_all_json(file_dir):
    pool = Pool(processes=10)
    srcs, tgts = [], []
    percentages = []
    file_paths = [os.path.join(file_dir, file_name) for file_name in os.listdir(file_dir)]
    for tup in pool.imap_unordered(process_json, file_paths):
        src, tgt = tup
        srcs.append(src)
        tgts.append(tgt)
        percentages.append(percentage_in_src_vocab(src, tgt))
    print(np.mean(percentages))
    return srcs, tgts

In [None]:
srcs_cnn, tgts_cnn = process_all_json(CNN_STORY_TOKENIZED)

In [None]:
srcs_dm, tgts_dm = process_all_json(DM_STORY_TOKENIZED)

In [None]:
src, tgt = srcs_cnn + srcs_dm, tgts_cnn + tgts_dm

In [None]:
f = open(SRC_JSON, 'wb')
pickle.dump(src, f)
f.close()

f = open(TGT_JSON, 'wb')
pickle.dump(tgt, f)
f.close()

In [4]:
f = open(SRC_JSON, 'rb')
src = pickle.load(f)
f.close()

f = open(TGT_JSON, 'rb')
tgt = pickle.load(f)
f.close()

## Preprocess to BERT

To use BERT, we must format our data into one that BERT is able to use. We also have to redefine the problem as a sequence tagging problem presented in the Bottom-Up paper.

In [5]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', max_len=510)

In [6]:
def tag(doc, tgt):
    """
    doc: a list of src tokens
    tgt: a list of tgt tokens that we will look for in the doc
    """
    if len(tgt) == 0:
        print('zero sized tgt')
        return None
    vocab = set(tgt)
    doc = np.array(doc)
    tgt = np.array(tgt)

    label = np.zeros(len(doc), dtype=bool)
    ## The following tags all tokens present in both the source and target
#     for i in range(len(doc)):
#         if doc[i] in vocab:
#             label[i] = 1
    ## The following does the max tagging thingy the original paper did
    l, r = 0, 0
    while r < len(tgt):
        old_idxs = []
        idxs = [(i,i+1) for i, token in enumerate(doc) if token == tgt[r]]
        while len(idxs) > 0 and r + 1 < len(tgt):
            r += 1
            old_idxs, idxs = idxs, []
            for idx in old_idxs:
                if idx[-1] < len(doc) and doc[idx[-1]] == tgt[r]:
                    idxs.append((idx[0], idx[-1] + 1))
        if len(idxs) > 0: ## we ran out of tgt
            label[idxs[0][0]:idxs[0][-1]] = 1
            break
        elif len(old_idxs) > 0: ## we found longest seq
            label[old_idxs[0][0]:old_idxs[0][-1]] = 1
        else: ## this token does not exist
            r += 1
    idxs = []
    for i in range(len(tgt)):
        idxs.append(list(np.argwhere(doc == tgt[i]).flatten()))
    return label, idxs

In [7]:
def process_src_tgt(srcs, tgts, start_idx=0, end_idx=-1):
    assert len(srcs) == len(tgts)
    docs, tags = [], []
    tagged_sum, gold_sum_bert, gold_sum_idxs = [], [], []
    ranges = []
    rn = range(len(srcs)) if end_idx == -1 else range(start_idx, end_idx)
    for i in rn:
        ## process src
        sents = [' '.join(sent) + ' [SEP]' for sent in srcs[i]]
        doc = ' '.join(['[CLS]'] + sents)
        doc = tokenizer.tokenize(doc)[:510]

        ## process tgt
        tgt = ' '.join([' '.join(sent) for sent in tgts[i]])
        tgt = tokenizer.tokenize(tgt)[:110]
        label, idxs = tag(doc, tgt)
        
        ## generate tagged_summary for oracle rouge
        tagged = []
        for idx in idxs:
            doc = np.array(doc)
            if len(doc[idx]) > 0:
                tagged.append(doc[idx][0])

        ## Add both to list
        docs.append(tokenizer.convert_tokens_to_ids(doc))
        tags.append(label)
        tagged_sum.append((' '.join(tagged)).replace(' ##', ''))
        gold_sum_bert.append(' '.join(tgt).replace(' ##', ''))
        gold_sum_idxs.append(idxs)
        ranges.append(i)
    return docs, tags, tagged_sum, gold_sum_bert, np.array(gold_sum_idxs), ranges
docs, tags, tagged_sum, gold_sum_bert, gold_sum_idxs, ranges = process_src_tgt(src, tgt, 9, 10)

In [8]:
gold_sum_bert

['sami al - hajj arrives home in sudan and is taken to hospital , network says . pakistani intelligence officers captured him in afghanistan in december 2001 . al - hajj was transferred to u . s . custody and held without charges or trial . al - jazeera said he was on an assignment when he was apprehended .']

In [9]:
tagged_sum

['sami al - hajj home in sudan and is taken to hospital , network . pakistani intelligence officers captured him in afghanistan in december 2001 . al - hajj was to u . s . and held without or trial . al - jazeera said he was on an assignment he was .']

In [10]:
def process_ranges(args):
    return process_src_tgt(src, tgt, args[0], args[1])

In [11]:
n = 35
pool = Pool(n)
k = len(src)//n
result = pool.map(process_ranges, [(start * k, (start+1) * k) for start in range(n)])

In [12]:
def check_strictly_increasing(nested_sequence):
    counter = 0
    for sequence in nested_sequence:
        for i in sequence:
            if i != counter:
                return False
            counter += 1
    return True
check_strictly_increasing([tup[-1] for tup in result])

True

In [13]:
src, tgt = None, None

In [14]:
def clean(lst, valid_ids):
    return [lst[i] for i in valid_ids]

In [15]:
docs, tags, tagged_sums, gold_sums_bert, gold_sums_idxs, ids = [], [], [], [], [], []
for a, b, c, d, e, f in result:
    valid_ids = [i for i in range(len(a)) if len(c[i]) > 0 and len(d[i]) > 0]
    docs.extend(clean(a, valid_ids))
    tags.extend(clean(b, valid_ids))
    tagged_sums.extend(clean(c, valid_ids))
    gold_sums_bert.extend(clean(d, valid_ids))
    gold_sums_idxs.extend(clean(e, valid_ids))
    ids.extend(clean(f, valid_ids))

In [16]:
for obj, fname in zip([docs, tags, tagged_sums, gold_sums_bert, gold_sums_idxs, ids], 
                      [DOCS, TAGS, TAGGED_SUMS, GOLD_SUMS, IDX_TAGS, IDS]):
    with open(fname, 'wb') as f:
        pickle.dump(obj, f)

In [17]:
rouge = Rouge()
scores = rouge.get_scores(tagged_sums, gold_sums_bert, avg=True)
print(scores)

{'rouge-1': {'f': 0.8811618935491741, 'p': 0.9932318864311466, 'r': 0.8002525427630327}, 'rouge-2': {'f': 0.7636215390100871, 'p': 0.8396691890160917, 'r': 0.7053841629669394}, 'rouge-l': {'f': 0.8550852296513152, 'p': 0.993229688715533, 'r': 0.8002509720851069}}


## Bert Model

We have calculated the "oracle" score above, and now we would like to fit a model that accurately predicts the tags defined above.

Later, we might change how the tags are defined and see if we can achieve better results than "first occurance tagging"

We will split 90/5/5 with a 5k tiny dataset selected from the train set for faster development

In [None]:
docs, tags, tagged_sums, gold_sums_bert, gold_sums_idxs, ids = \
    [pickle.load(open(file_path, 'rb')) for file_path in [DOCS, TAGS, TAGGED_SUMS, GOLD_SUMS, IDX_TAGS, IDS]]

In [None]:
scores = {'rouge-1': {'f': 0.8811621861531733, 'p': 0.9932318101053599, 'r': 0.8002531077407594}, 'rouge-2': {'f': 0.763622203917866, 'p': 0.839669677462507, 'r': 0.7053849661328021}, 'rouge-l': {'f': 0.8550856135961422, 'p': 0.9932296123897465, 'r': 0.8002515370628336}}

In [19]:
X_train, X_dev_test, y_tags_train, y_tags_dev_test, y_decode_train, y_decode_dev_test, ids_train, ids_dev_test = \
        train_test_split(docs, tags, gold_sums_idxs, ids, test_size=0.1)
X_dev, X_test, y_tags_dev, y_tags_test, y_decode_dev, y_decode_test, ids_dev, ids_test =\
        train_test_split(X_dev_test, y_tags_dev_test, y_decode_dev_test, ids_dev_test, test_size=0.5)
X_tiny, y_tags_tiny, y_decode_tiny, ids_tiny = \
        X_train[:5000], y_tags_train[:5000], y_decode_train[:5000], ids_train[:5000]
processed = dict()
processed['train'] = {'X':X_train, 'y_tag':y_tags_train, 'y_decode':y_decode_train,
        'ids':ids_train}
processed['dev'] = {'X':X_dev, 'y_tag':y_tags_dev, 'y_decode':y_decode_dev,
        'ids':ids_dev}
processed['test'] = {'X':X_test, 'y_tag':y_tags_test, 'y_decode':y_decode_test,
        'ids':ids_test}
processed['tiny'] = {'X':X_tiny, 'y_tag':y_tags_tiny, 'y_decode':y_decode_tiny,
        'ids':ids_tiny}

In [3]:
PROCESSED_DATA = os.path.join('data', 'data.pk')

In [20]:
with open(PROCESSED_DATA, 'wb') as f:
    pickle.dump(processed, f)

In [4]:
with open(PROCESSED_DATA, 'rb') as f:
    data = pickle.load(f)

In [5]:
X_tiny, y_tags_tiny, y_decode_tiny, ids_tiny = \
    data['tiny']['X'], data['tiny']['y_tag'], data['tiny']['y_decode'], data['tiny']['ids'],
super_tiny = {'tiny':{'X':X_tiny[:10], 'y_tag':y_tags_tiny[:10], 'y_decode':y_decode_tiny[:10],
        'ids':ids_tiny[:10]}}

In [6]:
SUPER_TINY = os.path.join('data', 'super_tiny.pk')
with open(SUPER_TINY, 'wb') as f:
    pickle.dump(super_tiny, f)

In [8]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', max_len=510)

In [12]:
' '.join(tokenizer.convert_ids_to_tokens(super_tiny['tiny']['X'][0])).replace(' ##', '')

"[CLS] a grieving dental nurse who piled on the pounds grieving the loss of her best friend has shed more than four stone to turn herself into a muscle - bound beauty queen . [SEP] sarah jayne maher , 26 , from denton , greater manchester , ballooned to 13 stone 12 lbs and suffered severe depression after the loss of her best friend simone hill in a car crash in january 2011 . [SEP] but after seeing a picture of herself looking overweight on a night out with friends , the 5ft 3in blonde shed the pounds and bulked up into a body - building beauty queen . [SEP] sarah jane maher piled on the pounds after the death of her best friend , but has since gone on to lose 4st and now competes as a body building beauty queen - lrb - right - rrb - . [SEP] and she is set to compete as miss manchester in the miss galaxy uk pageant on february 8 . [SEP] sarah was left severely depressed after the death of simone , who was tragically killed in a car crash in reddish , greater manchester , while travell

In [23]:
from rouge import Rouge

In [29]:
rouge = Rouge()
rouge.get_scores(['.', 'world dafsd'], [' adfasdf', 'world fdaskf'], avg=True)

ValueError: Collections must contain at least 1 sentence.

In [16]:
tagged = []
for idx in super_tiny['tiny']['y_decode'][0]:
    doc = np.array(super_tiny['tiny']['X'][0])
    if len(doc[idx]) > 0:
        tagged.append(doc[idx][0])
' '.join(tokenizer.convert_ids_to_tokens(tagged)).replace(' ##', '')

'sarah jane maher was after her friend died in a car crash . the 26 - - a , on for . after meeting body , sarah was determined to lose weight . she has now four stone and competes in beauty pageants .'

In [17]:
doc = np.array(super_tiny['tiny']['X'][0])
tagged = doc[np.array(super_tiny['tiny']['y_tag'][0])]
' '.join(tokenizer.convert_ids_to_tokens(tagged)).replace(' ##', '')

'a on the her friend four stone - beauty . , 26 and after in a car crash body sarah jane maher now competes pageant was for meeting weight sarah was determined to lose she has dieds'

## Generative model (Unused)

To train a generative model, we need to process the Glove embeddings.

NO MORE GLOVE!!

In [None]:
SENTENCE_START = '<s>'
SENTENCE_END = '</s>'
def build_vocab(srcs):
    vocab = {SENTENCE_START, SENTENCE_END}
    for src in srcs:
        # src is a list of list of words
        for sent in src:
            vocab |= set([word.lower() for word in sent])
    return {word:i for i, word in enumerate(vocab)}
vocab = build_vocab(src)

In [None]:
len(vocab)

In [None]:
GLOVE_HOME = os.path.join('data', 'glove.840B.300d.txt')
def glove2dict(src_filename, model_vocab):
    """GloVe Reader.
    Parameters
    ----------
    src_filename : str
        Full path to the GloVe file to be processed.
    Returns
    -------
    dict
        Mapping words to their GloVe vectors.
    """
    data = {}
    with open(src_filename, 'r', newline="") as f:
        while True:
            try:
                line = next(f)
                line = line.strip().split()
                line[0] = line[0].lower()
                if line[0] in model_vocab:
                    data[line[0]] = np.array(line[1: ], dtype=np.float)
            except StopIteration:
                break
            except UnicodeDecodeError:
                pass
            except:
                pass
    return data
glove = glove2dict(GLOVE_HOME, vocab)

In [None]:
len(glove)

In [None]:
glove_vocab = {word:i for i, word in enumerate(glove.keys())}

In [None]:
len(vocab) - len(glove_vocab)

In [None]:
not_in_vocab = set()
for t in tgt:
    for sent in t:
        for word in sent:
            word = word.lower()
            if word not in vocab:
                not_in_vocab.add(word)

In [None]:
reverse_vocab = {i:word for word, i in vocab.items()}

In [None]:
' '.join([reverse_vocab[i] for i in gold_sum_idxs[0]])