In [1]:
from datasets import list_datasets
import torch
import torch.nn as nn
import numpy as np
from torch.utils.data import TensorDataset, DataLoader
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
from IPython import get_ipython
import pandas as pd
from tqdm import tqdm
from models.MergeModel import MergeModel

get_ipython().run_line_magic('load_ext', 'autoreload')
get_ipython().run_line_magic('autoreload', 2)

In [2]:
from transformers import BartTokenizer, BartForConditionalGeneration, BartForSequenceClassification

In [3]:
tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')

In [4]:
vocab_size = tokenizer.vocab_size
vocab_size

50265

In [7]:
df = pd.read_csv("dataset/test.csv")

# Report the number of sentences.
print('Number of training sentences: {:,}\n'.format(df.shape[0]))

# Display 10 random rows from the data.
df.sample(10)

Number of training sentences: 11,490



Unnamed: 0,article,highlights,id
1568,St. Louis Blues forward Ryan Reaves proved on ...,St. Louis Blues forward Ryan Reaves was checke...,0b1072939b9c0af0cfc29bcf92a1334630fb6325
2720,A guard at the U.S. Census Bureau headquarters...,A guard was shot and critically injured outsid...,281d340a0b9ba6fa2f5526ecd59db7cdcc1de4cc
4661,"In the dock: Victorino Chua, 49, has given evi...","Victorino Chua, 49, denies murdering patients ...",5681a060c743561b9718ac6d2b54ee0104e03fb0
6164,The Philadelphia Office of Transportation aims...,Road safety video stars walk streets in bizarr...,7b3060da24c633f153916892a0feb99135130c31
2400,A woman in China showed her strength by using ...,Woman performed stunt at a birthday party in S...,1fb282d29c002a41e08640a71c6c02bc54e65ba8
4970,"Bitten: Austin Hatfield, 18, found the venomou...",Austin Hatfield of Wimauma was keeping the pot...,5d20cbef6b0db055b0938f5dc7535d8dbe2324c4
9691,Baby Malakai has been alive for seven months a...,Baby boy Malakai suffers rare medical conditio...,d2c86883150136b6affe3a12aa6bcb7d2388591a
2204,A young mother diagnosed with dementia at only...,Kelly Watson began having problems with coordi...,1aa2504255a25e7497c3ba735665211fe22142b0
4648,"Forget the must-have wine cellar, hot-tub or £...",Sir Philip Green and Benedict Cumberbatch also...,564ec07adee8078b31a5a6479e1bb16297022a3d
6671,A 10-year-old boy who beat a younger child to ...,"Lee Allan Bonneau, 6, was attacked by a 10-yea...",88ab1b6e345f7e16003f87ea53ffcbb746f33e21


In [8]:
train_sentence = df.article.values
train_target = df.highlights.values

In [9]:
num_data_points = 10
train_sentence = list(train_sentence[:num_data_points])
train_target = list(train_target[:num_data_points])

In [10]:
article_encoding = tokenizer(train_sentence, return_tensors='pt', padding=True, truncation = True)
summary_encoding = tokenizer(train_target, return_tensors='pt', padding=True, truncation = True)

In [11]:
article_input_ids = article_encoding['input_ids']
article_attention_mask = article_encoding['attention_mask']

In [12]:
summary_input_ids = summary_encoding['input_ids']
summary_attention_mask = summary_encoding['attention_mask']

In [13]:
print(article_input_ids.shape, article_attention_mask.shape)

torch.Size([10, 1024]) torch.Size([10, 1024])


In [14]:
print(summary_input_ids.shape, summary_attention_mask.shape)

torch.Size([10, 57]) torch.Size([10, 57])


In [15]:
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler

In [16]:
batch_size = 1
train_data = TensorDataset(article_input_ids, article_attention_mask,\
                           summary_input_ids, summary_attention_mask)

In [17]:
train_sampler = RandomSampler(train_data)

In [18]:
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size)

In [19]:
def loss_fn(lm_logits, labels):
    loss_fct = CrossEntropyLoss(ignore_index = tokenizer.pad_token_id)
    loss = loss_fct(lm_logits.view(-1, vocab_size), labels.view(-1))
    return loss

In [20]:
mse_loss = nn.MSELoss()

In [5]:
summary_model = BartForConditionalGeneration.from_pretrained('facebook/bart-base')

In [6]:
senti_model = BartForSequenceClassification.from_pretrained('facebook/bart-base')

Some weights of the model checkpoint at facebook/bart-base were not used when initializing BartForSequenceClassification: ['final_logits_bias']
- This IS expected if you are initializing BartForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BartForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BartForSequenceClassification were not initialized from the model checkpoint at facebook/bart-base and are newly initialized: ['classification_head.dense.weight', 'classification_head.dense.bias', 'classification_head.out_proj.weight', 'classification_head.out_proj.bias']
You should probably TRAIN this model on a down-stream task to be able to u

In [21]:
summary_model = summary_model.cuda()
senti_model = senti_model.cuda()

In [22]:
merge_model = MergeModel(summary_model, senti_model)

In [23]:
optimizer = torch.optim.SGD(merge_model.parameters(),1e-3,momentum=0.9,weight_decay=0.01)

In [None]:
epochs = 500

for eps in tqdm(range(epochs)):
#     print('Epoch: ', eps)
    epoch_loss = 0
    for step, batch in enumerate(train_dataloader):
        
        batch[0] = batch[0].cuda()
        batch[1] = batch[1].cuda()
        batch[2] = batch[2].cuda()
        batch[3] = batch[3].cuda()
        
        summary_model.train()
        
        optimizer.zero_grad()
        
        summary_out,*sentiments = merge_model(batch[0],batch[1], batch[2], batch[3])
        
        loss1 = loss_fn(summary_out.logits, batch[2])
        loss2 = mse_loss(sentiments[0].logits, sentiments[1].logits)
        
        final_loss = loss1 + loss2
        
        epoch_loss = epoch_loss + final_loss
        
        final_loss.backward()
        optimizer.step()
    
    if not eps%10:
        print("Epoch: %d, Loss: %.3f" % (eps, epoch_loss.item()))


  0%|          | 1/500 [00:04<36:15,  4.36s/it]

Epoch: 0, Loss: 78.049


  2%|▏         | 11/500 [00:45<33:26,  4.10s/it]

Epoch: 10, Loss: 62.404


  4%|▍         | 21/500 [01:25<31:59,  4.01s/it]

Epoch: 20, Loss: 47.068


  6%|▌         | 31/500 [02:05<31:22,  4.01s/it]

Epoch: 30, Loss: 42.038


  8%|▊         | 41/500 [02:45<30:26,  3.98s/it]

Epoch: 40, Loss: 39.869


 10%|█         | 51/500 [03:25<30:10,  4.03s/it]

Epoch: 50, Loss: 37.978


 11%|█         | 53/500 [03:33<30:01,  4.03s/it]

# Testing

In [None]:
b1,b2,b3,b4 = next(iter(train_dataloader))
b1 = b1.cuda()
b2 = b2.cuda()
b3 = b3.cuda()
b4 = b4.cuda()

In [None]:
summary_out,*sentiments= merge_model(b1, b2, b3, b4)

In [None]:
tokenizer.decode(list(b1.view(-1)))

In [None]:
tokenizer.decode(list(summary_out.logits.argmax(2).view(-1)))