In [1]:
from datasets import list_datasets
import torch
import torch.nn as nn
import numpy as np
from torch.utils.data import TensorDataset, DataLoader
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
from IPython import get_ipython
import pandas as pd
from tqdm import tqdm
from models.MergeModel import MergeModel

get_ipython().run_line_magic('load_ext', 'autoreload')
get_ipython().run_line_magic('autoreload', 2)

In [2]:
from transformers import BartTokenizer, BartForConditionalGeneration, BartForSequenceClassification

In [3]:
tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')

In [4]:
vocab_size = tokenizer.vocab_size
vocab_size

50265

In [5]:
df = pd.read_csv("dataset/test.csv")

# Report the number of sentences.
print('Number of training sentences: {:,}\n'.format(df.shape[0]))

# Display 10 random rows from the data.
df.sample(10)

Number of training sentences: 11,490



Unnamed: 0,article,highlights,id
9988,Martin 'Mad Dog' Allen returns with his latest...,Martin Allen believes it's a crucial time for ...,da39f68c5d9a413e0759fc6bff8260f76ed1ac59
9960,Things haven't been going well for Manchester ...,Manchester United flop Anderson was sent off f...,d987f48e741373781c322a3c533cc9de7b6b3f0b
4151,What's worse than finding a cockroach in your ...,A woman has half eaten a large cockroach after...,4ac39a81e6ea5d2438c837b8f54340320a2eb292
3935,"As every celebrity and fashionista knows, no l...",Madonna's new handbag is emblazoned with the s...,4519e97cdc1c0805caf5ee032f8522536fed7c00
4840,A women's rights group has slammed a popular e...,Sydney cafe Lowenbrau Keller has released an a...,5a58966875bc965b431a80c5f9a2acb13abb36c8
10398,Labour received more than £1 every second from...,"Union barons gave more than £700,000 to Ed Mil...",e4edb27f5703e35f7e7ddeb4589a490c99ef841b
7971,Hidden along a dusty dirt track across several...,Town of Evansville in Comanche County was the ...,a86c25533105c4f53d264d72bcbdc358e8fff6b4
10235,"Dan Price, the CEO of Gravity Payments, is sla...","Dan Price, founder of Gravity Payments, will s...",e0f3dae20e8c05af54d04f211189bd9702d20e5b
5703,The Queen was today joined by her husband the ...,The Queen attends reception as part of First W...,6f7ff9f7773657a5944758ebf96842c09445b7f5
1437,"Janet Faal, who has suffered from agoraphobia ...","Janet Faal, 57, was out with a friend as part ...",07f3d8c1dca494ea281f2ec843ec730218db3023


In [6]:
train_sentence = df.article.values
train_target = df.highlights.values

In [7]:
num_data_points = 10
train_sentence = list(train_sentence[:num_data_points])
train_target = list(train_target[:num_data_points])

In [8]:
article_encoding = tokenizer(train_sentence, return_tensors='pt', padding=True, truncation = True)
summary_encoding = tokenizer(train_target, return_tensors='pt', padding=True, truncation = True)

In [9]:
article_input_ids = article_encoding['input_ids']
article_attention_mask = article_encoding['attention_mask']

In [10]:
summary_input_ids = summary_encoding['input_ids']
summary_attention_mask = summary_encoding['attention_mask']

In [11]:
print(article_input_ids.shape, article_attention_mask.shape)

torch.Size([10, 1024]) torch.Size([10, 1024])


In [12]:
print(summary_input_ids.shape, summary_attention_mask.shape)

torch.Size([10, 57]) torch.Size([10, 57])


In [13]:
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler

In [14]:
batch_size = 1
train_data = TensorDataset(article_input_ids, article_attention_mask,\
                           summary_input_ids, summary_attention_mask)

In [15]:
train_sampler = RandomSampler(train_data)

In [16]:
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size)

In [17]:
def loss_fn(lm_logits, labels):
    loss_fct = CrossEntropyLoss(ignore_index = tokenizer.pad_token_id)
    loss = loss_fct(lm_logits.view(-1, vocab_size), labels.view(-1))
    return loss

In [18]:
mse_loss = nn.MSELoss()

In [19]:
class Temp(torch.nn.Module):
    def __init__(self):
        super().__init__()
        
    def forward(self, x):
        return 0

In [None]:
# -------------- Testing Scatter -----------

In [20]:
a = torch.randint(0,10, (2,3,3))
a, a.shape

(tensor([[[2, 2, 5],
          [4, 9, 4],
          [3, 9, 6]],
 
         [[8, 9, 4],
          [8, 8, 9],
          [1, 8, 6]]]),
 torch.Size([2, 3, 3]))

In [21]:
idx =  torch.argmax(a, dim=2, keepdims=  True)

In [22]:
mask = torch.zeros_like(a).scatter_(2, idx, 1.)

In [23]:
(a * mask)

tensor([[[0, 0, 5],
         [0, 9, 0],
         [0, 9, 0]],

        [[0, 9, 0],
         [0, 0, 9],
         [0, 8, 0]]])

In [24]:
(a * mask).sum(axis=2).long().type()

'torch.LongTensor'

In [None]:
# --------------- End ------------------------

In [25]:
import torch

