# Based on the best model we have trained

### Imports

In [None]:
import numpy as np
import pandas as pd
import os
from collections import Counter
import pickle
from itertools import chain

import gc
import warnings
warnings.simplefilter('ignore')

from gensim.models import Word2Vec
from gensim.test.utils import get_tmpfile
from gensim.models import KeyedVectors

import tensorflow as tf
from keras.models import Model, Sequential
from keras.layers.recurrent import LSTM
from keras.layers import Embedding, Dense, Input, RepeatVector, TimeDistributed, concatenate, add, Dropout
from keras.callbacks import ModelCheckpoint
from keras.preprocessing.sequence import pad_sequences
from sklearn.model_selection import train_test_split

%config InteractiveShell.ast_node_interactivity = 'all'

config = tf.ConfigProto()
config.gpu_options.allow_growth = True
session = tf.Session(config=config)

### Helper

In [None]:
#defaults for params

MAX_INPUT_SEQ_LENGTH = 500
MAX_TARGET_SEQ_LENGTH = 50
MAX_INPUT_VOCAB_SIZE = 5000
MAX_TARGET_VOCAB_SIZE = 2000

def fit_text(X, Y, input_seq_max_length=None, target_seq_max_length=None):
    """Creates the dictionaries for the word to id lookup and vice versa,
    calculates the maximum input and output sequence length and the 
    number of tokens in the dictionary"""
    
    if input_seq_max_length is None:
        input_seq_max_length = MAX_INPUT_SEQ_LENGTH
    if target_seq_max_length is None:
        target_seq_max_length = MAX_TARGET_SEQ_LENGTH
    input_counter = Counter()
    target_counter = Counter()
    max_input_seq_length = 0
    max_target_seq_length = 0

    for line in X:
        text = [word.lower() for word in line.split(' ')]
        seq_length = len(text)
        if seq_length > input_seq_max_length:
            text = text[0:input_seq_max_length]
            seq_length = len(text)
        for word in text:
            input_counter[word] += 1
        max_input_seq_length = max(max_input_seq_length, seq_length)

    for line in Y:
        line2 = 'START ' + line.lower() + ' END'
        text = [word for word in line2.split(' ')]
        seq_length = len(text)
        if seq_length > target_seq_max_length:
            text = text[0:target_seq_max_length]
            seq_length = len(text)
        for word in text:
            target_counter[word] += 1
            max_target_seq_length = max(max_target_seq_length, seq_length)

    input_word2idx = dict()
    for idx, word in enumerate(input_counter.most_common(MAX_INPUT_VOCAB_SIZE)):
        input_word2idx[word[0]] = idx + 2
    input_word2idx['PAD'] = 0
    input_word2idx['UNK'] = 1
    input_idx2word = dict([(idx, word) for word, idx in input_word2idx.items()])

    target_word2idx = dict()
    for idx, word in enumerate(target_counter.most_common(MAX_TARGET_VOCAB_SIZE)):
        target_word2idx[word[0]] = idx + 1
    target_word2idx['UNK'] = 0

    target_idx2word = dict([(idx, word) for word, idx in target_word2idx.items()])
    
    num_input_tokens = len(input_word2idx)
    num_target_tokens = len(target_word2idx)

    config = dict()
    config['input_word2idx'] = input_word2idx
    config['input_idx2word'] = input_idx2word
    config['target_word2idx'] = target_word2idx
    config['target_idx2word'] = target_idx2word
    config['num_input_tokens'] = num_input_tokens
    config['num_target_tokens'] = num_target_tokens
    config['max_input_seq_length'] = max_input_seq_length
    config['max_target_seq_length'] = max_target_seq_length

    return config

def summarize(input_text):
    """Creates the summary from the input sequence;
    samples from the decoder until either the end token is reached
    or the maximum output sequence length is reached"""
    
    input_seq = []
    input_wids = []
    for word in input_text.lower().split(' '):
        idx = 1  # default [UNK]
        if word in  input_word2idx:
            idx =  input_word2idx[word]
        input_wids.append(idx)
    input_seq.append(input_wids)
    input_seq = pad_sequences(input_seq,  max_input_seq_length)
    start_token =  target_word2idx['START']
    wid_list = [start_token]
    sum_input_seq = pad_sequences([wid_list], min( num_target_tokens, MAX_DECODER_SEQ_LENGTH))
    terminated = False

    target_text = ''

    while not terminated:
        output_tokens =  model.predict([input_seq, sum_input_seq])
        sample_token_idx = np.argmax(output_tokens[0, :])
        sample_word =  target_idx2word[sample_token_idx]
        wid_list = wid_list + [sample_token_idx]

        if sample_word != 'START' and sample_word != 'END':
            target_text += ' ' + sample_word
        if sample_word == 'END' or len(wid_list) >=  max_target_seq_length:
            terminated = True
        else:
            sum_input_seq = pad_sequences([wid_list], min( num_target_tokens, MAX_DECODER_SEQ_LENGTH))
    return target_text.strip()

### Load model

In [None]:
from keras.models import load_model
model = load_model(r'model4_pretrained') #if exists; point to your path if different
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

### Load data

In [None]:
import pickle
with open('titlesAbstracts_AT.pkl', 'rb') as fh:
    titles = pickle.load(fh)

with open('abstractsCorpus_ATsigns.pkl', 'rb') as fh:
    text = pickle.load(fh)

textConcat = list()

from itertools import chain
for each in text.values():
    tmp = ' '.join(list(chain.from_iterable(each)))
    textConcat.append(tmp)

X = textConcat[:10000]
Y = list(titles.values())[:10000]

Xfull = textConcat
Yfull = list(titles.values())

### Get params

In [None]:
conf = fit_text(X, Y)

HIDDEN_UNITS = 100

MAX_DECODER_SEQ_LENGTH = 4

input_word2idx = conf['input_word2idx']
input_idx2word = conf['input_idx2word']
target_word2idx = conf['target_word2idx'] 
target_idx2word = conf['target_idx2word']
num_input_tokens = conf['num_input_tokens']
num_target_tokens = conf['num_target_tokens']
max_input_seq_length = conf['max_input_seq_length']
max_target_seq_length = conf['max_target_seq_length']

### Get random article to test the prediction on

In [None]:
from random import randint
c = randint(10000, len(Xfull))

textPredict = Xfull[c]
labelPredict = Yfull[c]

summarize(textPredict)
labelPredict

# Evaluate

In [None]:
from sumeval.metrics.rouge import RougeCalculator
rouge = RougeCalculator(stopwords=True, lang="en")

In [None]:
rouge1 = dict()

for _ in range(50):
    from random import randint
    c = randint(10000, len(Xfull))

    textPredict = Xfull[c]
    labelPredict = Yfull[c]

    generated = summarize(textPredict)
    reference = labelPredict

    score = rouge.rouge_n(
                summary=generated,
                references=reference,
                n=1)
    
    rouge1[score] = (generated, reference)

for s in sorted(rouge1.keys(), reverse=True):
    
    gen, orig = rouge1[s]
    print(s)
    print(f'Generated headline:{gen}')
    print(f'Original headline:{orig}')

### Rouge1 avg:

In [None]:
np.mean(list(rouge1.keys()))

# (For reference) Rouge1 example on ideal

In [None]:
rouge.rouge_n(summary='I would like an apple.', references='I would like an apple.', n=1)
rouge.rouge_n(summary='I would like to eat an apple.', references='I feel like having an apple.', n=1)