#Paper: Model Explorations in Abstractive Summarisation

GROUP 15:  
Youning Xia (19054254)
Yan Song (9898418)
Shuyi Han (19060915)
Wan Jing Song (17130843)


This paper sets out to assess the performanceof Deep Reinforcement Learning (DRL) basedabstractive summarization models.  4 differentmodel variants are applied on 3 datasets andevaluate on ROUGE and BERTScores. Work-ing  on  the  novel  WikiHow  dataset  (Koupaeeand Wang, 2018) which is slightly more com-plex to train on has magnified the characteris-tics of the models.   It exposes the instabilityof  training  on  ROUGE-L  scores  and  suggestBERTScore as an alternative.

#Source code



1.   Data pre-processing
2.   Dataset files
3.   Package installed
4.   model configuration
5.   Vocabulary class
6.   Batcher
7.   Model with global attention mechanism
8.   Training utility
9.   Training class with ROUGE reward function
10.  Training class with BERTScore reward function
11.  Beam search class
12.  ROUGE Evaluation class(validating and testing)
13.  BERTScore Evaluation class 
14.  Model with local attention machanism
15.  Experiment on three datasets.



#1.Data pre-processing

In [0]:
import os
import shutil
import collections
import tqdm
from tensorflow.core.example import example_pb2
import struct
import random
import shutil

finished_path = "/Users/youning/Documents/UCL_CSML/CSML_Lecture/COMP0087_NLP/Text-Summarizer-Pytorch-master/data/finished"      
unfinished_path = "/Users/youning/Documents/UCL_CSML/CSML_Lecture/COMP0087_NLP/Text-Summarizer-Pytorch-master/data/unfinished"
chunk_path = "/Users/youning/Documents/UCL_CSML/CSML_Lecture/COMP0087_NLP/Text-Summarizer-Pytorch-master/data/chunked"

vocab_path = "/Users/youning/Documents/UCL_CSML/CSML_Lecture/COMP0087_NLP/Text-Summarizer-Pytorch-master/data/vocab"
VOCAB_SIZE = 200000

CHUNK_SIZE = 1500                           # num examples per chunk, for the chunked data
train_bin_path = os.path.join(finished_path, "train.bin")
valid_bin_path = os.path.join(finished_path, "valid.bin")
test_bin_path = os.path.join(finished_path, "test.bin")

def make_folder(folder_path):
    if not os.path.exists(folder_path):
        os.makedirs(folder_path)

def delete_folder(folder_path):
    if os.path.exists(folder_path):
        shutil.rmtree(folder_path)

def shuffle_text_data(unshuffled_art, unshuffled_abs, shuffled_art, shuffled_abs):
    article_itr = open(os.path.join(unfinished_path, unshuffled_art), "r")
    abstract_itr = open(os.path.join(unfinished_path, unshuffled_abs), "r")
    list_of_pairs = []
    for article in article_itr:
        article = article.strip()
        abstract = next(abstract_itr).strip()
        list_of_pairs.append((article, abstract))
    article_itr.close()
    abstract_itr.close()
    random.shuffle(list_of_pairs)
    article_itr = open(os.path.join(unfinished_path, shuffled_art), "w")
    abstract_itr = open(os.path.join(unfinished_path, shuffled_abs), "w")
    for pair in list_of_pairs:
        article_itr.write(pair[0]+"\n")
        abstract_itr.write(pair[1]+"\n")
    article_itr.close()
    abstract_itr.close()

def write_to_bin(article_path, abstract_path, out_file, vocab_counter = None):

    with open(out_file, 'wb') as writer:

        article_itr = open(article_path, 'r')
        abstract_itr = open(abstract_path, 'r')
        for article in tqdm.tqdm(article_itr):
            article = article.strip()
            abstract = next(abstract_itr).strip()

            tf_example = example_pb2.Example()
            tf_example.features.feature['article'].bytes_list.value.extend([bytes(article, encoding= 'utf-8')])
            tf_example.features.feature['abstract'].bytes_list.value.extend([bytes(abstract, encoding= 'utf-8')])
            tf_example_str = tf_example.SerializeToString()
            str_len = len(tf_example_str)
            writer.write(struct.pack('q', str_len))
            writer.write(struct.pack('%ds' % str_len, tf_example_str))

            if vocab_counter is not None:
                art_tokens = article.split(' ')
                abs_tokens = abstract.split(' ')
                # abs_tokens = [t for t in abs_tokens if
                #               t not in [SENTENCE_START, SENTENCE_END]]  # remove these tags from vocab
                tokens = art_tokens + abs_tokens
                tokens = [t.strip() for t in tokens]  # strip
                tokens = [t for t in tokens if t != ""]  # remove empty
                vocab_counter.update(tokens)

    if vocab_counter is not None:
        with open(vocab_path, 'w') as writer:
            for word, count in vocab_counter.most_common(VOCAB_SIZE):
                writer.write(word + ' ' + str(count) + '\n')


def creating_finished_data():
    make_folder(finished_path)

    vocab_counter = collections.Counter()

    write_to_bin(os.path.join(unfinished_path, "train.art.shuf.txt"), os.path.join(unfinished_path, "train.abs.shuf.txt"), train_bin_path, vocab_counter)
    write_to_bin(os.path.join(unfinished_path, "valid.art.shuf.txt"), os.path.join(unfinished_path, "valid.abs.shuf.txt"), valid_bin_path)
    write_to_bin(os.path.join(unfinished_path, "test.art.shuf.txt"), os.path.join(unfinished_path, "test.abs.shuf.txt"), test_bin_path)


def chunk_file(set_name, chunks_dir, bin_file):
    make_folder(chunks_dir)
    reader = open(bin_file, "rb")
    chunk = 0
    finished = False
    while not finished:
        chunk_fname = os.path.join(chunks_dir, '%s_%04d.bin' % (set_name, chunk)) # new chunk
        with open(chunk_fname, 'wb') as writer:
            for _ in range(CHUNK_SIZE):
                len_bytes = reader.read(8)
                if not len_bytes:
                    finished = True
                    break
                str_len = struct.unpack('q', len_bytes)[0]
                example_str = struct.unpack('%ds' % str_len, reader.read(str_len))[0]
                writer.write(struct.pack('q', str_len))
                writer.write(struct.pack('%ds' % str_len, example_str))
            chunk += 1


if __name__ == "__main__":
    shuffle_text_data("train.article.txt", "train.title.txt", "train.art.shuf.txt", "train.abs.shuf.txt")
    shuffle_text_data("valid.article.filter.txt", "valid.title.filter.txt", "valid.art.shuf.txt", "valid.abs.shuf.txt")
    shuffle_text_data("test.article.filter.txt", "test.title.filter.txt", "test.art.shuf.txt", "test.abs.shuf.txt")
    print("Completed shuffling train & valid & test text files")
    delete_folder(finished_path)
    creating_finished_data()        #create bin files
    print("Completed creating bin file for train & valid & test")
    delete_folder(chunk_path)
    chunk_file("train", os.path.join(chunk_path, "train"), train_bin_path)
    chunk_file("valid", os.path.join(chunk_path, "main_valid"), valid_bin_path)
    chunk_file("test", os.path.join(chunk_path, "main_test"), test_bin_path)
    print("Completed chunking main bin files into smaller ones")
    #Performing rouge evaluation on 1.9 lakh sentences takes lot of time. So, create mini validation set & test set by borrowing 15k samples each from these 1.9 lakh sentences
    make_folder(os.path.join(chunk_path, "valid"))
    make_folder(os.path.join(chunk_path, "test"))
    bin_chunks = os.listdir(os.path.join(chunk_path, "main_valid"))
    bin_chunks.sort()
    samples = random.sample(set(bin_chunks[:-1]), 2)  # Exclude last bin file; contains only 9k sentences
    valid_chunk, test_chunk = samples[0], samples[1]
    shutil.copyfile(os.path.join(chunk_path, "main_valid", valid_chunk),os.path.join(chunk_path, "valid", "valid_00.bin"))
    shutil.copyfile(os.path.join(chunk_path, "main_valid", test_chunk), os.path.join(chunk_path, "test", "test_00.bin"))

    # delete_folder(finished)
    # delete_folder(os.path.join(chunk_path, "main_valid"))

#2.Dataset

In [0]:
from google.colab import drive                  #Most of these files are shared on google drive, you can simply save them into you personal drive and have access directly
drive.mount('/content/drive')

##A. Gigaword dataset(3.8M) 



1.   Binary data files: containing the chunked training, validation and testing data file in binary format. **LINK**:  https://drive.google.com/drive/folders/1se96ql8HQx1Sg1EiJ66NchqH2BWm3vEM?usp=sharing

2.   pre-trained model: ML training for 10,000 iterations. **LINK**:  https://drive.google.com/file/d/1tEiDx77a9Tf6AA8vYHC6t2tIF966Uj2f/view?usp=sharing



3.   **ML** all training checkpointed models for validating purpose:  10,000 - 20,000 itrations. **LINK**: https://drive.google.com/drive/folders/1HFqaVSc56CFwAk7S9AT9yv29bzlZNrVA?usp=sharing



4.   **ML+RL(r)** all training checkpointed models for validating purpose:  10,000-20,000 iterations. **LINK**: https://drive.google.com/open?id=16u6iaVKmSg636V1VlVHIAxiy0b17svDb



5.   **RL(r), RL(b)**: all training checkpointed models for validating purpose: these models files are currently not on google drive, available upon request.

6.   **Testing models files**: contain testing models for ROUGE and BERTScore testing. **LINK**: https://drive.google.com/drive/folders/1HxaIm1VuCu7tQKBplQvA-WU1NpijGv0t?usp=sharing





##B. CNN/DM dataset(287K)



1.   Binary data files: chunked training, validation, testing binary data files. **LINK**: https://drive.google.com/drive/folders/1_-qeitl6QmFoa_At8bZXlvvfsEVxHoNV?usp=sharing ;  vocab file: https://drive.google.com/file/d/1INJZ7F5vwqISG5lgbn-3FXIz8hqDr390/view?usp=sharing

2.   Pre-trained models(**ML**): 10,000 iterations. LINK: https://drive.google.com/file/d/1-aOEDsjrYNaQed_HfgFZmI19x1Wu-mCo/view?usp=sharing

3.   **ML** training files: LINK: https://drive.google.com/drive/folders/155ldJSInimq06Xo7cdhkfCeSI8QNypem?usp=sharing

4.   **ML+RL** training files:  LINK: https://drive.google.com/drive/folders/1WuESSVflSyt92G28yu0pfEzlksxRgOt3?usp=sharing



5.   **RL(r)** training models files:  LINK: https://drive.google.com/drive/folders/1j1GKlRhX3VtkKEnKINtM3IpU3QkIFk_3?usp=sharing


6.  **RL(b)** training models files: LINK:  https://drive.google.com/drive/folders/1HeF2NOK8u9b9a1Q-bZS6lNvliPGlQApc?usp=sharing

7.  **Testing models files**:  https://drive.google.com/drive/folders/1w8dqO4IMQVP6LCGosTb0cJ_3095VRTk4?usp=sharing








##C. WikiHow dataset (161K)

1.   Binary data files:  https://drive.google.com/drive/folders/1oaYyf3NPYYbrnJCRXt6OAb4ngAX8UsTZ?usp=sharing

2.   Pre-trained(**ML**) models:  LINK: https://drive.google.com/file/d/1-_AIyVz-4T0VOYE7PSSaUdrw_MG6lonT/view?usp=sharing

3.   **ML** training model files:  https://drive.google.com/drive/folders/15w-hSaHmhGkTzFlABct6LnYjjgHZjs-0?usp=sharing


4.   **ML+RL(r)** tarining model files:  Available upon request.


5.   **RL(r)** training mdoel files:  LINK: https://drive.google.com/drive/folders/1gPPBGrYQkGd06kcCkJ-TdCtilFPPPkGv?usp=sharing


6.  **RL(b)** training models files:  available upon request.

7.  **Testing models files**: https://drive.google.com/drive/folders/1pmJ7BfxENJTjp8RdjBK_S8DEYL3bnmih?usp=sharing





#3. Package

In [0]:
pip install rouge           #ROUGE metric

In [0]:
pip install bert-score      #BERTScore

In [0]:
import glob
import random
import struct
import csv
from tensorflow.core.example import example_pb2
import numpy as np

import queue as Queue
import time
from random import shuffle
from threading import Thread


import tensorflow as tf

import torch as T
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
import torch.nn.functional as F

import os

from rouge import Rouge
import argparse


import shutil
import collections
import tqdm

import gc
import bert_score
import argparse



#4. Config

In [0]:
train_data_path = 	'drive/My Drive/final WIKIHOW/WikiHow_bin/wiki_chunked/train/train_*'   #this is the location of trunked training files                         #"drive/My Drive/WikiHow_bin/wiki_chunked/train/train_*"
valid_data_path =  'drive/My Drive/final WIKIHOW/WikiHow_bin/wiki_chunked/main_valid/valid_0000.bin'

vocab_path = 		'drive/My Drive/final WIKIHOW/WikiHow_bin/vocab_wiki'
   #this is the location of vocab                                                            #"drive/My Drive/WikiHow_bin/vocab_wiki"


# Hyperparameters
hidden_dim = 512
emb_dim = 256
batch_size =  10          #10
max_enc_steps = 800
max_dec_steps = 150
beam_size = 4
min_dec_steps= 3
vocab_size = 20000

lr = 0.001
rand_unif_init_mag = 0.02
trunc_norm_init_std = 1e-4
max_batch_queue = 30        # max number of batches the batch_queue can hold

eps = 1e-12
max_iterations = 15000


save_model_path = 'drive/My Drive/local pre'  #this is the location of pre-trained model                                          #'drive/My Drive/RL_wiki_model'             #"drive/My Drive/wiki_model"
new_save_model_path = 'drive/My Drive/local RLb'  #this is the location you want your new model to be saved    #'drive/My Drive/RL_wiki_model'

intra_encoder = False
intra_decoder = True
local_attention_d = 100

#5. Vocab

In [0]:

SENTENCE_START = '<s>'      # <s> and </s> are used in the data files to segment the abstracts into sentences. They don't receive vocab ids.
SENTENCE_END = '</s>'       # or <t>, </t>

PAD_TOKEN = '[PAD]'         # This has a vocab id, which is used to pad the encoder input, decoder input and target sequence
UNKNOWN_TOKEN = '[UNK]'     # This has a vocab id, which is used to represent out-of-vocabulary words
START_DECODING = '[START]'  # This has a vocab id, which is used at the start of every decoder input sequence
STOP_DECODING = '[STOP]'    # This has a vocab id, which is used at the end of untruncated target sequences

# Note: none of <s>, </s>, [PAD], [UNK], [START], [STOP] should appear in the vocab file.

