In [1]:
import torch
import torch.nn as nn
import numpy as np
from torch.utils.data import TensorDataset, DataLoader
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss

In [2]:
import pandas as pd

# Load the dataset into a pandas dataframe.
df = pd.read_csv("test.csv")

# Report the number of sentences.
print('Number of training sentences: {:,}\n'.format(df.shape[0]))

# Display 10 random rows from the data.
df.sample(10)

Number of training sentences: 11,490



Unnamed: 0,article,highlights,id
3602,Ben Affleck is apologizing after it came to li...,Ben Affleck has apologized after he demanded i...,3d9c0b6d1e1093ef2918ef6c797c7f9b7b821fb5
4075,With six games to go - five for QPR and seven ...,"Leicester, Burnley and QPR hoping to avoid goi...",48c5580be9e2b8b83881b1a8fd889fa13bb418c3
4041,The BBC has refused to hand over the emails of...,Mother-of-two died following a 10-year battle ...,4822423290fd24a1a8546e7311e4830302eb55ea
2513,These eerie photographs show the inside of an ...,Mysterious history of Grade II listed asylum ...,22ac75d7c650abe94f1abbcf9309556647844136
6860,A man dubbed 'New Zealand's worst ever drink d...,Man set to walk free from jail just three year...,8cff9d18d04ea583c4062ac4003b31679ca7885c
11396,Jonathan Davies has warned Saracens that Clerm...,Clermont Auvergne take on Saracens in Champion...,fd62146fe8c3cf7eb37c6266ea0c81578148dd15
5251,Andres Iniesta has responded to his critics by...,Midfielder fires a warning to his detractors t...,6435a2011b44e57382d9baa8b1f98a2b3e377d7a
8108,A young fisherman in Thailand reeled in the ca...,The young boy holds a tiny blue and yellow to...,ac4c516da1ddbf17dd3e57039bbe83d5cf3a09c8
7563,Shops offering everything for £1 or less have ...,Number of Britain's highest earners going to p...,9ef8de0159d987fba67911dc6595dbcf667b76e2
1215,A devastating fire has caused serious damage t...,Witnesses reported massive plumes of smoke bil...,031b2b600b555f1aa0a31189babea20fea8c3316


In [3]:
train_sentence = df.article.values
train_target = df.highlights.values

In [4]:
print(train_sentence.shape, train_target.shape)

(11490,) (11490,)


In [5]:
num_data_points = 100
train_sentence = train_sentence[:num_data_points]
train_target = train_target[:num_data_points]

In [6]:
train_sentence = list(train_sentence)
train_target = list(train_target)

### Tokenize

In [7]:
# !conda install -c huggingface transformers

In [8]:
from transformers import BartTokenizer, BartForConditionalGeneration
# Load the BART tokenizer.
print('Loading BART tokenizer...')
tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')

Loading BART tokenizer...


In [9]:
vocab_size = tokenizer.vocab_size

In [10]:
article_encoding = tokenizer(train_sentence, return_tensors='pt', padding=True, truncation = True)

In [11]:
article_input_ids = article_encoding['input_ids']
article_attention_mask = article_encoding['attention_mask']

In [12]:
print(article_input_ids.shape, article_attention_mask.shape)

torch.Size([100, 1024]) torch.Size([100, 1024])


In [13]:
target_encoding = tokenizer(train_target, return_tensors='pt', padding=True, truncation = True)

In [14]:
target_input_ids = target_encoding['input_ids']
target_attention_mask = target_encoding['attention_mask']

In [15]:
print(target_input_ids.shape, target_attention_mask.shape)

torch.Size([100, 72]) torch.Size([100, 72])


### Dataloader

In [16]:
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
# The DataLoader needs to know our batch size for training, so we specify it 
# here.
# For fine-tuning BERT on a specific task, the authors recommend a batch size of
# 16 or 32.
batch_size = 4
# Create the DataLoader for our training set.
train_data = TensorDataset(article_input_ids, article_attention_mask, target_input_ids, target_attention_mask)

In [17]:
train_sampler = RandomSampler(train_data)

In [18]:
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size)

### Loss

In [19]:
# The loss function
def loss_fn(lm_logits, labels):
    loss_fct = CrossEntropyLoss(ignore_index = tokenizer.pad_token_id)
    loss = loss_fct(lm_logits.view(-1, vocab_size), labels.view(-1))
    return loss
    

In [20]:
# Import the model
model = BartForConditionalGeneration.from_pretrained('facebook/bart-base')
model = model.cuda()
model.train()

