In [1]:
import torch
from torch import nn
from torch import optim

from tqdm import tqdm

In [2]:
from datasets import load_dataset

cnn = load_dataset("cnn_dailymail", "3.0.0")

Found cached dataset cnn_dailymail (/root/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/1b3c71476f6d152c31c1730e83ccb08bcf23e348233f4fcc11e182248e6bf7de)


  0%|          | 0/3 [00:00<?, ?it/s]

In [3]:
def rem(x):
    s = x.split("--")
    if len(s) < 2:
        return x
    else:
        return "--".join(s[1:])

In [18]:
def random_num(max_num):
    return torch.randint(max_num,(1,))[0]

def random_place(max_idx, num):
    rand = torch.randperm(max_idx)[:num].sort().values
    return rand[5:] if len(rand) > 8 else rand

def random_token_number():
    #return tokenizer.encode(f"{random_num(10000)} ")[1:-1]
    return tokenizer.encode(f"{random_num(10000)/(3+random_num(100)):.2} ")[1:-1]

In [5]:
def bos_eos(list_seq):
    return torch.cat([torch.tensor([tokenizer.bos_token_id]), torch.tensor(list_seq), torch.tensor([tokenizer.eos_token_id])], dim=0)

In [6]:
def noise(input_text, max_seq, rto):
    enc = tokenizer.encode(input_text)[1:-1]
    ll = min(round(max_seq * (1-rto)), len(enc))
    enc = enc[:ll]
    rd = random_place(ll, random_num(min(max_seq, round(ll / (1-rto))) - ll))

    collect = []
    onehot = []

    idx = 0
    l = len(rd)
    for i in range(ll):
        if i < l and i == rd[idx]:
            for j in range(random_num(20)):
                rand = random_token_number()
                collect += rand
                onehot += [1 for _ in rand]
            idx += 1
        
        collect += [enc[i]]
        onehot += [0]

    input_ids = torch.ones(max_seq, dtype=torch.int64) * tokenizer.pad_token_id
    e_col = bos_eos(collect[:max_seq-2])
    input_ids[:len(e_col)] = e_col
    input_ids = input_ids.unsqueeze(0)
    onehot = torch.tensor([0] + onehot[:max_seq-2] + [0], dtype=torch.float32).unsqueeze(0)

    label = torch.zeros((onehot.shape[0], max_seq))
    label[:, :onehot.shape[1]] = onehot

    return input_ids, label

In [7]:
from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained('xlm-roberta-large')

class RoBERTa_Denoiser(nn.Module):
    def __init__(self, device="cuda"):
        super().__init__()
        self.device = device
        self.roberta_encoder = AutoModel.from_pretrained("xlm-roberta-large")
        self.head = nn.Sequential(
            nn.Dropout(0.1),
            nn.Linear(1024, 1)
        )
    
    def forward(self, input_ids, attention_mask=None):
        hidden_state = self.roberta_encoder(input_ids, attention_mask).last_hidden_state

        output = torch.zeros((hidden_state.shape[0], 512, 1024)).to(self.device)
        output[:, :hidden_state.shape[1], :] = hidden_state

        output = self.head(output).squeeze(-1)
        
        return output

In [8]:
device = "cuda"

denoiser = RoBERTa_Denoiser(device).to(device)
denoiser.load_state_dict(torch.load("denoiser_roberta_rto_10000.pth"))

Some weights of the model checkpoint at xlm-roberta-large were not used when initializing XLMRobertaModel: ['lm_head.dense.bias', 'lm_head.layer_norm.bias', 'lm_head.layer_norm.weight', 'lm_head.bias', 'lm_head.dense.weight']
- This IS expected if you are initializing XLMRobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing XLMRobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


<All keys matched successfully>

In [9]:
def criterion(pred, labels, weight):
    sigpred = pred.sigmoid()
    return (-(labels * sigpred.log() + (1-labels) * (1-sigpred).log()) * (labels * (weight-1) + 1)).sum()

optimizer = optim.Adagrad(denoiser.parameters(), lr=3e-05)

In [10]:
def dataload(i, n, max_seq, noise_rto):
    b_input_ids = []
    b_labels = []
    for d in cnn["train"]["article"][i:i+n]:
        input_ids, label = noise(rem(d), max_seq, noise_rto)
        b_input_ids += [input_ids]
        b_labels += [label]

    return torch.cat(b_input_ids, dim=0), torch.cat(b_labels, dim=0)