class Vocab(object):
  """Vocabulary class for mapping between words and ids (integers)"""

  def __init__(self, vocab_file, max_size):
    """Creates a vocab of up to max_size words, reading from the vocab_file. If max_size is 0, reads the entire vocab file.

    param  vocab_file: path to the vocab file, which is assumed to contain "<word> <frequency>" on each line, sorted with most frequent word first. This code doesn't actually use the frequencies, though.
    param  max_size: integer. The maximum size of the resulting Vocabulary."""
    self._word_to_id = {}                                                       #dictionary with word as key
    self._id_to_word = {}                                                       #           with id as key
    self._count = 0                                                             # keeps track of total number of words in the Vocab

    # [UNK], [PAD], [START] and [STOP] get the ids 0,1,2,3.
    for w in [UNKNOWN_TOKEN, PAD_TOKEN, START_DECODING, STOP_DECODING]:
      self._word_to_id[w] = self._count                                         #fill in the dict
      self._id_to_word[self._count] = w     
      self._count += 1

    # Read the vocab file and add words up to max_size
    with open(vocab_file, 'r') as vocab_f:
      for line in vocab_f:
        pieces = line.split()                                                   #list with two elements, the first is word the other one is its id
        if len(pieces) != 2:
          print('Warning: incorrectly formatted line in vocabulary file: %s\n' % line)
          continue
        w = pieces[0]                                                           #word
        if w in [SENTENCE_START, SENTENCE_END, UNKNOWN_TOKEN, PAD_TOKEN, START_DECODING, STOP_DECODING]:        #raise error if these words exist in vocab
          raise Exception('<s>, </s>, [UNK], [PAD], [START] and [STOP] shouldn\'t be in the vocab file, but %s is' % w)
        if w in self._word_to_id:                                             
          raise Exception('Duplicated word in vocabulary file: %s' % w)         #Duplication error
        self._word_to_id[w] = self._count                                       #assign id to word
        self._id_to_word[self._count] = w                                       #assign word to id
        self._count += 1                                                        #id increment
        if max_size != 0 and self._count >= max_size:                           #stop assigning when max_size has been reached
          print("max_size of vocab was specified as %i; we now have %i words. Stopping reading." % (max_size, self._count))
          break

    print("Finished constructing vocabulary of %i total words. Last word added: %s" % (self._count, self._id_to_word[self._count-1]))

  def word2id(self, word):
    """Returns the id (integer) of a word (string). Returns [UNK] id if word is OOV."""
    if word not in self._word_to_id:
      return self._word_to_id[UNKNOWN_TOKEN]
    return self._word_to_id[word]

  def id2word(self, word_id):
    """Returns the word (string) corresponding to an id (integer)."""
    if word_id not in self._id_to_word:
      raise ValueError('Id not found in vocab: %d' % word_id)
    return self._id_to_word[word_id]

  def size(self):
    """Returns the total size of the vocabulary"""
    return self._count
  '''
  def write_metadata(self, fpath):
    """Writes metadata file for Tensorboard word embedding visualizer as described here:
      https://www.tensorflow.org/get_started/embedding_viz
    Args:
      fpath: place to write the metadata file
    """
    print("Writing word embedding metadata file to %s..." % (fpath))
    with open(fpath, "w") as f:
      fieldnames = ['word']
      writer = csv.DictWriter(f, delimiter="\t", fieldnames=fieldnames)
      for i in range(self.size()):
        writer.writerow({"word": self._id_to_word[i]})
  '''




def example_generator(data_path, single_pass):
  """Generates tf.Examples from data files.
    Binary data format: <length><blob>. <length> represents the byte size
    of <blob>. <blob> is serialized tf.Example proto. The tf.Example contains
    the tokenized article text and summary.
    https://www.tensorflow.org/tutorials/load_data/tfrecord

    param  data_path:
      Path to tf.Example data files. Can include wildcards, e.g. if you have several training data chunk files train_001.bin, train_002.bin, etc, then pass data_path=train_* to access them all.
    param  single_pass:
      Boolean. If True, go through the dataset exactly once, generating examples in the order they appear, then return. Otherwise, generate random examples indefinitely.
    Yields:
      Deserialized tf.Example.
  """
  while True:
    filelist = glob.glob(data_path)                                             # get the list of datafiles
    assert filelist, ('Error: Empty filelist at %s' % data_path)                # check filelist isn't empty
    if single_pass:
      filelist = sorted(filelist)
    else:
      random.shuffle(filelist)                                                  #random shuffle the data if not single_pass
    for f in filelist:
      reader = open(f, 'rb')
      while True:
        len_bytes = reader.read(8)
        if not len_bytes: break # finished reading this file
        str_len = struct.unpack('q', len_bytes)[0]
        example_str = struct.unpack('%ds' % str_len, reader.read(str_len))[0]
        yield example_pb2.Example.FromString(example_str)
    if single_pass:
      print("example_generator completed reading all datafiles. No more data.")
      break


def article2ids(article_words, vocab):
  """Map the article words to their ids. Also return a list of OOVs in the article.
  Args:
    article_words: list of words (strings)
    vocab: Vocabulary object
  Returns:
    ids:
      A list of word ids (integers); OOVs are represented by their temporary article OOV number. If the vocabulary size is 50k and the article has 3 OOVs, then these temporary OOV numbers will be 50000, 50001, 50002.
    oovs:
      A list of the OOV words in the article (strings), in the order corresponding to their temporary article OOV numbers."""
  ids = []
  oovs = []
  unk_id = vocab.word2id(UNKNOWN_TOKEN)
  for w in article_words:
    i = vocab.word2id(w)
    if i == unk_id:                         # If w is OOV
      if w not in oovs:                     # Add to list of OOVs
        oovs.append(w)
      oov_num = oovs.index(w)               # This is 0 for the first article OOV, 1 for the second article OOV...
      ids.append(vocab.size() + oov_num)    # This is e.g. 50000 for the first article OOV, 50001 for the second...(append at the end)
    else:
      ids.append(i)
  return ids, oovs


def abstract2ids(abstract_words, vocab, article_oovs):
  """Map the abstract words to their ids. In-article OOVs are mapped to their temporary OOV numbers.
  Args:
    abstract_words: list of words (strings)
    vocab: Vocabulary object
    article_oovs: list of in-article OOV words (strings), in the order corresponding to their temporary article OOV numbers
  Returns:
    ids: List of ids (integers). In-article OOV words are mapped to their temporary OOV numbers. Out-of-article OOV words are mapped to the UNK token id."""
  ids = []
  unk_id = vocab.word2id(UNKNOWN_TOKEN)
  for w in abstract_words:
    i = vocab.word2id(w)
    if i == unk_id:                         # If w is an OOV word
      if w in article_oovs:                 # If w is an in-article OOV
        vocab_idx = vocab.size() + article_oovs.index(w)    # Map to its temporary article OOV number
        ids.append(vocab_idx)
      else:                                 # If w is an out-of-article OOV
        ids.append(unk_id)                  # Map to the UNK token id
    else:
      ids.append(i)
  return ids


def outputids2words(id_list, vocab, article_oovs):
  """Maps output ids to words, including mapping in-article OOVs from their temporary ids to the original OOV string (applicable in pointer-generator mode).

    param  id_list: list of ids (integers)
            vocab: Vocabulary object
    param  article_oovs: list of OOV words (strings) in the order corresponding to their temporary article OOV ids (that have been assigned in pointer-generator mode), or None (in baseline mode)
    Returns:
            words: list of words (strings)
  """
  words = []
  for i in id_list:
    try:
      w = vocab.id2word(i) # might be [UNK]
    except ValueError as e: # w is OOV
      assert article_oovs is not None, "Error: model produced a word ID that isn't in the vocabulary. This should not happen in baseline (no pointer-generator) mode"
      article_oov_idx = i - vocab.size()
      try:
        w = article_oovs[article_oov_idx]
      except ValueError as e: # i doesn't correspond to an article oov
        raise ValueError('Error: model produced word ID %i which corresponds to article OOV %i but this example only has %i article OOVs' % (i, article_oov_idx, len(article_oovs)))
    words.append(w) 
  return words


def abstract2sents(abstract):
  """Splits abstract text from datafile into list of sentences.
  Args:
    abstract: string containing <s> and </s> tags for starts and ends of sentences
  Returns:
    sents: List of sentence strings (no tags)"""
  cur = 0
  sents = []
  while True:
    try:
      start_p = abstract.index(SENTENCE_START, cur)
      end_p = abstract.index(SENTENCE_END, start_p + 1)
      cur = end_p + len(SENTENCE_END)
      sents.append(abstract[start_p+len(SENTENCE_START):end_p])
    except ValueError as e: # no more sentences
      return sents

'''
def show_art_oovs(article, vocab):
  """Returns the article string, highlighting the OOVs by placing __underscores__ around them"""
  unk_token = vocab.word2id(UNKNOWN_TOKEN)
  words = article.split(' ')
  words = [("__%s__" % w) if vocab.word2id(w)==unk_token else w for w in words]
  out_str = ' '.join(words)
  return out_str


def show_abs_oovs(abstract, vocab, article_oovs):
  """Returns the abstract string, highlighting the article OOVs with __underscores__.
  If a list of article_oovs is provided, non-article OOVs are differentiated like !!__this__!!.
  Args:
    abstract: string
    vocab: Vocabulary object
    article_oovs: list of words (strings), or None (in baseline mode)
  """
  unk_token = vocab.word2id(UNKNOWN_TOKEN)
  words = abstract.split(' ')
  new_words = []
  for w in words:
    if vocab.word2id(w) == unk_token: # w is oov
      if article_oovs is None: # baseline mode
        new_words.append("__%s__" % w)
      else: # pointer-generator mode
        if w in article_oovs:
          new_words.append("__%s__" % w)
        else:
          new_words.append("!!__%s__!!" % w)
    else: # w is in-vocab word
      new_words.append(w)
  out_str = ' '.join(new_words)
  return out_str
'''

#6. Batcher

In [0]:

import queue as Queue
import time
from random import shuffle
from threading import Thread

import numpy as np
import tensorflow as tf


import random
random.seed(1234)

class Example(object):
    '''
    read article, abstrction and vocab and process them into batching-ready format
    '''

  def __init__(self, article, abstract_sentences, vocab):
    # Get ids of special tokens
    start_decoding = vocab.word2id(START_DECODING)      #data
    stop_decoding = vocab.word2id(STOP_DECODING)        #data

    # Process the article
    article_words = article.split()
    if len(article_words) > max_enc_steps:              #truncate the input article to max_enc_step length
      article_words = article_words[ : max_enc_steps]   
    self.enc_len = len(article_words)                   # store the length after truncation but before padding
    self.enc_input = [vocab.word2id(w) for w in article_words] # list of word ids; OOVs are represented by the id for UNK token

    # Process the abstract
    abstract = ' '.join(abstract_sentences)                 # string
    abstract_words = abstract.split()                       # list of strings, split by space
    abs_ids = [vocab.word2id(w) for w in abstract_words]    # list of word ids; OOVs are represented by the id for UNK token

    # Get the decoder input sequence and target sequence
    self.dec_input, _ = self.get_dec_inp_targ_seqs(abs_ids, max_dec_steps, start_decoding, stop_decoding)
    self.dec_len = len(self.dec_input)

    # If using pointer-generator mode, we need to store some extra info
    # Store a version of the enc_input where in-article OOVs are represented by their temporary OOV id; also store the in-article OOVs words themselves
    self.enc_input_extend_vocab, self.article_oovs = article2ids(article_words, vocab)     #data

    # Get a verison of the reference summary where in-article OOVs are represented by their temporary article OOV id
    abs_ids_extend_vocab = abstract2ids(abstract_words, vocab, self.article_oovs)           #data

    # Get decoder target sequence
    _, self.target = self.get_dec_inp_targ_seqs(abs_ids_extend_vocab, max_dec_steps, start_decoding, stop_decoding)

    # Store the original strings
    self.original_article = article
    self.original_abstract = abstract
    self.original_abstract_sents = abstract_sentences



  def get_dec_inp_targ_seqs(self, sequence, max_len, start_id, stop_id):
    inp = [start_id] + sequence[:]
    target = sequence[:]
    if len(inp) > max_len: # truncate
      inp = inp[:max_len]
      target = target[:max_len] # no end_token
    else: # no truncation
      target.append(stop_id) # end token
    assert len(inp) == len(target)
    return inp, target


  def pad_decoder_inp_targ(self, max_len, pad_id):              #padding
    while len(self.dec_input) < max_len:
      self.dec_input.append(pad_id)
    while len(self.target) < max_len:
      self.target.append(pad_id)


  def pad_encoder_input(self, max_len, pad_id):                 #padding
    while len(self.enc_input) < max_len:
      self.enc_input.append(pad_id)
    while len(self.enc_input_extend_vocab) < max_len:
      self.enc_input_extend_vocab.append(pad_id)



#-------------------------------------------------------------------------------------------------


class Batch(object):
  def __init__(self, example_list, vocab, batch_size):
    self.batch_size = batch_size
    self.pad_id = vocab.word2id(PAD_TOKEN)      # id of the PAD token used to pad sequences
    self.init_encoder_seq(example_list)         # initialize the input to the encoder
    self.init_decoder_seq(example_list)         # initialize the input and targets for the decoder
    self.store_orig_strings(example_list)       # store the original strings


  def init_encoder_seq(self, example_list):
    # Determine the maximum length of the encoder input sequence in this batch
    max_enc_seq_len = max([ex.enc_len for ex in example_list])

    # Pad the encoder input sequences up to the length of the longest sequence
    for ex in example_list:
      ex.pad_encoder_input(max_enc_seq_len, self.pad_id)

    # Initialize the numpy arrays
    # Note: our enc_batch can have different length (second dimension) for each batch because we use dynamic_rnn for the encoder.
    self.enc_batch = np.zeros((self.batch_size, max_enc_seq_len), dtype=np.int32)
    self.enc_lens = np.zeros((self.batch_size), dtype=np.int32)
    self.enc_padding_mask = np.zeros((self.batch_size, max_enc_seq_len), dtype=np.float32)

    # Fill in the numpy arrays
    for i, ex in enumerate(example_list):
      self.enc_batch[i, :] = ex.enc_input[:]
      self.enc_lens[i] = ex.enc_len
      for j in range(ex.enc_len):
        self.enc_padding_mask[i][j] = 1

    # For pointer-generator mode, need to store some extra info
    # Determine the max number of in-article OOVs in this batch
    self.max_art_oovs = max([len(ex.article_oovs) for ex in example_list])
    # Store the in-article OOVs themselves
    self.art_oovs = [ex.article_oovs for ex in example_list]
    # Store the version of the enc_batch that uses the article OOV ids
    self.enc_batch_extend_vocab = np.zeros((self.batch_size, max_enc_seq_len), dtype=np.int32)
    for i, ex in enumerate(example_list):
      self.enc_batch_extend_vocab[i, :] = ex.enc_input_extend_vocab[:]

  def init_decoder_seq(self, example_list):
    # Pad the inputs and targets
    for ex in example_list:
      ex.pad_decoder_inp_targ(max_dec_steps, self.pad_id)           #config

    # Initialize the numpy arrays.
    self.dec_batch = np.zeros((self.batch_size, max_dec_steps), dtype=np.int32)         #config
    self.target_batch = np.zeros((self.batch_size, max_dec_steps), dtype=np.int32)      #config
    # self.dec_padding_mask = np.zeros((self.batch_size, config.max_dec_steps), dtype=np.float32)
    self.dec_lens = np.zeros((self.batch_size), dtype=np.int32)

    # Fill in the numpy arrays
    for i, ex in enumerate(example_list):
      self.dec_batch[i, :] = ex.dec_input[:]
      self.target_batch[i, :] = ex.target[:]
      self.dec_lens[i] = ex.dec_len
      # for j in range(ex.dec_len):
      #   self.dec_padding_mask[i][j] = 1

  def store_orig_strings(self, example_list):
    self.original_articles = [ex.original_article for ex in example_list] # list of lists
    self.original_abstracts = [ex.original_abstract for ex in example_list] # list of lists
    self.original_abstracts_sents = [ex.original_abstract_sents for ex in example_list] # list of list of lists


