Imports

In [1]:
import pandas as pd
import numpy as np
import torch, transformers
import nltk
import sklearn
import gzip, json

## Data pipeline

Specify paths

In [2]:
# Modify this to wherever you locally downloaded the data
data_base_path = './data/newsroom-release/release/'

train_path = data_base_path + 'train.jsonl.gz'
validation_path = data_base_path + 'dev.jsonl.gz'
test_path = data_base_path + 'dev.jsonl.gz'

In [3]:
# Calculate length
with gzip.open(train_path, "rb") as f:
    for i, x in enumerate(f):
        pass
dset_len = i

Create dataset

In [4]:
class NewsroomDataset(torch.utils.data.Dataset):
    '''
    Attributes:
        batch_size: Batch size to be taken on single getitem
        file: path to the dataset file
        category: category of the data summarization. i.e. 'extractive'
    '''
    def __init__(self, path, category: str, dataset_len: int):
        self.file = gzip.open(path, "rt")
        self.category = category
        self.length = dataset_len
        
    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        return self.load_data()

    
    def load_data(self):
        # Find sample that is in our category
        sample = json.loads(self.file.readline())
        while sample['density_bin'] != self.category:
            sample = json.loads(self.file.readline())

        return sample

Dataset is too big to load to memory - create minibatches and parallelize loading with DataLoader

In [5]:
train_dset = NewsroomDataset(train_path, "extractive", dataset_len=dset_len)
trainloader = torch.utils.data.DataLoader(train_dset, batch_size=2)    

Get batch for visualization

## Tokenization

Import the dependencies and initialize tokenizer and model

In [6]:
from transformers import BertTokenizer, BertModel
from nltk import tokenize
from sklearn.cluster import KMeans
from operator import itemgetter
nltk.download('punkt')

