In [1]:
import os
import time
import json

from utils.file_utils import *
from datasets import list_datasets
import torch
import torch.nn as nn
import numpy as np
from torch.utils.data import TensorDataset, DataLoader
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
from IPython import get_ipython
import pandas as pd
from tqdm import tqdm
from models.MergeModel import MergeModel
from models.ClassifierModel import ClassifierModel
import matplotlib.pyplot as plt

get_ipython().run_line_magic('load_ext', 'autoreload')
get_ipython().run_line_magic('autoreload', 2)

In [2]:
from transformers import BartTokenizer, BartForConditionalGeneration, BartForSequenceClassification

In [3]:
tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')

In [4]:
vocab_size = tokenizer.vocab_size
vocab_size

50265

In [5]:
df = pd.read_csv("dataset/test.csv")

# Report the number of sentences.
print('Number of training sentences: {:,}\n'.format(df.shape[0]))

# Display 10 random rows from the data.
df.sample(10)

Number of training sentences: 11,490



Unnamed: 0,article,highlights,id
8768,A surrender order issued by Hitler's successor...,Typed dispatch sent by German president Karl D...,bcbdb86d4eba1a7fac7e262cfb661f0113b755b5
6087,The world’s third largest cruise liner today s...,Two large cruise ships - Anthem of the Seas an...,795ec7ced1ec15f19b38e88aea63079cb996dc51
6901,Big-hearted Ipswich Town left back Tyrone Ming...,Left back posted screenshot of text message co...,8e2a91eb8fecd055aac9ff9f9ecead6e09287ebd
2339,"As Prime Minister David Cameron puts it, on St...",Fish and chips has believed to be partly Portu...,1dd5816322e998cb7b4745363767ccee5ffd639d
6214,Spain's 2-0 defeat by Holland on Tuesday broug...,Holland beat Spain 2-0 at the Amsterdam Arena ...,7c9201a07ffd4647e3c37af77b69571e0fad1c45
9996,"Sometimes, putting up a billboard just isn't q...",Two Russian companies have started offering un...,da645b6de7f2f0b4d96b7b5dab962540d78d6aad
2891,A mother has released images of her teenage so...,The teenager suffered a broken jaw in 2 places...,2c39cb95345ead8ea9ff4f9e2229faddfebc8db4
802,(CNN)The killing of an employee at Wayne Commu...,Relatives of Wayne Community College shooting ...,c0340e53445e1d38aaf9a2681c2ae2e950a98860
1212,A woman has received the shocking news that th...,Sarah appeared on Monday's Jeremy Kyle Show to...,0312bea2586ef3a65a1b9a3d25328d1b417e2871
9150,EU leaders have agreed a package of measures t...,Funding increased to £7m a month for EU's bord...,c61946875505531b55fb6d23decb7f7186f66a40


In [6]:
train_sentence = df.article.values
train_target = df.highlights.values

In [7]:
num_data_points = 10000
train_sentence = list(train_sentence[:num_data_points])
train_target = list(train_target[:num_data_points])

In [8]:
article_encoding = tokenizer(train_sentence, return_tensors='pt', padding=True, truncation = True, max_length=500)
summary_encoding = tokenizer(train_target, return_tensors='pt', padding=True,truncation = True, max_length=100)

In [9]:
article_input_ids = article_encoding['input_ids']
article_attention_mask = article_encoding['attention_mask']

In [10]:
summary_input_ids = summary_encoding['input_ids']
summary_attention_mask = summary_encoding['attention_mask']

In [11]:
print(article_input_ids.shape, article_attention_mask.shape)

torch.Size([10000, 500]) torch.Size([10000, 500])


In [12]:
print(summary_input_ids.shape, summary_attention_mask.shape)

torch.Size([10000, 100]) torch.Size([10000, 100])


In [13]:
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler

In [14]:
batch_size = 10
train_data = TensorDataset(article_input_ids, article_attention_mask,\
                           summary_input_ids, summary_attention_mask)

In [15]:
train_sampler = RandomSampler(train_data)

In [16]:
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size)

In [17]:
def loss_fn(lm_logits, labels):
    loss_fct = CrossEntropyLoss(ignore_index = tokenizer.pad_token_id)
    loss = loss_fct(lm_logits.view(-1, vocab_size), labels.view(-1))
    return loss

In [18]:
mse_loss = nn.MSELoss()

In [19]:
sent_model = torch.load('experiment/classifier/classifier.pt')



In [20]:
sent_model

ClassifierModel(
  (embedding): Embedding_(
    (embedding): Embedding(50265, 64)
  )
  (lstm): LSTM(64, 512, num_layers=2, batch_first=True, dropout=0.5)
  (fc1): Linear(in_features=512, out_features=32, bias=True)
  (fc2): Linear(in_features=32, out_features=2, bias=True)
  (loss_fn): CrossEntropyLoss()
)

In [21]:
summary_model = BartForConditionalGeneration.from_pretrained('facebook/bart-base')

In [22]:
summary_model = summary_model.cuda()
senti_model = sent_model.cuda()

In [23]:
# embedding_dim = 64
# out_dim = 2
# n_layers = 4
# hidden_size = 512
merge_model = MergeModel(summary_model, senti_model)
model_name = 'BART_classifier'
model_dir = './experiment'
model_path = os.path.join(model_dir,model_name)
epochs  = 100

In [24]:
config = {
#     'embedding_dim': embedding_dim,
#     'out_dim': out_dim,
#     'n_layers': n_layers,
#     'hidden_size': hidden_size,
    'model_name': model_name,
    'epochs':epochs
}