#-------------------------------------------------------------------------------------------------
class Batcher(object):
  BATCH_QUEUE_MAX = max_batch_queue  #30 # max number of batches the batch_queue can hold

  def __init__(self, data_path, vocab, mode, batch_size, single_pass):
    self._data_path = data_path
    self._vocab = vocab
    self._single_pass = single_pass
    self.mode = mode
    self.batch_size = batch_size
    # Initialize a queue of Batches waiting to be used, and a queue of Examples waiting to be batched
    self._batch_queue = Queue.Queue(self.BATCH_QUEUE_MAX)
    self._example_queue = Queue.Queue(self.BATCH_QUEUE_MAX * self.batch_size)

    # Different settings depending on whether we're in single_pass mode or not
    if single_pass:
      self._num_example_q_threads = 1 # just one thread, so we read through the dataset just once
      self._num_batch_q_threads = 1  # just one thread to batch examples
      self._bucketing_cache_size = 1 # only load one batch's worth of examples before bucketing; this essentially means no bucketing
      self._finished_reading = False # this will tell us when we're finished reading the dataset
    else:
      self._num_example_q_threads = 1 #16 # num threads to fill example queue
      self._num_batch_q_threads = 1 #4  # num threads to fill batch queue
      self._bucketing_cache_size = 1 #100 # how many batches-worth of examples to load into cache before bucketing

    # Start the threads that load the queues
    self._example_q_threads = []
    for _ in range(self._num_example_q_threads):
      self._example_q_threads.append(Thread(target=self.fill_example_queue))
      self._example_q_threads[-1].daemon = True
      self._example_q_threads[-1].start()
    self._batch_q_threads = []
    for _ in range(self._num_batch_q_threads):
      self._batch_q_threads.append(Thread(target=self.fill_batch_queue))
      self._batch_q_threads[-1].daemon = True
      self._batch_q_threads[-1].start()

    # Start a thread that watches the other threads and restarts them if they're dead
    if not single_pass: # We don't want a watcher in single_pass mode because the threads shouldn't run forever
      self._watch_thread = Thread(target=self.watch_threads)
      self._watch_thread.daemon = True
      self._watch_thread.start()

  def next_batch(self):
    # If the batch queue is empty, print a warning
    if self._batch_queue.qsize() == 0:
      # tf.logging.warning('Bucket input queue is empty when calling next_batch. Bucket queue size: %i, Input queue size: %i', self._batch_queue.qsize(), self._example_queue.qsize())
      if self._single_pass and self._finished_reading:
        tf.logging.info("Finished reading dataset in single_pass mode.")
        return None

    batch = self._batch_queue.get() # get the next Batch
    return batch

  def fill_example_queue(self):
    input_gen = self.text_generator(example_generator(self._data_path, self._single_pass))

    while True:
      try:
        (article, abstract) = next(input_gen) # read the next example from file. article and abstract are both strings.
      except StopIteration: # if there are no more examples:
        tf.logging.info("The example generator for this example queue filling thread has exhausted data.")
        if self._single_pass:
          tf.logging.info("single_pass mode is on, so we've finished reading dataset. This thread is stopping.")
          self._finished_reading = True
          break
        else:
          raise Exception("single_pass mode is off but the example generator is out of data; error.")

      # abstract_sentences = [sent.strip() for sent in data.abstract2sents(abstract)] # Use the <s> and </s> tags in abstract to get a list of sentences.
      abstract_sentences = [abstract.strip()]
      example = Example(article, abstract_sentences, self._vocab) # Process into an Example.
      self._example_queue.put(example) # place the Example in the example queue.

  def fill_batch_queue(self):
    while True:
      if self.mode == 'decode':
        # beam search decode mode single example repeated in the batch
        ex = self._example_queue.get()
        b = [ex for _ in range(self.batch_size)]
        self._batch_queue.put(Batch(b, self._vocab, self.batch_size))
      else:
        # Get bucketing_cache_size-many batches of Examples into a list, then sort
        inputs = []
        for _ in range(self.batch_size * self._bucketing_cache_size):
          inputs.append(self._example_queue.get())
        inputs = sorted(inputs, key=lambda inp: inp.enc_len, reverse=True) # sort by length of encoder sequence

        # Group the sorted Examples into batches, optionally shuffle the batches, and place in the batch queue.
        batches = []
        for i in range(0, len(inputs), self.batch_size):
          batches.append(inputs[i:i + self.batch_size])
        if not self._single_pass:
          shuffle(batches)
        for b in batches:  # each b is a list of Example objects
          self._batch_queue.put(Batch(b, self._vocab, self.batch_size))

  def watch_threads(self):
    while True:
      tf.logging.info(
        'Bucket queue size: %i, Input queue size: %i',
        self._batch_queue.qsize(), self._example_queue.qsize())

      time.sleep(60)
      for idx,t in enumerate(self._example_q_threads):
        if not t.is_alive(): # if the thread is dead
          tf.logging.error('Found example queue thread dead. Restarting.')
          new_t = Thread(target=self.fill_example_queue)
          self._example_q_threads[idx] = new_t
          new_t.daemon = True
          new_t.start()
      for idx,t in enumerate(self._batch_q_threads):
        if not t.is_alive(): # if the thread is dead
          tf.logging.error('Found batch queue thread dead. Restarting.')
          new_t = Thread(target=self.fill_batch_queue)
          self._batch_q_threads[idx] = new_t
          new_t.daemon = True
          new_t.start()


  def text_generator(self, example_generator):
    while True:
      e = next(example_generator) # e is a tf.Example
      try:
        article_text = e.features.feature['article'].bytes_list.value[0] # the article text was saved under the key 'article' in the data files
        abstract_text = e.features.feature['abstract'].bytes_list.value[0] # the abstract text was saved under the key 'abstract' in the data files
        article_text = article_text.decode()
        abstract_text = abstract_text.decode()
      except ValueError:
        tf.logging.error('Failed to get article or abstract from example')
        continue
      if len(article_text)==0: # See https://github.com/abisee/pointer-generator/issues/1
        #tf.logging.warning('Found an example with empty article text. Skipping it.')
        continue
      else:
        yield (article_text, abstract_text)

#7. Model with global attention 

In [0]:
'''
import torch as T
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
import torch.nn.functional as F
'''

def init_lstm_wt(lstm):                                 #initialisation 
    for name, _ in lstm.named_parameters():
        if 'weight' in name:
            wt = getattr(lstm, name)
            #print(wt)
            wt.data = wt.data.uniform_(-rand_unif_init_mag, rand_unif_init_mag)
            #wt.uniform_(-rand_unif_init_mag, rand_unif_init_mag)     #commit
        elif 'bias' in name:
            # set forget bias to 1
            bias = getattr(lstm, name)
            n = bias.size(0)
            start, end = n // 4, n // 2

            bias.data = bias.data.fill_(0.)
            #bias.fill_(0.)
            bias.data[start:end].fill_(1.)

def init_linear_wt(linear):                             #initialisation

    linear.weight.data = linear.weight.data.normal_(std=trunc_norm_init_std)
    #linear.weight.normal_(std=trunc_norm_init_std)
    if linear.bias is not None:
        linear.bias.data = linear.bias.data.normal_(std=trunc_norm_init_std)
        #linear.bias.normal_(std=trunc_norm_init_std)

def init_wt_normal(wt):                                 #initialisation
    wt.data = wt.data.normal_(std=trunc_norm_init_std)
    #wt.normal_(std=trunc_norm_init_std)


class Encoder(nn.Module):
    """
    bi-directional LSTM, reduced linear layer
    """
    def __init__(self):
        super(Encoder, self).__init__()

        self.lstm = nn.LSTM(emb_dim, hidden_dim, num_layers=1, batch_first=True, bidirectional=True)    #batch_true: If True, then the input and output tensors are provided as (batch, seq, feature)
        init_lstm_wt(self.lstm)

        self.reduce_h = nn.Linear(hidden_dim * 2, hidden_dim)        #hidden state 
        init_linear_wt(self.reduce_h)
        self.reduce_c = nn.Linear(hidden_dim * 2, hidden_dim)
        init_linear_wt(self.reduce_c)

    def forward(self, x, seq_lens):
        """
        param  x:  input sequence
        param  seq_len:  length of sequence

        return:
        enc_out:  Outputs of the encoder for all time steps (batch_size, length_input_sequence, 2*hidden_size)
        (h_reduced, c_reduced):  Tuple containing final hidden state & cell state of encoder. Shape of h & c: (batch_size, hidden_size)
        """
        packed = pack_padded_sequence(x, seq_lens, batch_first=True)    #accepts any input that has at least two dimensions. You can apply it to pack the labels, and use the output of the RNN with them to compute the loss directly
        enc_out, enc_hid = self.lstm(packed)                            # tensor containing the hidden state for t = seq_len. and tensor containing the cell state for t = seq_len.
        enc_out,_ = pad_packed_sequence(enc_out, batch_first=True)      
        enc_out = enc_out.contiguous()                              #bs, n_seq, 2*n_hid
        h, c = enc_hid                                              #shape of h: 2, bs, n_hid
        h = T.cat(list(h), dim=1)                                   #bs, 2*n_hid
        c = T.cat(list(c), dim=1)
        h_reduced = F.relu(self.reduce_h(h))                        #bs,n_hid
        c_reduced = F.relu(self.reduce_c(c))
        return enc_out, (h_reduced, c_reduced)      #enc_out: all encoder


class encoder_attention(nn.Module):

    def __init__(self):
        super(encoder_attention, self).__init__()
        self.W_h = nn.Linear(hidden_dim * 2, hidden_dim * 2, bias=False)    #no bias just the weight matrix
        self.W_s = nn.Linear(hidden_dim * 2, hidden_dim * 2)
        self.v = nn.Linear(hidden_dim * 2, 1, bias=False)


    def forward(self, st_hat, h, enc_padding_mask, sum_temporal_srcs):              #INTRA-TEMPORAL ATTENTION ON INPUT SEQUENCE in paper?
        ''' Perform attention over encoder hidden states
        :param st_hat: decoder hidden state at current time step
        :param h: encoder hidden states
        :param enc_padding_mask:
        :param sum_temporal_srcs: if using intra-temporal attention, contains summation of attention weights from previous decoder time steps
        return
        ct_e: encoder context vector
        at:   attention weight
        sum_temporal_srcs:  sum of attention weights from previous decoder time steps, will need to be input again for next iteration
        '''

        # Standard attention technique (eq 1 in https://arxiv.org/pdf/1704.04368.pdf)
        et = self.W_h(h)                        # bs,n_seq,2*n_hid
        dec_fea = self.W_s(st_hat).unsqueeze(1) # bs,1,2*n_hid
        et = et + dec_fea
        et = T.tanh(et)                         # bs,n_seq,2*n_hid
        et = self.v(et).squeeze(2)              # bs,n_seq

        # intra-temporal attention     (eq 3 in https://arxiv.org/pdf/1705.04304.pdf)
        if intra_encoder:
            exp_et = T.exp(et)
            if sum_temporal_srcs is None:
                et1 = exp_et
                sum_temporal_srcs  = get_cuda(T.FloatTensor(et.size()).fill_(1e-10)) + exp_et           #defined temporal score
            else:
                et1 = exp_et/sum_temporal_srcs  #bs, n_seq
                sum_temporal_srcs = sum_temporal_srcs + exp_et
        else:
            et1 = F.softmax(et, dim=1)

        # assign 0 probability for padded elements
        at = et1 * enc_padding_mask
        normalization_factor = at.sum(1, keepdim=True)
        at = at / normalization_factor

        at = at.unsqueeze(1)                    #bs,1,n_seq
        # Compute encoder context vector
        ct_e = T.bmm(at, h)                     #bs, 1, 2*n_hid, batch matrix-matrix product of matrices (temporal score and encoder hidden state)
        ct_e = ct_e.squeeze(1)
        at = at.squeeze(1)

        return ct_e, at, sum_temporal_srcs

class decoder_attention(nn.Module):
    def __init__(self):
        super(decoder_attention, self).__init__()
        if intra_decoder:
            self.W_prev = nn.Linear(hidden_dim, hidden_dim, bias=False)        #weight
            self.W_s = nn.Linear(hidden_dim, hidden_dim)
            self.v = nn.Linear(hidden_dim, 1, bias=False)

    def forward(self, s_t, prev_s):
        '''Perform intra_decoder attention
        Args
        :param s_t: hidden state of decoder at current time step
        :param prev_s: If intra_decoder attention, contains list of previous decoder hidden states
        '''
        if intra_decoder is False:
            ct_d = get_cuda(T.zeros(s_t.size()))
        elif prev_s is None:
            ct_d = get_cuda(T.zeros(s_t.size()))
            prev_s = s_t.unsqueeze(1)               #bs, 1, n_hid
        else:
            # Standard attention technique (eq 1 in https://arxiv.org/pdf/1704.04368.pdf)      e = v tanh(Wh + Ws + b)
            et = self.W_prev(prev_s)                # bs,t-1,n_hid
            dec_fea = self.W_s(s_t).unsqueeze(1)    # bs,1,n_hid
            et = et + dec_fea
            et = T.tanh(et)                         # bs,t-1,n_hid
            et = self.v(et).squeeze(2)              # bs,t-1
            # intra-decoder attention     (eq 7 & 8 in https://arxiv.org/pdf/1705.04304.pdf)
            at = F.softmax(et, dim=1).unsqueeze(1)  #bs, 1, t-1,  alpha
            ct_d = T.bmm(at, prev_s).squeeze(1)     #bs, n_hid
            prev_s = T.cat([prev_s, s_t.unsqueeze(1)], dim=1)    #bs, t, n_hid, keep adding previous hidden state 

        return ct_d, prev_s                 #decoder context vector, previous decoder hidden states