BartForConditionalGeneration(
  (model): BartModel(
    (shared): Embedding(50265, 768, padding_idx=1)
    (encoder): BartEncoder(
      (embed_tokens): Embedding(50265, 768, padding_idx=1)
      (embed_positions): BartLearnedPositionalEmbedding(1026, 768, padding_idx=1)
      (layers): ModuleList(
        (0): BartEncoderLayer(
          (self_attn): BartAttention(
            (k_proj): Linear(in_features=768, out_features=768, bias=True)
            (v_proj): Linear(in_features=768, out_features=768, bias=True)
            (q_proj): Linear(in_features=768, out_features=768, bias=True)
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (fc1): Linear(in_features=768, out_features=3072, bias=True)
          (fc2): Linear(in_features=3072, out_features=768, bias=True)
          (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
   

In [21]:
optimizer = torch.optim.SGD(model.parameters(),1e-3,momentum=0.9,weight_decay=0.01)

In [22]:
epochs = 500

for eps in range(epochs):
    # For each batch of training data...
    print('Epoch: ', eps)
    
    epoch_loss = 0
    
    for step, batch in enumerate(train_dataloader):
        
        # push the batch to the cuda
        batch[0] = batch[0].cuda()
        batch[1] = batch[1].cuda()
        batch[2] = batch[2].cuda()
        batch[3] = batch[3].cuda()
        
        model.train()
        
        optimizer.zero_grad()
        
        out = model(input_ids=batch[0], attention_mask =batch[1], labels=batch[2], decoder_attention_mask=batch[3], return_dict=True)
        
        loss = loss_fn(out.logits, batch[2])
        
        epoch_loss = epoch_loss + loss
        
        # print([tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in out.logits.argmax(dim = 2)])
    
        loss.backward()
        optimizer.step()
    
    # For each batch of training data... print loss and the batch number
    print("Epoch: %d, Loss: %.3f" % (eps, epoch_loss.item()))


Epoch:  0
Epoch: 0, Loss: 86.300
Epoch:  1
Epoch: 1, Loss: 65.039
Epoch:  2
Epoch: 2, Loss: 55.634
Epoch:  3
Epoch: 3, Loss: 50.394
Epoch:  4
Epoch: 4, Loss: 46.293
Epoch:  5
Epoch: 5, Loss: 41.135
Epoch:  6
Epoch: 6, Loss: 38.265
Epoch:  7
Epoch: 7, Loss: 34.882
Epoch:  8
Epoch: 8, Loss: 32.798
Epoch:  9
Epoch: 9, Loss: 29.365
Epoch:  10
Epoch: 10, Loss: 27.291
Epoch:  11
Epoch: 11, Loss: 25.083
Epoch:  12
Epoch: 12, Loss: 23.044
Epoch:  13
Epoch: 13, Loss: 20.923
Epoch:  14
Epoch: 14, Loss: 19.129
Epoch:  15
Epoch: 15, Loss: 17.382
Epoch:  16
Epoch: 16, Loss: 16.303
Epoch:  17
Epoch: 17, Loss: 14.827
Epoch:  18
Epoch: 18, Loss: 14.197
Epoch:  19
Epoch: 19, Loss: 12.953
Epoch:  20
Epoch: 20, Loss: 11.955
Epoch:  21
Epoch: 21, Loss: 11.466
Epoch:  22
Epoch: 22, Loss: 58.834
Epoch:  23
Epoch: 23, Loss: 185.245
Epoch:  24
Epoch: 24, Loss: 170.444
Epoch:  25
Epoch: 25, Loss: 160.739
Epoch:  26
Epoch: 26, Loss: 121.642
Epoch:  27
Epoch: 27, Loss: 65.060
Epoch:  28
Epoch: 28, Loss: 35.696
E

In [24]:
summary_ids = model.generate(article_input_ids[0].view(1,-1).cuda(), num_beams=4, max_length = 72, early_stopping=True)

In [25]:
print([tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summary_ids])

['James Best, who played the sheriff on "The Dukes of Hazzard," died Monday at 88 .\n"Hazzard" ran from 1979 to 1985 and was among the most popular shows on TV .']


In [26]:
print([tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in target_input_ids[0].view(1,-1)])

['James Best, who played the sheriff on "The Dukes of Hazzard," died Monday at 88 .\n"Hazzard" ran from 1979 to 1985 and was among the most popular shows on TV .']
