<a href="https://colab.research.google.com/github/MorenoSara/Summarizing-Long-Form-Document-with-Rich-Discourse-Information/blob/main/Summarizing_Long_Form_Documents.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install rouge
!pip install dgl
!pip install dgl-cu101

In [None]:
import torch
from keras.preprocessing.text import Tokenizer
from rouge import Rouge
import numpy as np
import sys
import json
from keras.preprocessing.sequence import pad_sequences
from tqdm import tqdm
from torch.utils.data import Dataset
import re
from torch.utils.data import DataLoader
import torch
from torch import nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence, PackedSequence, pad_sequence
import torch.optim as optim
from collections import Counter
from itertools import combinations
import dgl
import torch.nn.init as init
import torch.nn.functional as F
import os

## Dataset

In [None]:
rouge = Rouge()

class Sentence():
  def __init__(self, sentence, tokenized_sentence):
    self.sentence = sentence 
    self.tokenized_sentence = tokenized_sentence
    self.label = 0

  def true_scoring(self, abstract):
    score = rouge.get_scores(self.sentence, abstract.all_sentences) # abstract is a section
    self.y_s = np.mean([score[0]['rouge-1']['f'], score[0]['rouge-2']['f'], score[0]['rouge-l']['f']])

  def __len__(self):
    return np.count_nonzero(self.tokenized_sentence)

  def set_label(self):
    self.label = 1


class Section():
  def __init__(self, title, sentences):
    self.sentences = sentences
    self.all_sentences = ' '.join([s.sentence for s in self.sentences])
    self.title = title

  def true_scoring(self, abstract):
    self.y_S = rouge.get_scores(self.all_sentences, abstract.all_sentences)[0]['rouge-2']['r']

  def __len__(self):
    return len(self.sentences)


class Document():
  def __init__(self, sections, abstract):
    self.sections = sections
    self.abstract = abstract

  def __len__(self):
    return len(self.sections)

  def __getitem__(self, index):
    return self.sections[index]

  def num_sentences(self):
    return sum([len(s) for s in self.sections])

In [None]:
class ScientificPapaerDataset(Dataset): # map-style dataset
    def __init__(self, datapath, max_sent_len, tokenizer):
      """
      ScientificPapaerDataset reads the data from the file and builds the data structure.

      :param datapath: path of the input document
      :param max_sent_len: length to which each sentence is padded/truncated
      :param tokenizer: tokenizer
      """ 
      '''
              format of each document:
              { 
                  'article_id': str,
                  'abstract_text': List[str],
                  'article_text': List[str],
                  'section_names': List[str],
                  'sections': List[List[str]]
              }
      '''
      super(ScientificPapaerDataset, self).__init__()

      data = []
      for line in open(datapath, 'r'):
        data.append(json.loads(line))

      for doc in tqdm(data, desc='Building tokenizer'):
        tokenizer.fit_on_texts(doc['section_names'])
        for sec in doc['sections']:
          tokenizer.fit_on_texts(sec)
        tokenizer.fit_on_texts(doc['abstract_text'])
        
      
      self.documents =[]
      for doc in tqdm(data, desc='Scanning documents and building data structure'):

        #Abstract
        sentences = []
        for abs_sent in doc['abstract_text']:
          sentences.append(Sentence(abs_sent, pad_sequences(tokenizer.texts_to_sequences([abs_sent]), max_sent_len, padding='post')[0]))
        abstract = Section(None, sentences)

        #Sections
        sections = []
        for sec_title, sec in zip(doc['section_names'], doc['sections']):
          title = Sentence(sec_title, pad_sequences(tokenizer.texts_to_sequences([sec_title]), max_sent_len, padding='post')[0])
          if len(title) > 0: # avoid sections without title
            if bool(re.match('^(?=.*[a-zA-Z])', sec[0])): # discard sections that start with an empty sentence
              sentences = []
              for sent in sec:
                sentence = Sentence(sent, pad_sequences(tokenizer.texts_to_sequences([sent]), max_sent_len, padding='post')[0])
                if len(sentence) > 0: # avoid empty sentences
                  sentence.true_scoring(abstract)
                  sentences.append(sentence)
              section = Section(Sentence(sec_title, pad_sequences(tokenizer.texts_to_sequences([sec_title]), max_sent_len, padding='post')[0]), sentences)
              section.true_scoring(abstract)
              sections.append(section)
        if len(sections) > 0: # avoid empty sections
          self.documents.append(Document(sections, abstract))

    def __getitem__(self, index):
      return self.documents[index]

    def __len__(self):
      return len(self.documents)


## DataLoader

In [None]:
class CRDataLoader(DataLoader):
  def __init__(self, dataset, batch_size, num_workers = 0, shuffle = True, max_sections_per_doc = 8, max_sentences_per_section = 18, max_sent_len = 36):
    """
    CRDataLoader builds the dataloader for the Content Ranking Module.

    :param dataset: dataset previously created
    :param batch_size: size of the batch
    :param num_workers: number of processes that generate batches in parallel.
    :param shuffle: True to permute the indices of all samples
    :param max_sections_per_doc: maximum number of sections contained in each doc
    :param max_sentences_per_section: maximum number of sentences in each section
    :param max_sent_len: maximum number of tokens in each sentence
    :return: 
    """ 
    super(CRDataLoader,self).__init__(dataset, batch_size, num_workers=num_workers, shuffle=shuffle, collate_fn = self.my_collate_fn)
    self.max_sent_len = max_sent_len
    self.max_sentences_per_section = max_sentences_per_section
    self.max_sections_per_doc = max_sections_per_doc


  def my_collate_fn(self, batch): # no default collate because batch must contain tensors, numpy arrays, numbers, dicts or lists;
    """
    my_collate_fn creates the data structure from each batch used in the training phase of the Content Ranking Module.

    :param batch: the batch recived in input that contains instances of the Document class
    :return dataset: a dictionary containg all the information used in the training
    :return sections_importances: a 2D matrix containg the ground truth labels for each section
    :return sentences_importances: a 3D matrix containg the ground truth labels for each sentence
    """ 
    dataset = {
        'section_titles': torch.zeros(len(batch),self.max_sections_per_doc,self.max_sent_len), # (doc, title, word)
        'section_texts': torch.zeros(len(batch),self.max_sections_per_doc,self.max_sentences_per_section,self.max_sent_len), # (doc, section, sentence, word)
        'sections_per_doc': torch.zeros(len(batch)), # (doc)
        'sentences_per_section': torch.zeros(len(batch), self.max_sections_per_doc), # (doc, section)
        'words_per_title': torch.zeros(len(batch), self.max_sections_per_doc), # (doc, section)
        'words_per_sentence': torch.zeros(len(batch),self.max_sections_per_doc,self.max_sentences_per_section) # (doc, section, sentence)
               }
    sections_importances = torch.zeros(len(batch),self.max_sections_per_doc) # (doc, section)
    sentences_importances = torch.zeros(len(batch),self.max_sections_per_doc,self.max_sentences_per_section) # (doc, section, sentence)

    for doc_id, doc in enumerate(batch):
      titles = [np.zeros(self.max_sent_len) for _ in range(self.max_sections_per_doc)] # shape (max_sections_per_doc, max_sent_len) => (title, word)
      sections = [[np.zeros(self.max_sent_len) for _ in range(self.max_sentences_per_section)] for _ in range(self.max_sections_per_doc)] # shape (max_sections_per_doc, max_sentences_per_section, max_sent_len) => (section, sentence, word)
      section_lengths = np.zeros(self.max_sections_per_doc)
      title_lengths = np.zeros(self.max_sections_per_doc)
      sentence_lengths =[np.zeros(self.max_sentences_per_section) for _ in range(self.max_sections_per_doc)]
  
      for sec_id, sec in enumerate(doc[:self.max_sections_per_doc]):
          if len(sec) > self.max_sentences_per_section: 
            section_lengths[sec_id] = self.max_sentences_per_section
          else :
            section_lengths[sec_id] = len(sec)
          if len(sec.title) > self.max_sentences_per_section: 
            title_lengths[sec_id] = self.max_sentences_per_section
          elif len(sec.title) == 0:
            continue
          else :
            title_lengths[sec_id] = len(sec.title)
          sections_importances[doc_id][sec_id] = torch.tensor(sec.y_S)
          titles[sec_id] = sec.title.tokenized_sentence
          sentences = [np.zeros(self.max_sent_len) for _ in range(self.max_sentences_per_section) ] # shape (max_sentences_per_section, max_sent_len) => (sentence, word)
          for sent_id, sent in enumerate(sec.sentences[:self.max_sentences_per_section]):
              sentences_importances[doc_id][sec_id][sent_id]= torch.tensor(sent.y_s)
              sentences[sent_id] = sent.tokenized_sentence
              if len(sent) > self.max_sent_len: 
                sentence_lengths[sec_id][sent_id] = self.max_sent_len
              else :
                sentence_lengths[sec_id][sent_id] = len(sent)
          sections[sec_id] = sentences
      titles = torch.tensor(titles)
      sections = torch.tensor(sections)
      dataset['section_titles'][doc_id] = titles
      dataset['section_texts'][doc_id] = sections
      if len(doc) > self.max_sections_per_doc: 
        dataset['sections_per_doc'][doc_id] = self.max_sections_per_doc
      else :
        dataset['sections_per_doc'][doc_id] = len(doc)
      dataset['sentences_per_section'][doc_id] = torch.tensor(section_lengths)
      dataset['words_per_title'][doc_id] = torch.tensor(title_lengths)
      dataset['words_per_sentence'][doc_id] = torch.tensor(sentence_lengths)
    return dataset, sections_importances, sentences_importances 

## Content ranking module

### Embedding layer

In [None]:
!wget http://nlp.stanford.edu/data/glove.6B.zip
!unzip -q glove.6B.zip

In [None]:
class embedding_layer(nn.Module):
  def __init__(self, tokenizer, embedding_dimension=300, device = 'cpu'):
    """
      embedding_layer creates the embeddings for the Content Ranking Module starting from pretrained GloVe embeddings.

      :param tokenizer: tokenizer previously instantiated
      :param embedding_dimension:  embedding dimension, possible values: {50, 100, 200, 300}
      :param device: cpu or cuda
      """ 
    super(embedding_layer, self).__init__()
    self.to(device)
    self.d = device
    embeddings_index = {}
    with open(f"glove.6B.{embedding_dimension}d.txt") as f:
        for line in f:
            word, coefs = line.split(maxsplit=1)
            coefs = np.fromstring(coefs, "f", sep=" ")
            embeddings_index[word] = coefs

    hits = 0
    misses = 0
    embedding_matrix = np.zeros((tokenizer.num_words, embedding_dimension))
    for word, i in tokenizer.word_index.items():
      embedding_vector = embeddings_index.get(word)
      if i < tokenizer.num_words:
        if embedding_vector is not None:
            # Words not found in embedding index will be all-zeros.
            embedding_matrix[i] = embedding_vector
            hits += 1
        else:
            misses += 1

    #print("Converted %d words (%d misses)" % (hits, misses)) # hits + misses = len(tokenizer.word_index)
    embedding_matrix = torch.Tensor(embedding_matrix)
    self.embedding = nn.Embedding(tokenizer.num_words, embedding_dimension)
    self.embedding.weight = nn.Parameter(embedding_matrix)
    self.embedding.weight.requires_grad = False

  def forward(self, item):
    item.to(self.d)
    return self.embedding(item.to(torch.int32)).float()