In [19]:
batch_size = 8
max_seq_len = 512
print_size = 50
weight = 25
noise_rto = 0.2

start_point = 12000
num = 2000

In [20]:
buf = 0
denoiser.train()
for i in tqdm(range(start_point, start_point+num, batch_size)):
    input_ids, labels = dataload(i, batch_size, max_seq_len, noise_rto)
    
    input_ids = input_ids.to(device)
    labels = labels.to(device)
    
    pred = denoiser(input_ids)
    loss = criterion(pred, labels, weight)
    buf += loss.item()
    if i % (batch_size * print_size) == 0:
        print(buf / print_size)
        buf = 0

    loss.backward()
    optimizer.step()

    optimizer.zero_grad()

  0%|          | 0/250 [00:00<?, ?it/s]

626.77640625


 20%|██        | 50/250 [01:14<05:01,  1.51s/it]

301.2696774291992


 40%|████      | 100/250 [02:30<03:44,  1.50s/it]

156.21316438674927


 60%|██████    | 150/250 [03:45<02:32,  1.52s/it]

119.1456770324707


 80%|████████  | 200/250 [05:00<01:15,  1.50s/it]

97.73045713424682


100%|██████████| 250/250 [06:15<00:00,  1.50s/it]


In [25]:
torch.save(denoiser.state_dict(), "./denoiser_roberta_rto_num_14000.pth")

In [13]:
def bos_eos(list_seq):
    return torch.cat([torch.tensor([tokenizer.bos_token_id]), torch.tensor(list_seq), torch.tensor([tokenizer.eos_token_id])], dim=0)

def denoise(text, max_seq_len=512):
    enc = tokenizer.encode(text)[1:-1]
    ll = len(enc)
    chunk = max_seq_len - 2

    ret = []
    garb = []
    for i in range(0, ll, chunk):
        input_ids = bos_eos(enc[i:i+chunk]).unsqueeze(0).to(device)
        
        with torch.no_grad():
            remove = (denoiser(input_ids) > 0).to("cpu")
        
        for j, k in zip(input_ids[0,1:-1], remove[0,1:-1]):
            if not k:
                ret += [j]
            else:
                garb += [tokenizer.decode([j])]

    return tokenizer.decode(ret), garb
        