class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.enc_attention = encoder_attention()
        self.dec_attention = decoder_attention()
        self.x_context = nn.Linear(hidden_dim*2 + emb_dim, emb_dim)

        self.lstm = nn.LSTMCell(emb_dim, hidden_dim)
        init_lstm_wt(self.lstm)

        self.p_gen_linear = nn.Linear(hidden_dim * 5 + emb_dim, 1)

        #p_vocab
        self.V = nn.Linear(hidden_dim*4, hidden_dim)
        self.V1 = nn.Linear(hidden_dim, vocab_size)
        init_linear_wt(self.V1)

    def forward(self, x_t, s_t, enc_out, enc_padding_mask, ct_e, extra_zeros, enc_batch_extend_vocab, sum_temporal_srcs, prev_s):
        x = self.x_context(T.cat([x_t, ct_e], dim=1))
        s_t = self.lstm(x, s_t)

        dec_h, dec_c = s_t
        st_hat = T.cat([dec_h, dec_c], dim=1)
        ct_e, attn_dist, sum_temporal_srcs = self.enc_attention(st_hat, enc_out, enc_padding_mask, sum_temporal_srcs)

        ct_d, prev_s = self.dec_attention(dec_h, prev_s)        #intra-decoder attention

        p_gen = T.cat([ct_e, ct_d, st_hat, x], 1)
        p_gen = self.p_gen_linear(p_gen)            # bs,1
        p_gen = T.sigmoid(p_gen)                    # bs,1

        out = T.cat([dec_h, ct_e, ct_d], dim=1)     # bs, 4*n_hid
        out = self.V(out)                           # bs,n_hid
        out = self.V1(out)                          # bs, n_vocab
        vocab_dist = F.softmax(out, dim=1)
        vocab_dist = p_gen * vocab_dist
        attn_dist_ = (1 - p_gen) * attn_dist

        # pointer mechanism (as suggested in eq 9 https://arxiv.org/pdf/1704.04368.pdf)
        if extra_zeros is not None:
            vocab_dist = T.cat([vocab_dist, extra_zeros], dim=1)
        final_dist = vocab_dist.scatter_add(1, enc_batch_extend_vocab, attn_dist_)

        return final_dist, s_t, ct_e, sum_temporal_srcs, prev_s



class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()
        self.embeds = nn.Embedding(vocab_size, emb_dim)
        init_wt_normal(self.embeds.weight)

        self.encoder = get_cuda(self.encoder)
        self.decoder = get_cuda(self.decoder)
        self.embeds = get_cuda(self.embeds)





#8.train Util

In [0]:
import numpy as np
import torch as T


def get_cuda(tensor):
    if T.cuda.is_available():
        tensor = tensor.cuda()
    return tensor

def get_enc_data(batch):
    batch_size = len(batch.enc_lens)
    enc_batch = T.from_numpy(batch.enc_batch).long()
    enc_padding_mask = T.from_numpy(batch.enc_padding_mask).float()

    enc_lens = batch.enc_lens

    ct_e = T.zeros(batch_size, 2*   hidden_dim)   #config.hidden_dim

    enc_batch = get_cuda(enc_batch)
    enc_padding_mask = get_cuda(enc_padding_mask)

    ct_e = get_cuda(ct_e)

    enc_batch_extend_vocab = None
    if batch.enc_batch_extend_vocab is not None:
        enc_batch_extend_vocab = T.from_numpy(batch.enc_batch_extend_vocab).long()
        enc_batch_extend_vocab = get_cuda(enc_batch_extend_vocab)

    extra_zeros = None
    if batch.max_art_oovs > 0:
        extra_zeros = T.zeros(batch_size, batch.max_art_oovs)
        extra_zeros = get_cuda(extra_zeros)


    return enc_batch, enc_lens, enc_padding_mask, enc_batch_extend_vocab, extra_zeros, ct_e


def get_dec_data(batch):
    dec_batch = T.from_numpy(batch.dec_batch).long()
    dec_lens = batch.dec_lens
    max_dec_len = np.max(dec_lens)
    dec_lens = T.from_numpy(batch.dec_lens).float()

    target_batch = T.from_numpy(batch.target_batch).long()

    dec_batch = get_cuda(dec_batch)
    dec_lens = get_cuda(dec_lens)
    target_batch = get_cuda(target_batch)

    return dec_batch, max_dec_len, dec_lens, target_batch

#8. Training (ROUGE reward function for RL training)

In [0]:
random.seed(123)
T.manual_seed(123)
if T.cuda.is_available():
    T.cuda.manual_seed_all(123)

class Train(object):
    def __init__(self, opt):
        self.vocab = Vocab(vocab_path, vocab_size)
        self.batcher = Batcher(train_data_path, self.vocab, mode='train',
                               batch_size=batch_size, single_pass=False)            #config
        self.opt = opt
        self.start_id = self.vocab.word2id(START_DECODING)          #data.
        self.end_id = self.vocab.word2id(STOP_DECODING)
        self.pad_id = self.vocab.word2id(PAD_TOKEN)
        self.unk_id = self.vocab.word2id(UNKNOWN_TOKEN)
        time.sleep(5)

    def save_model(self, iter):
        save_path = new_save_model_path + "/%07d.tar" % iter
        T.save({
            "iter": iter + 1,
            "model_dict": self.model.state_dict(),
            "trainer_dict": self.trainer.state_dict()
        }, save_path)

    def setup_train(self):
        self.model = Model()
        self.model = get_cuda(self.model)
        self.trainer = T.optim.Adam(self.model.parameters(), lr=lr)
        start_iter = 0
        if self.opt.load_model is not None:
            load_model_path = os.path.join(save_model_path, self.opt.load_model)
            checkpoint = T.load(load_model_path)
            start_iter = checkpoint["iter"]
            self.model.load_state_dict(checkpoint["model_dict"])
            self.trainer.load_state_dict(checkpoint["trainer_dict"])
            print("Loaded model at " + load_model_path)
        if self.opt.new_lr is not None:
            self.trainer = T.optim.Adam(self.model.parameters(), lr=self.opt.new_lr)
        return start_iter

    def train_batch_MLE(self, enc_out, enc_hidden, enc_padding_mask, ct_e, extra_zeros, enc_batch_extend_vocab, batch):
        ''' Calculate Negative Log Likelihood Loss for the given batch. In order to reduce exposure bias,
                pass the previous generated token as input with a probability of 0.25 instead of ground truth label
        Args:
        :param enc_out: Outputs of the encoder for all time steps (batch_size, length_input_sequence, 2*hidden_size)
        :param enc_hidden: Tuple containing final hidden state & cell state of encoder. Shape of h & c: (batch_size, hidden_size)
        :param enc_padding_mask: Mask for encoder input; Tensor of size (batch_size, length_input_sequence) with values of 0 for pad tokens & 1 for others
        :param ct_e: encoder context vector for time_step=0 (eq 5 in https://arxiv.org/pdf/1705.04304.pdf)
        :param extra_zeros: Tensor used to extend vocab distribution for pointer mechanism
        :param enc_batch_extend_vocab: Input batch that stores OOV ids
        :param batch: batch object
        '''
        dec_batch, max_dec_len, dec_lens, target_batch = get_dec_data(batch)                        #Get input and target batchs for training decoder
        step_losses = []
        s_t = (enc_hidden[0], enc_hidden[1])                                                        #Decoder hidden states
        x_t = get_cuda(T.LongTensor(len(enc_out)).fill_(self.start_id))                             #Input to the decoder
        prev_s = None                                                                               #Used for intra-decoder attention (section 2.2 in https://arxiv.org/pdf/1705.04304.pdf)
        sum_temporal_srcs = None                                                                    #Used for intra-temporal attention (section 2.1 in https://arxiv.org/pdf/1705.04304.pdf)
        for t in range(min(max_dec_len, max_dec_steps)):
            use_gound_truth = get_cuda((T.rand(len(enc_out)) > 0.25)).long()                        #Probabilities indicating whether to use ground truth labels instead of previous decoded tokens
            x_t = use_gound_truth * dec_batch[:, t] + (1 - use_gound_truth) * x_t                   #Select decoder input based on use_ground_truth probabilities
            x_t = self.model.embeds(x_t)
            final_dist, s_t, ct_e, sum_temporal_srcs, prev_s = self.model.decoder(x_t, s_t, enc_out, enc_padding_mask, ct_e, extra_zeros, enc_batch_extend_vocab, sum_temporal_srcs, prev_s)
            target = target_batch[:, t]
            log_probs = T.log(final_dist + eps)
            step_loss = F.nll_loss(log_probs, target, reduction="none", ignore_index=self.pad_id)
            step_losses.append(step_loss)
            x_t = T.multinomial(final_dist, 1).squeeze()                                            #Sample words from final distribution which can be used as input in next time step
            is_oov = (x_t >= vocab_size).long()                                              #Mask indicating whether sampled word is OOV
            x_t = (1 - is_oov) * x_t.detach() + (is_oov) * self.unk_id                              #Replace OOVs with [UNK] token

        losses = T.sum(T.stack(step_losses, 1), 1)                                                  #unnormalized losses for each example in the batch; (batch_size)
        batch_avg_loss = losses / dec_lens                                                          #Normalized losses; (batch_size)
        mle_loss = T.mean(batch_avg_loss)                                                           #Average batch loss
        return mle_loss

    def train_batch_RL(self, enc_out, enc_hidden, enc_padding_mask, ct_e, extra_zeros, enc_batch_extend_vocab, article_oovs, greedy):
        '''Generate sentences from decoder entirely using sampled tokens as input. These sentences are used for ROUGE evaluation
        Args
        :param enc_out: Outputs of the encoder for all time steps (batch_size, length_input_sequence, 2*hidden_size)
        :param enc_hidden: Tuple containing final hidden state & cell state of encoder. Shape of h & c: (batch_size, hidden_size)
        :param enc_padding_mask: Mask for encoder input; Tensor of size (batch_size, length_input_sequence) with values of 0 for pad tokens & 1 for others
        :param ct_e: encoder context vector for time_step=0 (eq 5 in https://arxiv.org/pdf/1705.04304.pdf)
        :param extra_zeros: Tensor used to extend vocab distribution for pointer mechanism
        :param enc_batch_extend_vocab: Input batch that stores OOV ids
        :param article_oovs: Batch containing list of OOVs in each example
        :param greedy: If true, performs greedy based sampling, else performs multinomial sampling
        Returns:
        :decoded_strs: List of decoded sentences
        :log_probs: Log probabilities of sampled words
        '''
        s_t = enc_hidden                                                                            #Decoder hidden states
        x_t = get_cuda(T.LongTensor(len(enc_out)).fill_(self.start_id))                             #Input to the decoder
        prev_s = None                                                                               #Used for intra-decoder attention (section 2.2 in https://arxiv.org/pdf/1705.04304.pdf)
        sum_temporal_srcs = None                                                                    #Used for intra-temporal attention (section 2.1 in https://arxiv.org/pdf/1705.04304.pdf)
        inds = []                                                                                   #Stores sampled indices for each time step
        decoder_padding_mask = []                                                                   #Stores padding masks of generated samples
        log_probs = []                                                                              #Stores log probabilites of generated samples
        mask = get_cuda(T.LongTensor(len(enc_out)).fill_(1))                                        #Values that indicate whether [STOP] token has already been encountered; 1 => Not encountered, 0 otherwise

        for t in range(max_dec_steps):
            x_t = self.model.embeds(x_t)
            probs, s_t, ct_e, sum_temporal_srcs, prev_s = self.model.decoder(x_t, s_t, enc_out, enc_padding_mask, ct_e, extra_zeros, enc_batch_extend_vocab, sum_temporal_srcs, prev_s)
            if greedy is False:
                multi_dist = Categorical(probs)
                x_t = multi_dist.sample()                                                           #perform multinomial sampling
                log_prob = multi_dist.log_prob(x_t)
                log_probs.append(log_prob)
            else:
                _, x_t = T.max(probs, dim=1)                                                        #perform greedy sampling
            x_t = x_t.detach()
            inds.append(x_t)
            mask_t = get_cuda(T.zeros(len(enc_out)))                                                #Padding mask of batch for current time step
            mask_t[mask == 1] = 1                                                                   #If [STOP] is not encountered till previous time step, mask_t = 1 else mask_t = 0
            mask[(mask == 1) + (x_t == self.end_id) == 2] = 0                                       #If [STOP] is not encountered till previous time step and current word is [STOP], make mask = 0
            decoder_padding_mask.append(mask_t)
            is_oov = (x_t>=vocab_size).long()                                                #Mask indicating whether sampled word is OOV
            x_t = (1-is_oov)*x_t + (is_oov)*self.unk_id                                             #Replace OOVs with [UNK] token

        inds = T.stack(inds, dim=1)
        decoder_padding_mask = T.stack(decoder_padding_mask, dim=1)
        if greedy is False:                                                                         #If multinomial based sampling, compute log probabilites of sampled words
            log_probs = T.stack(log_probs, dim=1)
            log_probs = log_probs * decoder_padding_mask                                            #Not considering sampled words with padding mask = 0
            lens = T.sum(decoder_padding_mask, dim=1)                                               #Length of sampled sentence
            log_probs = T.sum(log_probs, dim=1) / lens  # (bs,)                                     #compute normalizied log probability of a sentence
        decoded_strs = []
        for i in range(len(enc_out)):
            id_list = inds[i].cpu().numpy()
            oovs = article_oovs[i]
            S = outputids2words(id_list, self.vocab, oovs)                                     #Generate sentence corresponding to sampled words
            try:
                end_idx = S.index(STOP_DECODING)
                S = S[:end_idx]
            except ValueError:
                S = S
            if len(S) < 2:                                                                           #If length of sentence is less than 2 words, replace it with "xxx"; Avoids setences like "." which throws error while calculating ROUGE
                S = ["xxx"]
            S = " ".join(S)
            decoded_strs.append(S)

        return decoded_strs, log_probs

    def reward_function(self, decoded_sents, original_sents):
        rouge = Rouge()
        try:
            scores = rouge.get_scores(decoded_sents, original_sents)
        except Exception:
            print("Rouge failed for multi sentence evaluation.. Finding exact pair")
            scores = []
            for i in range(len(decoded_sents)):
                try:
                    score = rouge.get_scores(decoded_sents[i], original_sents[i])
                except Exception:
                    print("Error occured at:")
                    print("decoded_sents:", decoded_sents[i])
                    print("original_sents:", original_sents[i])
                    score = [{"rouge-l":{"f":0.0}}]
                scores.append(score[0])
        rouge_l_f1 = [score["rouge-l"]["f"] for score in scores]
        rouge_l_f1 = get_cuda(T.FloatTensor(rouge_l_f1))
        return rouge_l_f1

    # def write_to_file(self, decoded, max, original, sample_r, baseline_r, iter):
    #     with open("temp.txt", "w") as f:
    #         f.write("iter:"+str(iter)+"\n")
    #         for i in range(len(original)):
    #             f.write("dec: "+decoded[i]+"\n")
    #             f.write("max: "+max[i]+"\n")
    #             f.write("org: "+original[i]+"\n")
    #             f.write("Sample_R: %.4f, Baseline_R: %.4f\n\n"%(sample_r[i].item(), baseline_r[i].item()))


    def train_one_batch(self, batch, iter):
        enc_batch, enc_lens, enc_padding_mask, enc_batch_extend_vocab, extra_zeros, context = get_enc_data(batch)

        enc_batch = self.model.embeds(enc_batch)                                                    #Get embeddings for encoder input
        enc_out, enc_hidden = self.model.encoder(enc_batch, enc_lens)

        # -------------------------------Summarization-----------------------
        if self.opt.train_mle == "yes":                                                             #perform MLE training
            mle_loss = self.train_batch_MLE(enc_out, enc_hidden, enc_padding_mask, context, extra_zeros, enc_batch_extend_vocab, batch)
        else:
            mle_loss = get_cuda(T.FloatTensor([0]))
        # --------------RL training-----------------------------------------------------
        if self.opt.train_rl == "yes":                                                              #perform reinforcement learning training
            # multinomial sampling
            sample_sents, RL_log_probs = self.train_batch_RL(enc_out, enc_hidden, enc_padding_mask, context, extra_zeros, enc_batch_extend_vocab, batch.art_oovs, greedy=False)
            with T.autograd.no_grad():
                # greedy sampling
                greedy_sents, _ = self.train_batch_RL(enc_out, enc_hidden, enc_padding_mask, context, extra_zeros, enc_batch_extend_vocab, batch.art_oovs, greedy=True)

            sample_reward = self.reward_function(sample_sents, batch.original_abstracts)
            baseline_reward = self.reward_function(greedy_sents, batch.original_abstracts)
            # if iter%200 == 0:
            #     self.write_to_file(sample_sents, greedy_sents, batch.original_abstracts, sample_reward, baseline_reward, iter)
            rl_loss = -(sample_reward - baseline_reward) * RL_log_probs                             #Self-critic policy gradient training (eq 15 in https://arxiv.org/pdf/1705.04304.pdf)
            rl_loss = T.mean(rl_loss)

            batch_reward = T.mean(sample_reward).item()
        else:
            rl_loss = get_cuda(T.FloatTensor([0]))
            batch_reward = 0

    # ------------------------------------------------------------------------------------
        self.trainer.zero_grad()
        (self.opt.mle_weight * mle_loss + self.opt.rl_weight * rl_loss).backward()
        self.trainer.step()

        return mle_loss.item(), batch_reward

    def trainIters(self):                                       #training step
        iter = self.setup_train()
        count = mle_total = r_total = 0
        while iter <= max_iterations:
            batch = self.batcher.next_batch()
            try:
                mle_loss, r = self.train_one_batch(batch, iter)
            except KeyboardInterrupt:
                print("-------------------Keyboard Interrupt------------------")
                sys.exit

            mle_total += mle_loss
            r_total += r
            count += 1
            iter += 1

            if iter % 200 == 0:
                mle_avg = mle_total / count
                r_avg = r_total / count
                print("iter:", iter, "mle_loss:", "%.3f" % mle_avg, "reward:", "%.4f" % r_avg)
                count = mle_total = r_total = 0

            if iter % 1000 == 0:
                self.save_model(iter)



