In [None]:
!pip install datasets

In [None]:
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from nltk import tokenize
import nltk
from datasets import load_dataset, Dataset
nltk.download('punkt')
from transformers import AutoTokenizer


[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

#1. Specify model

In [None]:
class HierarchicalGRU(nn.Module):
    def __init__(self, vocab_size, embedding_dim, rnn_hidden_size, num_layers, dropout, device):
        super(HierarchicalGRU, self).__init__()
        self.device          = device
        self.vocab_size      = vocab_size
        self.embedding_dim   = embedding_dim
        self.rnn_hidden_size = rnn_hidden_size

        # embedding layer
        self.embedding = nn.Embedding(self.vocab_size, embedding_dim=self.embedding_dim, padding_idx=0)

        # word-level GRU layer: word-embeddings -> utterance representation
        # divide by 2 becuase bi-directional
        self.gru_wlevel = nn.GRU(input_size=self.embedding_dim, hidden_size=int(self.rnn_hidden_size/2), num_layers=num_layers,
                                bias=True, batch_first=True, dropout=dropout, bidirectional=True)

        # utterance-level GRU layer (with  binary gate)
        self.gru_ulevel = nn.GRU(input_size=self.rnn_hidden_size, hidden_size=int(self.rnn_hidden_size/2), num_layers=num_layers,
                                bias=True, batch_first=True, dropout=dropout, bidirectional=True)

    def forward(self, input, u_len, w_len):
        # input => [batch_size, num_utterances, num_words]
        # embed => [batch_size, num_utterances, num_words, embedding_dim]
        # embed => [batch_size*num_utterances,  num_words, embedding_dim]

        batch_size     = input.size(0)
        num_utterances = input.size(1)
        num_words      = input.size(2)

        embed = self.embedding(input)
        embed = embed.view(batch_size*num_utterances, num_words, self.embedding_dim)

        # word-level GRU
        self.gru_wlevel.flatten_parameters()
        w_output, _ = self.gru_wlevel(embed)
        w_len = w_len.reshape(-1)

        # utterance-level GRU
        utt_input = torch.zeros((w_output.size(0), w_output.size(2)), dtype=torch.float).to(self.device)
        for idx, l in enumerate(w_len):
            utt_input[idx] = w_output[idx, l-1]
        utt_input = utt_input.view(batch_size, num_utterances, self.rnn_hidden_size)
        self.gru_ulevel.flatten_parameters()
        utt_output, _ = self.gru_ulevel(utt_input)

        # reshape the output at different levels
        # w_output => [batch_size, num_utt, num_words, 2*hidden]
        # u_output => [batch_size, num_utt, hidden]
        w_output = w_output.view(batch_size, num_utterances, num_words, -1)
        w_len    = w_len.view(batch_size, -1)
        w2_len   = [None for _ in range(batch_size)]
        for bn, _l in enumerate(u_len):
            w2_len[bn] = w_len[bn, :_l].sum().item()

        w2_output = torch.zeros((batch_size, max(w2_len), w_output.size(-1))).to(self.device)
        utt_indices = [[] for _ in range(batch_size)]
        for bn, l1 in enumerate(u_len):
            x = 0
            for j, l2 in enumerate(w_len[bn, :l1]):
                w2_output[bn, x:x+l2, :] = w_output[bn, j, :l2, :]
                x += l2.item()
                utt_indices[bn].append(x-1) # minus one!!
        encoder_output_dict = {
            'u_output': utt_output, 'u_len': u_len,
            'w_output': w2_output, 'w_len': w2_len, 'utt_indices': utt_indices
        }
        return encoder_output_dict

In [None]:
class DecoderGRU(nn.Module):
    """A conditional RNN decoder with attention."""

    def __init__(self, vocab_size, embedding_dim, dec_hidden_size, mem_hidden_size,
                num_layers, dropout, device):
        super(DecoderGRU, self).__init__()
        self.device      = device
        self.vocab_size  = vocab_size
        self.dec_hidden_size = dec_hidden_size
        self.mem_hidden_size = mem_hidden_size
        self.num_layers  = num_layers
        self.dropout     = dropout

        self.embedding = nn.Embedding(vocab_size, embedding_dim=embedding_dim, padding_idx=0)

        self.rnn = nn.GRU(embedding_dim, dec_hidden_size, num_layers, batch_first=True, dropout=dropout)

        self.dropout_layer = nn.Dropout(p=dropout)

        self.attention_u = nn.Linear(mem_hidden_size, dec_hidden_size)
        self.attention_w = nn.Linear(mem_hidden_size, dec_hidden_size)

        self.output_layer = nn.Linear(dec_hidden_size+mem_hidden_size, vocab_size, bias=True)
        self.logsoftmax = nn.LogSoftmax(dim=-1)


    def forward(self, target, encoder_output_dict, logsoftmax=True):
        u_output = encoder_output_dict['u_output']
        u_len    = encoder_output_dict['u_len']
        w_output = encoder_output_dict['w_output']
        w_len    = encoder_output_dict['w_len']

        utt_indices     = encoder_output_dict['utt_indices']

        batch_size = target.size(0)

        embed = self.embedding(target)
        # initial hidden state
        initial_h = torch.zeros((self.num_layers, batch_size, self.dec_hidden_size), dtype=torch.float).to(self.device)
        for bn, l in enumerate(u_len):
            initial_h[:,bn,:] = u_output[bn,l-1,:].unsqueeze(0)

        self.rnn.flatten_parameters()
        rnn_output, _ = self.rnn(embed, initial_h)

        # attention mechanism LEVEL --- Utterance (u)
        scores_u = torch.bmm(rnn_output, self.attention_u(u_output).permute(0,2,1))
        for bn, l in enumerate(u_len):
            scores_u[bn,:,l:].fill_(float('-inf'))
        scores_u = F.log_softmax(scores_u, dim=-1)

        # attention mechanism LEVEL --- Word (w)
        scores_w = torch.bmm(rnn_output, self.attention_w(w_output).permute(0,2,1))
        for bn, l in enumerate(w_len):
            scores_w[bn,:,l:].fill_(float('-inf'))
        # scores_w = F.log_softmax(scores_w, dim=-1)
        scores_uw = torch.zeros(scores_w.shape).to(self.device)
        scores_uw.fill_(float('-inf')) # when doing log-addition

        # Utterance -> Word
        for bn in range(batch_size):
            idx1 = 0
            idx2 = 0
            end_indices = utt_indices[bn]
            start_indices = [0] + [a+1 for a in end_indices[:-1]]
            for i in range(len(utt_indices[bn])):
                i1 = start_indices[i]
                i2 = end_indices[i]+1 # python
                scores_uw[bn, :, i1:i2] = scores_u[bn, :, i].unsqueeze(-1) + F.log_softmax(scores_w[bn, :, i1:i2], dim=-1)

        scores_uw = torch.exp(scores_uw)
        context_vec = torch.bmm(scores_uw, w_output)

        dec_output = self.output_layer(torch.cat((context_vec, rnn_output), dim=-1))

        if logsoftmax:
            dec_output = self.logsoftmax(dec_output)

        return dec_output, scores_uw, torch.exp(scores_u)

        # FOR multiple GPU training --- cannot have scores_uw (size error)
        # return dec_output

    def forward_step(self, xt, ht, encoder_output_dict, d_prev=None, eu_prev=None, logsoftmax=True):
        u_output = encoder_output_dict['u_output']
        u_len    = encoder_output_dict['u_len']
        w_output = encoder_output_dict['w_output']
        w_len    = encoder_output_dict['w_len']

        utt_indices     = encoder_output_dict['utt_indices']

        batch_size = xt.size(0)

        xt = self.embedding(xt) # xt => [batch_size, 1, input_size]
                                # ht => [batch_size, num_layers, hidden_size]

        rnn_output, ht1  = self.rnn(xt, ht)

        # attention mechanism LEVEL --- Utterance (u)
        scores_u = torch.bmm(rnn_output, self.attention_u(u_output).permute(0,2,1))
        for bn, l in enumerate(u_len):
            scores_u[bn,:,l:].fill_(float('-inf'))
        # scores_u = F.log_softmax(scores_u, dim=-1)
        scores_u = F.softmax(scores_u, dim=-1)

        # attention mechanism LEVEL --- Word (w)
        scores_w = torch.bmm(rnn_output, self.attention_w(w_output).permute(0,2,1))
        for bn, l in enumerate(w_len):
            scores_w[bn,:,l:].fill_(float('-inf'))
        scores_uw = torch.zeros(scores_w.shape).to(self.device)
        scores_uw.fill_(float('-inf')) # when doing log-addition

        # Utterance -> Word
        for bn in range(batch_size):
            idx1 = 0
            idx2 = 0
            end_indices = utt_indices[bn]
            start_indices = [0] + [a+1 for a in end_indices[:-1]]
            for i in range(len(utt_indices[bn])):
                i1 = start_indices[i]
                i2 = end_indices[i]+1 # python
                scores_uw[bn, :, i1:i2] = scores_u[bn, :, i].unsqueeze(-1) * F.softmax(scores_w[bn, :, i1:i2], dim=-1)


        # scores_uw = torch.exp(scores_uw)
        context_vec = torch.bmm(scores_uw, w_output)
        dec_output = self.output_layer(torch.cat((context_vec, rnn_output), dim=-1))


        if logsoftmax:
            logsm_dec_output = self.logsoftmax(dec_output)
            return logsm_dec_output[:,-1,:], ht1, scores_uw, scores_u, dec_output[:,-1,:]

        else:
            return dec_output[:,-1,:], ht1, scores_uw, scores_u, dec_output[:,-1,:]

In [None]:
class EXTLabeller(nn.Module):
    def __init__(self, dropout, rnn_hidden_size, device):
        super(EXTLabeller, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        self.linear = nn.Linear(rnn_hidden_size, 1, bias=True)
        self.sigmoid = nn.Sigmoid()

        for p in self.parameters():
            if p.dim() > 1: nn.init.xavier_normal_(p)
            else: nn.init.zeros_(p)

        self.to(device)

    def forward(self, utt_output):
        x = self.dropout(utt_output)
        x = self.linear(x)

        return self.sigmoid(x).squeeze(-1)

In [None]:
class EncoderDecoder(nn.Module):
    def __init__(self, args, device):
        super(EncoderDecoder, self).__init__()
        self.device = device


        # Encoder - Hierarchical GRU
        self.encoder = HierarchicalGRU(args['vocab_size'], args['embedding_dim'], args['rnn_hidden_size'],
                                       num_layers=args['num_layers_enc'], dropout=args['dropout'], device=device)

        # Decoder - GRU with attention mechanism
        self.decoder = DecoderGRU(args['vocab_size'], args['embedding_dim'], args['rnn_hidden_size'], args['rnn_hidden_size'],
                                       num_layers=args['num_layers_dec'], dropout=args['dropout'], device=device)

        self.param_init()

        self.to(device)

    def param_init(self):
        # Initialisation
        # zero out the bias term
        # don't zero out LayerNorm term e.g. transformer_encoder.layers.0.norm1.weight
        for name, p in self.encoder.named_parameters():
            if p.dim() > 1: nn.init.xavier_normal_(p)
            else:
                # if name[-4:] == 'bias': p.data.zero_()
                if 'bias' in name: nn.init.zeros_(p)
        for name, p in self.decoder.named_parameters():
            if p.dim() > 1: nn.init.xavier_normal_(p)
            else:
                # if name[-4:] == 'bias': p.data.zero_()
                if 'bias' in name: nn.init.zeros_(p)

    def forward(self, input, u_len, w_len, target):
        enc_output_dict = self.encoder(input, u_len, w_len)
        dec_output, attn_scores, u_attn_scores = self.decoder(target, enc_output_dict)

        # compute coverage
        # cov_scores = self.attn2cov(attn_scores)
        return dec_output, enc_output_dict['u_output'], attn_scores, u_attn_scores

        # FOR multiple GPU training --- cannot have scores_uw (size error)
        # dec_output = self.decoder(target, enc_output_dict)
        # return dec_output

    def decode_beamsearch(self, input, u_len, w_len, decode_dict):
        """
        this method is meant to be used at inference time
            input = input to the encoder
            u_len = utterance lengths
            w_len = word lengths
            decode_dict:
                - k                = beamwidth for beamsearch
                - batch_size       = batch_size
                - time_step        = max_summary_length
                - vocab_size       = 30522 for BERT
                - device           = cpu or cuda
                - start_token_id   = ID of the start token
                - stop_token_id    = ID of the stop token
                - alpha            = length normalisation
                - length_offset    = length offset
                - keypadmask_dtype = torch.bool
        """
        k                = decode_dict['k']
        batch_size       = decode_dict['batch_size']
        time_step        = decode_dict['time_step']
        vocab_size       = decode_dict['vocab_size']
        device           = decode_dict['device']
        start_token_id   = decode_dict['start_token_id']
        stop_token_id    = decode_dict['stop_token_id']
        alpha            = decode_dict['alpha']
        penalty_ug       = decode_dict['penalty_ug']
        # keypadmask_dtype = decode_dict['keypadmask_dtype'] ---> this is causing on the API that checks for torch1.2 (commented out on 11 Jan 2021)

        if batch_size != 1: raise ValueError("batch size must be 1")

        # create beam array & scores
        beams       = [None for _ in range(k)]
        beam_scores = np.zeros((k,))

        # we should only feed through the encoder just once!!
        enc_output_dict = self.encoder(input, u_len, w_len) # memory
        u_output = enc_output_dict['u_output']
        w_output = enc_output_dict['w_output']
        # w_len    = enc_output_dict['w_len']
        enc_time_step   = w_output.size(1)
        enc_time_step_u = u_output.size(1)

        # we run the decoder time_step times (auto-regressive)
        tgt_ids = torch.zeros((time_step,), dtype=torch.int64).to(device)
        tgt_ids[0] = start_token_id

        for i in range(k): beams[i] = tgt_ids

        finished_beams = []
        finished_beams_scores = []
        finished_attn = []
        finished_attn_u = []

        # initial hidden state
        ht = torch.zeros((self.decoder.num_layers, 1, self.decoder.dec_hidden_size), dtype=torch.float).to(self.device)
        l = u_len[0]
        ht[:,0,:] = u_output[0,l-1,:].unsqueeze(0)

        beam_ht = [None for _ in range(k)]
        for _k in range(k): beam_ht[_k] = ht.clone()

        finish = False

        attn_scores_array = [torch.zeros((time_step, enc_time_step)) for _ in range(k)]
        attn_scores_u_array = [torch.zeros((time_step, enc_time_step_u)) for _ in range(k)]

        for t in range(time_step-1):
            if finish: break

            decoder_output_t_array = torch.zeros((k*vocab_size,))

            for i, beam in enumerate(beams):

                # inference decoding
                decoder_output, beam_ht[i], attn_scores, attn_scores_u, _ = self.decoder.forward_step(beam[t:t+1].unsqueeze(0), beam_ht[i], enc_output_dict, logsoftmax=True)

                attn_scores_array[i][t, :] = attn_scores[0,0,:]
                attn_scores_u_array[i][t, :] = attn_scores_u[0,0,:]
                # check if there is STOP_TOKEN emitted in the previous time step already
                # i.e. if the input at this time step is STOP_TOKEN
                if beam[t] == stop_token_id: # already stop
                    decoder_output[0, :] = float('-inf')
                    decoder_output[0, stop_token_id] = 0.0 # to ensure STOP_TOKEN will be picked again!

                decoder_output_t_array[i*vocab_size:(i+1)*vocab_size] = decoder_output[0]

                # add previous beam score bias
                decoder_output_t_array[i*vocab_size:(i+1)*vocab_size] += beam_scores[i]

                if penalty_ug > 0.0:
                    # Penalty term for repeated uni-gram
                    unigram_dict = {}
                    for tt in range(t+1):
                        v = beam[tt].cpu().numpy().item()
                        if v not in unigram_dict: unigram_dict[v] = 1
                        else: unigram_dict[v] += 1
                    for vocab_id, vocab_count in unigram_dict.items():
                        decoder_output_t_array[(i*vocab_size)+vocab_id] -= penalty_ug*vocab_count/(t+1)

                # only support batch_size = 1!
                if t == 0:
                    decoder_output_t_array[(i+1)*vocab_size:] = float('-inf')
                    break


            # Argmax
            topk_scores, topk_ids = torch.topk(decoder_output_t_array, k, dim=-1)
            scores = topk_scores.double().cpu().numpy()
            indices = topk_ids.double().cpu().numpy()

            new_beams = [torch.zeros((time_step,), dtype=torch.int64).to(device) for _ in range(k)]
            new_attn_scores_array = [torch.zeros((time_step, enc_time_step)) for _ in range(k)]
            new_attn_scores_u_array = [torch.zeros((time_step, enc_time_step_u)) for _ in range(k)]
            new_beam_ht = [None for _ in range(k)]

            for c_idx, node in enumerate(indices):

                vocab_idx = node % vocab_size
                beam_idx  = int(node / vocab_size)

                new_beams[c_idx][:t+1] = beams[beam_idx][:t+1]
                new_beams[c_idx][t+1]  = vocab_idx

                new_beam_ht[c_idx]     = beam_ht[beam_idx]

                new_attn_scores_array[c_idx][:t+1 ,:] = attn_scores_array[beam_idx][:t+1 ,:]
                new_attn_scores_u_array[c_idx][:t+1 ,:] = attn_scores_u_array[beam_idx][:t+1 ,:]

                # if there is a beam that has [END_TOKEN] --- store it
                if vocab_idx == stop_token_id:
                    finished_beams.append(new_beams[c_idx][:t+1+1])
                    finished_beams_scores.append(scores[c_idx] / t**alpha)
                    finished_attn.append(new_attn_scores_array[c_idx][:t+1 ,:])
                    finished_attn_u.append(new_attn_scores_u_array[c_idx][:t+1 ,:])
                    # print("beam{}: [{:.5f}]".format(c_idx, scores[c_idx] / t**alpha), bert_tokenizer.decode(new_beams[c_idx][:t+1+1].cpu().numpy()))
                    scores[c_idx] = float('-inf')

            beams = new_beams
            beam_ht = new_beam_ht
            attn_scores_array = new_attn_scores_array
            attn_scores_u_array = new_attn_scores_u_array
            beam_scores = scores

            # print("=========================  t = {} =========================".format(t))
            # for ik in range(k):
            #     print("beam{}: [{:.5f}]".format(ik, scores[ik]),bert_tokenizer.decode(beams[ik].cpu().numpy()[:t+2]))
            # import pdb; pdb.set_trace()

        if len(finished_beams_scores) > 0:
            max_id = finished_beams_scores.index(max(finished_beams_scores))
            summary_ids = finished_beams[max_id].cpu().numpy()
            attn_score  = finished_attn[max_id]
            attn_score_u  = finished_attn_u[max_id]
        else:
            summary_ids = beams[0].cpu().numpy()
            attn_score  = attn_scores_array[0]
            attn_score_u = attn_scores_u_array[0]

        return summary_ids, attn_score, attn_score_u

#2. Training

In [None]:
config = {
    'vocab_size' :50265,
    'embedding_dim': 256,
    'rnn_hidden_size' : 512,
    'num_layers_enc' : 2,
    'num_layers_dec' : 1,
    'dropout' : 0.1,
    "gamma" : 0.2,
    "summary_length" : 512,
    "num_words" : 50,
    "num_utterances" : 1000,
}
num_epochs = 1
batch_size = 8
display_step = 5

Import dataset

In [None]:
train_dataset = load_dataset("LA1512/train_pubmed_ORC_4096_20k")["train"]
val_dataset = load_dataset("LA1512/val_pubmed_ORC_4096_1592")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


In [None]:
def length2mask(length, batch_size, max_len, torch_device):
    mask = torch.zeros((batch_size, max_len), dtype=torch.float, device=torch_device)
    for bn in range(batch_size):
        l = length[bn].item()
        mask[bn,:l].fill_(1.0)
    return mask

def shift_decoder_target(target, tgt_len, torch_device, mask_offset=False):
    batch_size = target.size(0)
    max_len = target.size(1)
    dtype0  = target.dtype

    decoder_target = torch.zeros((batch_size, max_len), dtype=dtype0, device=torch_device)
    decoder_target[:,:-1] = target.clone().detach()[:,1:]
    # decoder_target[:,-1:] = 103 # MASK_TOKEN_ID = 103
    # decoder_target[:,-1:] = 0 # add padding id instead of MASK

    # mask for shifted decoder target
    decoder_mask = torch.zeros((batch_size, max_len), dtype=torch.float, device=torch_device)
    if mask_offset:
        offset = 10
        for bn, l in enumerate(tgt_len):
            # decoder_mask[bn,:l-1].fill_(1.0)
            # to accommodate like 10 more [MASK] [MASK] [MASK] [MASK],...
            if l-1+offset < max_len: decoder_mask[bn,:l-1+offset].fill_(1.0)
            else: decoder_mask[bn,:].fill_(1.0)
    else:
        for bn, l in enumerate(tgt_len):
            decoder_mask[bn,:l-1].fill_(1.0)

    return decoder_target, decoder_mask

In [None]:
class HierArticleBatcher():
    def __init__(self, tokenizer, config,document, torch_device):

        self.num_utterances = config['num_utterances']
        self.num_words      = config['num_words']
        self.summary_length = config['summary_length']
        self.max_sum_len = config['summary_length']
        self.tokenizer = tokenizer
        self.document = document
        self.torch_device = torch_device
        self.device = torch_device


    # Override
    def get_a_batch(self, batch_size, batch_index):
        """
        return input, u_len, w_len, target, tgt_len, ext_label
        """
        input = np.zeros((batch_size, self.num_utterances, self.num_words), dtype=np.int64)
        u_len = np.zeros((batch_size), dtype=np.int64)
        w_len = np.zeros((batch_size, self.num_utterances), dtype=np.int64)
        ext_target = np.zeros((batch_size, self.num_utterances), dtype=np.float32)

        target  = np.zeros((batch_size, self.summary_length), dtype=np.int64)
        target.fill(50264)
        tgt_len = np.zeros((batch_size), dtype=np.int64)

        batch_count = 0

        while batch_count < batch_size:
            # ENCODER
            sentences = tokenize.sent_tokenize(self.document[batch_count + batch_index*batch_size]["article"])
            num_sentences = len(sentences)
            if num_sentences > self.num_utterances:
                num_sentences = self.num_utterances
                sentences = sentences[:self.num_utterances]
            u_len[batch_count] = num_sentences

            for j, sent in enumerate(sentences):
                token_ids = self.tokenizer.encode(sent.lower(), add_special_tokens=False, max_length=50000)
                utt_len = len(token_ids)
                if utt_len > self.num_words:
                    utt_len = self.num_words
                    token_ids = token_ids[:self.num_words]
                input[batch_count,j,:utt_len] = token_ids
                w_len[batch_count,j] = utt_len

            # Extractive Sum Label
            ext_target_example = self.document[batch_count + batch_index*batch_size]["ext_target"]
            positive_postions = [i for i in range(num_sentences) if ext_target_example[i] == 1]
            ext_target[batch_count][positive_postions] = 1.0

            # DECODER
            description   =  " ".join(self.document[batch_count + batch_index*batch_size]["abstract"]).lower()
            concat_tokens = [0]
            sentences = tokenize.sent_tokenize(description)
            for j, sent in enumerate(sentences):
                token_ids = self.tokenizer.encode(sent, add_special_tokens=False, max_length=50000)
                concat_tokens.extend(token_ids)
                concat_tokens.extend([1]) # [SEP]
            tl = len(concat_tokens)
            if tl > self.summary_length:
                concat_tokens = concat_tokens[:self.summary_length]
                tl = self.summary_length
            target[batch_count, :tl] = concat_tokens
            tgt_len[batch_count] = tl

            # increment
            batch_count += 1

        u_len_max = u_len.max()
        w_len_max = w_len.max()

        input = torch.from_numpy(input[:, :u_len_max, :w_len_max]).to(self.device)
        u_len = torch.from_numpy(u_len).to(self.device)
        w_len = torch.from_numpy(w_len[:, :u_len_max]).to(self.device)
        target = torch.from_numpy(target).to(self.device)
        ext_target = torch.from_numpy(ext_target[:, :u_len_max]).to(self.device)

        return input, u_len, w_len, target, tgt_len, ext_target

In [None]:
def evaluate(model, ext_labeller, gamma, val_batcher, batch_size, config, torch_device):
    print("start validating")
    criterion = nn.NLLLoss(reduction='none')
    ext_criterion = nn.BCELoss(reduction='none')

    eval_total_loss1 = 0.0
    eval_total_loss2 = 0.0
    eval_total_tokens1 = 0
    eval_total_tokens2 = 0

    num_batch = int(len(data_val_small) / batch_size)
    for batch_index in range(num_batch):
    # for i in range(5):
        input, u_len, w_len, target, tgt_len, ext_target = val_batcher.get_a_batch(batch_size, batch_index)

        # decoder target
        decoder_target, decoder_mask = shift_decoder_target(target, tgt_len, torch_device)
        decoder_target = decoder_target.view(-1)
        decoder_mask = decoder_mask.view(-1)
        decoder_output, enc_u_output, _, _ = model(input, u_len, w_len, target)

        loss1 = criterion(decoder_output.view(-1, config['vocab_size']), decoder_target)

        loss_utt_mask = length2mask(u_len, batch_size, u_len.max().item(), torch_device)
        ext_output = ext_labeller(enc_u_output).squeeze(-1)
        loss2 = ext_criterion(ext_output, ext_target)

        eval_total_loss1   += (loss1 * decoder_mask).sum().item()
        eval_total_loss2  += (loss2 * loss_utt_mask).sum().item()

        eval_total_tokens1 += decoder_mask.sum().item()
        eval_total_tokens2 += loss_utt_mask.sum().item()

        print("#", end="")
        sys.stdout.flush()

    print()
    avg_eval_loss1 = eval_total_loss1 / eval_total_tokens1
    avg_eval_loss2 = eval_total_loss2 / eval_total_tokens2
    val_batcher.epoch_counter = 0
    val_batcher.cur_id = 0
    print("finish validating")
    avg_eval_loss = (1-gamma)*avg_eval_loss1 + gamma*avg_eval_loss2
    print(avg_eval_loss)
    return avg_eval_loss

In [None]:
#Specify the loss function
criterion = nn.NLLLoss(reduction='none')
ext_criterion = nn.BCELoss(reduction='none')

In [None]:
torch_device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
#Specify the model

model = EncoderDecoder(config, device= torch_device)
ext_labeller = EXTLabeller(rnn_hidden_size=config['rnn_hidden_size'], dropout=config['dropout'], device=torch_device)



In [None]:
num_batch = int(len(data_small) / batch_size)

In [None]:
#Specify the optimizer
import torch.optim as optim
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.001,betas=(0.9,0.999),eps=1e-08,weight_decay=0)
optimizer.zero_grad()
ext_optimizer = optim.Adam(ext_labeller.parameters(),lr=0.001,betas=(0.9,0.999),eps=1e-08,weight_decay=0)
ext_optimizer.zero_grad()


In [None]:
import gc
import torch
torch.cuda.empty_cache()
gc.collect()

30

In [None]:
from tqdm.auto import tqdm


progress_bar = tqdm(range(num_training_steps))

tokenizer = AutoTokenizer.from_pretrained("pszemraj/led-base-book-summary")


batcher = HierArticleBatcher(tokenizer, config,data_small ,torch_device)
val_batcher = HierArticleBatcher(tokenizer, config,data_val_small ,torch_device)


if torch_device == 'cuda':
        model.cuda()
        ext_labeller.cuda()


model.train()
step = 1
for epoch in range(num_epochs):
    for batch_index in range(num_batch):
        input, u_len, w_len, target, tgt_len, ext_target = batcher.get_a_batch(batch_size,batch_index)

        # decoder target
        decoder_target, decoder_mask = shift_decoder_target(target, tgt_len, torch_device, mask_offset=False)
        decoder_target = decoder_target.view(-1)
        decoder_mask = decoder_mask.view(-1)

        # Forward pass
        decoder_output, enc_u_output, attn_scores, u_attn_scores = model(input, u_len, w_len, target)

        # Multitask Learning: Task 1 - Predicting targets
        loss1 = criterion(decoder_output.view(-1, config['vocab_size']), decoder_target)
        loss1 = (loss1 * decoder_mask).sum() / decoder_mask.sum()

        # Multitask Learning: Task 2 - Extractive Summarisation
        loss_utt_mask = length2mask(u_len, batch_size, u_len.max().item(), torch_device)
        ext_output = ext_labeller(enc_u_output).squeeze(-1)
        loss2 = ext_criterion(ext_output, ext_target)
        loss2 = (loss2 * loss_utt_mask).sum() / loss_utt_mask.sum()

        loss = (1-config["gamma"])*loss1 + config["gamma"]*loss2


        if step % display_step == 0:
          print(f"Step: {step} The loss: {loss / step}")

        step += 1
        loss.backward()

        optimizer.step()
        optimizer.zero_grad()
        progress_bar.update(1)

    evaluate(model, ext_labeller, config["gamma"],val_batcher, batch_size, config, torch_device)

  0%|          | 0/6 [00:00<?, ?it/s]

Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.


Step: 5 The loss: 1.7597471475601196
start validating
######
finish validating
8.798362437948713


In [None]:
import gc
import torch
torch.cuda.empty_cache()
gc.collect()

0

#3. Inference

In [None]:
args = {
    "inference_mode" : "mcs",
    "max_abssum_len" : 4096,
    "max_num_sent"   : 1000,
    "max_word_in_sent" : 120,
    "beam_width" : 4,
    "time_step" : 512,
    "penalty_ug" : 0,
    "alpha" : 1.25,
    "length_offset" : 5,
    "use_gpu" : True
}

In [None]:
def get_enc_input(tokenizer, list_sentences,
        max_num_sent, max_word_in_sent, use_gpu=True):

    batch_size = len(list_sentences)
    input = np.zeros((batch_size, max_num_sent, max_word_in_sent), dtype=np.int64)
    u_len = np.zeros((batch_size), dtype=np.int64)
    w_len = np.zeros((batch_size, max_num_sent), dtype=np.int64)

    for i, sentences in enumerate(list_sentences):
        num_sentences = len(sentences)
        if num_sentences > max_num_sent:
            num_sentences = max_num_sent
            sentences = sentences[:max_num_sent]
        u_len[i] = num_sentences

        for j, sent in enumerate(sentences):
            token_ids = tokenizer.encode(sent, max_length=500000)[1:-1] # remove [CLS], [SEP]
            utt_len = len(token_ids)
            if utt_len > max_word_in_sent:
                utt_len = max_word_in_sent
                token_ids = token_ids[:max_word_in_sent]
            input[i,j,:utt_len] = token_ids
            w_len[i,j] = utt_len
    input = torch.from_numpy(input)
    u_len = torch.from_numpy(u_len)
    w_len = torch.from_numpy(w_len)

    if use_gpu:
        input = input.cuda()
        u_len = u_len.cuda()
        w_len = w_len.cuda()

    return input, u_len, w_len


In [None]:
def get_utt_attn_without_ref(model, enc_batch, beam_width=4, time_step=240,
                            penalty_ug=0.0, alpha=1.25, length_offset=5, torch_device='cpu'):
    decode_dict = {
        'k': beam_width,
        'time_step': time_step,
        'vocab_size': 50265,
        'device': torch_device,
        'start_token_id': 0, 'stop_token_id': 2,
        'alpha': alpha,
        'length_offset': length_offset,
        'penalty_ug': penalty_ug,
        'keypadmask_dtype': torch.bool,
        'memory_utt': False,
        'batch_size': 1
    }
    # batch_size should be 1
    with torch.no_grad():
        summary_ids, attn_scores, u_attn_scores = model.decode_beamsearch(
                enc_batch["input"], enc_batch["u_len"], enc_batch["w_len"], decode_dict)

    N = enc_batch["u_len"][0].item()
    attention = u_attn_scores[:,:N].sum(dim=0) / u_attn_scores[:,:N].sum()
    attention = attention.cpu().numpy()
    return attention

In [None]:
def compute_ranking_score(score):
    """
        the item with lowest rank gets 0.0, the item with highest rank gets 1.0,
        and everthing else gets the value in between 0.0 and 1.0
    """
    rank_ascending = np.argsort(score)
    N = len(score)
    if N == 1: return np.array([1.0])
    ranking_score = [None for _ in range(N)]
    for i, idx in enumerate(rank_ascending):
        ranking_score[idx] = i/(N-1)
    return np.array(ranking_score)

In [None]:
from tqdm import tqdm

def inference_hiermodel(args, data):
    inference_mode = args['inference_mode']

    # uses GPU in training or not
    if torch.cuda.is_available() and args['use_gpu']: torch_device = 'cuda'
    else: torch_device = 'cpu'
    use_gpu = args['use_gpu']

    # ----- Hierarchical Model Configurations ----- #
    # TODO: replace this part to be not hard coded
    args['num_utterances'] = args['max_num_sent']
    args['num_words']      = args['max_word_in_sent']

    args['vocab_size']     = 50265 # BERT tokenizer
    args['embedding_dim']   = 256   # word embeeding dimension
    args['rnn_hidden_size'] = 512 # RNN hidden size
    args['dropout']        = 0.0
    args['num_layers_enc'] = 2    # in total it's num_layers_enc*2 (word/utt)
    args['num_layers_dec'] = 1
    # --------------------------------------------- #

    # Load the model
    trained_model_path = args['load']
    if use_gpu:
        state = torch.load(trained_model_path)
    else:
        state = torch.load(trained_model_path, map_location=torch.device('cpu'))
    model = EncoderDecoder(args, device=torch_device)
    model_state_dict = state['model']
    model.load_state_dict(model_state_dict)

    ext_labeller = EXTLabeller(rnn_hidden_size=args['rnn_hidden_size'], dropout=args['dropout'], device=torch_device)
    if inference_mode in ['ext', 'mcs']:
        ext_labeller_state_dict = state['ext_labeller']
        ext_labeller.load_state_dict(ext_labeller_state_dict)
    model.eval()
    ext_labeller.eval()
    print('model loaded!')


    tokenizer  = AutoTokenizer.from_pretrained("pszemraj/led-base-book-summary")

    # # data
    # # datapath = "podcast_sum0/lib/test_data/podcast_testset.bin"
    # # datapath = "arxiv_sum0/lib/data/arxiv_test.pk.bin"
    # # datapath = "arxiv_sum0/lib/pubmed_data/pubmed_test.pk.bin"
    # datapath = args['datapath']
    # print("datapath =", datapath)
    # with open(datapath, 'rb') as f:
    #     data = pickle.load(f, encoding="bytes")
    # print("len(data) = {}".format(len(data)))

    # ids = [x for x in range(start_id, end_id)]
    # if args['random_order']: random.shuffle(ids)

    # inference parameters
    beam_width    = args['beam_width']
    time_step     = args['time_step']
    penalty_ug    = args['penalty_ug']
    alpha         = args['alpha']
    length_offset = args['length_offset']


    data_list = data.to_list()

    for doc in tqdm(data_list):

        sentences = tokenize.sent_tokenize(doc["article"])

        num_sent = len(sentences)

        try: l1 = len(tokenizer.encode(doc["article"], max_length=500000))
        except IndexError: l1 = 0

        # the length is within the limit --> no selection needed
        if l1 < args['max_abssum_len']:
            filtered_sentences = sentences
        # perform MCS
        else:
            keep_idx = []
            input, u_len, w_len = get_enc_input(tokenizer, [sentences], args['max_num_sent'],
                                                args['max_word_in_sent'], use_gpu=use_gpu)
            # ------ MODULE1: Extractive Sum ------ #
            # Forward pass
            with torch.no_grad():
                encoder_output_dict = model.encoder(input, u_len, w_len)
                enc_u_output = encoder_output_dict['u_output']
                ext_output = ext_labeller(enc_u_output).squeeze(-1)
            ext_output = ext_output[0].cpu().numpy()
            # -------------- END MODULE1 --------------- #
            # ------ MODULE2: Sentence-Level Attn ------ #
            batch = {"input": input, "u_len": u_len, "w_len": w_len}
            attention = get_utt_attn_without_ref(model, batch, beam_width=beam_width, time_step=time_step,
                        penalty_ug=penalty_ug, alpha=alpha, length_offset=length_offset, torch_device=torch_device)
            if len(sentences) != attention.shape[0]:
                if len(sentences) > args['max_num_sent']:
                    sentences = sentences[:args['max_num_sent']]
                else:
                    raise ValueError("shape error #1")
            # -------------- END MODULE2 --------------- #
            N1 = len(attention)
            N2 = len(ext_output)
            if N2 > N1: ext_output = ext_output[:N1]
            ext_score = compute_ranking_score(ext_output)
            attn_score = compute_ranking_score(attention)

            if inference_mode == 'mcs':
                # taking geometric mean --- (simple mean works too!)
                total_score = np.sqrt(attn_score * ext_score)
            elif inference_mode == 'ext':  total_score = ext_score
            elif inference_mode == 'attn': total_score = attn_score
            else: raise Exception("inference mode error!")

            rank = np.argsort(total_score)[::-1]
            keep_idx = []
            total_length = 0
            for sent_i in rank:
                if total_length < args['max_abssum_len']:
                    sent = sentences[sent_i]
                    length = len(tokenizer.encode(sent, max_length=50000)[1:-1]) # ignore <s> and </s>
                    total_length += length
                    keep_idx.append(sent_i)
                else:
                    break
            keep_idx = sorted(keep_idx)
            filtered_sentences = [sentences[j] for j in keep_idx]
        doc["article_MCS"] = " ".join(filtered_sentences)

    return Dataset.from_list(data_list)



In [None]:
inference_hiermodel(args, data_val_small)

  0%|          | 0/50 [00:00<?, ?it/s]Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.
100%|██████████| 50/50 [12:35<00:00, 15.11s/it]


Dataset({
    features: ['article', 'abstract', 'section_names', 'article_CS', 'ext_target', 'article_MCS'],
    num_rows: 50
})