In [29]:
test = """   'text': 'Some input sequences still exceed LoBART’s longer fixed-span limit. Further extending the input span would lead to a small local attention span, a diminishing improvement, or GPU running out of memory. Alternatively, it has been shown that a better content selection improves abstractive summarization in news ( Chen and Bansal ,  2018 ; Gehrmann et al. ,  2018 ;  Hsu et al. ,  2018 ), multi doc- uments ( Liu and Lapata ,  2019a ;  Liu et al. ,  2018 ), and scientific articles ( Pilault et al. ,  2020 ). Thus, we propose to tackle the excess length by content selection. Here, we distinguish between two phases of content selection: training time and test time. 5.1 Training-time Content Selection During training, ground-truth targets are available. We categorize selection methods in this phase into two types: ground-truth based (model-free), which is also referred to as  oracle ; and model-based. Ground-truth based methods cannot be used at test time, while model-based methods can be applied at both phases. Although model-based methods do not rely on ground-truth targets, they have the advantage of matching in training and test phases. Existing oracle methods include using ROUGE-2 recall ( Liu et al. ,  2018 ) or the average of ROUGE-1,2,L recall ( Pilault et al. ,  2020 ). We discuss model-based methods in Section  5.2 , where we propose the MCS method. Let the subscript  ( i, j )  denote the position of the  j -th word in the  i -th input sentence, the full input  X = { x 1 , ...,  x i , ...,  x N 1 } = � [ x 1 , 1 , x 1 , 2 , x 1 ,J 1 �� � sent  1 , ..., x i, 1 , x i,J i � �� � sent  i , ..., x N 1 , 1 , x N 1 ,J N 1 � �� � sent  N 1 ] . Content selection re-ranks, truncates, and sorts  X to get  X cs   for training BART/LoBART as follows:  ̄ X  =  { x r 1 ,  x r 2 ,  x r 3 , ...,  x r R } (2) X cs   =  SortOrig ( TruncateN (   ̄ X )) (3) where  r i  is the index of the sentence of rank  i , the TruncateN  operation filters    ̄ X  such that the total of number of words is less than  N , and  SortOrig retains the original sentence order. The following ranking methods are considered: • Truncation (TRC):  r k  =  k . • Model-based: Given the score  f  of model  φ , r k  =  { i  ∈  N 1  :  f φ ( i | X )  is ranked  k -th } •  Oracle (ORC): Given the ground-truth sum- mary  y  and similarity measure  d , r k  =  { i  ∈  N 1  :  d ( x i ,  y )  is ranked  k -th } In this work, we use ROUGE-2 recall as the sim- ilarity measure  d . For the ORC method, first, we retain only sentences with positive  d , leading to R  ≤  N 1 . We found that the number of sentences with positive  d  is low at 21.3% of the total number of sentences in average on podcast data. This cor- responds to 56% of training instances being shorter than BART input span of 1024. 6   This no-padding oracle method (ORC no-pad ) is highly  aggressive , potentially preventing the downstream summarizer 6 We refer to this percentage as %AgORC no-pad  (the per- centage of inputs aggressively extracted by the oracle method). from learning complex abstraction. Hence, we propose variants of oracle methods to extend the ORC no-pad -selected input to the max input span  N : •  ORC pad-lead : Pad by leading unselected sen- tences and keep the original sentence order. •  ORC pad-rand : Pad by random unselected sen- tences and keep the original sentence order. TRC MCS ORC-pad-lead ORC-pad-rand ORC-no-pad 22.0 24.0 26.0 28.0 30.0 32.0 34.0 ROUGE-1 (F1) 27.88 28.14 29.99 30.39 32.39 26.82 27.24 26.34 27.28 25.26 26.43 26.32 24.78 25.54 22.71 Abstractive Generation Performance of Downsteam BART TestTime: Oracle (UpperBound) TestTime: MCS (CurrentBest) TestTime: Truncate (Baseline) In Figure  5 , since any oracle method is consid- ered cheating at test time, the best performance is obtained by MCS (in blue), and the upper bound performance is obtained by optimal oracle method (in green). The results show that although ORC no-pad  yields the highest upper bound, the ab- stractive model in fact does not learn how to per- form abstraction. For instance, with TRC or MCS at test time, ORC no-pad  yields the lowest perfor- mance level. The best way to fine-tune the abstrac- tive model shown in Figure  5  is using ORC pad-rand . Compared to ORC pad-lead , ORC pad-rand  is better as it introduces more diversity to the abstractive model. Compared to the model-based method, ORC pad-rand is also computationally less expensive. In addition, Table  5  shows that when there is no content selection at test time (i.e. TRC ap- plied), LoBART(4k) and LoBART(8k) benefit from ORC pad-rand , whereas BART(1k) does not. This is because in the 1k setting, content selection is more aggressive; as a result, the large mismatch between training and test leads to a poor result. Thus, we suggest that the best content selection during train- ing is ORC pad-rand  given that content selection will be used at test time, or model’s input span is long. 5.2 Multitask Content Selection (MCS) To process long input sequences entirely, we con- sider RNN, whose memory requirement grows lin- early with the sequence length, and hierarchical architectures which have been shown effective for long seq2seq tasks ( Cohan et al. ,  2018 ;  Li et al. , 2019 ). In this work, the hierarchical RNN model described in Section  3.2  has memory requirement given the target length of 144 during training of 0 . 83+ B (3 . 96 × 10 − 5 +3 . 33 × 10 − 5 N 2 ) N 1 , 7   where N 1  is #sentences, and  N 2  is the maximum number of words in a sentence, and  B  is batch size. By setting  N 1 =1000 and  N 2 =50, only 2% of podcast data exceeds this limit, while taking GPU memory to only 2.53GiB for  B =1. Thus, this shows that this model can cover long sequences. Previous model-based methods treat content se- lection as extractive labelling and create labels heuristically ( Pilault et al. ,  2020 ), or using encoder- decoder attention mechanism ( Manakul and Gales , 2020 ). To utilize both of these in one framework, we propose a Multitask Content Selection (MCS) method where we train the hierarchical encoder- decoder with attention mechanism and a classifi- cation layer on top of the encoder (described in Section  3.2 ). First, the model is trained on seq2seq abstractive summarization objective: L seq2seq  =  − M � m =1 log  P ( y m | y <m ,  X ) (4) Second, we create binary labels as follows: for sentence  i , the label  z i  is 1 if  d ( x i ,  y )  >  0 ; else  z i is 0, and  d  is the ROUGE-2 recall measure. The extractive labelling task objective is: L label  =  −   � N 1 i =1   ( z i  log ˆ z i  + (1  −  z i ) log(1  −  ˆ z i ))  (5) ˆ z i  =  sigmoid ( W T cls h i   +  b cls ) (6) where  h i  is the sentence-level encoder output as- sociated with sentence  i , and  W cls ,  b cls  are the parameters of the classification layer. Thus, the MCS training loss is defined as follows: L MCS  =  γ L label  + (1  −  γ ) L seq2seq (7) At inference stage, there are two modes: (i) stan- dard abstractive summary generation, e.g. via beam search decoding; (ii) ranking input sentences via labelling score and seq2seq attention score. The latter is how we use MCS during inference. 8   For sentence  i , the scores are: score i, ( label )  = ˆ z i ,  score i, ( seq2seq )  =   � M m =1   α s m,i (8) 7 Obtained by least-squares regression with 20 samples. 8 In practice, we run beam search decoding of width 4, and we obtain the attention score from the top beam. where  α s m,i   is the sentence-level attention weight at decoder step  m  over input sentence  i . Since the scores are on different scales, rather than using the scores defined in Eq.  8 , we simply rank the scores, and then normalize the score ranks into the range 0.0 to 1.0. Let nscore denote the normalized ranking score, the MCS inference score is: f φ ( i | X ) =  nscore i, ( label )  +  nscore i, ( seq2seq ) (9) In our preliminary experiments, we vary the amount of selected sentences from the limit of BART/LoBART to a few sentences, and we found that more aggressive selection at test time degrades the performance. Therefore, our MCS selects input sentences up to the limit of BART/LoBART. By setting  γ =0.0, our method is comparable to the attention-based method in  Manakul and Gales ( 2020 ). By setting  γ =1.0, our method is similar to the extractive models in  Hsu et al.  ( 2018 );  Pi- lault et al.  ( 2020 ). In Table  4 , we show that when coupled with BART, MCS yields better summariza- tion performance than both Attn-only and Ext-only baselines. MCS also achieves higher recall rate of sentences with  d ( x i ,  y )  >  0  than the two baselines. System %Recall R1 R2 RL Attn ( L seq2seq ) 38.85 26.90 9.70 18.78 Ext ( L label ) 35.26 26.39 8.90 18.03 MCS ( L MCS ) 40.50 27.28 9.82 19.00 '},
"""