### Word Attention layer

In [None]:
class WordAttention(nn.Module):
  def __init__(self, tokenizer, device, embedding_dimension=300, hidden_size=512):
    super(WordAttention, self).__init__()
    """
    WordAttention performs the word level attention mechanism.

    :param tokenizer: tokenizer previously instantiated
    :param device: cpu or cuda
    :param embedding_dimension:  embedding dimension, possible values: {50, 100, 200, 300}, default 300
    :param hidden_size: hidden size used for bidirectional LSTM layer, default 512
    """ 
    self.embedder = embedding_layer(tokenizer, embedding_dimension)
    self.BiLSTM = nn.LSTM(input_size=embedding_dimension, hidden_size=hidden_size, batch_first=True, bidirectional=True) # output dimenstion = 2*hidden_size because of bidirectionality
    self.word_attention = nn.Linear(hidden_size*2, 2*hidden_size)
    self.word_context_vector = nn.Linear(2*hidden_size, 1, bias = False)
    self.to(device)

  def forward(self, sentences, words_per_sentence):
    """
    WordAttention forward creates the sentence embeddings.

    :param sentences: list of tokenized sentences
    :param words_per_sentence: list with the number of words in each sentence
    :return: sentence embeddings
    """
    sentences = self.embedder(sentences) # (num_sentences, max_num_word_per_sentence, emb_dim)
    packed_words = pack_padded_sequence(
        sentences,
        lengths= words_per_sentence.tolist(),
        batch_first=True,
        enforce_sorted=False
    ) # returns a PackedSequence object, where 'data' is the flattened words (n_words, word_emb)
    packed_words, _ = self.BiLSTM(packed_words) # returns a PackedSequence object, where 'data' is the output of the BiLSTM (n_words, 2 * hiddn_size)
    att_w = self.word_attention(packed_words.data)  # (n_words, att_size)
    
    att_w = torch.tanh(att_w)
    att_w = self.word_context_vector(att_w).squeeze(1)
    
    max_value = att_w.max()  # scalar, for numerical stability during exponent calculation
    att_w = torch.exp(att_w - max_value)  # (n_words)

    # Re-arrange as sentences by re-padding with 0s
    att_w, _ = pad_packed_sequence(PackedSequence(data=att_w,
                                                  batch_sizes=packed_words.batch_sizes,
                                                  sorted_indices=packed_words.sorted_indices,
                                                  unsorted_indices=packed_words.unsorted_indices),
                                    batch_first=True)  # (n_sentences, max(words_per_sentence))
          
    
    # Calculate softmax values as now words are arranged in their respective sentences
    word_alphas = att_w / torch.sum(att_w, dim=1, keepdim=True)  # (n_sentences, max(words_per_sentence))

    # Similarly re-arrange word-level BiLSTM outputs as sentences by re-padding with 0s
    sentences, _ = pad_packed_sequence(packed_words,
                                        batch_first=True)  # (n_sentences, max(words_per_sentence), 2 * hidden_size)

    sentences = sentences * word_alphas.unsqueeze(2)  # (n_sentences, max(words_per_sentence), 2 * hidden_size)
    sentences = sentences.sum(dim=1)  # (n_sentences, 2 * hidden_size)

    return sentences

### Sentence Attention layer
Returns section representations and sentence scores

In [None]:
class SentenceAttention(nn.Module):
  def __init__(self, tokenizer, device, embedding_dimension=300, hidden_size=512):
    """
    SentenceAttention performs the sentence level attention mechanism.

    :param tokenizer: tokenizer previously instantiated
    :param device: cpu or cuda
    :param embedding_dimension:  embedding dimension, possible values: {50, 100, 200, 300}, default 300
    :param hidden_size: hidden size used for bidirectional LSTM layer, default 512
    """ 
    super(SentenceAttention, self).__init__()
    self.word_attention = WordAttention(tokenizer, device)
    self.sentence_BiLSTM = nn.LSTM(input_size=2*hidden_size, hidden_size=hidden_size, batch_first=True, bidirectional=True)
    self.sentence_attention = nn.Linear(hidden_size*2, 2*hidden_size)
    self.sentence_context_vector = nn.Linear(2*hidden_size, 1, bias = False)
    self.sentence_scores_layer = nn.Linear(2*hidden_size, 1)
    self.to(device)

  def forward(self, data, sentences_per_sections, words_per_sentence):
    """
    SentenceAttention forward creates the section embeddings.

    :param data: matrix of tokens of each sentence, in each section
    :param sentences_per_sections: list with the number of sentences in each section
    :param words_per_sentence: list with the number of words in each sentence
    :return sections: section embeddings
    :return sentence_scores: a matrix with the ranking score assigned to each sentence for each section
    """
    # from (num_non_null_sections, max_num_sentences_per_section) to (num_non_null_sencenteces)
    packed_words_per_sent = pack_padded_sequence(words_per_sentence,
                                      lengths=list(filter(lambda a: a != 0, sentences_per_sections.tolist())), #sentences_per_sections.tolist(),
                                      batch_first=True,
                                      enforce_sorted=False)
    
    # from (num_non_null_sections, max_num_sentences_per_section, max_num_words_per_sentence) to (num_non_null_sentences, max_num_words_per_sentence)
    packed_sections_by_sentence = pack_padded_sequence(data,
                                      lengths=sentences_per_sections.tolist(),
                                      batch_first=True,
                                      enforce_sorted=False)
    
    sentences = self.word_attention(packed_sections_by_sentence.data, packed_words_per_sent.data)

    packed_sentences, _ = self.sentence_BiLSTM(PackedSequence(data=sentences,
                                                           batch_sizes=packed_sections_by_sentence.batch_sizes,
                                                           sorted_indices=packed_sections_by_sentence.sorted_indices,
                                                           unsorted_indices=packed_sections_by_sentence.unsorted_indices))
  
    # Find attention vectors by applying the attention linear layer on the output of the sentence BiLSTM
    att_s = self.sentence_attention(packed_sentences.data)  # (n_non_null_sentences, 2*hidden_size)
    att_s = torch.tanh(att_s)  # (n_non_null_sentences, 2*hidden_size)
    # Take the dot-product of the attention vectors with the context vector (i.e. parameter of linear layer)
    att_s = self.sentence_context_vector(att_s).squeeze(1)  # (n_non_null_sentences)

    max_value = att_s.max()  # scalar, for numerical stability during exponent calculation
    att_s = torch.exp(att_s - max_value)  # (n_non_null_sentences)

    # Re-arrange as sections by re-padding with 0s
    att_s, _ = pad_packed_sequence(PackedSequence(data=att_s,
                                                  batch_sizes=packed_sentences.batch_sizes,
                                                  sorted_indices=packed_sentences.sorted_indices,
                                                  unsorted_indices=packed_sentences.unsorted_indices),
                                    batch_first=True)  # (n_non_null_sections, max(sentences_per_document))

    # Calculate softmax values as now sentences are arranged in their respective sections
    sentence_alphas = att_s / torch.sum(att_s, dim=1, keepdim=True)  # (n_non_null_sections, max_sentences_per_document)

    # Similarly re-arrange sentence-level BiLSTM outputs as sections by re-padding with 0s (sentences -> sections)
    sections, _ = pad_packed_sequence(packed_sentences,
                                        batch_first=True)  # (n_non_null_sections, max(sentences_per_document), 2 * hidden_size)
    
    
    sentence_scores = torch.sigmoid(self.sentence_scores_layer(packed_sentences.data)) # shape (num_non_null_sentences, 1)

    sentence_scores, _ = pad_packed_sequence(PackedSequence(data=sentence_scores,
                                                  batch_sizes=packed_sentences.batch_sizes,
                                                  sorted_indices=packed_sentences.sorted_indices,
                                                  unsorted_indices=packed_sentences.unsorted_indices),
                                    batch_first=True)# (num_non_null_sections, max_sentences_per_section, 1)
                                    
    # Find section embeddings
    sections = sections * sentence_alphas.unsqueeze(2)  # (n_non_null_sections, max(sentences_per_document), 2 * hidden_size)
    sections = sections.sum(dim=1)  # (n_non_null_sections, 2 * hidden_size)

    return sections, sentence_scores


### Complete module
Computes section and sentence scores