class MergeModel(torch.nn.Module):
    '''
            Todo
    '''
    def __init__(self, summary_model, sentiment_model):
        super().__init__()
        self.summary_model = summary_model
        self.sentiment_model = sentiment_model
    
    def forward(self, article_id, article_mask, summary_id, summary_mask):
        self.summary_out = self.summary_model(article_id,  article_mask, summary_id, summary_mask)

        idx =  torch.argmax(self.summary_out.logits, dim=2, keepdims=  True)
        mask = torch.zeros_like(self.summary_out.logits).scatter_(2, idx, 1.)
        
        pred_summary_id = (self.summary_out.logits * mask).sum(axis=2).long()
        pred_summary_mask = torch.ones_like(summary_mask).long()
        pred_summary_id[:, -1] = 2
        
        self.article_sentiment = self.sentiment_model(article_id,  article_mask, labels=None)
        self.summary_sentiment = self.sentiment_model(pred_summary_id,  pred_summary_mask, labels=None)
        
        return [self.summary_out, self.article_sentiment, self.summary_sentiment]

In [26]:
summary_model = BartForConditionalGeneration.from_pretrained('facebook/bart-base')

In [27]:
senti_model = BartForSequenceClassification.from_pretrained('facebook/bart-base')

Some weights of the model checkpoint at facebook/bart-base were not used when initializing BartForSequenceClassification: ['final_logits_bias']
- This IS expected if you are initializing BartForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BartForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BartForSequenceClassification were not initialized from the model checkpoint at facebook/bart-base and are newly initialized: ['classification_head.dense.weight', 'classification_head.dense.bias', 'classification_head.out_proj.weight', 'classification_head.out_proj.bias']
You should probably TRAIN this model on a down-stream task to be able to u

In [28]:
summary_model = summary_model.cuda()
senti_model = senti_model.cuda()

In [30]:
merge_model = MergeModel(summary_model, senti_model)

In [31]:
optimizer = torch.optim.SGD(merge_model.parameters(),1e-3,momentum=0.9,weight_decay=0.01)

In [32]:
epochs = 500

for eps in tqdm(range(epochs)):
#     print('Epoch: ', eps)
    epoch_loss = 0
    for step, batch in enumerate(train_dataloader):
        
        batch[0] = batch[0].cuda()
        batch[1] = batch[1].cuda()
        batch[2] = batch[2].cuda()
        batch[3] = batch[3].cuda()
        
        summary_model.train()
        
        optimizer.zero_grad()
        
        summary_out,*sentiments = merge_model(batch[0],batch[1], batch[2], batch[3])
        
        loss1 = loss_fn(summary_out.logits, batch[2])
        loss2 = mse_loss(sentiments[0].logits, sentiments[1].logits)
        
        final_loss = loss1 + loss2
        
        epoch_loss = epoch_loss + final_loss
        
        final_loss.backward()
        optimizer.step()
    
    if not eps%10:
        print("Epoch: %d, Loss: %.3f" % (eps, epoch_loss.item()))


  0%|          | 0/500 [00:00<?, ?it/s]

tensor(0, device='cuda:0') torch.Size([1, 57])
tensor(0, device='cuda:0') torch.Size([1, 57])
tensor(0, device='cuda:0') torch.Size([1, 57])
tensor(0, device='cuda:0') torch.Size([1, 57])
tensor(0, device='cuda:0') torch.Size([1, 57])
tensor(0, device='cuda:0') torch.Size([1, 57])
tensor(0, device='cuda:0') torch.Size([1, 57])
tensor(1, device='cuda:0') torch.Size([1, 57])
tensor(0, device='cuda:0') torch.Size([1, 57])
tensor(0, device='cuda:0') torch.Size([1, 57])


  0%|          | 1/500 [00:06<53:15,  6.40s/it]

Epoch: 0, Loss: 75.720
tensor(0, device='cuda:0') torch.Size([1, 57])
tensor(0, device='cuda:0') torch.Size([1, 57])
tensor(0, device='cuda:0') torch.Size([1, 57])
tensor(0, device='cuda:0') torch.Size([1, 57])
tensor(0, device='cuda:0') torch.Size([1, 57])
tensor(1, device='cuda:0') torch.Size([1, 57])
tensor(0, device='cuda:0') torch.Size([1, 57])
tensor(0, device='cuda:0') torch.Size([1, 57])
tensor(0, device='cuda:0') torch.Size([1, 57])
tensor(0, device='cuda:0') torch.Size([1, 57])


  0%|          | 2/500 [00:12<52:08,  6.28s/it]

tensor(0, device='cuda:0') torch.Size([1, 57])
tensor(0, device='cuda:0') torch.Size([1, 57])
tensor(0, device='cuda:0') torch.Size([1, 57])
tensor(0, device='cuda:0') torch.Size([1, 57])


  0%|          | 2/500 [00:14<1:00:53,  7.34s/it]


KeyboardInterrupt: 

# Testing

In [None]:
b1,b2,b3,b4 = next(iter(train_dataloader))
b1 = b1.cuda()
b2 = b2.cuda()
b3 = b3.cuda()
b4 = b4.cuda()

In [None]:
summary_out,*sentiments= merge_model(b1, b2, b3, b4)

In [None]:
tokenizer.decode(list(b1.view(-1)))

In [None]:
tokenizer.decode(list(summary_out.logits.argmax(2).view(-1)))