#9. Training (BERTScore)

In [0]:
import os
#os.environ["CUDA_VISIBLE_DEVICES"] = "0"    #Set cuda device
import sys
import time
import gc

import torch as T
import torch.nn as nn
import torch.nn.functional as F
#from model import Model

#from data_util import config, data
#from data_util.batcher import Batcher
#from data_util.data import Vocab
#from train_util import *
from torch.distributions import Categorical
import bert_score
#from rouge import Rouge
from numpy import random
import argparse

random.seed(123)
T.manual_seed(123)
if T.cuda.is_available():
    T.cuda.manual_seed_all(123)

class Train(object):
    def __init__(self, opt):
        self.vocab = Vocab(vocab_path, vocab_size)
        self.batcher = Batcher(train_data_path, self.vocab, mode='train',
                               batch_size=batch_size, single_pass=False)            #config
        self.opt = opt
        self.start_id = self.vocab.word2id(START_DECODING)          #data.
        self.end_id = self.vocab.word2id(STOP_DECODING)
        self.pad_id = self.vocab.word2id(PAD_TOKEN)
        self.unk_id = self.vocab.word2id(UNKNOWN_TOKEN)
        time.sleep(5)

    def save_model(self, iter):
        save_path = new_save_model_path + "/%07d.tar" % iter
        T.save({
            "iter": iter + 1,
            "model_dict": self.model.state_dict(),
            "trainer_dict": self.trainer.state_dict()
        }, save_path)

    def setup_train(self):
        self.model = Model()
        self.model = get_cuda(self.model)
        self.trainer = T.optim.Adam(self.model.parameters(), lr=lr)
        start_iter = 0
        if self.opt.load_model is not None:
            load_model_path = os.path.join(save_model_path, self.opt.load_model)
            checkpoint = T.load(load_model_path)
            start_iter = checkpoint["iter"]
            self.model.load_state_dict(checkpoint["model_dict"])
            self.trainer.load_state_dict(checkpoint["trainer_dict"])
            print("Loaded model at " + load_model_path)
        if self.opt.new_lr is not None:
            self.trainer = T.optim.Adam(self.model.parameters(), lr=self.opt.new_lr)
        return start_iter

    def train_batch_MLE(self, enc_out, enc_hidden, enc_padding_mask, ct_e, extra_zeros, enc_batch_extend_vocab, batch):
        ''' Calculate Negative Log Likelihood Loss for the given batch. In order to reduce exposure bias,
                pass the previous generated token as input with a probability of 0.25 instead of ground truth label
        Args:
        :param enc_out: Outputs of the encoder for all time steps (batch_size, length_input_sequence, 2*hidden_size)
        :param enc_hidden: Tuple containing final hidden state & cell state of encoder. Shape of h & c: (batch_size, hidden_size)
        :param enc_padding_mask: Mask for encoder input; Tensor of size (batch_size, length_input_sequence) with values of 0 for pad tokens & 1 for others
        :param ct_e: encoder context vector for time_step=0 (eq 5 in https://arxiv.org/pdf/1705.04304.pdf)
        :param extra_zeros: Tensor used to extend vocab distribution for pointer mechanism
        :param enc_batch_extend_vocab: Input batch that stores OOV ids
        :param batch: batch object
        '''
        dec_batch, max_dec_len, dec_lens, target_batch = get_dec_data(batch)                        #Get input and target batchs for training decoder
        step_losses = []
        s_t = (enc_hidden[0], enc_hidden[1])                                                        #Decoder hidden states
        x_t = get_cuda(T.LongTensor(len(enc_out)).fill_(self.start_id))                             #Input to the decoder
        prev_s = None                                                                               #Used for intra-decoder attention (section 2.2 in https://arxiv.org/pdf/1705.04304.pdf)
        sum_temporal_srcs = None                                                                    #Used for intra-temporal attention (section 2.1 in https://arxiv.org/pdf/1705.04304.pdf)
        for t in range(min(max_dec_len, max_dec_steps)):
            #use_gound_truth = get_cuda((T.rand(len(enc_out)) > 0.25)).long()                        #Probabilities indicating whether to use ground truth labels instead of previous decoded tokens
            #x_t = use_gound_truth * dec_batch[:, t] + (1 - use_gound_truth) * x_t                   #Select decoder input based on use_ground_truth probabilities
            x_t = dec_batch[:, t]       #here we use purely ML

            x_t = self.model.embeds(x_t)
            final_dist, s_t, ct_e, sum_temporal_srcs, prev_s = self.model.decoder(x_t, s_t, enc_out, enc_padding_mask, ct_e, extra_zeros, enc_batch_extend_vocab, sum_temporal_srcs, prev_s)
            target = target_batch[:, t]
            log_probs = T.log(final_dist + eps)
            step_loss = F.nll_loss(log_probs, target, reduction="none", ignore_index=self.pad_id)
            step_losses.append(step_loss)
            x_t = T.multinomial(final_dist, 1).squeeze()                                            #Sample words from final distribution which can be used as input in next time step
            is_oov = (x_t >= vocab_size).long()                                              #Mask indicating whether sampled word is OOV
            x_t = (1 - is_oov) * x_t.detach() + (is_oov) * self.unk_id                              #Replace OOVs with [UNK] token

        losses = T.sum(T.stack(step_losses, 1), 1)                                                  #unnormalized losses for each example in the batch; (batch_size)
        batch_avg_loss = losses / dec_lens                                                          #Normalized losses; (batch_size)
        mle_loss = T.mean(batch_avg_loss)                                                           #Average batch loss
        return mle_loss

    def train_batch_RL(self, enc_out, enc_hidden, enc_padding_mask, ct_e, extra_zeros, enc_batch_extend_vocab, article_oovs, greedy):
        '''Generate sentences from decoder entirely using sampled tokens as input. These sentences are used for ROUGE evaluation
        Args
        :param enc_out: Outputs of the encoder for all time steps (batch_size, length_input_sequence, 2*hidden_size)
        :param enc_hidden: Tuple containing final hidden state & cell state of encoder. Shape of h & c: (batch_size, hidden_size)
        :param enc_padding_mask: Mask for encoder input; Tensor of size (batch_size, length_input_sequence) with values of 0 for pad tokens & 1 for others
        :param ct_e: encoder context vector for time_step=0 (eq 5 in https://arxiv.org/pdf/1705.04304.pdf)
        :param extra_zeros: Tensor used to extend vocab distribution for pointer mechanism
        :param enc_batch_extend_vocab: Input batch that stores OOV ids
        :param article_oovs: Batch containing list of OOVs in each example
        :param greedy: If true, performs greedy based sampling, else performs multinomial sampling
        Returns:
        :decoded_strs: List of decoded sentences
        :log_probs: Log probabilities of sampled words
        '''
        s_t = enc_hidden                                                                            #Decoder hidden states
        x_t = get_cuda(T.LongTensor(len(enc_out)).fill_(self.start_id))                             #Input to the decoder
        prev_s = None                                                                               #Used for intra-decoder attention (section 2.2 in https://arxiv.org/pdf/1705.04304.pdf)
        sum_temporal_srcs = None                                                                    #Used for intra-temporal attention (section 2.1 in https://arxiv.org/pdf/1705.04304.pdf)
        inds = []                                                                                   #Stores sampled indices for each time step
        decoder_padding_mask = []                                                                   #Stores padding masks of generated samples
        log_probs = []                                                                              #Stores log probabilites of generated samples
        mask = get_cuda(T.LongTensor(len(enc_out)).fill_(1))                                        #Values that indicate whether [STOP] token has already been encountered; 1 => Not encountered, 0 otherwise

        for t in range(max_dec_steps):
            x_t = self.model.embeds(x_t)
            probs, s_t, ct_e, sum_temporal_srcs, prev_s = self.model.decoder(x_t, s_t, enc_out, enc_padding_mask, ct_e, extra_zeros, enc_batch_extend_vocab, sum_temporal_srcs, prev_s)
            if greedy is False:
                multi_dist = Categorical(probs)
                x_t = multi_dist.sample()                                                           #perform multinomial sampling
                log_prob = multi_dist.log_prob(x_t)
                log_probs.append(log_prob)
            else:
                _, x_t = T.max(probs, dim=1)                                                        #perform greedy sampling
            x_t = x_t.detach()
            inds.append(x_t)
            mask_t = get_cuda(T.zeros(len(enc_out)))                                                #Padding mask of batch for current time step
            mask_t[mask == 1] = 1                                                                   #If [STOP] is not encountered till previous time step, mask_t = 1 else mask_t = 0
            mask[(mask == 1) + (x_t == self.end_id) == 2] = 0                                       #If [STOP] is not encountered till previous time step and current word is [STOP], make mask = 0
            decoder_padding_mask.append(mask_t)
            is_oov = (x_t>=vocab_size).long()                                                #Mask indicating whether sampled word is OOV
            x_t = (1-is_oov)*x_t + (is_oov)*self.unk_id                                             #Replace OOVs with [UNK] token

        inds = T.stack(inds, dim=1)
        decoder_padding_mask = T.stack(decoder_padding_mask, dim=1)
        if greedy is False:                                                                         #If multinomial based sampling, compute log probabilites of sampled words
            log_probs = T.stack(log_probs, dim=1)
            log_probs = log_probs * decoder_padding_mask                                            #Not considering sampled words with padding mask = 0
            lens = T.sum(decoder_padding_mask, dim=1)                                               #Length of sampled sentence
            log_probs = T.sum(log_probs, dim=1) / lens  # (bs,)                                     #compute normalizied log probability of a sentence
        decoded_strs = []
        for i in range(len(enc_out)):
            id_list = inds[i].cpu().numpy()
            oovs = article_oovs[i]
            S = outputids2words(id_list, self.vocab, oovs)                                     #Generate sentence corresponding to sampled words
            try:
                end_idx = S.index(STOP_DECODING)
                S = S[:end_idx]
            except ValueError:
                S = S
            if len(S) < 2:                                                                           #If length of sentence is less than 2 words, replace it with "xxx"; Avoids setences like "." which throws error while calculating ROUGE
                S = ["xxx"]
            S = " ".join(S)
            decoded_strs.append(S)

        return decoded_strs, log_probs

    def reward_function(self, decoded_sents, original_sents):
        #rouge = Rouge()

        _,_,scores = bert_score.score(decoded_sents, original_sents, lang = 'en', verbose = False, batch_size = 50)
        gc.collect()


            #scores = T.tensor([0.])
            #for i in range(len(decoded_sents)):
             #   print('processing {}th reward...'.format(i+1))
              #  _,_, score = bert_score.score([decoded_sents[i]], [original_sents[i]], lang = 'en', verbose = False)
               # scores = scores + score
                #gc.collect()
        bertscore_f = get_cuda(scores)
        #bertscore_f = get_cuda(scores/len(decoded_sents))          #the return of bert-score.score is a tensor, 
                                               #Each Tensor has the same number of items with the candidate and reference lists. 
        print('pass reward function')                                        #Each item in the list is a scalar, representing the score for the corresponding candidates and references.
        return bertscore_f


    # def write_to_file(self, decoded, max, original, sample_r, baseline_r, iter):
    #     with open("temp.txt", "w") as f:
    #         f.write("iter:"+str(iter)+"\n")
    #         for i in range(len(original)):
    #             f.write("dec: "+decoded[i]+"\n")
    #             f.write("max: "+max[i]+"\n")
    #             f.write("org: "+original[i]+"\n")
    #             f.write("Sample_R: %.4f, Baseline_R: %.4f\n\n"%(sample_r[i].item(), baseline_r[i].item()))


    def train_one_batch(self, batch, iter):
        enc_batch, enc_lens, enc_padding_mask, enc_batch_extend_vocab, extra_zeros, context = get_enc_data(batch)

        enc_batch = self.model.embeds(enc_batch)                                                    #Get embeddings for encoder input
        enc_out, enc_hidden = self.model.encoder(enc_batch, enc_lens)

        # -------------------------------Summarization-----------------------
        if self.opt.train_mle == "yes":                                                             #perform MLE training
            mle_loss = self.train_batch_MLE(enc_out, enc_hidden, enc_padding_mask, context, extra_zeros, enc_batch_extend_vocab, batch)
        else:
            mle_loss = get_cuda(T.FloatTensor([0]))
        # --------------RL training-----------------------------------------------------
        if self.opt.train_rl == "yes":                                                              #perform reinforcement learning training
            # multinomial sampling
            sample_sents, RL_log_probs = self.train_batch_RL(enc_out, enc_hidden, enc_padding_mask, context, extra_zeros, enc_batch_extend_vocab, batch.art_oovs, greedy=False)
            with T.autograd.no_grad():
                # greedy sampling
                greedy_sents, _ = self.train_batch_RL(enc_out, enc_hidden, enc_padding_mask, context, extra_zeros, enc_batch_extend_vocab, batch.art_oovs, greedy=True)

            sample_reward = self.reward_function(sample_sents, batch.original_abstracts)
            baseline_reward = self.reward_function(greedy_sents, batch.original_abstracts)
            # if iter%200 == 0:
            #     self.write_to_file(sample_sents, greedy_sents, batch.original_abstracts, sample_reward, baseline_reward, iter)
            rl_loss = -(sample_reward - baseline_reward) * RL_log_probs                             #Self-critic policy gradient training (eq 15 in https://arxiv.org/pdf/1705.04304.pdf)
            rl_loss = T.mean(rl_loss)

            batch_reward = T.mean(sample_reward).item()
        else:
            rl_loss = get_cuda(T.FloatTensor([0]))
            batch_reward = 0

    # ------------------------------------------------------------------------------------
        self.trainer.zero_grad()
        (self.opt.mle_weight * mle_loss + self.opt.rl_weight * rl_loss).backward()
        self.trainer.step()

        return mle_loss.item(), batch_reward

    def trainIters(self):                                       #training step
        iter = self.setup_train()
        count = mle_total = r_total = 0
        reward_list = []
        while iter <= max_iterations:
            batch = self.batcher.next_batch()
            try:
                mle_loss, r = self.train_one_batch(batch, iter)
            except KeyboardInterrupt:
                print("-------------------Keyboard Interrupt------------------")
                sys.exit()

            mle_total += mle_loss
            r_total += r
            count += 1
            iter += 1

            mle_avg = mle_total / count
            r_avg = r_total / count
            print("iter:", iter, "mle_loss:", "%.3f" % mle_avg, "reward:", "%.4f" % r_avg)
            count = mle_total = r_total = 0

            reward_list.append(r_avg)

            if iter % 100 == 0:
                self.save_model(iter)
        return reward_list