In [25]:
def train(model, train_dataloader, val_dataloader=None, epochs=4, evaluation=False, model_dir='./experiment',optimizer='ADAM',config = None):
    
    train_losses = []
    train_accs = []
    val_losses = []
    val_accs = []
    min_val_loss = np.inf
    
    model_name = config['model_name']
    print(f"Start training for Model {model_name}...\n")
    
    
    if not os.path.exists(os.path.join(model_dir,model_name)):
        os.mkdir(os.path.join(model_dir,model_name))
    model_path = os.path.join(model_dir,model_name)
    write_to_file_in_dir(model_path, 'config.json', config)
    
    train_log =  'train_log.txt'
    write_string_train = f"Epoch, Train_Loss, Train_Acc"
    log_to_file_in_dir(model_path, train_log, write_string_train)

    if evaluation:
        val_log = 'val_log.txt'
        write_string_val = f"Epoch, Val_Loss, Val_Acc"
        log_to_file_in_dir(model_path, val_log, write_string_val)
    
    if optimizer == 'ADAM':
        optimizer = torch.optim.Adam(model.parameters())
    else:    
        optimizer = torch.optim.SGD(model.parameters(),1e-3,momentum=0.9,weight_decay=0.01)
    
    for epoch_i in range(epochs):
        print(f"{'Epoch':^7} | {'Batch':^7} | {'Train Loss':^12} | {'Val Loss':^10} | {'Val Acc':^9} | {'Elapsed':^9}")
        print("-"*70)
        t0_epoch, t0_batch = time.time(), time.time()

        total_loss, batch_loss, batch_counts = 0, 0, 0
        model.train()
        
        for step, batch in enumerate(train_dataloader):
            model.train()
            batch_counts +=1
            
            batch[0] = batch[0].cuda()
            batch[1] = batch[1].cuda()
            batch[2] = batch[2].cuda()
            batch[3] = batch[3].cuda()
            
            model.zero_grad()
            
            summary_out,*sentiments = merge_model(batch[0],batch[1], batch[2], batch[3])
        
            loss1 = loss_fn(summary_out.logits, batch[2])
            loss2 = mse_loss(sentiments[0], sentiments[1])

            loss = loss1+loss2
            batch_loss += loss.item()
            total_loss += loss.item()
            
            train_losses.append(loss.item())

            write_string_train = f"{epoch_i}, {loss.item()}"
            log_to_file_in_dir(model_path, train_log, write_string_train)
            
            
            loss.backward()

            optimizer.step()

            if (step % 100 == 0 and step != 0):
                time_elapsed = time.time() - t0_batch

                print(f"{epoch_i + 1:^7} | {step:^7} | {batch_loss / batch_counts:^12.6f} | {'-':^10} | {'-':^9} | {time_elapsed:^9.2f}")

                batch_loss, batch_counts = 0, 0
                t0_batch = time.time()

                print("-"*70)

        if ((epoch_i %20 ==0) and (epoch_i != 0)) or (epoch_i==epochs-1):
            if evaluation == True:
                val_loss, val_accuracy = evaluate(model, val_dataloader)
                val_losses.append(val_loss)
                val_accs.append(val_accuracy)
                
                write_string_val = f"{epoch_i}, {val_loss}, {val_accuracy}"
                log_to_file_in_dir(model_path, val_log, write_string_val)
                
                time_elapsed = time.time() - t0_epoch
                
                if val_loss < min_val_loss:
                    min_val_loss = val_loss
                    
                    torch.save(model, os.path.join(model_path, f'{model_name}.pt'))
                    
                print(f"{epoch_i + 1:^7} | {'-':^7} | {val_loss:^10.6f} | {val_accuracy:^9.2f} | {time_elapsed:^9.2f}")
        
    torch.save(model, os.path.join(model_path,f'{model_name}_final.pt'))
        
    return  train_losses, train_accs, val_losses,val_accs

In [None]:
stats = train(merge_model, train_dataloader, val_dataloader=None, epochs=epochs, evaluation=False,  config=config, optimizer='ADAM')

Start training for Model BART_classifier...

 Epoch  |  Batch  |  Train Loss  |  Val Loss  |  Val Acc  |  Elapsed 
----------------------------------------------------------------------
   1    |   100   |  12.846584   |     -      |     -     |   50.12  
----------------------------------------------------------------------
   1    |   200   |   9.330114   |     -      |     -     |   49.89  
----------------------------------------------------------------------
   1    |   300   |   8.523876   |     -      |     -     |   49.99  
----------------------------------------------------------------------
   1    |   400   |   8.219645   |     -      |     -     |   50.13  
----------------------------------------------------------------------
   1    |   500   |   8.053046   |     -      |     -     |   50.42  
----------------------------------------------------------------------
   1    |   600   |   7.934429   |     -      |     -     |   50.29  
---------------------------------------

In [None]:
train_losses, train_accs, val_losses,val_accs = stats

In [None]:
plt.plot(np.arange(len(train_losses)), np.array(train_losses))

# Testing

b1,b2,b3,b4 = next(iter(train_dataloader))
b1 = b1.cuda()
b2 = b2.cuda()
b3 = b3.cuda()
b4 = b4.cuda()

b1.shape, b2.shape, b3.shape, b4.shape

z1 = torch.nn.Parameter(torch.ones_like(b1).float()).long()
z2 = torch.ones_like(b2)

z1_dash = z1 + 5
z2_dash = z2 + 5

t = summary_model(b1, b2)

t.logits.shape

merge_model.train()

t_dash = merge_model(b1, b2, b3,b4)

t_dash[0].logits.shape

loss = mse_loss(t_dash[1], t_dash[2])

print(torch.autograd.grad(loss,merge_model.summary_model.parameters(), retain_graph=True))