In [None]:
class ContentRankingModule(nn.Module):
  def __init__(self, tokenizer, device, embedding_dimension=300, hidden_size=512):
    """
    Content Ranking Module ranks each sentence and each section.

    :param tokenizer: tokenizer previously instantiated
    :param device: cpu or cuda
    :param embedding_dimension:  embedding dimension, possible values: {50, 100, 200, 300}, default 300
    :param hidden_size: hidden size used for bidirectional LSTM layer, default 512
    """ 
    super(ContentRankingModule, self).__init__()

    self.title_attention = WordAttention(tokenizer, device)
    self.section_and_sentence_attention = SentenceAttention(tokenizer, device)
    self.section_and_title_attention = nn.Linear(hidden_size*2, 2*hidden_size)
    self.section_and_title_context_vector = nn.Linear(2*hidden_size, 1, bias = False)
    self.section_scores_layer = nn.Linear(2*hidden_size, 1)
    self.to(device)
    

  def forward(self, data): 
    """
    Content Ranking Module forward call.

    :param data: the dictionary returned by the dataloader
    :return section_scores: a matrix containg rank of each section
    :return sentence_scores: a matrix containing the ranking of each sentence of each section
    """
    # from (batch, max_num_section, max_sent_len) to (num_non_null_sections, max_sent_len)  
    packed_titles = pack_padded_sequence(data['section_titles'],
                                         lengths=data['sections_per_doc'].tolist(),
                                         batch_first=True,
                                         enforce_sorted=False)
    
    # from (batch, max_num_section) to (num_non_null_sections)
    packed_words_per_title = pack_padded_sequence(data['words_per_title'],
                                         lengths=data['sections_per_doc'].tolist(),
                                         batch_first=True,
                                         enforce_sorted=False)
    
    titles = self.title_attention(packed_titles.data, packed_words_per_title.data) # shape (num_non_null_section, hidden_size*2)
    
    titles, _ = pad_packed_sequence(PackedSequence(data=titles,
                                                      batch_sizes=packed_titles.batch_sizes,
                                                      sorted_indices=packed_titles.sorted_indices,
                                                      unsorted_indices=packed_titles.unsorted_indices),
                                       batch_first=True) # (batch, max_num_section, hidden_size*2)
    
    # CREATE INDEX
    # from (batch, max_num_sections, max_num_sentences_per_section) to (num_non_null_sections, max_num_sentences_per_section)
    packed_words_per_sentence = pack_padded_sequence(data['words_per_sentence'],
                                      lengths=data['sections_per_doc'].tolist(),
                                      batch_first=True,
                                      enforce_sorted=False)
    
    # from (batch, max_num_sections) to (num_non_null_sections)
    packed_sentences_per_section = pack_padded_sequence(data['sentences_per_section'],
                                      lengths=data['sections_per_doc'].tolist(),
                                      batch_first=True,
                                      enforce_sorted=False)
    
    # PACK DATA
    # from (batch, max_num_sections_per_doc, max_num_sentences_per_section, max_num_words_per_sentence) to (num_non_null_sections, max_num_sentences_per_section, max_num_words_per_sentence)
    packed_sections = pack_padded_sequence(data['section_texts'],
                                        lengths=data['sections_per_doc'].tolist(),
                                        batch_first=True,
                                        enforce_sorted=False)
    
    sections, sentence_scores = self.section_and_sentence_attention(packed_sections.data, packed_sentences_per_section.data, packed_words_per_sentence.data)
    # sections: (num_non_null_sections, 2 * hidden_size)
    # sentence_scores: (num_non_null_sections, max_sentences_per_section, 1)

    sentence_scores, _ = pad_packed_sequence(PackedSequence(data=sentence_scores,
                                                  batch_sizes=packed_sections.batch_sizes,
                                                  sorted_indices=packed_sections.sorted_indices,
                                                  unsorted_indices=packed_sections.unsorted_indices),
                                    batch_first=True) # (batch, max_num_sections_per_doc, max_sentences_per_section, 1)
    sections, _ = pad_packed_sequence(PackedSequence(data=sections,
                                                  batch_sizes=packed_sentences_per_section.batch_sizes,
                                                  sorted_indices=packed_sentences_per_section.sorted_indices,
                                                  unsorted_indices=packed_sentences_per_section.unsorted_indices),
                                    batch_first=True) # (batch, max_num_sections_per_doc, hidden_size*2)

    sections = sections.unsqueeze(2) # shape(batch, max_num_sections_per_doc, 1, hidden_size*2)
    titles = titles.unsqueeze(2) # shape(batch, max_num_sections_per_doc, 1, hidden_size*2)
    sections_with_titles = torch.stack((titles, sections), dim = 2).squeeze() # "vertical" stack of 2 4D matrices
    packed_sections_with_titles = pack_padded_sequence(sections_with_titles,
                                         lengths=data['sections_per_doc'].tolist(),
                                         batch_first=True,
                                         enforce_sorted=False) # (num_non_null_sections, 2, hidden_size*2)
    att_S = self.section_and_title_attention(packed_sections_with_titles.data)
    att_S = torch.tanh(att_S)
    att_S = self.section_and_title_context_vector(att_S).squeeze(1)
    max_value = att_S.max()
    att_S = torch.exp(att_S - max_value)
    att_S, _ = pad_packed_sequence(PackedSequence(data=att_S,
                                                      batch_sizes=packed_sections_with_titles.batch_sizes,
                                                      sorted_indices=packed_sections_with_titles.sorted_indices,
                                                      unsorted_indices=packed_sections_with_titles.unsorted_indices),
                                       batch_first=True)
    section_alphas = att_S / torch.sum(att_S, dim=1, keepdim=True) # shape (batch, max_num_sections_per_doc, 2, 1)
    sections, _ = pad_packed_sequence(packed_sections_with_titles,
                                           batch_first=True) # shape(batch, max_num_sections_per_doc, 2, hidden_size*2) 2 because of vertical stack of title and content representations
    
    sections = sections * section_alphas 
    sections = sections.sum(dim=2) # final shape (batch, max_num_sections_per_doc, hidden_size*2)
    
    packed_sections = pack_padded_sequence(sections,
                                        lengths=data['sections_per_doc'].tolist(),
                                        batch_first=True,
                                        enforce_sorted=False) # shape (num_non_null_sections, 2*hidden_size)

    section_scores = torch.sigmoid(self.section_scores_layer(packed_sections.data)) # shape (num_non_null_sections, 1)

    section_scores, _ = pad_packed_sequence(PackedSequence(data=section_scores,
                                                  batch_sizes=packed_sections.batch_sizes,
                                                  sorted_indices=packed_sections.sorted_indices,
                                                  unsorted_indices=packed_sections.unsorted_indices),
                                    batch_first=True) # (batch, max_num_sections_per_doc, 1)
    
    return section_scores.squeeze(), sentence_scores.squeeze() # become 2D and 3D matrices respectively



### Loss

In [None]:
class CRLoss(nn.Module):
  def __init__(self, device):
      """
      CRLoss computes the Binary Cross Entropy loss betwee the ranking scores and the ground truth labels.

      :param device: cpu or cuda
      """
      super(CRLoss, self).__init__()
      self.section_loss = nn.BCELoss()
      self.sentence_loss = nn.BCELoss()
      self.to(device)

  def forward(self, sections_importance, sentences_importance, sections_gold, sentences_gold):
      """
      CRLoss forward call.

      :param sections_importance: matrix of the generated section ranking
      :param sentences_importance: matrix of the generated sentence ranking
      :param sections_gold: matrix of the ground truth section scores
      :param sentences_gold: matrix of the ground truth sentence scores
      :return: sum of the losses computed on section and sentences
      """
      # if in a batch the maximum number of sections and sentences is lower than the maximum value fixed, 
      # using pad_packed_sequnce the dimentions returned will be equal to the maximal dimentions of the batch. 
      # Therefore it is necessary to properly slice the ground truth labels
      loss1 = self.section_loss(sections_importance, sections_gold[:, :sections_importance.shape[1]]) 
      loss2 = self.sentence_loss(sentences_importance, sentences_gold[:, :sentences_importance.shape[1], :sentences_importance.shape[2]])
      return torch.add(loss1,loss2)

## Training of content ranking module

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 25
eval_batch_size = 10
lr = 1e-4
epochs = 5
workers = 2
m = 2
n = 4
train_datapath = '/content/drive/MyDrive/arx_pub-dataset/train_subset.txt'
eval_datapath = '/content/drive/MyDrive/arx_pub-dataset/val_subset.txt'
CR_model_path = '/content/drive/MyDrive/CR_model_arx_pub.pkl'
max_sent_len = 36
tokenizer = Tokenizer(num_words=50000) # keep only the 50000 more frequent words

In [None]:
sys.setrecursionlimit(500000)
documents = ScientificPapaerDataset(train_datapath, max_sent_len, tokenizer)
eval_docs = ScientificPapaerDataset(eval_datapath, max_sent_len, tokenizer) 

content_ranking_dataloader = CRDataLoader(documents, batch_size=batch_size, num_workers=workers)
CR_eval_dataloader = CRDataLoader(eval_docs, batch_size=eval_batch_size, num_workers=workers)

In [None]:
model = ContentRankingModule(tokenizer=tokenizer, device=device)
model.to(device)

CR_loss = CRLoss(device)
CR_loss.to(device)

optimizer = optim.SGD(model.parameters(), lr=lr)

In [None]:
import warnings
warnings.filterwarnings('ignore')

In [None]:
def save_model(model, model_path):
    """Save model."""
    torch.save(model.state_dict(), model_path)

def load_model(model, model_path, use_cuda=True):
    """Load model."""
    map_location = 'cpu'
    if use_cuda and torch.cuda.is_available():
        map_location = 'cuda:0'
    model.load_state_dict(torch.load(model_path, map_location))
    return model

In [None]:
best_eval_loss = np.inf
for epoch in range(epochs):
  # Training
  training_loss = 0

  model.train()

  for batch, (dataset, sections_reference_importances, sentences_reference_importances) in enumerate(content_ranking_dataloader):
     # put anything on cuda if available
    for k in dataset.keys():
      dataset[k] = dataset[k].to(device)
    sections_reference_importances = sections_reference_importances.to(device)
    sentences_reference_importances = sentences_reference_importances.to(device)
    
    # forward call of the model
    sections_scores, sentences_scores = model(dataset)

    loss = CR_loss(sections_scores, sentences_scores, sections_reference_importances, sentences_reference_importances)

    training_loss += loss.item()

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    print(f'Epoch {batch}/{epochs}. Training loss: {training_loss:.3f}.')


  # Evaluation
  model.eval()
  eval_loss = 0
  for eval_batch, (eval, sections_reference_importances_eval, sentences_reference_importances_eval) in enumerate(CR_eval_dataloader):
    for k in eval.keys():
      eval[k] = eval[k].to(device)
    sections_reference_importances_eval = sections_reference_importances_eval.to(device)
    sentences_reference_importances_eval = sentences_reference_importances_eval.to(device)

    with torch.no_grad():
      sections_scores_eval, sentences_scores_eval = model(eval)

    eval_batch_loss = CR_loss(sections_scores_eval, sentences_scores_eval, sections_reference_importances_eval, sentences_reference_importances_eval)
    eval_loss += eval_batch_loss.item()
  print("Evaluation loss: ", eval_loss)
    
  if eval_loss < best_eval_loss: # save the model that reaches the lowest loss
    print("Saving best model")
    best_eval_loss = eval_loss
    save_model(model, CR_model_path)

## Predict

In [None]:
def predict(model, data, batch_size, workers, max_num_sec = 8, max_num_sent = 18):
  """
  predict applies the trained Content Ranking Module on the data recived in input.

  :param model: Content Ranking Module already trained
  :param data: data to feed to the internal data loader
  :param batch_size: batch size used by the internal data loader
  :param max_num_sec: maximum number of sections for each document, default 8
  :param max_num_sent: maximum number of sentences for each section, default 18
  :return: 2 matrices, one containing the genreated scores for the sections and one for those related to the sentences
  """
  model.eval()
  all_section_scores = []
  all_sentence_scores = []
  
  dataloader = CRDataLoader(data, batch_size=batch_size, num_workers=workers, shuffle = False)
  for _ , (dataset, _, _) in enumerate(dataloader):
    # because pad_packed_sequence returns the dimensions of the longest sequence in the batch, 
    # thus there can be a dimensional mismatch between the retruned matrices of different batches 
    empy_sec_score = torch.zeros(batch_size, max_num_sec) 
    empty_sent_score = torch.zeros(batch_size, max_num_sec, max_num_sent)
    for k in dataset.keys():
      dataset[k] = dataset[k].to(device)

    with torch.no_grad():
      sections_scores, sentences_scores = model(dataset)

    empy_sec_score[:sections_scores.shape[0], :sections_scores.shape[1]] = sections_scores
    empty_sent_score[:sentences_scores.shape[0], :sentences_scores.shape[1], :sentences_scores.shape[2]] = sentences_scores
    all_section_scores.append(empy_sec_score[:sections_scores.shape[0]]) # preserves the len of the batch
    all_sentence_scores.append(empty_sent_score[:sentences_scores.shape[0]])
  
  all_section_scores = torch.cat(all_section_scores, 0)
  all_sentence_scores = torch.cat(all_sentence_scores, 0)
  return all_section_scores, all_sentence_scores


Apply the model to the training and evaluation data sets

In [None]:
model = load_model(ContentRankingModule(tokenizer=tokenizer, device=device), CR_model_path) # load the previously saved best model
all_section_scores, all_sentence_scores= predict(model, documents, batch_size, workers)
eval_section_scores, eval_sentence_scores = predict(model, eval_docs, batch_size, workers)

## Digested dataset