[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\teemu\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [7]:
bert_model = 'bert-base-cased'
bert_tokenizer = BertTokenizer.from_pretrained(bert_model)
bert_base_model = BertModel.from_pretrained(bert_model)

Split text to sentences, and tokenize sentences

In [8]:
def tokenize_sentences(sentences):
    """Tokenizes and preprocesses to sentences"""
    sentences = [tokenize.sent_tokenize(s) for s in sentences][0] # Split to sentences
    sentences = [x for x in sentences if len(x) > 2] # Remove too short sentences
    return sentences

## Test metrics

Import dependencies

In [9]:
from rouge import Rouge

Create prediction loop, test rouge score for each batch

In [27]:
rouge = Rouge()
scores = []

for i, batch in enumerate(trainloader):
    ### TODO ###
    # currently we only get single sample from batch, we need better indxing from the batch
    
    # Get text and summary from batch
    text, summary = itemgetter('text', 'summary')(batch)
    # Preprocess
    summary_tokens = tokenize_sentences(summary)
    text_tokens = tokenize_sentences(text)
    text_embeddings = bert_tokenizer(text_tokens, return_tensors='pt', padding='longest')['input_ids']
    
    # Create embeddings
    print(f"text_embeddings: {text_embeddings.shape}")
    model_out = bert_base_model(text_embeddings)
    embeddings = model_out.pooler_output.detach().numpy()
    
    print(f"embeddings: {embeddings}")
    # Do K-means clustering
    k = len(summary_tokens)
    kmeans = KMeans(n_clusters=k).fit(embeddings)
    # Find corresponding summary sentences
    centroids = kmeans.cluster_centers_
    text_summaries = best_n_summaries(centroids=centroids, 
                                      pooler_outputs=embeddings, 
                                      model_outputs=np.array(text_tokens))
    
    # Evaluation
    print(f"text_summaries: {text_summaries}")
    score = rouge.get_scores(text_summaries, summary_tokens)
    scores.append(score)
    break

{'url': ['http://www.cnbc.com/2013/12/09/hard-to-say-if-airline-consolidation-will-pause-american-airlines-ceo.html', 'http://www.cnbc.com/2014/09/10/beware-wild-cards-to-watch-for-in-2015.html'], 'archive': ['http://web.archive.org/web/20151102042147id_/http://www.cnbc.com/2013/12/09/hard-to-say-if-airline-consolidation-will-pause-american-airlines-ceo.html', 'http://web.archive.org/web/20151103051855id_/http://www.cnbc.com/2014/09/10/beware-wild-cards-to-watch-for-in-2015.html'], 'title': ['American Airlines CEO', 'Beware: Wild cards to watch for in 2015'], 'date': ['20151102042147', '20151103051855'], 'text': ['"The three of us have now the ability to take people pretty much anywhere in the world. What used to be a business where it was purely on schedule, if you have the ability to take people everywhere, you have to compete on product. And we\'re prepared to do that," said Parker, who had been the CEO of US Airways before the merger.\n\nThe newly combined carrier began trading Mon

In [25]:
summary

["POPEYE-WORTHY PIE. PHYLLO DOUGH WRAPS SPINACH IN PURE GOLD BY ROSEMARY BLACK Spinach has terrorized generations of veggie-phobic kids, and many grownups don't much like it, either. But when it's combined with seasonings and feta cheese and wrapped in a golden crisp phyllo dough crust, even those who despise Popeye's Â\xadfavorite food ask for seconds. The spinach pie at Kebab House II on Orchard St. is a specialty of owner and executive chef Ramazan Ay,",
 'All day, every day, Cheryl Bernstein thanks her 16-month-old son. "I gave life to Reid, but he gave me life - a reason to get clean and go on,"she said yesterday after graduating from the Manhattan Family Treatment Court program. Bernstein, 41, and her husband, Doug Flaumenbaum, 33, both recovering crack and heroin addicts, were among three dozen men and women who regained custody of their children.']

Test the K-Means clustering with single sample

In [None]:
sample_train, sample_test = itemgetter('text', 'summary')(next(iter(trainloader)))

sample_sentences = tokenize_sentences(sample_train)

# Tokenize
sample_train_tokens = bert_tokenizer(sample_sentences, return_tensors='pt', padding='longest')

# Preprocess
sample_sentences = np.array(sample_sentences)

## Get BERT CLS embeddings
model_output = bert_base_model(**sample_train_tokens)
pooler_output = model_output.pooler_output.detach().numpy() # Get numpy array

# Cluster embeddings to find centroids
sample_sentences_test = tokenize_sentences(sample_test)

# Choose same amount of centroids than the actual summary has
k = len(sample_sentences_test)
kmeans = KMeans(n_clusters=k).fit(pooler_output)

Find closest token to the centroid

In [11]:
def best_n_summaries(centroids, pooler_outputs, model_outputs: np.ndarray, n_summaries=1):
    '''
    Returns n most likely sentences for summarization
    
    Attributes:
        n_summaries: How many sentences to choose from a single centroid (for debugging)
        centroids: Centroids of the K-clusters
        pooler_outputs: embeddings from the model's [CLS] token
        model_outputs: Text tokenized to sentences. Used for retrieving sentences from embeddings
    '''
    assert n_summaries < pooler_outputs.shape[0], "n_summaries must be less than sentences in the trainset"
    
    summarizations = []
        
    for k in centroids: # Go through centroids
        indices = [] # Get indices for each pooler output
        values = [] # Get distances to the centroid k
        
        # Go through pooler outputs, and find sentence closest to centroid
        for i, sample in enumerate(pooler_outputs):
            current_dist = np.linalg.norm(sample - k)
            values.append(current_dist)
            indices.append(i)

        # Sort indices based on values
        _, indices = zip(*sorted(zip(values, indices)))
        indices = list(indices)
        
        summarization = model_outputs[indices[:n_summaries]]
        summarizations.append(''.join(summarization))
        
    return summarizations

# centroid = kmeans.cluster_centers_
# summaries = best_n_summaries(3, centroid, pooler_output, sample_sentences)

Compare found centroid and actual summarization

In [None]:
def print_top_n_summaries(summaries, test_summary):
    print("Generated summaries: \n")
    [print(f"{i+1}. {summary}") for i, summary in enumerate(summaries)]
    print(f"\n\nActual summary: \n {test_summary}")
    
print_top_n_summaries(summaries, sample_sentences_test)