#10.Beam search

In [0]:
import numpy as np
import torch as T


class Beam(object):
    def __init__(self, start_id, end_id, unk_id, hidden_state, context):
        h,c = hidden_state                                              #(n_hid,)
        self.tokens = T.LongTensor(beam_size,1).fill_(start_id)  #(beam, t) after t time steps
        self.scores = T.FloatTensor(beam_size,1).fill_(-30)      #beam,1; Initial score of beams = -30
        self.tokens, self.scores = get_cuda(self.tokens), get_cuda(self.scores)
        self.scores[0][0] = 0                                           #At time step t=0, all beams should extend from a single beam. So, I am giving high initial score to 1st beam
        self.hid_h = h.unsqueeze(0).repeat(beam_size, 1)         #beam, n_hid
        self.hid_c = c.unsqueeze(0).repeat(beam_size, 1)         #beam, n_hid
        self.context = context.unsqueeze(0).repeat(beam_size, 1) #beam, 2*n_hid
        self.sum_temporal_srcs = None
        self.prev_s = None
        self.done = False
        self.end_id = end_id
        self.unk_id = unk_id

    def get_current_state(self):
        tokens = self.tokens[:,-1].clone()
        for i in range(len(tokens)):
            if tokens[i].item() >= vocab_size:
                tokens[i] = self.unk_id
        return tokens


    def advance(self, prob_dist, hidden_state, context, sum_temporal_srcs, prev_s):
        '''Perform beam search: Considering the probabilites of given n_beam x n_extended_vocab words, select first n_beam words that give high total scores
        :param prob_dist: (beam, n_extended_vocab)
        :param hidden_state: Tuple of (beam, n_hid) tensors
        :param context:   (beam, 2*n_hidden)
        :param sum_temporal_srcs:   (beam, n_seq)
        :param prev_s:  (beam, t, n_hid)
        '''
        n_extended_vocab = prob_dist.size(1)
        h, c = hidden_state
        log_probs = T.log(prob_dist+eps)                         #beam, n_extended_vocab

        scores = log_probs + self.scores                                #beam, n_extended_vocab
        scores = scores.view(-1,1)                                      #beam*n_extended_vocab, 1
        best_scores, best_scores_id = T.topk(input=scores, k=beam_size, dim=0)   #will be sorted in descending order of scores
        self.scores = best_scores                                       #(beam,1); sorted
        beams_order = best_scores_id.squeeze(1)/n_extended_vocab        #(beam,); sorted
        best_words = best_scores_id%n_extended_vocab                    #(beam,1); sorted
        self.hid_h = h[beams_order]                                     #(beam, n_hid); sorted
        self.hid_c = c[beams_order]                                     #(beam, n_hid); sorted
        self.context = context[beams_order]
        if sum_temporal_srcs is not None:
            self.sum_temporal_srcs = sum_temporal_srcs[beams_order]     #(beam, n_seq); sorted
        if prev_s is not None:
            self.prev_s = prev_s[beams_order]                           #(beam, t, n_hid); sorted
        self.tokens = self.tokens[beams_order]                          #(beam, t); sorted
        self.tokens = T.cat([self.tokens, best_words], dim=1)           #(beam, t+1); sorted

        #End condition is when top-of-beam is EOS.
        if best_words[0][0] == self.end_id:
            self.done = True

    def get_best(self):
        best_token = self.tokens[0].cpu().numpy().tolist()              #Since beams are always in sorted (descending) order, 1st beam is the best beam
        try:
            end_idx = best_token.index(self.end_id)
        except ValueError:
            end_idx = len(best_token)
        best_token = best_token[1:end_idx]
        return best_token

    def get_all(self):
        all_tokens = []
        for i in range(len(self.tokens)):
            all_tokens.append(self.tokens[i].cpu().numpy())
        return all_tokens


def beam_search(enc_hid, enc_out, enc_padding_mask, ct_e, extra_zeros, enc_batch_extend_vocab, model, start_id, end_id, unk_id):

    batch_size = len(enc_hid[0])
    beam_idx = T.LongTensor(list(range(batch_size)))
    beams = [Beam(start_id, end_id, unk_id, (enc_hid[0][i], enc_hid[1][i]), ct_e[i]) for i in range(batch_size)]   #For each example in batch, create Beam object
    n_rem = batch_size                                                  #Index of beams that are active, i.e: didn't generate [STOP] yet
    sum_temporal_srcs = None                                            #Number of examples in batch that didn't generate [STOP] yet
    prev_s = None

    for t in range(max_dec_steps):
        x_t = T.stack(
            [beam.get_current_state() for beam in beams if beam.done == False]      #remaining(rem),beam
        ).contiguous().view(-1)                                                     #(rem*beam,)
        x_t = model.embeds(x_t)                                                 #rem*beam, n_emb

        dec_h = T.stack(
            [beam.hid_h for beam in beams if beam.done == False]                    #rem*beam,n_hid
        ).contiguous().view(-1,hidden_dim)
        dec_c = T.stack(
            [beam.hid_c for beam in beams if beam.done == False]                    #rem,beam,n_hid
        ).contiguous().view(-1,hidden_dim)                                   #rem*beam,n_hid

        ct_e = T.stack(
            [beam.context for beam in beams if beam.done == False]                  #rem,beam,n_hid
        ).contiguous().view(-1,2*hidden_dim)                                 #rem,beam,n_hid

        if sum_temporal_srcs is not None:
            sum_temporal_srcs = T.stack(
                [beam.sum_temporal_srcs for beam in beams if beam.done == False]
            ).contiguous().view(-1, enc_out.size(1))                                #rem*beam, n_seq

        if prev_s is not None:
            prev_s = T.stack(
                [beam.prev_s for beam in beams if beam.done == False]
            ).contiguous().view(-1, t, hidden_dim)                           #rem*beam, t-1, n_hid


        s_t = (dec_h, dec_c)
        enc_out_beam = enc_out[beam_idx].view(n_rem,-1).repeat(1, beam_size).view(-1, enc_out.size(1), enc_out.size(2))
        enc_pad_mask_beam = enc_padding_mask[beam_idx].repeat(1, beam_size).view(-1, enc_padding_mask.size(1))

        extra_zeros_beam = None
        if extra_zeros is not None:
            extra_zeros_beam = extra_zeros[beam_idx].repeat(1, beam_size).view(-1, extra_zeros.size(1))
        enc_extend_vocab_beam = enc_batch_extend_vocab[beam_idx].repeat(1, beam_size).view(-1, enc_batch_extend_vocab.size(1))

        final_dist, (dec_h, dec_c), ct_e, sum_temporal_srcs, prev_s = model.decoder(x_t, s_t, enc_out_beam, enc_pad_mask_beam, ct_e, extra_zeros_beam, enc_extend_vocab_beam, sum_temporal_srcs, prev_s)              #final_dist: rem*beam, n_extended_vocab

        final_dist = final_dist.view(n_rem, beam_size, -1)                   #final_dist: rem, beam, n_extended_vocab
        dec_h = dec_h.view(n_rem, beam_size, -1)                             #rem, beam, n_hid
        dec_c = dec_c.view(n_rem, beam_size, -1)                             #rem, beam, n_hid
        ct_e = ct_e.view(n_rem, beam_size, -1)                             #rem, beam, 2*n_hid

        if sum_temporal_srcs is not None:
            sum_temporal_srcs = sum_temporal_srcs.view(n_rem, beam_size, -1) #rem, beam, n_seq

        if prev_s is not None:
            prev_s = prev_s.view(n_rem, beam_size, -1, hidden_dim)    #rem, beam, t

        # For all the active beams, perform beam search
        active = []         #indices of active beams after beam search

        for i in range(n_rem):
            b = beam_idx[i].item()
            beam = beams[b]
            if beam.done:
                continue

            sum_temporal_srcs_i = prev_s_i = None
            if sum_temporal_srcs is not None:
                sum_temporal_srcs_i = sum_temporal_srcs[i]                              #beam, n_seq
            if prev_s is not None:
                prev_s_i = prev_s[i]                                                #beam, t, n_hid
            beam.advance(final_dist[i], (dec_h[i], dec_c[i]), ct_e[i], sum_temporal_srcs_i, prev_s_i)
            if beam.done == False:
                active.append(b)

        if len(active) == 0:
            break

        beam_idx = T.LongTensor(active)
        n_rem = len(beam_idx)

    predicted_words = []
    for beam in beams:
        predicted_words.append(beam.get_best())

    return predicted_words

#11. Evaluation and testing class using ROUGE

In [0]:
'''
import os

import time

import torch as T
import torch.nn as nn
import torch.nn.functional as F


from rouge import Rouge
import argparse
'''

def get_cuda(tensor):
    if T.cuda.is_available():
        tensor = tensor.cuda()
    return tensor

class Evaluate(object):
    def __init__(self, data_path, opt, batch_size = batch_size):
        self.vocab = Vocab(vocab_path, vocab_size)
        self.batcher = Batcher(data_path, self.vocab, mode='eval',
                               batch_size=batch_size, single_pass=True)
        self.opt = opt
        time.sleep(5)

    def setup_valid(self):
        self.model = Model()
        self.model = get_cuda(self.model)
        checkpoint = T.load(os.path.join(save_model_path, self.opt.load_model))
        self.model.load_state_dict(checkpoint["model_dict"])


    def print_original_predicted(self, decoded_sents, ref_sents, article_sents, loadfile):
        filename = "test_"+loadfile.split(".")[0]+".txt"
    
        with open(os.path.join(save_example_path,filename), "w") as f:
            for i in range(len(decoded_sents)):
                f.write("article: "+article_sents[i] + "\n")
                f.write("ref: " + ref_sents[i] + "\n")
                f.write("dec: " + decoded_sents[i] + "\n\n")

    def evaluate_batch(self, print_sents = False):

        self.setup_valid()
        batch = self.batcher.next_batch()
        start_id = self.vocab.word2id(START_DECODING)
        end_id = self.vocab.word2id(STOP_DECODING)
        unk_id = self.vocab.word2id(UNKNOWN_TOKEN)
        decoded_sents = []
        ref_sents = []
        article_sents = []
        rouge = Rouge()
        while batch is not None:
            enc_batch, enc_lens, enc_padding_mask, enc_batch_extend_vocab, extra_zeros, ct_e = get_enc_data(batch)

            with T.autograd.no_grad():
                enc_batch = self.model.embeds(enc_batch)
                enc_out, enc_hidden = self.model.encoder(enc_batch, enc_lens)

            #-----------------------Summarization----------------------------------------------------
            with T.autograd.no_grad():
                pred_ids = beam_search(enc_hidden, enc_out, enc_padding_mask, ct_e, extra_zeros, enc_batch_extend_vocab, self.model, start_id, end_id, unk_id)

            for i in range(len(pred_ids)):
                decoded_words = outputids2words(pred_ids[i], self.vocab, batch.art_oovs[i])
                if len(decoded_words) < 2:
                    decoded_words = "xxx"
                else:
                    decoded_words = " ".join(decoded_words)
                decoded_sents.append(decoded_words)
                abstract = batch.original_abstracts[i]
                article = batch.original_articles[i]
                ref_sents.append(abstract)
                article_sents.append(article)

            batch = self.batcher.next_batch()

        load_file = self.opt.load_model

        if print_sents:
            self.print_original_predicted(decoded_sents, ref_sents, article_sents, load_file)

        scores = rouge.get_scores(decoded_sents, ref_sents, avg = True)
        if self.opt.task == "test":
            print(load_file, "scores:", scores)
        else:
            rouge_l = scores["rouge-l"]["f"]
            print(load_file, "rouge_l:", "%.4f" % rouge_l)
            print('Scores:', scores)





#12. Evaluation and testing using BERTScore

In [0]:
import os

import time

import torch as T
import torch.nn as nn
import torch.nn.functional as F

import gc
import bert_score
import argparse

def get_cuda(tensor):
    if T.cuda.is_available():
        tensor = tensor.cuda()
    return tensor

