# BERT Key Phrase Extractor

In this notebook we aim to realize the Bottom-Up Summarization Paper's extractor with BERT as the contextual embedding and see if we are able to extract phrases that maximizes the ROGUE scores. Our first goal in this project is to generate non-sensical summaries that maximizes the ROGUE score. Then, we aim to train an additional language model-like network to generate abstractive summaries. 

In [1]:
import nltk
from nltk.tokenize import sent_tokenize, word_tokenize

import os
import subprocess
import json
import pickle
from multiprocessing import Pool

import numpy as np
import torch
from pytorch_pretrained_bert import BertTokenizer, BertModel

from rouge import Rouge 

In [2]:
CNN_STORY_DIR = os.path.join('data', 'cnn', 'stories')
DM_STORY_DIR = os.path.join('data', 'dailymail', 'stories')

CNN_STORY_TOKENIZED = os.path.join('data', 'cnn', 'stories-tokenized')
DM_STORY_TOKENIZED = os.path.join('data', 'dailymail', 'stories-tokenized')

SRC_JSON = os.path.join('data', 'src.pk')
TGT_JSON = os.path.join('data', 'tgt.pk')

PREDICTED_SUMS = os.path.join('out', 'predicted')
GOLD_SUMS = os.path.join('out', 'gold')

## Preprocessing

We will first read in the files and process them into tokenized sentences and words, and separate out the source document and the abstract. Here, we heavily borrowed code from Pointer Generator code

In [3]:
dirs = [d for d in os.listdir(CNN_STORY_TOKENIZED)]

In [4]:
dm_single_close_quote = u'\u2019' # unicode
dm_double_close_quote = u'\u201d'
END_TOKENS = ['.', '!', '?', '...', "'", "`", '"', dm_single_close_quote, dm_double_close_quote, ")"] # acceptable ways to end a sentence
SENTENCE_START = '<s>'
SENTENCE_END = '</s>'
def process_json(filename):
    src, tgt = [], [] # a document is a list of list of words
    highlight = False # highlights are always at the end of the document 
    f = open(filename, 'r')
    parsed = json.load(f)
    for sent in parsed['sentences']:
        words = [word['word'] for word in sent['tokens']]
        if words[-1] not in END_TOKENS:
            words += ['.']
        if words[0] == '@highlight':
            highlight = True
        elif highlight:
            tgt += [words]
        else:
            src += [words]
    return src, tgt

src, tgt = process_json(os.path.join(CNN_STORY_TOKENIZED, dirs[0]))

In [5]:
def percentage_in_src_vocab(src, tgt):
    src_vocab = set()
    for sent in src:
        src_vocab |= set(sent)
    count = 0
    total_len = 0
    for sent in tgt:
        for word in sent:
            if word in src_vocab:
                count += 1
            total_len += 1
    return count / total_len
percentage_in_src_vocab(src, tgt)

0.7301587301587301

In [7]:
def process_all_json(file_dir):
    pool = Pool(processes=10)
    srcs, tgts = [], []
    percentages = []
    file_paths = [os.path.join(file_dir, file_name) for file_name in os.listdir(file_dir)]
    for tup in pool.imap_unordered(process_json, file_paths):
        src, tgt = tup
        srcs.append(src)
        tgts.append(tgt)
        percentages.append(percentage_in_src_vocab(src, tgt))
    print(np.mean(percentages))
    return srcs, tgts

In [8]:
srcs_cnn, tgts_cnn = process_all_json(CNN_STORY_TOKENIZED)

0.8373284430595709


In [9]:
srcs_dm, tgts_dm = process_all_json(DM_STORY_TOKENIZED)

0.8751616451112775


In [10]:
src, tgt = srcs_cnn + srcs_dm, tgts_cnn + tgts_dm

In [14]:
f = open(SRC_JSON, 'wb')
pickle.dump(src, f)
f.close()

f = open(TGT_JSON, 'wb')
pickle.dump(tgt, f)
f.close()

In [3]:
f = open(SRC_JSON, 'rb')
src = pickle.load(f)
f.close()

f = open(TGT_JSON, 'rb')
tgt = pickle.load(f)
f.close()


To use BERT, we must format our data into one that BERT is able to use. We also have to redefine the problem as a sequence tagging problem presented in the Bottom-Up paper.

In [4]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [5]:
def tag(doc, tgt):
    """
    doc: a list of src tokens
    tgt: a list of tgt tokens that we will look for in the doc
    """
    if len(tgt) == 0:
        print('zero sized tgt')
        return None
    label = np.zeros(len(doc), dtype=int)
    l, r = 0, 0
    while r < len(tgt):
        old_idxs = []
        idxs = [(i,i+1) for i, token in enumerate(doc) if token == tgt[r]]
        while len(idxs) > 0 and r + 1 < len(tgt):
            r += 1
            old_idxs, idxs = idxs, []
            for idx in old_idxs:
                if doc[idx[-1]] == tgt[r]:
                    idxs.append((idx[0], idx[-1] + 1))
        if len(idxs) > 0: ## we ran out of tgt
            label[idxs[0][0]:idxs[0][-1]] = 1
            break
        elif len(old_idxs) > 0: ## we found longest seq
            label[old_idxs[0][0]:old_idxs[0][-1]] = 1
        else: ## this token does not exist
            r += 1
    return label

