In [1]:
'''
This notebook (as it is) tests the best extractive model on the validation and test sets
In order to train a new model uncomment the line containing the call to the train function (make sure the load variable is set to False)
The usage of a GPU is recommended

The model is a reimplementation of the paper: https://arxiv.org/pdf/1611.04230.pdf
'''

'\nThis notebook (as it is) tests the best extractive model on the validation and test sets\nIn order to train a new model uncomment the line containing the call to the train function (make sure the load variable is set to False)\nThe usage of a GPU is recommended\n\nThe model is a reimplementation of the paper: https://arxiv.org/pdf/1611.04230.pdf\n'

In [2]:
# Install packages
!pip install -U torchtext
!pip install Rouge
!pip install datasets

Collecting Rouge
  Downloading rouge-1.0.1-py3-none-any.whl (13 kB)
Installing collected packages: Rouge
Successfully installed Rouge-1.0.1
Collecting datasets
  Downloading datasets-1.11.0-py3-none-any.whl (264 kB)
[K     |████████████████████████████████| 264 kB 5.3 MB/s 
Collecting xxhash
  Downloading xxhash-2.0.2-cp37-cp37m-manylinux2010_x86_64.whl (243 kB)
[K     |████████████████████████████████| 243 kB 38.5 MB/s 
[?25hCollecting huggingface-hub<0.1.0
  Downloading huggingface_hub-0.0.16-py3-none-any.whl (50 kB)
[K     |████████████████████████████████| 50 kB 6.3 MB/s 
Collecting fsspec>=2021.05.0
  Downloading fsspec-2021.7.0-py3-none-any.whl (118 kB)
[K     |████████████████████████████████| 118 kB 48.8 MB/s 
Installing collected packages: xxhash, huggingface-hub, fsspec, datasets
Successfully installed datasets-1.11.0 fsspec-2021.7.0 huggingface-hub-0.0.16 xxhash-2.0.2


In [3]:
# imports
from google_drive_downloader import GoogleDriveDownloader as gdd
import os, struct
import glob
import random
import csv
from tensorflow.core.example import example_pb2
import torch, torch.nn as nn
from tqdm import tqdm
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from collections import defaultdict
from torch.nn.utils.rnn import pack_padded_sequence
from rouge import Rouge
import gc, math
import json
from nltk.tokenize import RegexpTokenizer
torch.set_printoptions(4)

In [4]:
# To train the model with a batch size of 500 15 GB of GPU memory are required
from pynvml import *
nvmlInit()
h = nvmlDeviceGetHandleByIndex(0)
info = nvmlDeviceGetMemoryInfo(h)
print(f'total    : {info.total/pow(10,9)}')
print(f'free     : {info.free/pow(10,9)}')
print(f'used     : {info.used/pow(10,9)}')

total    : 11.996954624
free     : 11.996954624
used     : 0.0


In [5]:
# from google.colab import drive
# drive.mount('/content/drive')

In [6]:
OOV_WORD, PADDING, START_DEC, STOP_DEC = "[UNK]", "[PAD]", "[START]", "[STOP]"
''' 
Vocabulary class
The vocabulary is taken from the dataset: "https://github.com/JafferWilson/Process-Data-of-CNN-DailyMail" "FINISHED FILES"
This class encodes the vocabulary in a hot encoding.  
'''
class Vocab():
  def __init__(self, path="/content/dataset/cnn_dm_extractive_compressed_5000/vocab", max_size = None):
    self.count = 0
    self.word2id = {} # translate a word to its hot encoding
    self.id2word = {} # translate an id to word (since this setting is extractive this dictionary is not used)
    for w in [OOV_WORD, PADDING, START_DEC, STOP_DEC]:
      self.word2id[w] = self.count
      self.id2word[self.count] = w
      self.count+=1
    # populate the vocabulary
    with open(path, 'r') as f:
      for l in f:
        line = l.split()
        if(len(line) != 2):
          #print("Error: wrong voc format (different then 2)")
          pass
        else:
          if((max_size != None) and (self.get_size()==max_size)): # break if max size is reached
            break
          self.word2id[line[0]] = self.count
          self.id2word[self.count] = line[0]
          self.count += 1
  
  # get id given a word
  def get_id(self, word):
    if(word not in self.word2id):
      return self.word2id[OOV_WORD]
    return self.word2id[word]

  # get word given the id
  def get_word(self, id):
    return self.id2word[id]
  
  # get the size of the vocabulary
  def get_size(self):
    return len(self.word2id)
  
  # get word2id dictionary
  def get_vocab(self):
    return self.word2id

In [7]:
# MAX_TEXT_LENGTH: 2882
# MAX_ABS_LENGTH: 1726
# TRUNCATE_TEXT_LENGTH = 580
# TRUNCATE_ABSTRACT_LENGTH = 173

# Init the maximum number of sentences and tokens per sentence
MAX_SENTENCES = 50 # could be set lower
TOKEN_PER_SENT = 30

In [8]:
'''
The Dataset class
This class manages the articles and their respective summaries
The dataset loaded is taken from "https://paperswithcode.com/sota/extractive-document-summarization-on-cnn"
Input:
  - mode: the subset to load (train, val or test)
  - vocab: the vocabulary class
  - max_size: maximum articles to load (None -> All dataset)

The class reads and processes the specific subset of the dataset
'''
class CNN_dailymail(Dataset):
  def __init__(self, mode, vocab, max_size = None, path = "/content/dataset/cnn_dm_extractive_compressed_5000/"):
    if(mode == "train"):
      self.path = path+"/train.*"
    elif(mode == "val"):
      self.path = path+"/val.*"
    else:
      self.path = path+"/test.*"

    self.tokenizer = RegexpTokenizer(r'\w+')
    self.vocab = vocab
    self.examples = {} # The dictionary composed by the samples of the dataset 
    self.summaries = {} # Contains the strings of the summaries 
    self.sentences = {} # contains the articles split in sentences
    count = 0

    files = glob.glob(self.path)
    exit = False
    for _f in tqdm(files, position=0, leave=True):
      with open(_f) as f:
        data = json.load(f)
        for row in data:
          # process the data
          no_target, article, summary, str_summary, sentences = self.process(row)
          if(not no_target):
            self.examples[count] = (article, summary)
            self.summaries[count] = str_summary
            self.sentences[count] = sentences
            count+=1
          if((max_size != None) and (count >= max_size)):
            exit = True
            break
      if(exit):
        break

  # this function process the articles and summaries, it outputs the tensor article and summary, the string summary and the sentences of the article
  def process(self, data):
    sentences_list = data["src"]
    article, sentences = [], []
    str_summary = ""
    
    article = [[self.vocab.get_id("[PAD]")]*TOKEN_PER_SENT]*MAX_SENTENCES

    for dim1, sentence in enumerate(sentences_list):
      sent = " ".join(sentence)
      sentences.append(sent)
      sentence = self.tokenizer.tokenize(sent) # remove punctuation
      sent_len = 0
      # populate the article
      if(len(sentences) < MAX_SENTENCES):
        for dim2, word in enumerate(sentence):
          if(sent_len >= TOKEN_PER_SENT):
            break
          article[dim1][dim2] = self.vocab.get_id(word)
          sent_len += 1

    # create the string summary
    for idx, target in enumerate(data["labels"]):
      if(target == 1):
        str_summary += sentences[idx]
    
    # if the summary is empty, skip this article
    if(len(str_summary) == 0):
      return True, None, None, None, None
    
    # add sentences set to 0 to reach MAX_SENTENCES
    if(len(data["labels"]) < MAX_SENTENCES):
      data["labels"].extend([0 for i in range(MAX_SENTENCES-len(data["labels"]))])
    return False, torch.tensor(article), torch.tensor(data["labels"][:MAX_SENTENCES]), str_summary, sentences

  # return the (article, summary), the string summary and the sentences (joined with "|" since this list can have a different size and the dataloader would get an error)
  def __getitem__(self, idx):
    return self.examples[idx], self.summaries[idx], "|".join(self.sentences[idx])
  
  # return the size of the dataset
  def __len__(self):
    return len(self.examples)

In [9]:
'''
Download the dataset (from my google drive)
This will take some seconds
'''
def download_dataset(drive_id="1zcUZtgVpabefM0Q9CEh26HPJFYSpqFmR", file_name="CNN_extractive.zip"):
  path = "dataset/" + file_name
  os.system(f'mkdir {"dataset/"}')
  GDRIVE_ID = drive_id # id of file in google drive
  gdd.download_file_from_google_drive(file_id=GDRIVE_ID,
                                dest_path=f"dataset/{file_name}",
                                unzip=True)
  os.remove("dataset/CNN_extractive.zip")

if(not os.path.isdir("dataset/")):
      download_dataset()

Downloading 1zcUZtgVpabefM0Q9CEh26HPJFYSpqFmR into dataset/CNN_extractive.zip... Done.
Unzipping...Done.


In [10]:
'''
The Encoder, this model has to encode the input sentence
The input sentence is processed via two bidirectional GRU layers, the first one at word level and the second one at sentence level
The hidden states from the word level layer are concatenated and used as input to the second GRU layer with the previous sentence hidden state (sentence_hidden_state) as hidden state
The output is the sentence embedding and the hidden state from the GRU sentence layer
'''
class Encoder(nn.Module):
  def __init__(self, input_size, emb, hidden_size, sentence_length, n_layer, bidirectional, batch_size, vocab, dropout):
    super().__init__()
    self.hidden_size = hidden_size
    self.sentence_length = sentence_length
    self.batch_size = batch_size
    self.n_layer = n_layer
    self.embedding = nn.Embedding(input_size, emb, padding_idx = vocab.get_id("[PAD]"))
    self.dropout = nn.Dropout(dropout)
    self.bidirectional = 1
    if(bidirectional):
      self.bidirectional = 2
    self.word_layer = nn.GRU(emb, hidden_size, num_layers = n_layer, bidirectional = bidirectional, batch_first = True) # initialize the word level GRU layer
    self.sentence_layer = nn.GRU(hidden_size * sentence_length, hidden_size, num_layers = n_layer, bidirectional = bidirectional, batch_first=False, dropout=0.4) # initialize the sentence level GRU layer

  def forward(self, sentence, sentence_hidden_state):
    # sentence [2, 20]
    hidden = torch.zeros(self.bidirectional * self.n_layer, self.batch_size, self.hidden_size).to(device)
    concat_hidden = None
    # for each word in the sentence
    for idx in range(self.sentence_length):
      word = sentence[:,idx] # word -> [2]
      emb = self.dropout(self.embedding(word)).unsqueeze(1) # [2, 1, 256]
      word_enc, hidden = self.word_layer(emb, hidden)
      if(concat_hidden == None):
        concat_hidden = hidden
      else:
        concat_hidden = torch.cat((concat_hidden, hidden), 2)

    sentence_emb, hidden = self.sentence_layer(concat_hidden, sentence_hidden_state)
    return sentence_emb, hidden

  # init the hidden layer
  def init_layers(self):
    return torch.zeros(self.bidirectional * self.n_layer, self.batch_size, self.hidden_size).to(device)

'''
The decoder takes in input the sentence hidden state from the encoder, the hidden sum (H on the report) and the current summary state
The decoder performs a non linear transformation on H generating the representation of the entire article d
Then the content, saliency and novelty are computed via linear and bilinear transformations
Content + saliency - novelty are the input to the final classifier, the ouput of the decoder is the probability for the specific sentence to be selected or not (two classes)
'''
class Decoder(nn.Module):
  def __init__(self, output_size, hidden_size, input_size, n_layer, vocab, batch_size, dropout):
    super().__init__()
    self.device = device
    self.batch_size = batch_size
    self.representation = nn.Linear(hidden_size*2, hidden_size)
    self.sentence_transform = nn.Linear(hidden_size*2, hidden_size*2)
    self.content = nn.Linear(hidden_size*2, hidden_size, bias = False)
    self.salience = nn.Bilinear(hidden_size*2, hidden_size, hidden_size, bias = False)
    self.novelty = nn.Bilinear(hidden_size*2, hidden_size*2, hidden_size, bias = False)
    self.classifier = nn.Sequential(nn.Linear(hidden_size, hidden_size), nn.Linear(hidden_size, 2))

  def forward(self, sentence_hidden, hidden_sum, current_summary_hidden):
    d = torch.tanh(self.representation(hidden_sum))
    sentence_hidden = torch.tanh(self.sentence_transform(sentence_hidden))
    cont = self.content(sentence_hidden)
    salience = self.salience(sentence_hidden, d)
    novelty = self.novelty(sentence_hidden, torch.tanh(current_summary_hidden))
    # d [3, 256]
    # sentence [3, 512]
    # current_summary_hidden [3, 512]
    # cont [3, 256]
    # salience [3, 256]
    # novelty [3, 256]
    return torch.sigmoid(self.classifier(cont+salience-novelty))

In [11]:
'''
Converts the predicted output into a list of sentences (strings)
Used to compute the Rouge metric
'''
def convert_pred(output, sentences, vocab):
  pred = []
  for idx_sample, sample in enumerate(output):
    s = ''
    sent = sentences[idx_sample].split("|")
    for idx_sentence, sentence in enumerate(sample):
      if(idx_sentence >= len(sent)):
        break
      if(sentence == 1):
        s += sent[idx_sentence]
    if(len(s) == 0):
      s = '[PAD]'
    pred.append(s)
  return pred

'''
Decode one article (for debugging purposes)
'''
def print_art(art, vocab):
  article = ""
  for id in art:
    article += vocab.get_word(id.item()) + " "
  return article

'''
Decode one output as sentence (for debugging purposes)
'''
def print_pred(sample, vocab):
  s = ""
  for word in sample:
    if(vocab.get_word(word.item()) == STOP_DEC):
      break
    s += vocab.get_word(word.item()) + " "
  return s

'''
Compute the Rouge score given the outputs and abstracts of a batch
'''
def compute_accuracy(outputs, abstracts, sentences, vocab):
  acc = Rouge().get_scores(tuple(convert_pred(outputs, sentences, vocab)), tuple(abstracts), avg=True)
  return torch.tensor([acc['rouge-1']['r'], 
         acc['rouge-1']['p'], 
         acc['rouge-1']['f'],
         acc['rouge-2']['r'],
         acc['rouge-2']['p'],
         acc['rouge-2']['f'],
         acc['rouge-l']['r'],
         acc['rouge-l']['p'],
         acc['rouge-l']['f']])

'''
Print the metrics in a readable way
'''
def print_accuracies(acc):
  print(f"----ACC----\nRouge-1: recall {acc[0]}, precision {acc[1]}, f1 {acc[2]}\nRouge-2: recall {acc[3]}, precision {acc[4]}, f1 {acc[5]}\nRouge-l: recall {acc[6]}, precision {acc[7]}, f1 {acc[8]}\n")

In [12]:
'''
Save the encoder and decoder for continuing training
'''
def save_model(encoder, decoder, epoch, enc_opt, dec_opt):
  torch.save({
            'epoch': epoch,
            'encoder_state_dict': encoder.state_dict(),
            'decoder_state_dict': decoder.state_dict(),
            'enc_optimizer_state_dict': enc_opt.state_dict(),
            'dec_optimizer_state_dict': dec_opt.state_dict()
            }, "/content/model_extractive.ckp")

'''
Load a saved model (the best one in this case)
'''
def load_model(v_length, emb, hidden_size, n_layers, batch_size, vocab, learning_rate=0.01):
  encoder = Encoder(v_length, emb, hidden_size, TOKEN_PER_SENT, n_layers, True, batch_size, vocab, dropout = 0.1).to(device)
  decoder = Decoder(v_length, hidden_size, MAX_SENTENCES, n_layers, vocab, batch_size, dropout = 0.1).to(device)

  enc_opt = torch.optim.SGD(encoder.parameters(), lr=learning_rate)
  dec_opt = torch.optim.SGD(decoder.parameters(), lr=learning_rate)

  checkpoint = torch.load("/content/model_extractive_best.ckp")
  encoder.load_state_dict(checkpoint['encoder_state_dict'])
  enc_opt.load_state_dict(checkpoint['enc_optimizer_state_dict'])

  decoder.load_state_dict(checkpoint['decoder_state_dict'])
  dec_opt.load_state_dict(checkpoint['dec_optimizer_state_dict'])

  epoch = checkpoint['epoch']

  return encoder, enc_opt, decoder, dec_opt, epoch

In [13]:
'''
Download the best model
'''
def download_model(drive_id="1aphX8f02Vl7-kJboOXmUylTt3h3kO6b7", file_name="model_extractive_best.zip"):
  gdd.download_file_from_google_drive(file_id=drive_id,
                                dest_path=f"/content/{file_name}",
                                unzip=True)
  os.remove(f"/content/{file_name}")

download_model()

Downloading 1aphX8f02Vl7-kJboOXmUylTt3h3kO6b7 into /content/model_extractive_best.zip... Done.
Unzipping...Done.


In [14]:
'''
Train the encoder and decoder
This function executes N° epochs to train the model
Input:
  - encoder: the encoder to train
  - decoder: the decoder to train
  - enc_opt: the optimizer of the encoder
  - dec_opt: the optimizer of the decoder
  - loss_fn: loss function 
  - data_loader: the dataloader managing the dataset
  - check_val: a dataloader managing a subset of the validation set (1000 samples), to check every "print_acc" times
  - vocab: the vocabulary object
  - batch_size: the size of each batch
  - hidden_size: the size of the hidden layers (equal for both encoder and decoder)
  - n_sentences: the number of sentences
  - epochs: the number of epochs to perform
  - epoch: the epoch where to start (0 or the last epoch if the model was loaded)
  - device: gpu or cpu
'''
import sklearn.metrics as metrics
def train(encoder, decoder, enc_opt, dec_opt, loss_fn, data_loader, check_val, vocab, batch_size, hidden_size, n_sentences, epochs, epoch, device = "cuda"):
  torch.cuda.empty_cache()

  print_acc = 10 # print the accuracy every 10 epochs
  # Put the models in training mode
  encoder.train()
  decoder.train()

  recall_loss, precision_loss, accuracy_loss = 0, 0, 0

  # Epochs
  for e in range(epoch, epochs):
    print(f"---------EPOCH {e}---------")
    avg_loss = torch.zeros(1) # Average the loss
    c = 0
    recall_avg, precision_avg, accuracy_avg = 0, 0, 0
    acc_avg = torch.zeros(9) # Contains the accuracies (Rouge)
    
    # For each batch (data = (article, summary), summaries = raw summaries in strings, sentences = sentences of each article delimited by "|")
    for data, summaries, sentences in tqdm(data_loader, position=0, leave=True):
      precision, recall, accuracy = 0, 0, 0
      articles, abstracts = data[0].to(device), data[1].to(device)
      sentence_hidden = encoder.init_layers() # init the encoder hidden layer
      sum_sentence_hidden_states = torch.zeros(batch_size, hidden_size*2).to(device) # set the initial summation of the sentence hidden states to 0
      sentence_hidden_states = torch.zeros(n_sentences, batch_size, hidden_size*2).to(device) # set the sentence hidden states to 0
      acc = None
      # For each sentence of the article computes the encoding
      for i in range(n_sentences):
        _, sentence_hidden = encoder(articles[:,i].clone().to(device), sentence_hidden)
        sentence_hidden_states[i,:] = sentence_hidden.view(1,batch_size,-1).squeeze(0) # save the sentence hidden state
        sum_sentence_hidden_states[:] += sentence_hidden.view(1,batch_size,-1).squeeze(0) # sum over the hidden states

      H = sum_sentence_hidden_states[:]/n_sentences # normalize the summation -> H
      summary_state = torch.zeros(batch_size, hidden_size*2).to(device) # init the summary state to 0

      loss = 0

      outputs = torch.zeros(batch_size, n_sentences, dtype = torch.int32).to(device) # init the outputs to 0
      
      # for each sentence of the article
      for step in range(n_sentences):
        decoder_output = decoder(sentence_hidden_states[step,:].clone().to(device), H.to(device), summary_state.to(device))
        outputs[:,step] = decoder_output.argmax(1).detach() # take the highest predicted class
        summary_state += sentence_hidden_states[step,:].clone()*decoder_output[:,1].unsqueeze(1).clone() # update the summary state with the hidden state of the current sentence weighted with the probability of being part of the summary
        loss += loss_fn(decoder_output.to(device), abstracts[:,step].clone(), recall_loss, precision_loss) # compute loss

      enc_opt.zero_grad()
      dec_opt.zero_grad()
      loss.backward() # compute backpropagation
      # clipping
      torch.nn.utils.clip_grad_norm_(filter(lambda p: p.requires_grad, encoder.parameters()), 2.) 
      torch.nn.utils.clip_grad_norm_(filter(lambda p: p.requires_grad, decoder.parameters()), 2.) 
      enc_opt.step() # apply backpropagation
      dec_opt.step()    

      # compute precision, recall and accuracy
      for idx in range(batch_size):
        precision += metrics.precision_score(abstracts[idx].cpu(), outputs[idx].cpu(), zero_division=True)
        recall += metrics.recall_score(abstracts[idx].cpu(), outputs[idx].cpu(), zero_division=True)
        accuracy += metrics.accuracy_score(abstracts[idx].cpu(), outputs[idx].cpu())

      # update avgs
      precision_avg += precision/batch_size
      recall_avg += recall/batch_size
      accuracy_avg += accuracy/batch_size
      # compute the accuracy every "print_acc" epochs
      if((e+1)%print_acc == 0):
        acc = compute_accuracy(outputs, summaries, sentences, vocab)
        acc_avg += acc

      avg_loss += loss.item()/n_sentences
      
      c += 1

    recall_loss = recall_avg/c
    precision_loss = precision_avg/c

    if((e+1)%print_acc == 0):
      # print some prediction and ground truth
      # print(outputs[batch_size-1])
      # print(abstracts[batch_size-1])
      # print(outputs[batch_size-2])
      # print(abstracts[batch_size-2])
      # print(outputs[batch_size-10])
      # print(abstracts[batch_size-10])
      # print the accuracy
      print_accuracies(acc_avg/c) 
      # check current performance on a subest of the validation set
      print("ACC val subset:\n")
      test(encoder, decoder, check_val, vocab, batch_size, loss_fn) # compute the accuracy in the subset of the validation set
      encoder.train()
      decoder.train()
    print("Precision", precision_loss)
    print("Recall", recall_loss)
    print("Accuracy", accuracy_avg/c)
    save_model(encoder, decoder, e, enc_opt, dec_opt) # save the current model
    print("\nLoss: ", (avg_loss/c).item()) # print loss

In [15]:
'''
This function aims at testing the model, the dataloader should be the validation or test set
Input:
 - encoder: the trained encoder
 - decoder: the trained decoder
 - loader: the dataloader where to test the model
 - vocab: the vocabulary
 - batch_size: the size of each batch
 - device: cpu or gpu (cuda)
'''
def test(encoder, decoder, loader, vocab, batch_size, loss_fn, device = "cuda"):
  torch.cuda.empty_cache()
  # evaluation mode
  encoder.eval()
  decoder.eval()

  with torch.no_grad():
    c = 0
    acc_avg = torch.zeros(9)
    recall_avg, precision_avg, accuracy_avg = 0, 0, 0
    for data, summaries, sentences in tqdm(loader, position=0, leave=True): # for each batch
      precision, recall, accuracy = 0, 0, 0
      articles, abstracts = data[0].to(device), data[1].to(device)

      sentence_hidden = encoder.init_layers()

      n_sentences = articles.size(1)
      target_length = abstracts.size(1)

      sum_sentence_hidden_states = torch.zeros(batch_size, encoder.hidden_size*2).to(device)
      sentence_hidden_states = torch.zeros(n_sentences, batch_size, encoder.hidden_size*2).to(device)

      loss = 0

      for i in range(n_sentences):
        _, sentence_hidden = encoder(articles[:,i].clone().to(device), sentence_hidden)
        sentence_hidden_states[i,:] = sentence_hidden.view(1,batch_size,-1).squeeze(0)
        sum_sentence_hidden_states[:] += sentence_hidden.view(1,batch_size,-1).squeeze(0)

      d = sum_sentence_hidden_states[:]/n_sentences
      summary_state = torch.zeros(batch_size, encoder.hidden_size*2).to(device)

      outputs = torch.zeros(batch_size, n_sentences, dtype = torch.int32).to(device)

      for step in range(target_length):
        decoder_output = decoder(sentence_hidden_states[step,:].clone().to(device), d.to(device), summary_state.to(device))
        # outputs[:,step] = torch.where(decoder_output > 0.5, 1, 0).squeeze(1)
        outputs[:,step] = decoder_output.argmax(1).detach()
        summary_state += sentence_hidden_states[step,:].clone()*decoder_output[:,1].unsqueeze(1).clone() 
        loss += loss_fn(decoder_output.to(device), abstracts[:,step].clone(), recall/batch_size, precision/batch_size)
      
      for idx in range(batch_size):
        precision_avg += metrics.precision_score(abstracts[idx].cpu(), outputs[idx].cpu(), zero_division=True)
        recall_avg += metrics.recall_score(abstracts[idx].cpu(), outputs[idx].cpu(), zero_division=True)
        accuracy_avg += metrics.accuracy_score(abstracts[idx].cpu(), outputs[idx].cpu())

      precision_avg += precision/batch_size
      recall_avg += recall/batch_size
      accuracy_avg += accuracy/batch_size

      acc_avg += compute_accuracy(outputs, summaries, sentences, vocab)
      c+=1

    print_accuracies(acc_avg/c) 

In [16]:
# Compute the baselines using a random approach
# select K random sentences for each article (K = 4)
def random_baseline(loader, vocab, batch_size, n_sentences, K = 4):
  tot_avg = 0
  for i in range(10):
    acc_avg = 0
    c = 0
    for data, summaries, sentences in tqdm(loader, position=0, leave=True):
      outputs = torch.zeros(batch_size, n_sentences)
      rand = torch.rand(batch_size, n_sentences).topk(K, dim=1)[1]
      for idx, sample in enumerate(rand):
        for sentence in sample:
          outputs[idx, sentence] = 1
      acc = compute_accuracy(outputs, summaries, sentences, vocab)
      acc_avg += acc
      c+=1
    print_accuracies(acc_avg/c)
    tot_avg+=acc_avg/c
  print("Rouge score averaged over 10 runs:")
  print_accuracies(tot_avg/10)

In [17]:
# Init the vocabulary
vocab = Vocab(max_size = 50000)

# Load the datasets
t = CNN_dailymail("train", vocab, max_size=5000)
v = CNN_dailymail("val", vocab)
check_val = CNN_dailymail("val", vocab, max_size=1000) # to check validation performances during training
te = CNN_dailymail("test", vocab)

  0%|          | 0/58 [00:05<?, ?it/s]
100%|██████████| 3/3 [00:11<00:00,  3.93s/it]
  0%|          | 0/3 [00:01<?, ?it/s]
100%|██████████| 3/3 [00:09<00:00,  3.31s/it]


In [18]:
# init the device to use
device = 'cuda' if torch.cuda.is_available() else 'cpu'
batch_size = 500 # need to have 15 GB of gpu memory available

# init the dataloaders
train_ds = DataLoader(t, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True, drop_last=True)
val_ds = DataLoader(v, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True, drop_last=True)
check_val = DataLoader(check_val, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True, drop_last=True)
test_ds = DataLoader(te, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True, drop_last=True)

In [19]:
# loss function
def loss(outputs, targets, recall, precision):
  weight_class_0 = 0.1
  weight_class_1 = 1.0
  return nn.CrossEntropyLoss(weight=torch.tensor([weight_class_0, weight_class_1])).to(device)(outputs, targets) + (1-(recall*precision))**2
  # return nn.CrossEntropyLoss(weight=torch.tensor([weight_class_0, weight_class_1])).to(device)(outputs, targets)

In [20]:
v_length = vocab.get_size()
emb, hidden_size = 64, 128

load = True # load the best model
lr = 0.001 # learning rate
epoch = 0 # starting epoch
force_training = False # start from epoch 0 (if already trained model)
n_layers = 1 # number of layers for the GRU units

if(not load):
  # init the models
  encoder = Encoder(v_length, emb, hidden_size, TOKEN_PER_SENT, n_layers, True, batch_size, vocab, dropout = 0.1).to(device)
  decoder = Decoder(v_length, hidden_size, MAX_SENTENCES, 1, vocab, batch_size, dropout = 0.1).to(device)
  opt_enc = torch.optim.Adam(encoder.parameters(), lr=lr)
  opt_dec = torch.optim.Adam(decoder.parameters(), lr=lr)
  # init optimizers
  # opt_enc = torch.optim.SGD(encoder.parameters(), lr=lr)
  # opt_dec = torch.optim.SGD(decoder.parameters(), lr=lr)
else:
  # load the last saved models
  encoder, opt_enc, decoder, opt_dec, epoch = load_model(vocab.get_size(), emb, hidden_size, n_layers, batch_size, vocab, learning_rate=lr)

loss_fn = loss

if(force_training):
  epoch = 0

print("\nN° Parameters: ", (sum(p.numel() for p in encoder.parameters() if p.requires_grad) + sum(p.numel() for p in decoder.parameters() if p.requires_grad)))

  "num_layers={}".format(dropout, num_layers))



N° Parameters:  19129090


In [21]:
# To train a new model uncomment the below line (make sure the load variable is set to False)
# train(encoder, decoder, opt_enc, opt_dec, loss_fn, train_ds, check_val, vocab, batch_size, hidden_size, MAX_SENTENCES, 90, epoch, device)

In [22]:
# Eventually test the final model

test(encoder, decoder, val_ds, vocab, batch_size, loss)

test(encoder, decoder, test_ds, vocab, batch_size, loss)

100%|██████████| 26/26 [03:14<00:00,  7.49s/it]


----ACC----
Rouge-1: recall 0.5485537648200989, precision 0.47357696294784546, f1 0.4781179130077362
Rouge-2: recall 0.40462473034858704, precision 0.3306196630001068, f1 0.3338182270526886
Rouge-l: recall 0.5215343832969666, precision 0.44730716943740845, f1 0.45264753699302673



100%|██████████| 22/22 [02:46<00:00,  7.57s/it]

----ACC----
Rouge-1: recall 0.5492520332336426, precision 0.46647870540618896, f1 0.47440892457962036
Rouge-2: recall 0.4034233093261719, precision 0.3237416446208954, f1 0.32917407155036926
Rouge-l: recall 0.5217207074165344, precision 0.4402707815170288, f1 0.44875916838645935






In [None]:
# baselines
random_baseline(val_ds, vocab, batch_size, MAX_SENTENCES)
random_baseline(test_ds, vocab, batch_size, MAX_SENTENCES)