class Evaluate(object):
    def __init__(self, data_path, opt, batch_size = batch_size):
        self.vocab = Vocab(vocab_path, vocab_size)
        self.batcher = Batcher(data_path, self.vocab, mode='eval',
                               batch_size=batch_size, single_pass=True)
        self.opt = opt

        time.sleep(5)

    def setup_valid(self):
        self.model = Model()
        self.model = get_cuda(self.model)
        checkpoint = T.load(os.path.join(save_model_path, self.opt.load_model))
        self.model.load_state_dict(checkpoint["model_dict"])


    def print_original_predicted(self, decoded_sents, ref_sents, article_sents, loadfile):
        filename = "test_"+loadfile.split(".")[0]+".txt"
    
        with open(os.path.join(save_example_path,filename), "w") as f:
            for i in range(len(decoded_sents)):
                f.write("article: "+article_sents[i] + "\n")
                f.write("ref: " + ref_sents[i] + "\n")
                f.write("dec: " + decoded_sents[i] + "\n\n")

    def evaluate_batch(self, print_sents = False):

        self.setup_valid()
        batch = self.batcher.next_batch()
        start_id = self.vocab.word2id(START_DECODING)
        end_id = self.vocab.word2id(STOP_DECODING)
        unk_id = self.vocab.word2id(UNKNOWN_TOKEN)
        decoded_sents = []
        ref_sents = []
        article_sents = []
        #rouge = Rouge()
        while batch is not None:
            enc_batch, enc_lens, enc_padding_mask, enc_batch_extend_vocab, extra_zeros, ct_e = get_enc_data(batch)

            with T.autograd.no_grad():
                enc_batch = self.model.embeds(enc_batch)
                enc_out, enc_hidden = self.model.encoder(enc_batch, enc_lens)

            #-----------------------Summarization----------------------------------------------------
            with T.autograd.no_grad():
                pred_ids = beam_search(enc_hidden, enc_out, enc_padding_mask, ct_e, extra_zeros, enc_batch_extend_vocab, self.model, start_id, end_id, unk_id)

            for i in range(len(pred_ids)):
                decoded_words = outputids2words(pred_ids[i], self.vocab, batch.art_oovs[i])
                if len(decoded_words) < 2:
                    decoded_words = "xxx"
                else:
                    decoded_words = " ".join(decoded_words)
                decoded_sents.append(decoded_words)
                abstract = batch.original_abstracts[i]
                article = batch.original_articles[i]
                ref_sents.append(abstract)
                article_sents.append(article)

            batch = self.batcher.next_batch()

        load_file = self.opt.load_model

        _,_,f = bert_score.score(decoded_sents, ref_sents, lang='en', verbose = False)

        print('{}  score:{}'.format(load_file, f.mean()))
        gc.collect()
        
        if print_sents:
            self.print_original_predicted(decoded_sents, ref_sents, article_sents, load_file)


#13. Model with local attention

In [0]:
import torch as T
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

import torch.nn.functional as F


def init_lstm_wt(lstm):
    for name, _ in lstm.named_parameters():
        if 'weight' in name:
            wt = getattr(lstm, name)
            #print(wt)
            wt.data = wt.data.uniform_(-rand_unif_init_mag, rand_unif_init_mag)
            #wt.uniform_(-rand_unif_init_mag, rand_unif_init_mag)     #commit
        elif 'bias' in name:
            # set forget bias to 1
            bias = getattr(lstm, name)
            n = bias.size(0)
            start, end = n // 4, n // 2

            bias.data = bias.data.fill_(0.)
            #bias.fill_(0.)
            bias.data[start:end].fill_(1.)

def init_linear_wt(linear):

    linear.weight.data = linear.weight.data.normal_(std=trunc_norm_init_std)
    #linear.weight.normal_(std=trunc_norm_init_std)
    if linear.bias is not None:
        linear.bias.data = linear.bias.data.normal_(std=trunc_norm_init_std)
        #linear.bias.normal_(std=trunc_norm_init_std)

def init_wt_normal(wt):
    wt.data = wt.data.normal_(std=trunc_norm_init_std)
    #wt.normal_(std=trunc_norm_init_std)


class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()

        self.lstm = nn.LSTM(emb_dim, hidden_dim, num_layers=1, batch_first=True, bidirectional=True)    #batch_true: If True, then the input and output tensors are provided as (batch, seq, feature)
        init_lstm_wt(self.lstm)

        self.reduce_h = nn.Linear(hidden_dim * 2, hidden_dim)        #hidden state 
        init_linear_wt(self.reduce_h)
        self.reduce_c = nn.Linear(hidden_dim * 2, hidden_dim)
        init_linear_wt(self.reduce_c)

    def forward(self, x, seq_lens):
        packed = pack_padded_sequence(x, seq_lens, batch_first=True)    #accepts any input that has at least two dimensions. You can apply it to pack the labels, and use the output of the RNN with them to compute the loss directly
        enc_out, enc_hid = self.lstm(packed)                            # tensor containing the hidden state for t = seq_len. and tensor containing the cell state for t = seq_len.
        enc_out,_ = pad_packed_sequence(enc_out, batch_first=True)      
        enc_out = enc_out.contiguous()                              #bs, n_seq, 2*n_hid
        h, c = enc_hid                                              #shape of h: 2, bs, n_hid
        h = T.cat(list(h), dim=1)                                   #bs, 2*n_hid
        c = T.cat(list(c), dim=1)
        h_reduced = F.relu(self.reduce_h(h))                        #bs,n_hid
        c_reduced = F.relu(self.reduce_c(c))
        return enc_out, (h_reduced, c_reduced)

'''
class encoder_attention(nn.Module):

    def __init__(self):
        super(encoder_attention, self).__init__()
        self.W_h = nn.Linear(hidden_dim * 2, hidden_dim * 2, bias=False)    #no bias just the weight matrix
        self.W_s = nn.Linear(hidden_dim * 2, hidden_dim * 2)
        self.v = nn.Linear(hidden_dim * 2, 1, bias=False)

        self.W_p = nn.Linear(hidden_dim * 2, hidden_dim * 2, bias = False)
        self.v_p = nn.Linear(hidden_dim * 2, 1, bias = False)

    def forward(self, st_hat, h, enc_padding_mask, sum_temporal_srcs, local_attention_d=None):              #INTRA-TEMPORAL ATTENTION ON INPUT SEQUENCE in paper?

        Perform attention over encoder hidden states
        :param st_hat: decoder hidden state at current time step        (batch_size, 2*hidden_size)
        :param h: encoder hidden states     (batch_size, length_input_sequence, 2*hidden_size), length_input is fixed = max_enc_length
        :param enc_padding_mask:
        :param sum_temporal_srcs: if using intra-temporal attention, contains summation of attention weights from previous decoder time steps
        :param local_attention_d: if not None, implement local attention with d instead (eq 9&10 in https://arxiv.org/pdf/1508.04025.pdf)


        if local_attention_d is not None:       #perform local  attention
            d = local_attention_d
            temp_1 = self.W_p(st_hat)           #bs, 2*hid
            temp_2 = T.tanh(temp_1)
            temp_3 = self.v_p(temp_2)           #bs,1
            temp_4 = T.sigmoid(temp_3).squeeze(1)

            with T.no_grad():
                S = get_cuda(T.zeros(st_hat.shape[0]))        #bs
                for i in range(enc_padding_mask.shape[0]):
                    temp_len = len(enc_padding_mask[i].nonzero())
                    S[i] = temp_len

            pt = S * temp_4     #bs
            #print('pt before:', pt)
            #print('pt',pt)
            #construct padding matrix
            with T.no_grad():
                pt_padding = get_cuda(T.zeros(h.shape[0], h.shape[1]))
                for j in range(pt.shape[0]):
                    if int(pt[j])-d < 0:
                        pt_padding[j, : int(pt[j])+d].fill_(1.)
                    elif int(pt[j])+d >= h.shape[1]:
                        pt_padding[j, int(pt[j])-d:].fill_(1.)
                    else:
                        pt_padding[j, int(pt[j])-d: int(pt[j])+d].fill_(1.)
            
            #print('pt after:', pt_padding)
        #print('size of decoder hiddent state:', st_hat.size())
        # Standard global attention technique (eq 1 in https://arxiv.org/pdf/1704.04368.pdf)
        et = self.W_h(h)                        # bs,n_seq,2*n_hid
        dec_fea = self.W_s(st_hat).unsqueeze(1) # bs,1,2*n_hid
        et = et + dec_fea
        et = T.tanh(et)                         # bs,n_seq,2*n_hid
        et = self.v(et).squeeze(2)              # bs,n_seq


        # intra-temporal attention     (eq 3 in https://arxiv.org/pdf/1705.04304.pdf)
        if intra_encoder:
            exp_et = T.exp(et)
            if local_attention_d is not None:
                exp_et = exp_et * pt_padding        #zero-out the entry outside the window
                exp_et = exp_et/ T.sum(exp_et, axis = 1).unsqueeze(1)   #normalised within the window

            if sum_temporal_srcs is None:
                et1 = exp_et
                sum_temporal_srcs  = get_cuda(T.FloatTensor(et.size()).fill_(1e-10)) + exp_et           #defined temporal score
            else:
                et1 = exp_et/sum_temporal_srcs  #bs, n_seq
                sum_temporal_srcs = sum_temporal_srcs + exp_et          #the sum_temporal_src is also returned so that it can be reuse for following encoder attention
        else:
            et1 = F.softmax(et, dim=1)
            if local_attention_d is not None:
                et1 = et1 * pt_padding
                et1 = et1/T.sum(et1, axis = 1).unsqueeze(1)


        # assign 0 probability for padded elements
        at = et1 * enc_padding_mask
        normalization_factor = at.sum(1, keepdim=True)
        at = at / normalization_factor

        #gaussian normalisation




        at = at.unsqueeze(1)                    #bs,1,n_seq
        # Compute encoder context vector
        ct_e = T.bmm(at, h)                     #bs, 1, 2*n_hid, batch matrix-matrix product of matrices (temporal score and encoder hidden state)
        ct_e = ct_e.squeeze(1)
        at = at.squeeze(1)
        #print('at',at)
        return ct_e, at, sum_temporal_srcs
'''

class encoder_attention(nn.Module):

    def __init__(self):
        super(encoder_attention, self).__init__()
        self.W_h = nn.Linear(hidden_dim * 2, hidden_dim * 2, bias=False)    #no bias just the weight matrix
        self.W_s = nn.Linear(hidden_dim * 2, hidden_dim * 2)
        self.v = nn.Linear(hidden_dim * 2, 1, bias=False)

        self.W_p = nn.Linear(hidden_dim * 2, hidden_dim * 2, bias = False)
        self.v_p = nn.Linear(hidden_dim * 2, 1, bias = False)

    def forward(self, st_hat, h, enc_padding_mask, sum_temporal_srcs, local_attention_d=None):              #INTRA-TEMPORAL ATTENTION ON INPUT SEQUENCE in paper?
        '''
        Perform attention over encoder hidden states
        :param st_hat: decoder hidden state at current time step        (batch_size, 2*hidden_size)
        :param h: encoder hidden states     (batch_size, length_input_sequence, 2*hidden_size), length_input is fixed = max_enc_length
        :param enc_padding_mask:
        :param sum_temporal_srcs: if using intra-temporal attention, contains summation of attention weights from previous decoder time steps
        :param local_attention_d: if not None, implement local attention with d instead (eq 9&10 in https://arxiv.org/pdf/1508.04025.pdf)
        '''
        bs = st_hat.shape[0]

        if local_attention_d is not None:       #perform local  attention
            d = local_attention_d               #window size
            std_squared = (d/2)**2              # sigma^2
            window_length = 2 * d + 1           #window length
            temp_1 = self.W_p(st_hat)           #bs, 2*hid
            temp_2 = T.tanh(temp_1)
            temp_3 = self.v_p(temp_2)           #bs,1
            temp_4 = T.sigmoid(temp_3).squeeze(1)  #probability of position

            #retrieve length vector from enc_padding_mask
            S = get_cuda(T.zeros(bs))        #bs
            for i in range(bs):
                temp_len = len(enc_padding_mask[i].nonzero())
                S[i] = temp_len
            #print('size of S', S.shape)
            #print('size of temp_4', temp_4.shape)
            pt = S * temp_4     #vector with size [bs], representing the center of window for each batch
            #print('pt before:', pt)
            #print('pt',pt)

            
            #construct padding matrix and gaussian matrix
            pt_padding = get_cuda(T.zeros(h.shape[0], h.shape[1]))    #[batch_size, max_enc_length], use T.empty cause cuda
            position = get_cuda(T.zeros(pt_padding.shape))
            window_start = T.round(pt - d).int()
            window_end = window_start + window_length
            
            for i in range(bs):
                start = window_start[i].item()
                end = window_end[i].item()
                if start < 0:
                    pt_padding[i, : end].fill_(1.)
                    position[i, : end] = get_cuda(T.arange(end, dtype = T.float))    #matrix of index
                elif end >= h.shape[1]:     #need to check the index
                    pt_padding[i, start:].fill_(1.)
                    position[i, start:] = get_cuda(T.arange(start, h.shape[1], dtype = T.float))
                else:
                    pt_padding[i, start: end].fill_(1.)
                    position[i, start:end] = get_cuda(T.arange(start, end, dtype = T.float))
            
            gaussian = T.exp(-(position - pt.unsqueeze(1))**2 / (2*std_squared))
            gaussian = gaussian * pt_padding            #zero-out the entry outside the window
            #gaussian = gaussian * enc_padding_mask      #zero-out the padding entry
            


            
            #print('pt after:', pt_padding)
        #print('size of decoder hiddent state:', st_hat.size())
        # Standard global attention technique (eq 1 in https://arxiv.org/pdf/1704.04368.pdf)
        et = self.W_h(h)                        # bs,n_seq,2*n_hid
        dec_fea = self.W_s(st_hat).unsqueeze(1) # bs,1,2*n_hid
        et = et + dec_fea
        et = T.tanh(et)                         # bs,n_seq,2*n_hid
        et = self.v(et).squeeze(2)              # bs,n_seq


        # intra-temporal attention     (eq 3 in https://arxiv.org/pdf/1705.04304.pdf)
        if intra_encoder:
            exp_et = T.exp(et)
            if local_attention_d is not None:
                #print('dimension of exp_et', exp_et.shape)
                #print('pt_padding dimension:', pt_padding.shape)
                exp_et = exp_et * pt_padding        #zero-out the entry outside the window
                #print('pass this')
            if sum_temporal_srcs is None:
                et1 = exp_et
                sum_temporal_srcs  = get_cuda(T.FloatTensor(et.size()).fill_(1e-10)) + exp_et           #defined temporal score
            else:
                et1 = exp_et/sum_temporal_srcs  #bs, n_seq
                sum_temporal_srcs = sum_temporal_srcs + exp_et          #the sum_temporal_src is also returned so that it can be reuse for following encoder attention
        else:
            et1 = F.softmax(et, dim=1)
            if local_attention_d is not None:
                et1 = et1 * pt_padding
        #print('et1 type', type(et1))
        #print('padding type', type(pt_padding))
        #print('pd-padding dimenson', pt_padding.shape)
        #print('enc_padding_mask dim', enc_padding_mask.shape)
        #et1 = et1 * pt_padding
        # assign 0 probability for padded elements
        #print('et1',et1)
        #print('padding', pt_padding)
        at = et1 * enc_padding_mask
        normalization_factor = at.sum(1, keepdim=True)
        at = at / normalization_factor
        at = at * gaussian

        at = at.unsqueeze(1)                    #bs,1,n_seq
        # Compute encoder context vector
        ct_e = T.bmm(at, h)                     #bs, 1, 2*n_hid, batch matrix-matrix product of matrices (temporal score and encoder hidden state)
        ct_e = ct_e.squeeze(1)
        at = at.squeeze(1)
        #print('at',at)
        return ct_e, at, sum_temporal_srcs