In [None]:
def select_top_sections_and_sentences(section_scores, sentence_scores, m, n):
  """
  select_top_sections_and_sentences selects the indices of the best scoring sections and sentences, 
  maintaining the order in which their respective document.

  :param section_scores: Content Ranking Module scores for the sections
  :param sentence_scores: Content Ranking Module scores for the sentences
  :param m: number of sections to select
  :param n: number of sentences to select
  :return: 2 lists of ids, the first are the selected sections, the second are the selected sentences
  """
  sec_sort_ids = torch.argsort(section_scores, dim=1, descending = True)[:, :m].sort().values
  sent_sort_ids = torch.argsort(sentence_scores, dim=2, descending = True)[:, :, :n].sort().values

  return sec_sort_ids, sent_sort_ids

In [None]:
def compute_boundary_distance(l, pos): 
  """
  compute_boundary_distance computes the boundary distance feature for each sentence with respect to the corresponding section.

  :param l: number of sentence in the section
  :param pos: sentence position of the sentence in the section
  :return: integer indicating the distance from the boundaries of the section
  """
  dist = abs((pos - (l/2)))/(l/2)
  return dist

In [None]:
class DigestedDataset(Dataset): # map-style dataset
  def __init__(self, data, top_sections_ids, top_sentences_ids, m = 2, n = 4, n_best_sent_per_abs_sent = 2, disable_bound_dist = False):
    """
    DigestedDataset creates the digested documenets.

    :param data: list of Document intances
    :param top_sections_ids: best ranked section ids outputted by the select_top_sections_and_sentences function
    :param top_sentences_ids: best ranked section ids outputted by the select_top_sections_and_sentences function
    :param m: number of sections to select, default 2
    :param n: number of sentences to select, defalut 4
    :param n_best_sent_per_abs_sent: number of digested document sentences to be selected accordingly to their 
                                     average ROUGE F1 score computed against each abstract sentence
    :param disable_bound_dist: default False, if True disables the boundary distance contribution, useful for the ablation study
    """
    self.m = m
    self.n = n
    self.k = n_best_sent_per_abs_sent
    super(DigestedDataset, self).__init__()
    self.documents = []
    for doc_id, doc in enumerate(tqdm(data, desc="Reorganizing dataset")):
      sections = []
      for sec_id in top_sections_ids[doc_id]:
        if sec_id >= len(doc):
          continue # to handle documents with less than the defined maximum number of sections
        sentences = []
        for i in top_sentences_ids[doc_id][sec_id].tolist():
          if i >= len(doc[sec_id].sentences):
            continue# to handle sections with less than the defined maximum number of sentences
          sentences.append(doc[sec_id].sentences[i])
        sections.append(Section(None, sentences))
      sections = self.define_labels(doc.abstract, sections)
      self.documents.append(Document(sections, doc.abstract))
      self.dis_bd = disable_bound_dist

  def __getitem__(self, index):
    G = self.createGraph(self.documents[index])
    return G, index

  def get_doc(self, index):
    return self.documents[index]

  def __len__(self):
    return len(self.documents)

  def define_labels(self, abstract, sections): # defines the ground truth labels of each sentence
    for abs_sent in abstract.sentences:
      sentence_scores = torch.zeros(self.m * self.n)
      for sec_id, sec in enumerate(sections):
        for sent_id, sent in enumerate(sec.sentences):
          score = rouge.get_scores(sent.sentence, abs_sent.sentence) # abstract is a section
          score = np.mean([score[0]['rouge-1']['f'], score[0]['rouge-2']['f'], score[0]['rouge-l']['f']])
          sentence_scores[sent_id+sec_id*self.n] = score
      best_indices = sentence_scores.argsort(descending = True)[:self.k]
      for index in best_indices:
        if index//self.n < len(sections):
          if index%self.n < len(sections[index//self.n].sentences):
            sections[index//self.n].sentences[index%self.n].set_label()
    return sections

  def AddWordNode(self, G, document):
    wid2nid = {} # dictionary mapping word ids (keys) to node ids (values)
    nid = 0
    for sec in document.sections:
      for sent in sec.sentences:
          for wid in sent.tokenized_sentence:
              if wid not in wid2nid.keys() and wid != 0:
                  wid2nid[wid] = nid
                  nid += 1

    w_nodes = len(wid2nid)

    G.add_nodes(w_nodes)
    G.set_n_initializer(dgl.init.zero_initializer)
    G.ndata["id"] = torch.LongTensor(list(wid2nid.keys()))
    G.ndata["semantic_type"] = torch.zeros(w_nodes) # semantic_type of word nodes = 0

    return wid2nid

  def MapSent2Sec(self, doc, num_sentences):
    sent2sec = {} # dictionary mapping the sentence numbers (keys) to the corresponding section number (values)
    sentNo = 0
    for i, sec in enumerate(doc.sections):
      for j in range(len(sec)):
        sent2sec[sentNo] = i
        sentNo += 1
        if sentNo >= num_sentences:
          return sent2sec
    return sent2sec


  def createGraph(self, document):
    # builds a heterogeneous graph for each digested document
    G = dgl.DGLGraph()
    wid2nid = self.AddWordNode(G, document)
    w_nodes = len(wid2nid)

    N = document.num_sentences()
    G.add_nodes(N)
    G.ndata["semantic_type"][w_nodes:] = torch.ones(N) # semantic_type of sentence nodes = 1
    sentid2nid = [i + w_nodes for i in range(N)]
    ws_nodes = w_nodes + N
    
    sent2sec = self.MapSent2Sec(document, N)
    sec_num = len(set(sent2sec.values()))
    G.add_nodes(sec_num)
    G.ndata["semantic_type"][ws_nodes:] = torch.ones(sec_num) * 2 # semantic_type of section nodes = 2
    secid2nid = [i + ws_nodes for i in range(sec_num)]

    sent_id = 0 # all the sentences of the different sections have consecutive ids
    for sec_id, sec in enumerate(document.sections):
      secid = sent2sec[sent_id]
      secnid = secid2nid[secid] # node id of the section containit the considered sentence
      for sent_count, sent in enumerate(sec.sentences):
        c = Counter(sent.tokenized_sentence)
        sent_nid = sentid2nid[sent_id] # node id of the considered sentence
        G.nodes[sent_nid].data["label"] = torch.LongTensor([sent.label])
        G.nodes[sent_nid].data["words"] = torch.LongTensor([sent.tokenized_sentence]) # list of token ids
        G.nodes[sent_nid].data["num_words"] = torch.LongTensor([len(sent)])

        ###########################################   BOUNDARY DISTANCE    ############################################
        if self.dis_bd == True:
          G.nodes[sent_nid].data["boundary_dist"] = torch.LongTensor([1]) # all 1 if boundary distance feature is disabled
        else: 
          G.nodes[sent_nid].data["boundary_dist"] = torch.LongTensor([compute_boundary_distance(len(sec), sent_count)])
        ############################################################################################################### 

        G.add_edges(sent_nid, secnid,
                            data={"edge_type": torch.Tensor([4])}) # intra-section sentence2section: 4
        for other_sec_nid in secid2nid:
          if other_sec_nid != secnid:
            G.add_edges(other_sec_nid, sent_nid,
                            data={"edge_type": torch.Tensor([3])}) # cross-section section2sentence: 3
        for wid, cnt in c.items():
            if wid in wid2nid.keys():
                # w2s s2w
                G.add_edges(wid2nid[wid], sent_nid,
                            data={"edge_type": torch.Tensor([0])}) # word2sentenc: 0
                G.add_edges(sent_nid, wid2nid[wid],
                            data={"edge_type": torch.Tensor([1])}) # sentence2word: 1
        sent_id += 1

      
      for (sent_nid1, sent_nid2) in combinations([sentid2nid[sid] for sid,Sid in sent2sec.items() if Sid == sec_id], 2):
        G.add_edges(sent_nid1, sent_nid2, data={"edge_type": torch.Tensor([2])})# intra-section sentence2sentence: 2
        G.add_edges(sent_nid2, sent_nid1, data={"edge_type": torch.Tensor([2])})

    for (sec_nid1, sec_nid2) in combinations(secid2nid, 2):
      G.add_edges(sec_nid1, sec_nid2,
                          data={"edge_type": torch.Tensor([5])}) # section2section: 5
      G.add_edges(sec_nid2, sec_nid1,
                          data={"edge_type": torch.Tensor([5])}) # section2section: 5

    return G

In [None]:
top_sections_ids, top_sentences_ids = select_top_sections_and_sentences(all_section_scores, all_sentence_scores, 2, 4)
eval_top_sections_ids, eval_top_sentences_ids = select_top_sections_and_sentences(eval_section_scores, eval_sentence_scores, 2, 4)

In [None]:
digested_docs = DigestedDataset(documents, top_sections_ids, top_sentences_ids, m, n)
eval_digested_docs = DigestedDataset(eval_docs, eval_top_sections_ids, eval_top_sentences_ids, m, n)

## Extractive summarization dataloader

In [None]:
def graph_collate_fn(samples):
  '''
  graph_collate_fn creates batches of graphs
  
  :param samples: (G, index)
  :return: (batched graph, index)
  '''
  graphs, index = map(list, zip(*samples)) # graphs is a list of graphs, index is a list of the corresponding indices, both of len batch_size
  batched_graph = dgl.batch(graphs)
  return batched_graph, index

# Extractive summarization module

## Create $H^{0}_{w|s|S}$

In [None]:
class sentence_cnn_encoder(nn.Module):
  def __init__(self, embedding_layer, device = 'cpu', embedding_dimension = 300):
    """
    sentence_cnn_encoder creates sentence embeddings using a CNN.

    :param embedding_layer: layer that creates sentence initial embeddings
    :param device: cpu or cuda
    :param embedding_dimension: final embedding size, default 300
    """
    super(sentence_cnn_encoder, self).__init__()
    self.sent_embedder = embedding_layer
    self.to(device)

    input_channels = 1
    out_channels = 50
    min_kernel_size = 2
    max_kernel_size = 7
    width = embedding_dimension

    # cnn
    self.convs = nn.ModuleList([nn.Conv2d(input_channels, out_channels, kernel_size=(height, width)) for height in
                                range(min_kernel_size, max_kernel_size + 1)])

    for conv in self.convs:
        init_weight_value = 6.0
        init.xavier_normal_(conv.weight.data, gain=np.sqrt(init_weight_value))

  def forward(self, sent_tokens):
    # input: [s_nodes, max_sent_len]
    enc_embed_input = self.sent_embedder(sent_tokens)  # [s_nodes, max_sent_len, dimension of initial embedding]

    enc_conv_input = enc_embed_input.unsqueeze(1)  # [s_nodes, 1, max_sent_len, dimension of initial embedding]
    enc_conv_output = [F.relu(conv(enc_conv_input)).squeeze(3) for conv in self.convs]  # for each kernel size shape (s_nodes, output_channel, max_sent_len-(kernel_size-1)
    enc_maxpool_output = [F.max_pool1d(x, x.size(2)).squeeze(2) for x in enc_conv_output]  # for each kernel size shape (s_nodes, out_channel)
    sent_embedding = torch.cat(enc_maxpool_output, 1)  # [s_nodes, 50 * 6]
    return sent_embedding   # [s_nodes, 300]

In [None]:
class H_encoder(nn.Module):
  def __init__(self, embedding_layer, device = 'cpu', ouput_size = 264, embedding_size=300, hidden_size=256):
    """
    H_encoder creates word, sentence and section inital embeddings for the GAT.

    :param embedding_layer: layer that creates sentence initial embeddings
    :param device: cpu or cuda
    :param ouput_size: final embeddings size, must be common multiple of the attention heads for word, sentence and section, default 264
    :param embedding_size: initial embeddings size, default 300
    :param hidden_size: hidden size of LSTMs
    """
    super().__init__()
    self.to(device)
    self._emb_lay = embedding_layer
    self.word_projection = nn.Linear(embedding_size, ouput_size)

    self.sent_encoder = sentence_cnn_encoder(embedding_layer, device)
    self.BiLSTM = nn.LSTM(input_size=embedding_size, hidden_size=hidden_size, batch_first=True, bidirectional=True, num_layers=2, dropout=0.1)
    self.lstm_projection = nn.Linear(hidden_size*2, ouput_size)

    self.sentence_attention = nn.Linear(ouput_size, ouput_size)
    self.sentence_context_vector = nn.Linear(ouput_size, 1, bias = False)
    self.additional_embedding = nn.Linear(in_features = 1,out_features=64)
    self.final_proj = nn.Linear(in_features=ouput_size+64, out_features=ouput_size)  #64 is the additional boundary feature embedding size
    self.sec_BiLSTM = nn.LSTM(input_size=ouput_size, hidden_size=hidden_size, batch_first=True, bidirectional=True, num_layers=2, dropout=0.1)
    self.sec_projection = nn.Linear(hidden_size*2, ouput_size)

  def forward(self, graph): # graph is the batched graph
    word_emb = self.set_word_embeddings(graph)
    sent_emb = self.set_sentence_embedding(graph)
    sec_emb = self.set_section_embedding(graph)
    graph.ndata.pop('intermediate_embeddings')
    return word_emb, sent_emb, sec_emb

  def set_word_embeddings(self, graph): # graph is the batched graph
    wnode_id = graph.filter_nodes(lambda nodes: nodes.data["semantic_type"]==0) # word nodes
    wid = graph.nodes[wnode_id].data["id"]  # [n_wnodes]
    w_embed = self._emb_lay(wid)  # [n_wnodes, embedder_dimension]
    w_embed = self.word_projection(w_embed) # [n_wnodes, ouput_size]
    graph.nodes[wnode_id].data["initial_embeddings"] = w_embed
    return w_embed

  def set_sentence_embedding(self, graph): # graph is the batched graph
    snode_id = graph.filter_nodes(lambda nodes: nodes.data["semantic_type"] == 1) # sentence nodes
    cnn_feature = self.sent_encoder(graph.nodes[snode_id].data["words"])
    graph.nodes[snode_id].data["intermediate_embeddings"] = cnn_feature # create intermediate embeddings using CNN

    features, glen = self.get_sentence_features_and_len(graph, 1) # return the input format suitable for the pack_padded_sequence function with the sentences (semantic type = 1)
    lstm_feature = self.sent_lstm_feature(features, glen, self.BiLSTM, self.lstm_projection) # sentence embeddings after BiLSTM

    transposed_dist = graph.nodes[snode_id].data["boundary_dist"].reshape(lstm_feature.size()[0], 1).type(torch.FloatTensor).to(device) # column vector containing the boundary distance feature of each sentence
    custom = self.additional_embedding(transposed_dist) # integers are mapped to 64 dimentional embeddings

    complete_embedding = torch.cat([lstm_feature, custom], 1)
    initial_embedding = self.final_proj(complete_embedding) # concatenation reprojected in the same dimensions of other semantic nodes initial embeddings.
    graph.nodes[snode_id].data["initial_embeddings"] = initial_embedding

    return initial_embedding

  def set_section_embedding(self, graph): # graph is the batched graph
    secnode_id = graph.filter_nodes(lambda nodes: nodes.data["semantic_type"] == 2) # section nodes
    node_feature_list = []
    for Snode in secnode_id:
        snodes = [nid for nid in graph.predecessors(Snode) if graph.nodes[nid].data["semantic_type"]==1] # all sentence predecessors
        sec_sent_feature = graph.nodes[snodes].data["initial_embeddings"] # predecessors initial embeddings

        # apply attention mechanism over sentence initial embeddings to obtain intermediate section embeddings
        ui = self.sentence_attention(sec_sent_feature)
        ui = torch.tanh(ui)
        uiuatt = self.sentence_context_vector(ui)
        max_value = uiuatt.max()  # scalar, for numerical stability during exponent calculation
        alphas = F.softmax(uiuatt-max_value, dim = 0)
        section = sec_sent_feature * alphas
        section = section.sum(dim=0)
        node_feature_list.append(section)

    node_feature = torch.stack(node_feature_list)
    graph.ndata.pop('intermediate_embeddings')
    graph.nodes[secnode_id].data["intermediate_embeddings"] = node_feature
    features, glen = self.get_sentence_features_and_len(graph, 2)# return the input format suitable for the pack_padded_sequence function with the sections (semantic type = 2)
    lstm_feature = self.sent_lstm_feature(features, glen, self.sec_BiLSTM, self.sec_projection) # section embeddings after BiLSTM
    graph.nodes[secnode_id].data["initial_embeddings"] = lstm_feature
    return lstm_feature

  def get_sentence_features_and_len(self, G, sem_type):
    # returns 2 lists: 
    # *feature* contains the an intermediate representation for each sem_type node, 
    # *glen* contains the number of sem_type nodes for each unbatched graph
    glist = dgl.unbatch(G)
    feature = []
    glen = []
    for g in glist:
        snode_id = g.filter_nodes(lambda nodes: nodes.data["semantic_type"] == sem_type)
        feature.append(g.nodes[snode_id].data["intermediate_embeddings"]) # list of lists where each element is an embedding vector
        glen.append(len(snode_id))
    return feature, glen

  def sent_lstm_feature(self, features, glen, BiLSTM, lstm_projection):
    # passes the sentence CNN embedding through the BiLSTM
    pad_seq = pad_sequence(features, 
                            batch_first=True)
    lstm_input = pack_padded_sequence(pad_seq, 
                                      glen, 
                                      batch_first=True, 
                                      enforce_sorted=False)
    lstm_output, _ = BiLSTM(lstm_input) # takes a packed sequence in input and returns a packed sequence
    lstm_feature = lstm_projection(lstm_output.data)  # [n_sent_nodes, embd_dim]
    return lstm_feature

## GAT layers

In [None]:
class WSGATLayer(nn.Module):
  def __init__(self, in_dim, out_dim, device = 'cpu'):
    """
    WSGATLayer word to sentence GAT.

    :param in_dim: dimenion of embeddings in input
    :param out_dim: output dimention of the linear layer
    :param device: cpu or cuda
    """
    super().__init__()
    self.to(device)
    self.fc = nn.Linear(in_dim, out_dim, bias=False)
    self.attn_fc = nn.Linear(2 * out_dim, 1, bias=False)

  def edge_attention(self, edges):
    z2 = torch.cat([edges.src['wh'], edges.dst['wh']], dim=1)  # [edge_num, 2 * out_dim]
    wa = F.leaky_relu(self.attn_fc(z2))  # [edge_num, 1]
    return {'e': wa}

  def message_func(self, edges):
    return {'wh': edges.src['wh'], 'e': edges.data['e']}

  def reduce_func(self, nodes):
    alpha = F.softmax(nodes.mailbox['e'], dim=1)
    h = torch.sum(alpha * nodes.mailbox['wh'], dim=1)
    return {'sh': h}

  def forward(self, g, h):
    wnode_id = g.filter_nodes(lambda nodes: nodes.data["semantic_type"] == 0) # words
    snode_id = g.filter_nodes(lambda nodes: nodes.data["semantic_type"] == 1) # sentences
    wsedge_id = g.filter_edges(lambda edges: (edges.src["semantic_type"] == 0) & (edges.dst["semantic_type"] == 1) & (edges.data['edge_type'] == 0)) # word to sentence edges
    wh = self.fc(h)
    g.nodes[wnode_id].data['wh'] = wh[:len(wnode_id)]
    g.nodes[snode_id].data['wh'] = wh[-len(snode_id):]
    g.apply_edges(self.edge_attention, edges=wsedge_id)
    g.pull(snode_id, self.message_func, self.reduce_func)
    g.ndata.pop('wh')
    h = g.ndata.pop('sh') # h is the weighted sum of the input embedding, where the weights are the attention coefficients
    return h[snode_id]


In [None]:
class SWGATLayer(nn.Module):
  def __init__(self, in_dim, out_dim, device = 'cpu'):
    """
    SWGATLayer sentence to word GAT.

    :param in_dim: dimenion of embeddings in input
    :param out_dim: output dimention of the linear layer
    :param device: cpu or cuda
    """
    super().__init__()
    self.to(device)
    self.fc = nn.Linear(in_dim, out_dim, bias=False)
    self.attn_fc = nn.Linear(2 * out_dim, 1, bias=False)

  def edge_attention(self, edges):
    z2 = torch.cat([edges.src['wh'], edges.dst['wh']], dim=1)  # [edge_num, 2 * out_dim]
    wa = F.leaky_relu(self.attn_fc(z2))  # [edge_num, 1]
    return {'e': wa}

  def message_func(self, edges):
    return {'wh': edges.src['wh'], 'e': edges.data['e']}

  def reduce_func(self, nodes):
    alpha = F.softmax(nodes.mailbox['e'], dim=1)
    h = torch.sum(alpha * nodes.mailbox['wh'], dim=1)
    return {'sh': h}

  def forward(self, g, h):
    wnode_id = g.filter_nodes(lambda nodes: nodes.data["semantic_type"] == 0) # word
    snode_id = g.filter_nodes(lambda nodes: nodes.data["semantic_type"] == 1) # sentence
    swedge_id = g.filter_edges(lambda edges: (edges.src["semantic_type"] == 1) & (edges.dst["semantic_type"] == 0) & (edges.data['edge_type'] == 1)) # sentence to word edges
    wh = self.fc(h)
    g.nodes[wnode_id].data['wh'] = wh[:len(wnode_id)]
    g.nodes[snode_id].data['wh'] = wh[-len(snode_id):]
    g.apply_edges(self.edge_attention, edges=swedge_id)
    g.pull(wnode_id, self.message_func, self.reduce_func)
    g.ndata.pop('wh')
    h = g.ndata.pop('sh') # h is the weighted sum of the input embedding, where the weights are the attention coefficients
    return h[wnode_id]

In [None]:
class SSGATLayer(nn.Module):
  def __init__(self, in_dim, out_dim, device = 'cpu'):
    """
    SSGATLayer sentence to sentence GAT.

    :param in_dim: dimenion of embeddings in input
    :param out_dim: output dimention of the linear layer
    :param device: cpu or cuda
    """    
    super().__init__()
    self.to(device)
    self.fc = nn.Linear(in_dim, out_dim, bias=False)
    self.attn_fc = nn.Linear(2 * out_dim, 1, bias=False)

  def edge_attention(self, edges):
    z2 = torch.cat([edges.src['wh'], edges.dst['wh']], dim=1)  # [edge_num, 2 * out_dim]
    wa = F.leaky_relu(self.attn_fc(z2))  # [edge_num, 1]
    return {'e': wa}

  def message_func(self, edges):
    return {'wh': edges.src['wh'], 'e': edges.data['e']}

  def reduce_func(self, nodes):
    alpha = F.softmax(nodes.mailbox['e'], dim=1)
    h = torch.sum(alpha * nodes.mailbox['wh'], dim=1)
    return {'sh': h}

  def forward(self, g, h):
    snode_id = g.filter_nodes(lambda nodes: nodes.data["semantic_type"] == 1) # sentence nodes
    ssedge_id = g.filter_edges(lambda edges: edges.data['edge_type'] == 2) # intra-section sentence to sentence edges
    wh = self.fc(h)
    g.nodes[snode_id].data['wh'] = wh
    g.apply_edges(self.edge_attention, edges=ssedge_id)
    g.pull(snode_id, self.message_func, self.reduce_func)
    g.ndata.pop('wh')
    h = g.ndata.pop('sh') # h is the weighted sum of the input embedding, where the weights are the attention coefficients
    return h[snode_id]


In [None]:
class SecSGATLayer(nn.Module):
  def __init__(self, in_dim, out_dim, device = 'cpu'):
    """
    SecSGATLayer section to sentence GAT.

    :param in_dim: dimenion of embeddings in input
    :param out_dim: output dimention of the linear layer
    :param device: cpu or cuda
    """    
    super().__init__()
    self.to(device)
    self.fc = nn.Linear(in_dim, out_dim, bias=False)
    self.attn_fc = nn.Linear(2 * out_dim, 1, bias=False)

  def edge_attention(self, edges):
    z2 = torch.cat([edges.src['wh'], edges.dst['wh']], dim=1)  # [edge_num, 2 * out_dim]
    wa = F.leaky_relu(self.attn_fc(z2))  # [edge_num, 1]
    return {'e': wa}

  def message_func(self, edges):
    return {'wh': edges.src['wh'], 'e': edges.data['e']}

  def reduce_func(self, nodes):
    alpha = F.softmax(nodes.mailbox['e'], dim=1)
    h = torch.sum(alpha * nodes.mailbox['wh'], dim=1)
    return {'sh': h}

  def forward(self, g, h):
    Snode_id = g.filter_nodes(lambda nodes: nodes.data["semantic_type"] == 2) # section nodes
    snode_id = g.filter_nodes(lambda nodes: nodes.data["semantic_type"] == 1) # sentence nodes
    Ssedge_id = g.filter_edges(lambda edges: (edges.src["semantic_type"] == 2) & (edges.dst["semantic_type"] == 1) & (edges.data['edge_type'] == 3)) # cross-section section to sentence edges
    wh = self.fc(h)
    g.nodes[Snode_id].data['wh'] = wh[:len(Snode_id)]
    g.nodes[snode_id].data['wh'] = wh[-len(snode_id):]
    g.apply_edges(self.edge_attention, edges=Ssedge_id)
    g.pull(snode_id, self.message_func, self.reduce_func)
    g.ndata.pop('wh')
    h = g.ndata.pop('sh') # h is the weighted sum of the input embedding, where the weights are the attention coefficients
    return h[snode_id]

In [None]:
class SSecGATLayer(nn.Module):
  def __init__(self, in_dim, out_dim, device = 'cpu'):
    """
    SSecGATLayer sentence to section GAT.

    :param in_dim: dimenion of embeddings in input
    :param out_dim: output dimention of the linear layer
    :param device: cpu or cuda
    """
    super().__init__()
    self.to(device)
    self.fc = nn.Linear(in_dim, out_dim, bias=False)
    self.attn_fc = nn.Linear(2 * out_dim, 1, bias=False)

  def edge_attention(self, edges):
    z2 = torch.cat([edges.src['wh'], edges.dst['wh']], dim=1)  # [edge_num, 2 * out_dim]
    wa = F.leaky_relu(self.attn_fc(z2))  # [edge_num, 1]
    return {'e': wa}

  def message_func(self, edges):
    return {'wh': edges.src['wh'], 'e': edges.data['e']}

  def reduce_func(self, nodes):
    alpha = F.softmax(nodes.mailbox['e'], dim=1)
    h = torch.sum(alpha * nodes.mailbox['wh'], dim=1)
    return {'sh': h}

  def forward(self, g, h):
    Snode_id = g.filter_nodes(lambda nodes: nodes.data["semantic_type"] == 2) # section nodes
    snode_id = g.filter_nodes(lambda nodes: nodes.data["semantic_type"] == 1) # sentence nodes
    sSedge_id = g.filter_edges(lambda edges: (edges.src["semantic_type"] == 1) & (edges.dst["semantic_type"] == 2) & (edges.data['edge_type'] == 4)) # intra-section sentence to section edges
    wh = self.fc(h)
    g.nodes[Snode_id].data['wh'] = wh[:len(Snode_id)]
    g.nodes[snode_id].data['wh'] = wh[-len(snode_id):]
    g.apply_edges(self.edge_attention, edges=sSedge_id)
    g.pull(Snode_id, self.message_func, self.reduce_func)
    g.ndata.pop('wh')
    h = g.ndata.pop('sh') # h is the weighted sum of the input embedding, where the weights are the attention coefficients
    return h[Snode_id]

In [None]:
class SecSecGATLayer(nn.Module):
  def __init__(self, in_dim, out_dim, device = 'cpu'):
    """
    SecSecGATLayer section to section GAT.

    :param in_dim: dimenion of embeddings in input
    :param out_dim: output dimention of the linear layer
    :param device: cpu or cuda
    """
    super().__init__()
    self.to(device)
    self.fc = nn.Linear(in_dim, out_dim, bias=False)
    self.attn_fc = nn.Linear(2 * out_dim, 1, bias=False)

  def edge_attention(self, edges):
    z2 = torch.cat([edges.src['wh'], edges.dst['wh']], dim=1)  # [edge_num, 2 * out_dim]
    wa = F.leaky_relu(self.attn_fc(z2))  # [edge_num, 1]
    return {'e': wa}

  def message_func(self, edges):
    return {'wh': edges.src['wh'], 'e': edges.data['e']}

  def reduce_func(self, nodes):
    alpha = F.softmax(nodes.mailbox['e'], dim=1)
    h = torch.sum(alpha * nodes.mailbox['wh'], dim=1)
    return {'sh': h}

  def forward(self, g, h):
    Snode_id = g.filter_nodes(lambda nodes: nodes.data["semantic_type"] == 2) # section nodes
    SSedge_id = g.filter_edges(lambda edges: edges.data['edge_type'] == 5) # section to section edges
    wh = self.fc(h)
    g.nodes[Snode_id].data['wh'] = wh
    g.apply_edges(self.edge_attention, edges=SSedge_id)
    g.pull(Snode_id, self.message_func, self.reduce_func)
    g.ndata.pop('wh')
    h = g.ndata.pop('sh') # h is the weighted sum of the input embedding, where the weights are the attention coefficients
    return h[Snode_id]

## Multi head GAT


In [None]:
class MultiHeadGAT(nn.Module):
    def __init__(self, in_dim, out_dim, num_heads, attn_drop_out, layer, device = 'cpu', merge='cat'):
        """
        MultiHeadGAT merges all the heads attentions.

        :param in_dim: dimenion of the GAT layer in input
        :param out_dim: output dimention of the GAT layer in input
        :param num_heads: number of heads for the GAT layer in input
        :param attn_drop_out: dropout probability
        :param layer: a specific GAT layer
        :param device: cpu or cuda
        :param merge: 'cut' or 'avg' if the heads are concatenated or averaged
        """
        super(MultiHeadGAT, self).__init__()
        self.heads = nn.ModuleList()
        for i in range(num_heads):
            self.heads.append(layer(in_dim, out_dim)) # list of num_heads GAT layers
        self.merge = merge
        self.dropout = nn.Dropout(attn_drop_out)
        self.to(device)

    def forward(self, g, h):
        head_outs = [attn_head(g, self.dropout(h)) for attn_head in self.heads]  # n_head * [n_nodes, out_dim]
        if self.merge == 'cat':
            # concat on the output feature dimension (dim=1)
            result = torch.cat(head_outs, dim=1)  # [n_nodes, out_dim * n_head]
        elif self.merge == 'avg':
            # merge using average
            result = torch.mean(torch.stack(head_outs))
        return result

## Complete GAT

In [None]:
class GAT(nn.Module):
  def __init__(self, in_dim, out_dim, num_heads, attn_drop_out, layerType, device = 'cpu'):
    """
    GAT runs the complete GAT mechanism.

    :param in_dim: dimenion of the GAT layer in input
    :param out_dim: output dimention of the GAT layer in input, must be a common multiple of the different num_head used
    :param num_heads: number of heads for the GAT layer in input
    :param attn_drop_out: dropout probability
    :param layerType: which type of GAT layer is desired
    :param device: cpu or cuda
    """
    super().__init__()
    self.layerType = layerType
    self.to(device)
    if layerType == "W2S":
      self.layer = MultiHeadGAT(in_dim, int(out_dim / num_heads), num_heads, attn_drop_out, layer=WSGATLayer, device=device)
    elif layerType == "S2W":
      self.layer = MultiHeadGAT(in_dim, int(out_dim / num_heads), num_heads, attn_drop_out, layer=SWGATLayer, device=device)
    elif layerType == "S2S":
      self.layer = MultiHeadGAT(in_dim, int(out_dim / num_heads), num_heads, attn_drop_out, layer=SSGATLayer, device=device)
    elif layerType == "Sec2S":
      self.layer = MultiHeadGAT(in_dim, int(out_dim / num_heads), num_heads, attn_drop_out, layer=SecSGATLayer, device=device)
    elif layerType == "S2Sec":
      self.layer = MultiHeadGAT(in_dim, int(out_dim / num_heads), num_heads, attn_drop_out, layer=SSecGATLayer, device=device)
    elif layerType == "Sec2Sec":
      self.layer = MultiHeadGAT(in_dim, int(out_dim / num_heads), num_heads, attn_drop_out, layer=SecSecGATLayer, device=device)
    else:
      raise NotImplementedError("GAT Layer has not been implemented!")


  def forward(self, g, w, s):
    if self.layerType == "W2S":
      origin, neighbor = s, w
      total = torch.cat([w,s])
    elif self.layerType == "S2W":
      origin, neighbor = w, s
      total = torch.cat([w,s])
    elif self.layerType == "S2S":
      assert torch.equal(w, s) # check that w and s conincide in sentence to sentenc GAT
      origin, total = w, s
    elif self.layerType == "Sec2S": # w correspond to section
      origin, neighbor = s, w
      total = torch.cat([w,s])
    elif self.layerType == "S2Sec": # w correspond to section
      origin, neighbor = w, s
      total = torch.cat([w,s])
    elif self.layerType == "Sec2Sec":
      assert torch.equal(w, s) # check that w and s conincide in section to section GAT
      origin, total = w, s
    else:
        origin, neighbor = None, None
    
    h = F.elu(self.layer(g, total))
    return h

## Fusion

In [None]:
class fusionLayer(nn.Module):
  def __init__(self, embedding_size, device = 'cpu'):
    """
    fusionLayer performs the fusion operation.

    :param embedding_size: size of the embeddings to be merged by means of fusion
    :param device: cpu or cuda
    """
    super().__init__()
    self.lin = nn.Linear(2*embedding_size, embedding_size)
    self.to(device)

  def forward(self, x, y):
    z = torch.cat([x, y], dim = 1)
    z = self.lin(z)
    z = torch.sigmoid(z)
    fusion = z*x + (1-z)*y
    return fusion

## Position-wise Feed-Forward Network

In [None]:
class PositionwiseFeedForward(nn.Module):
  def __init__(self, d_in, d_hid, dropout=0.1, device = 'cpu'):
    """
    PositionwiseFeedForward applies a position wise feed forward neural network.

    :param d_in: dimenion of the input embedding
    :param d_hid: hidden dimension of the module
    :param dropout: dropout probability, default 0.1
    :param device: cpu or cuda
    """
    super().__init__()
    self.to(device)
    self.w_1 = nn.Linear(d_in, d_hid)
    self.w_2 = nn.Linear(d_hid, d_in)
    self.layer_norm = nn.LayerNorm(d_in, eps=1e-6)
    self.dropout = nn.Dropout(dropout)

  def forward(self, x):
    residual = x
    x = self.w_2(F.relu(self.w_1(x)))
    x = self.dropout(x)
    x += residual
    x = self.layer_norm(x)
    return x

## Complete extractive summarization module

In [None]:
class ExtractiveSummarizazionModule(nn.Module):
  def __init__(self, embedding_layer, num_heads1, num_heads2, num_iter, FFN_size, dropout, device = 'cpu', embedding_size = 264):
    """
    ExtractiveSummarizazionModule runs the complete GAT mechanism.

    :param embedding_layer: embedding layer used by the H0 encoder
    :param num_heads1: num heads for the word GAT
    :param num_heads2: number of heads for the sentence and section GATs
    :param num_iter: corresponds to the T parameter, is the number of iteative updates of each node representation
    :param FFN_size: hidden dimention of the position-wise FFN
    :param dropout: dropout probability
    :param device: cpu or cuda
    :param embedding_size: must be a common multiple of num_heads1 and num_heads2
    """
    super().__init__()
    self.to(device)
    self.H0 = H_encoder(embedding_layer, device)
    self.T = num_iter

    self.word2sentGat = GAT(embedding_size, embedding_size, num_heads1, dropout, "W2S", device)
    self.sent2wordGat = GAT(embedding_size, embedding_size, num_heads2, dropout, "S2W", device)
    self.sent2sentGat = GAT(embedding_size, embedding_size, num_heads2, dropout, "S2S", device)
    self.sec2sentGat = GAT(embedding_size, embedding_size, num_heads2, dropout, "Sec2S", device)
    self.sent2secGat = GAT(embedding_size, embedding_size, num_heads2, dropout, "S2Sec", device)
    self.sec2secGat = GAT(embedding_size, embedding_size, num_heads2, dropout, "Sec2Sec", device)

    self.ws_Ss_fusion = fusionLayer(embedding_size, device)
    self.prevFus_ss_fusion = fusionLayer(embedding_size, device)
    self.sS_SS_fusion = fusionLayer(embedding_size, device)

    self.FFN_word = PositionwiseFeedForward(embedding_size, FFN_size, dropout, device)
    self.FFN_sent = PositionwiseFeedForward(embedding_size, FFN_size, dropout, device)
    self.FFN_sec = PositionwiseFeedForward(embedding_size, FFN_size, dropout, device)

    self.classification_layer = nn.Linear(embedding_size, 2)
    

  def forward(self, graph):
    word_emb, sent_emb, sec_emb = self.H0(graph) # set initial embeddings

    word_state = word_emb
    sent_state = sent_emb
    sec_state = sec_emb

    for _ in range(self.T):
      U_s2w = self.sent2wordGat(graph, word_state, sent_state) # U_s2w has shape (num_words in batched graph, embedding_size)
      h_w2s = self.word2sentGat(graph, word_state, sent_state) # h_w2s has shape (num_sentences in batched graph, embedding_size)
      h_s2s = self.sent2sentGat(graph, sent_state, sent_state) # h_s2s has shape (num_sentences in batched graph, embedding_size)
      h_sec2s = self.sec2sentGat(graph, sec_state, sent_state) # h_sec2s has shape (num_sentences in batched graph, embedding_size)
      h_s2sec = self.sent2secGat(graph, sec_state, sent_state) # h_s2sec has shape (num_sections in batched graph, embedding_size)
      h_sec2sec = self.sec2secGat(graph, sec_state, sec_state) # h_sec2sec has shape (num_sections in batched graph, embedding_size)
      
      U_s = self.prevFus_ss_fusion(self.ws_Ss_fusion(h_w2s, h_sec2s), h_s2s) # shape (num_sentences in batched graph, embedding_size)
      U_Sec = self.sS_SS_fusion(h_s2sec, h_sec2sec) # shape (num_sections in batched graph, embedding_size)

      word_state = self.FFN_word(U_s2w + word_state) # shape (num_words in batched graph, embedding_size)
      sent_state = self.FFN_sent(U_s + sent_state) # shape (num_sentences in batched graph, embedding_size)
      sec_state = self.FFN_sec(U_Sec + sec_state) # shape (num_sections in batched graph, embedding_size)

    probs = torch.sigmoid(self.classification_layer(sent_state))
    return probs # returns for each sentence 2 probabilities

## Tester

In [None]:
class Tester():
  def __init__(self, model, n_max, dir = None):
    """
    Tester evaluates the model and returns a file containing extracted and reference summaries when dir != None

        :param model: the model
        :param n_max: maximum number of words in summary
        :param dir: for saving decode files
    """
    self.model = model
    self.n_max = n_max
    self.save_dir = dir
    self.extracts = []

    self.batch_number = 0
    self.running_loss = 0
    self.example_num = 0
    self.total_sentence_num = 0
    self.rougePairNum = 0

    self.hypothesis = []
    self.reference = []

    self.pred, self.true, self.match, self.match_true = 0, 0, 0, 0
    self._F = 0
    self.criterion = torch.nn.CrossEntropyLoss(reduction='none')

  def SaveDecodeFile(self):
    import datetime
    nowTime = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
    log_dir = os.path.join(self.save_dir, nowTime)
    with open(log_dir, "wb") as resfile:
      for i in range(self.rougePairNum):
        resfile.write(b"[Reference]\t")
        resfile.write(self.reference[i].encode('utf-8'))
        resfile.write(b"\n")
        resfile.write(b"[Hypothesis]\t")
        resfile.write(self.hypothesis[i].encode('utf-8'))
        resfile.write(b"\n")
        resfile.write(b"\n")
        resfile.write(b"\n")
    
  def running_avg_loss(self):
    return self.running_loss / self.batch_number

  def evaluation(self, G, index, dataset):
    """
      :param G: the batched graph
      :param index: list containing the index of each graph in the batched graph
      :param dataset: dataset which includes text and summary
    """
    self.batch_number += 1
    outputs = self.model.forward(G)
    snode_id = G.filter_nodes(lambda nodes: nodes.data["semantic_type"] == 1) # sentence nodes
    label = G.ndata["label"][snode_id] # [sent_nodes]
    G.nodes[snode_id].data["loss"] = self.criterion(outputs, label).unsqueeze(-1) # [sent_nodes, 1] loss of each sentence
    loss = dgl.sum_nodes(G, "loss")  # [batch_size, 1] sums the losses of all the sentences of a single graph
    loss = loss.mean() # mean of the losses of each graph
    self.running_loss += float(loss.data)
    G.nodes[snode_id].data["p"] = outputs
    glist = dgl.unbatch(G)
    for j in range(len(glist)): # each unbatched graph is a document
      original_article_sents = []
      idx = index[j]
      example = dataset.get_doc(idx)
      for sec in example.sections:
        original_article_sents.extend([sent.sentence for sent in sec.sentences])
      sent_max_number = len(original_article_sents)
      reference = example.abstract.all_sentences

      g = glist[j]
      snode_id = g.filter_nodes(lambda nodes: nodes.data["semantic_type"] == 1) # sentence nodes of a document
      N = len(snode_id)
      p_sent = g.ndata["p"][snode_id] # [snode, 2]
      label = g.ndata["label"][snode_id].squeeze().cpu()   # [n_node]
      num_words = g.ndata["num_words"][snode_id]
      topk, pred_idx = torch.topk(p_sent[:,1], N) # order the sentences on the basis of their probability to be extracted (p[:,1])
      lens = 0
      hyps = ''
      for id in pred_idx:
        lens += num_words[id]
        if lens < self.n_max: # if adding the new sentence we are still under the limit of self.n_max word the sentence is concatenated
          if id < sent_max_number:
            hyps = hyps + original_article_sents[id] # the order of the original document is not preserved
      self.hypothesis.append(hyps)
      self.reference.append(reference)
      self.rougePairNum += 1

## Define training parameter for extractive summarization

In [None]:
ES_model_path = '/content/drive/MyDrive/ES_model_arx_pub.pkl'
ES_batch_size = 25
lr = 1e-4
epochs = 5
max_grad_norm = 2
workers = 2
num_heads_wordGat = 6
num_heads_sent_secGat = 8
T = 2
FFN_hidden_size = 2048
dropout = 0.1
output_size = 264 # must be a common multiple of the attention heads
max_num_words_in_summary = 200
best_rouge = True # True to save the model the obtains the best rouge score against the golden summary, False to save the model which reaches the lowest cross entropy loss
embedder = embedding_layer(tokenizer, 300, device)
criterion = nn.CrossEntropyLoss(reduction='none')
model = ExtractiveSummarizazionModule(embedder, num_heads_wordGat, num_heads_sent_secGat, T, FFN_hidden_size, dropout, device, output_size)
model.to(device)
train_loader = torch.utils.data.DataLoader(digested_docs, batch_size=ES_batch_size, shuffle=True, num_workers=workers,collate_fn=graph_collate_fn)
eval_loader = torch.utils.data.DataLoader(eval_digested_docs, batch_size=ES_batch_size, shuffle=True, num_workers=workers,collate_fn=graph_collate_fn)
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, betas = [0.9, 0.999], eps=1e-8)

# Training of extractive summarization module

In [None]:
best_eval_loss = np.float('inf')
best_eval_rouge = 0.0
for epoch in range(epochs):

  # Training
  training_loss = 0
  model.train()

  for batch, (batched_graph, index) in enumerate(train_loader):
    batched_graph = batched_graph.to(device)
    predictions = model(batched_graph)
    sent_nodes = batched_graph.filter_nodes(lambda node: node.data['semantic_type'] == 1) # sentence nodes
    labels = batched_graph.nodes[sent_nodes].data['label']
    batched_graph.nodes[sent_nodes].data["loss"] = criterion(predictions, labels).unsqueeze(-1)  # [n_nodes, 1], cross entropy loss between probabilities and ground truth labels
    loss = dgl.sum_nodes(batched_graph, "loss")  # [batch_size, 1]
    loss = loss.mean()
    print(f'Batch/Epoch {batch}/{epochs}. Batch loss: {loss:.3f}.')
    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)

    optimizer.step()

    training_loss += float(loss.data)

  # Evaluation, at each epoch
  model.eval()
  with torch.no_grad():
      tester = Tester(model, max_num_words_in_summary)
      for i, (G, index) in enumerate(eval_loader):
        G = G.to(device)
        tester.evaluation(G, index, eval_digested_docs)

  running_avg_loss = tester.running_avg_loss
  rouge = Rouge()
  scores_all = rouge.get_scores(tester.hypothesis, tester.reference, avg=True)
  avg_rouge = np.mean([scores_all['rouge-1']['f'], scores_all['rouge-2']['f'], scores_all['rouge-l']['f']])# mean of all ROUGE f1 score
  if best_rouge: # save the model that reaches the highest ROUGE score on the validation set
    if avg_rouge > best_eval_rouge:
      print("Save model with higher rouge score. Average F rouge score: ", avg_rouge, "Previous best avg rouge score: ", best_eval_rouge )
      best_eval_rouge = avg_rouge
      save_model(model, ES_model_path)
  else: # save the model that reaches the lowest loss
    if running_avg_loss < best_eval_loss:
      print("Save model with lower loss. Evaluation loss: ", running_avg_loss)
      best_eval_loss = running_avg_loss
      save_model(model, ES_model_path)

  print(f'Epoch {epoch}. Training loss: {training_loss:.3f}.')

# Test

In [None]:
test_datapath = '/content/drive/MyDrive/arx_pub-dataset/test_subset.txt'
test_results = './test' # directory for the document containing reference and extracted summary for each test document
test_batch_size = 10

In [None]:
CR_model = load_model(ContentRankingModule(tokenizer=tokenizer, device=device), CR_model_path)
ES_model = load_model(ExtractiveSummarizazionModule(embedder, num_heads_wordGat, num_heads_sent_secGat, T, FFN_hidden_size, dropout, device, output_size), ES_model_path)
ES_model.to(device)
test_docs = ScientificPapaerDataset(test_datapath, max_sent_len, tokenizer)

In [None]:
test_section_scores, test_sentence_scores = predict(CR_model, test_docs, test_batch_size, workers)
test_top_sections_ids, test_top_sentences_ids = select_top_sections_and_sentences(test_section_scores, test_sentence_scores, m, n)

In [None]:
test_digested_docs = DigestedDataset(test_docs, test_top_sections_ids, test_top_sentences_ids)

In [None]:
def ES_predict(model, data, ES_batch_size, max_num_words_in_summary, test_results):
  """
  ES_predict extracts and evaluates the summaries for the text documents

      :param model: the Extractive Summarization Module
      :param data: documents in input
      :param ES_batch_size: batch size for the internal dataloader
      :param max_num_words_in_summary: maximum word length of the extracted summary
      :param test_results: directory in which to store extracted and reference summary of each document in input
  """
  if not os.path.exists(test_results) : os.makedirs(test_results)
  model.eval()
  dataloader = torch.utils.data.DataLoader(data, batch_size=ES_batch_size, num_workers=4, collate_fn=graph_collate_fn)

  with torch.no_grad():
    tester = Tester(model, max_num_words_in_summary, test_results)

    for batched_graph, index in dataloader:
      batched_graph = batched_graph.to(device)
      tester.evaluation(batched_graph, index, data)

  running_avg_loss = tester.running_avg_loss
  rouge = Rouge()
  scores_all = rouge.get_scores(tester.hypothesis, tester.reference, avg=True)

  res = "Rouge1:\n\tp:%.6f, r:%.6f, f:%.6f\n" % (scores_all['rouge-1']['p'], scores_all['rouge-1']['r'], scores_all['rouge-1']['f']) \
          + "Rouge2:\n\tp:%.6f, r:%.6f, f:%.6f\n" % (scores_all['rouge-2']['p'], scores_all['rouge-2']['r'], scores_all['rouge-2']['f']) \
              + "Rougel:\n\tp:%.6f, r:%.6f, f:%.6f\n" % (scores_all['rouge-l']['p'], scores_all['rouge-l']['r'], scores_all['rouge-l']['f'])
  print(res)
  tester.SaveDecodeFile()

In [None]:
ES_predict(ES_model, test_digested_docs, ES_batch_size, max_num_words_in_summary, test_results)

# Ablation study

## Disable content ranking module

In [None]:
def select_lead_digests(section_scores, sentence_scores, m, n): # select the ids of the m lead sections and  n lead sentences
  sec_lead_ids = np.tile(range(m), (section_scores.shape[0], 1))
  sent_lead_ids = np.tile(range(n), (sentence_scores.shape[0], sentence_scores.shape[1], 1))
  return sec_lead_ids, sent_lead_ids

In [None]:
sec_lead_ids, sent_lead_ids = select_lead_digests(test_section_scores, test_sentence_scores, m, n)

In [None]:
lead_test_digested_docs = DigestedDataset(test_docs, sec_lead_ids, sent_lead_ids)

In [None]:
ES_predict(ES_model, lead_test_digested_docs, ES_batch_size, max_num_words_in_summary, test_results)

## Disable iteative update in extractive summarization module

In [None]:
T = 1
ES_model_path_T1 = '/content/drive/MyDrive/ES_model_T1.pkl'
model = ExtractiveSummarizazionModule(embedder, num_heads_wordGat, num_heads_sent_secGat, T, FFN_hidden_size, dropout, device, output_size)
model.to(device)
T = 2 # for future instanciations of the model

New training of the extractive summarization module

In [None]:
best_eval_loss = np.float('inf')
best_eval_rouge = 0.0
for epoch in range(epochs):
  training_loss = 0
  model.train()

  for batch, (batched_graph, index) in enumerate(train_loader):
    batched_graph = batched_graph.to(device)
    predictions = model(batched_graph)
    sent_nodes = batched_graph.filter_nodes(lambda node: node.data['semantic_type'] == 1)
    labels = batched_graph.nodes[sent_nodes].data['label']
    batched_graph.nodes[sent_nodes].data["loss"] = criterion(predictions, labels).unsqueeze(-1)  # [n_nodes, 1]
    loss = dgl.sum_nodes(batched_graph, "loss")  # [batch_size, 1]
    loss = loss.mean()
    print(f'Batch/Epoch {batch}/{epochs}. Batch loss: {loss:.3f}.')
    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)

    optimizer.step()

    training_loss += float(loss.data)

  # Evaluation
  model.eval()
  with torch.no_grad():
      tester = Tester(model, max_num_words_in_summary)
      for i, (G, index) in enumerate(eval_loader):
        G = G.to(device)
        tester.evaluation(G, index, eval_digested_docs)

  running_avg_loss = tester.running_avg_loss
  rouge = Rouge()
  scores_all = rouge.get_scores(tester.hypothesis, tester.reference, avg=True)
  avg_rouge = np.mean([scores_all['rouge-1']['f'], scores_all['rouge-2']['f'], scores_all['rouge-l']['f']])
  if best_rouge:
    if avg_rouge > best_eval_rouge:
      print("Save model with higher rouge score. Average F rouge score: ", avg_rouge, "Previous best avg rouge score: ", best_eval_rouge )
      best_eval_rouge = avg_rouge
      save_model(model, ES_model_path_T1)
  else:
    if running_avg_loss < best_eval_loss:
      print("Save model with lower loss. Evaluation loss: ", running_avg_loss)
      best_eval_loss = running_avg_loss
      save_model(model, ES_model_path_T1)

  print(f'Epoch {epoch}. Training loss: {training_loss:.3f}.')