In [6]:
def process_src_tgt(srcs, tgts, start_idx=0, end_idx=-1):
    assert len(srcs) == len(tgts)
    docs, tags = [], []
    tagged_sum, gold_sum = [], []
    rn = range(len(srcs)) if end_idx == -1 else range(start_idx, end_idx)
    for i in rn:
        ## process src
        sents = [' '.join(sent) + ' [SEP]' for sent in srcs[i]]
        doc = ' '.join(['[CLS]'] + sents)
        doc = tokenizer.tokenize(doc)[:510] + ['[SEP]']
        
        ## process tgt
        tgt = ' '.join([' '.join(sent) for sent in tgts[i]])
        tgt = tokenizer.tokenize(tgt)[:110]
        label = tag(doc, tgt)
        
        ## Add both to list
        docs.append(tokenizer.convert_tokens_to_ids(doc))
        tags.append(label)
        tagged_sum.append((' '.join(np.array(doc)[label.astype(bool)])).replace(' ##', ''))
        gold_sum.append(' '.join(tgt).replace(' ##', ''))
    return docs, tags, tagged_sum, gold_sum
docs, tags, tagged_sums, gold_sums = process_src_tgt(src, tgt, 1000, 1001)

In [25]:
def process_ranges(args):
    return process_src_tgt(src, tgt, args[0], args[1])

In [None]:
n = 40
pool = Pool(n)
k = len(src)//n
result = pool.map(process_ranges, [(start * k, (start+1) * k) for start in range(n)])

In [30]:
docs, tags, tagged_sums, gold_sums = [], [], [], []
for a, b, c, d in result:
    docs.extend(a)
    tags.extend(b)
    tagged_sums.extend(c)
    gold_sums.extend(d)

In [51]:
docs_cleaned, tags_cleaned, tagged_sums_cleaned, gold_sums_cleaned = [], [], [], []
for i in range(len(docs)):
    if len(tagged_sums[i]) > 0 and len(gold_sums[i]) > 0:
        docs_cleaned.append(docs[i])
        tags_cleaned.append(tags[i])
        tagged_sums_cleaned.append(tagged_sums[i])
        gold_sums_cleaned.append(gold_sums[i])
docs, tags, tagged_sums, gold_sums = docs_cleaned, tags_cleaned, tagged_sums_cleaned, gold_sums_cleaned

In [52]:
DOCS = os.path.join('data', 'docs.pk')
TAGS = os.path.join('data', 'tags.pk')
TAGGED_SUMS = os.path.join('data', 'tagged_sums.pk')
GOLD_SUMS = os.path.join('data', 'gold_sums.pk')

In [53]:
for file_path, obj in zip([DOCS, TAGS, TAGGED_SUMS, GOLD_SUMS], [docs, tags, tagged_sums, gold_sums]):
    f = open(file_path, 'wb')
    pickle.dump(obj, f)


In [48]:
tagged_sums[1]

", for . jonas cuaron son of mexican director alfonso cuaron the a feature film and is a montage of stills have , ` ` ano una ' ' ' s father"

In [49]:
gold_sums[1]

"jonas cuaron ' s debut feature , ` ` ano una ' ' is a montage of photographic stills . jonas is the son of mexican new wave director , alfonso cuaron . father and son have launched a short film competition for young filmmakers ."

In [54]:
rouge = Rouge()
scores = rouge.get_scores(tagged_sums, gold_sums, avg=True)
print(scores)

{'rouge-1': {'f': 0.8792832549923678, 'p': 0.9920658394023293, 'r': 0.798113827257415}, 'rouge-2': {'f': 0.5025017957813821, 'p': 0.5673329261939167, 'r': 0.4560521901503767}, 'rouge-l': {'f': 0.7111830800695053, 'p': 0.82231380861409, 'r': 0.6666003987211477}}


## Bert Model

We have calculated the "oracle" score above, and now we would like to fit a model that accurately predicts the tags defined above.

Later, we might change how the tags are defined and see if we can achieve better results than "first occurance tagging"

We will split 90/5/5 with a 5k tiny dataset selected from the train set for faster development

In [45]:
model = ertModel.from_pretrained('bert-base-uncased')

100%|██████████| 407873900/407873900 [00:38<00:00, 10586104.93B/s]


In [53]:
model.eval()
tokens_tensor = torch.tensor([doc])
with torch.no_grad():
    encoded_layers, _ = model(tokens_tensor[0], output_all_encoded_layers=False)

In [55]:
encoded_layers[0]

tensor([[ 0.0216,  0.2807, -0.0808,  ...,  0.0314,  0.6625,  0.6284],
        [ 0.9894,  0.4365,  0.1380,  ...,  0.8988,  1.3305,  0.3496],
        [ 0.2269, -0.3134,  0.4247,  ..., -0.3216,  0.7174,  0.7578],
        ...,
        [ 0.5443, -0.2824, -0.1584,  ..., -0.1154,  0.3110,  0.3094],
        [ 0.6949,  0.4653, -0.3678,  ...,  0.0592, -0.4090, -0.4708],
        [ 0.4271,  0.3990,  0.0077,  ...,  0.2882, -0.3808, -0.4419]])

In [None]:
processed_srcs = process_src(src)

In [None]:
processed_srcs[0]