class decoder_attention(nn.Module):
    def __init__(self):
        super(decoder_attention, self).__init__()
        if intra_decoder:
            self.W_prev = nn.Linear(hidden_dim, hidden_dim, bias=False)        #weight
            self.W_s = nn.Linear(hidden_dim, hidden_dim)
            self.v = nn.Linear(hidden_dim, 1, bias=False)

    def forward(self, s_t, prev_s):
        '''Perform intra_decoder attention
        Args
        :param s_t: hidden state of decoder at current time step
        :param prev_s: If intra_decoder attention, contains list of previous decoder hidden states
        '''
        if intra_decoder is False:
            ct_d = get_cuda(T.zeros(s_t.size()))
        elif prev_s is None:
            ct_d = get_cuda(T.zeros(s_t.size()))
            prev_s = s_t.unsqueeze(1)               #bs, 1, n_hid
        else:
            # Standard attention technique (eq 1 in https://arxiv.org/pdf/1704.04368.pdf)      e = v tanh(Wh + Ws + b)
            et = self.W_prev(prev_s)                # bs,t-1,n_hid
            dec_fea = self.W_s(s_t).unsqueeze(1)    # bs,1,n_hid
            et = et + dec_fea
            et = T.tanh(et)                         # bs,t-1,n_hid
            et = self.v(et).squeeze(2)              # bs,t-1
            # intra-decoder attention     (eq 7 & 8 in https://arxiv.org/pdf/1705.04304.pdf)
            at = F.softmax(et, dim=1).unsqueeze(1)  #bs, 1, t-1,  alpha
            ct_d = T.bmm(at, prev_s).squeeze(1)     #bs, n_hid
            prev_s = T.cat([prev_s, s_t.unsqueeze(1)], dim=1)    #bs, t, n_hid, keep adding previous hidden state 

        return ct_d, prev_s                 #decoder context vector, previous decoder hidden states


class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.enc_attention = encoder_attention()
        self.dec_attention = decoder_attention()
        self.x_context = nn.Linear(hidden_dim*2 + emb_dim, emb_dim)

        self.lstm = nn.LSTMCell(emb_dim, hidden_dim)
        init_lstm_wt(self.lstm)

        self.p_gen_linear = nn.Linear(hidden_dim * 5 + emb_dim, 1)

        #p_vocab
        self.V = nn.Linear(hidden_dim*4, hidden_dim)
        self.V1 = nn.Linear(hidden_dim, vocab_size)
        init_linear_wt(self.V1)

    def forward(self, x_t, s_t, enc_out, enc_padding_mask, ct_e, extra_zeros, enc_batch_extend_vocab, sum_temporal_srcs, prev_s):
        x = self.x_context(T.cat([x_t, ct_e], dim=1))
        s_t = self.lstm(x, s_t)

        dec_h, dec_c = s_t
        st_hat = T.cat([dec_h, dec_c], dim=1)
        ct_e, attn_dist, sum_temporal_srcs = self.enc_attention(st_hat, enc_out, enc_padding_mask, sum_temporal_srcs, local_attention_d)

        ct_d, prev_s = self.dec_attention(dec_h, prev_s)        #intra-decoder attention

        p_gen = T.cat([ct_e, ct_d, st_hat, x], 1)
        p_gen = self.p_gen_linear(p_gen)            # bs,1
        p_gen = T.sigmoid(p_gen)                    # bs,1

        out = T.cat([dec_h, ct_e, ct_d], dim=1)     # bs, 4*n_hid
        out = self.V(out)                           # bs,n_hid
        out = self.V1(out)                          # bs, n_vocab
        vocab_dist = F.softmax(out, dim=1)
        vocab_dist = p_gen * vocab_dist
        attn_dist_ = (1 - p_gen) * attn_dist

        # pointer mechanism (as suggested in eq 9 https://arxiv.org/pdf/1704.04368.pdf)
        if extra_zeros is not None:
            vocab_dist = T.cat([vocab_dist, extra_zeros], dim=1)
        final_dist = vocab_dist.scatter_add(1, enc_batch_extend_vocab, attn_dist_)

        return final_dist, s_t, ct_e, sum_temporal_srcs, prev_s



class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()
        self.embeds = nn.Embedding(vocab_size, emb_dim)
        init_wt_normal(self.embeds.weight)

        self.encoder = get_cuda(self.encoder)
        self.decoder = get_cuda(self.decoder)
        self.embeds = get_cuda(self.embeds)

#14. Experiment

Procedure for each dataset:


1.   ML pre-training
2.   validating the pre-trained models and select the one with best score
3.   Using the same pre-trained model and train ML, ML+RL, RL(r), RL(b)
4.   Validating the models
5.   Model testing


## Gigaword

###Hyperparameter

In [0]:
train_data_path = 	"giga/chunked/train/train_*"            #training data file
valid_data_path = 	"giga/chunked/valid/valid_00.bin"       #valdiation data file
test_data_path = 	"giga/chunked/test/test_00.bin"         #testing data file
vocab_path = 		"giga/vocab"                            #vocab data file


# Hyperparameters
hidden_dim = 512
emb_dim = 256
batch_size = 200
max_enc_steps = 55
max_dec_steps = 15
beam_size = 4
min_dec_steps= 3
vocab_size = 50000

lr = 0.001
rand_unif_init_mag = 0.02
trunc_norm_init_std = 1e-4

eps = 1e-12
max_iterations = 16000
max_batch_queue = 1000          #1000


save_model_path = 'RL_BERT'#'giga_pre_model'#'giga_pre_model'  #"giga_pre_model"      #!!
new_save_model_path = "RL_BERT"#'giga_ML'
save_example_path = 'example'

#new_save_model_path = 'new_models'

intra_encoder = True
intra_decoder = True

###Pre-training

In [0]:
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--train_mle', type=str, default="yes")
    parser.add_argument('--train_rl', type=str, default="no")
    parser.add_argument('--mle_weight', type=float, default=1.0)
    parser.add_argument('--load_model', type=str, default=None)
    parser.add_argument('--new_lr', type=float, default=None)

    opt, unknown = parser.parse_known_args()
    #opt = parser.parse_args()
    opt.rl_weight = 1 - opt.mle_weight
    print("Training mle: %s, Training rl: %s, mle weight: %.2f, rl weight: %.2f"%(opt.train_mle, opt.train_rl, opt.mle_weight, opt.rl_weight))
    print("intra_encoder:", intra_encoder, "intra_decoder:",intra_decoder)          #config

    train_processor = Train(opt)
    train_processor.trainIters()




11

###Validation of pre-trained model

In [0]:
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--task", type=str, default="validate", choices=["validate","test"])
    parser.add_argument("--start_from", type=str, default="0001000.tar")
    parser.add_argument("--load_model", type=str, default=None)
    #opt = parser.parse_args()
    opt, unknown = parser.parse_known_args()

    if opt.task == "validate":
        saved_models = os.listdir(save_model_path)
        saved_models.sort()
        file_idx = saved_models.index(opt.start_from)
        saved_models = saved_models[file_idx:]
        for f in saved_models:
            opt.load_model = f
            eval_processor = Evaluate(valid_data_path, opt)
            eval_processor.evaluate_batch()
    else:   #test
        eval_processor = Evaluate(test_data_path, opt)
        eval_processor.evaluate_batch()

###Keep on ML training

In [0]:
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--train_mle', type=str, default="yes")
    parser.add_argument('--train_rl', type=str, default="no")
    parser.add_argument('--mle_weight', type=float, default=1.)
    parser.add_argument('--load_model', type=str, default='#######.tar')
    parser.add_argument('--new_lr', type=float, default=None)

    opt, unknown = parser.parse_known_args()
    #opt = parser.parse_args()
    opt.rl_weight = 1 - opt.mle_weight
    print("Training mle: %s, Training rl: %s, mle weight: %.2f, rl weight: %.2f"%(opt.train_mle, opt.train_rl, opt.mle_weight, opt.rl_weight))
    print("intra_encoder:", intra_encoder, "intra_decoder:",intra_decoder)          #config

    train_processor = Train(opt)
    train_processor.trainIters()




###KEEP on ML+RL training

In [0]:
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--train_mle', type=str, default="yes")
    parser.add_argument('--train_rl', type=str, default="yes")
    parser.add_argument('--mle_weight', type=float, default=0.5)
    parser.add_argument('--load_model', type=str, default='#####.tar')
    parser.add_argument('--new_lr', type=float, default=0.0001)

    opt, unknown = parser.parse_known_args()
    #opt = parser.parse_args()
    opt.rl_weight = 1 - opt.mle_weight
    print("Training mle: %s, Training rl: %s, mle weight: %.2f, rl weight: %.2f"%(opt.train_mle, opt.train_rl, opt.mle_weight, opt.rl_weight))
    print("intra_encoder:", intra_encoder, "intra_decoder:",intra_decoder)          #config

    train_processor = Train(opt)
    train_processor.trainIters()




###Keep on RL(r/b) training

In [0]:
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--train_mle', type=str, default="no")
    parser.add_argument('--train_rl', type=str, default="yes")
    parser.add_argument('--mle_weight', type=float, default=0.)
    parser.add_argument('--load_model', type=str, default='#####.tar')
    parser.add_argument('--new_lr', type=float, default=0.0001)

    opt, unknown = parser.parse_known_args()
    #opt = parser.parse_args()
    opt.rl_weight = 1 - opt.mle_weight
    print("Training mle: %s, Training rl: %s, mle weight: %.2f, rl weight: %.2f"%(opt.train_mle, opt.train_rl, opt.mle_weight, opt.rl_weight))
    print("intra_encoder:", intra_encoder, "intra_decoder:",intra_decoder)          #config

    train_processor = Train(opt)
    train_processor.trainIters()




###Validation

In [0]:
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--task", type=str, default="validate", choices=["validate","test"])
    parser.add_argument("--start_from", type=str, default="######.tar")
    parser.add_argument("--load_model", type=str, default=None)
    #opt = parser.parse_args()
    opt, unknown = parser.parse_known_args()

    if opt.task == "validate":
        saved_models = os.listdir(save_model_path)
        saved_models.sort()
        file_idx = saved_models.index(opt.start_from)
        saved_models = saved_models[file_idx:]
        for f in saved_models:
            opt.load_model = f
            eval_processor = Evaluate(valid_data_path, opt)
            eval_processor.evaluate_batch()
    else:   #test
        eval_processor = Evaluate(test_data_path, opt)
        eval_processor.evaluate_batch()

###Testing and print out examples

In [0]:
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--task", type=str, default="test", choices=["validate","test"])
    parser.add_argument("--start_from", type=str, default="######.tar")
    parser.add_argument("--load_model", type=str, default=None)
    #opt = parser.parse_args()
    opt, unknown = parser.parse_known_args()

    if opt.task == "validate":
        saved_models = os.listdir(save_model_path)
        saved_models.sort()
        file_idx = saved_models.index(opt.start_from)
        saved_models = saved_models[file_idx:]
        for f in saved_models:
            opt.load_model = f
            eval_processor = Evaluate(valid_data_path, opt)
            eval_processor.evaluate_batch()
    else:   #test
        eval_processor = Evaluate(test_data_path, opt)
        eval_processor.evaluate_batch(print=True)

##CNN/DM datasets

###Hyperparameter

In [0]:
train_data_path = 	"/content/drive/My Drive/NLPproject/CNNDM/finished_files/finished_files/chunked/train_*"               #this is where you save the training data
valid_data_path = 	"drive/My Drive/NLPProject/valid.bin"                           #this is where you save the validation data
test_data_path = 	"drive/My Drive/NLPProject/chunked/main_test/test_*"            #this is where you save the testing data
vocab_path = 		"drive/My Drive/NLPProject/vocab"                               #this is where you save the vocab file


# Hyperparameters
hidden_dim = 400
emb_dim = 200
batch_size = 50
max_enc_steps = 400		
max_dec_steps = 100		
beam_size = 5
min_dec_steps= 3
vocab_size = 30000 

lr = 0.001
rand_unif_init_mag = 0.02
trunc_norm_init_std = 1e-4

eps = 1e-12
max_iterations = 10000
max_batch_queue = 100


save_model_path = "drive/My Drive/NLPProject/CNNDM"                                 #this is where you save your pre-trained model
new_save_mode_path = '....'                                                         #this is where you save your new training model

intra_encoder = True
intra_decoder = True

##WikiHow dataset

###Hyperpatameter

In [0]:
train_data_path = '/datasets/wiki_train/train_*'
valid_data_path = 	"wiki_valid/valid_0000.bin"
test_data_path = 	"wiki_test/test_0000.bin"   
vocab_path = 		"wiki_vocab/vocab_wiki"


# Hyperparameters
hidden_dim = 512
emb_dim = 256
batch_size = 10
max_enc_steps = 500
max_dec_steps = 100
beam_size = 4
min_dec_steps= 3
vocab_size = 20000

lr = 0.001
rand_unif_init_mag = 0.02
trunc_norm_init_std = 1e-4

eps = 1e-12
max_iterations = 16000
max_batch_queue = 30 


save_model_path = 'wiki RL BERT'     #!!
save_example_path = 'example'
#new_save_model_path = 'wiki RL BERT'

intra_encoder = True
intra_decoder = True

###Local attention model hyperparameter

In [0]:
train_data_path =   '/datasets/wiki_train/train_*'
valid_data_path =  'wiki_valid/valid_0000.bin'
test_data_path = 'wiki_test/test_0004.bin'
vocab_path = 		'wiki_vocab/vocab_wiki'
   #this is the location of vocab                                                            #"drive/My Drive/WikiHow_bin/vocab_wiki"


# Hyperparameters
hidden_dim = 512
emb_dim = 256
batch_size =  10         #10
max_enc_steps = 500
max_dec_steps = 150
beam_size = 4
min_dec_steps= 3
vocab_size = 20000

lr = 0.001
rand_unif_init_mag = 0.02
trunc_norm_init_std = 1e-4

eps = 1e-12
max_iterations = 15000


save_model_path = 'local pre-train' #'local pre-train'  #this is the location of pre-trained model                                          #'drive/My Drive/RL_wiki_model'             #"drive/My Drive/wiki_model"
new_save_model_path = 'local RLb'  #this is the location you want your new model to be saved    #'drive/My Drive/RL_wiki_model'
save_example_path = 'example'

intra_encoder = False
intra_decoder = True
local_attention_d = 50  #or 100