In [None]:
ES_model_t1 = load_model(ExtractiveSummarizazionModule(embedder, num_heads_wordGat, num_heads_sent_secGat, T, FFN_hidden_size, dropout, device, output_size), ES_model_path_T1)
ES_model_t1.to(device)

In [None]:
ES_predict(ES_model_t1, test_digested_docs, ES_batch_size, max_num_words_in_summary, test_results)

## Disable boundary distance feature
Set it equal to 1 for each sentence

In [None]:
ES_model_path_no_bd = '/content/drive/MyDrive/ES_model_no_bd.pkl'
model = ExtractiveSummarizazionModule(embedder, num_heads_wordGat, num_heads_sent_secGat, T, FFN_hidden_size, dropout, device, output_size)
model.to(device)

In [None]:
no_bd_digested_docs = DigestedDataset(documents, top_sections_ids, top_sentences_ids, m, n, disable_bound_dist = True)
no_bd_eval_digested_docs = DigestedDataset(eval_docs, eval_top_sections_ids, eval_top_sentences_ids, m, n, disable_bound_dist = True)

In [None]:
train_loader = torch.utils.data.DataLoader(no_bd_digested_docs, batch_size=ES_batch_size, shuffle=True, num_workers=workers,collate_fn=graph_collate_fn)
eval_loader = torch.utils.data.DataLoader(no_bd_eval_digested_docs, batch_size=ES_batch_size, shuffle=True, num_workers=workers,collate_fn=graph_collate_fn)