In [30]:
result, garb = denoise(test)

In [31]:
result, garb

("'text': 'Some input sequences still exceed LoBART’s longer fixed-span limit. Further extending the input span would lead to a small local attention span, a diminishing improvement, or running out of memory. Alternatively, it has been shown that a better content selection improves abstractive summarization in news ( Chen and Ban, 2018 ; Gehrmann et al., 2018 ; Hsu et al., 2018 ), multi docments ( Liu and La, ; Liu et al., 2018), and scientific articles ( Pilault et al., 2020). Thus, we propose to tackle the excess length by content selection. Here, we distinguish between two phases of content selection: training time and test time. 5.1 Training-time Content Selection During training, ground-truth targets are available. We categorize selection methods in this phase into two types: ground-truth based (model-free), which is also referred to as oracle ; and model-based. Ground-truth based methods cannot be used at test time, while model-based methods can be applied at both phases. Althoug

In [24]:
test

"  'text': '6.1 Spotify Podcast results In Table  5 , a performance gain is obtained in all settings by adding MCS. By comparing different configurations with MCS, it can be seen that the gain from MCS in LoBART(8k) system is the low- est. This is because the average length is 5,727, meaning that many Podcasts inputs to LoBART(8k) do not benefit from content selection. CUED-filt, the best single-model system in  Man- akul and Gales  ( 2020 ), uses an attention-based con- tent selection at both training and test time, and it is combined with fine-tuned vanilla BART. Our approach outperforms CUED-filt by improved con- tent selection at both training time and test time as demonstrated by BART(1k)-ORC+MCS. Addition- ally, local self-attention allows training on longer sequences, and our LoBART(4k)-ORC+MCS sys- tem has yielded the best results. Lastly, even though LoBART(8k) requires more resource to train, it does not perform as well as LoBART(4k) due to its smaller attention window, and i