In [None]:
best_eval_loss = np.float('inf')
best_eval_rouge = 0.0
for epoch in range(epochs):
  training_loss = 0
  model.train()

  for batch, (batched_graph, index) in enumerate(train_loader):
    batched_graph = batched_graph.to(device)
    predictions = model(batched_graph)
    sent_nodes = batched_graph.filter_nodes(lambda node: node.data['semantic_type'] == 1)
    labels = batched_graph.nodes[sent_nodes].data['label']
    batched_graph.nodes[sent_nodes].data["loss"] = criterion(predictions, labels).unsqueeze(-1)  # [n_nodes, 1]
    loss = dgl.sum_nodes(batched_graph, "loss")  # [batch_size, 1]
    loss = loss.mean()
    print(f'Batch/Epoch {batch}/{epochs}. Batch loss: {loss:.3f}.')
    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)

    optimizer.step()

    training_loss += float(loss.data)

  # Evaluation
  model.eval()
  with torch.no_grad():
      tester = Tester(model, max_num_words_in_summary)
      for i, (G, index) in enumerate(eval_loader):
        G = G.to(device)
        tester.evaluation(G, index, eval_digested_docs)

  running_avg_loss = tester.running_avg_loss
  rouge = Rouge()
  scores_all = rouge.get_scores(tester.hypothesis, tester.reference, avg=True)
  avg_rouge = np.mean([scores_all['rouge-1']['f'], scores_all['rouge-2']['f'], scores_all['rouge-l']['f']])
  if best_rouge:
    if avg_rouge > best_eval_rouge:
      print("Save model with higher rouge score. Average F rouge score: ", avg_rouge, "Previous best avg rouge score: ", best_eval_rouge )
      best_eval_rouge = avg_rouge
      save_model(model, ES_model_path_no_bd)
  else:
    if running_avg_loss < best_eval_loss:
      print("Save model with lower loss. Evaluation loss: ", running_avg_loss)
      best_eval_loss = running_avg_loss
      save_model(model, ES_model_path_no_bd)

  print(f'Epoch {epoch}. Training loss: {training_loss:.3f}.')

In [None]:
ES_model_no_bd = load_model(ExtractiveSummarizazionModule(embedder, num_heads_wordGat, num_heads_sent_secGat, T, FFN_hidden_size, dropout, device, output_size), ES_model_path_no_bd)
ES_model_no_bd.to(device)

In [None]:
test_digested_docs_no_bd = DigestedDataset(test_docs, test_top_sections_ids, test_top_sentences_ids, disable_bound_dist = True)

In [None]:
ES_predict(ES_model_no_bd, test_digested_docs, ES_batch_size, max_num_words_in_summary, test_results)