In [1]:
import os
import sys
import math
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import spacy

In [2]:
!pip install transformers



In [3]:
from transformers import T5Tokenizer, T5ForConditionalGeneration, T5Config

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

Using device: cuda


In [5]:
PRETRAINED_MODEL = 't5-base'
DIR = "/content/drive/My Drive/ml_hw/NLP/question_generator/"
BATCH_SIZE = 4
SEQ_LENGTH = 512

tokenizer = T5Tokenizer.from_pretrained(PRETRAINED_MODEL)
tokenizer.add_special_tokens(
    {'additional_special_tokens': ['<answer>', '<context>']}
)

class QGDataset(Dataset):
    def __init__(self, csv):
        self.df = pd.read_csv(csv, engine='python')

    def __len__(self):
         return len(self.df)

    def __getitem__(self, idx):   
        if torch.is_tensor(idx):
            idx = idx.tolist()
        row = self.df.iloc[idx, 1:]       

        encoded_text = tokenizer(
            row['text'], 
            pad_to_max_length=True, 
            max_length=SEQ_LENGTH,
            truncation=True,
            return_tensors="pt"
        )
        encoded_text['input_ids'] = torch.squeeze(encoded_text['input_ids'])
        encoded_text['attention_mask'] = torch.squeeze(encoded_text['attention_mask'])

        encoded_question = tokenizer(
            row['question'],
            pad_to_max_length=True,
            max_length=SEQ_LENGTH,
            truncation=True,
            return_tensors='pt'
        )
        encoded_question['input_ids'] = torch.squeeze(encoded_question['input_ids'])

        return (encoded_text.to(device), encoded_question.to(device))

train_set = QGDataset(os.path.join(DIR, 'qg_train_3.csv')) 
train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)
valid_set = QGDataset(os.path.join(DIR, 'qg_valid_3.csv')) 
valid_loader = DataLoader(valid_set, batch_size=BATCH_SIZE, shuffle=False)

In [7]:
LR = 0.001
EPOCHS = 20
LOG_INTERVAL = 5000

config = T5Config(decoder_start_token_id=tokenizer.pad_token_id)
model = T5ForConditionalGeneration(config).from_pretrained(PRETRAINED_MODEL)
model.resize_token_embeddings(len(tokenizer)) # to account for new special tokens
model = model.to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=LR)

Some weights of T5ForConditionalGeneration were not initialized from the model checkpoint at t5-base and are newly initialized: ['encoder.embed_tokens.weight', 'decoder.embed_tokens.weight', 'lm_head.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [8]:
SAVED_MODEL_PATH = "/content/drive/My Drive/ml_hw/NLP/question_generator/qg_pretrained_t5_model_trained_3.pth"
TEMP_SAVE_PATH = "/content/drive/My Drive/ml_hw/NLP/question_generator/qg_pretrained_t5_model_trained_3_TEMP.pth"

def train(epoch, best_val_loss):
    model.train()
    total_loss = 0.
    for batch_index, batch in enumerate(train_loader):
        data, target = batch
        optimizer.zero_grad()
        masked_labels = mask_label_padding(target['input_ids'])
        output = model(
            input_ids=data['input_ids'],
            attention_mask=data['attention_mask'],
            lm_labels=masked_labels
        )
        loss = output[0]
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()

        total_loss += loss.item()
        if batch_index % LOG_INTERVAL == 0 and batch_index > 0:
            cur_loss = total_loss / LOG_INTERVAL
            print('| epoch {:3d} | ' 
                  '{:5d}/{:5d} batches | '
                  'loss {:5.2f}'.format(
                    epoch, 
                    batch_index, len(train_loader), 
                    cur_loss))
            save(
                TEMP_SAVE_PATH,
                epoch, 
                model.state_dict(), 
                optimizer.state_dict(), 
                best_val_loss
            )
            total_loss = 0

def evaluate(eval_model, data_loader):
    eval_model.eval()
    total_loss = 0.
    with torch.no_grad():
        for batch_index, batch in enumerate(data_loader):
            data, target = batch
            masked_labels = mask_label_padding(target['input_ids'])
            output = eval_model(
                input_ids=data['input_ids'],
                attention_mask=data['attention_mask'],
                lm_labels=masked_labels
            )
            total_loss += output[0].item()
    return total_loss / len(data_loader)

def generate_attention_mask(input_ids):
    # 0 for masked tokens, 1 for not masked tokens
    mask = torch.where(
        input_ids.cpu() == torch.full(input_ids.size(), tokenizer.pad_token_id, dtype=int),
        torch.zeros(input_ids.size()), 
        torch.ones(input_ids.size())
    )
    return mask.to(device)

def mask_label_padding(labels):
    MASK_ID = -100
    labels[labels==tokenizer.pad_token_id] = MASK_ID
    return labels

def save(path, epoch, model_state_dict, optimizer_state_dict, loss):
    torch.save({
            'epoch': epoch,
            'model_state_dict': model_state_dict,
            'optimizer_state_dict': optimizer_state_dict,
            'best_loss': loss,
            }, path)

    print("| Model saved.")
    print_line()

def load(path):
    return torch.load(path)

def print_line():
    LINE_WIDTH = 60
    print('-' * LINE_WIDTH)

In [None]:
best_val_loss = float("inf")
best_model = None

val_loss = evaluate(model, valid_loader)
print_line()
print('| Before training | valid loss {:5.2f}'.format(
    val_loss)
)
print_line()

for epoch in range(1, EPOCHS + 1):

    train()
    val_loss = evaluate(model, valid_loader)
    print_line()
    print('| end of epoch {:3d} | valid loss {:5.2f}'.format(
        epoch,
        val_loss)
    )
    print_line()

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_model = model
        save(
             SAVED_MODEL_PATH,
             epoch, 
             model.state_dict(), 
             optimizer.state_dict(), 
             best_val_loss
        )

------------------------------------------------------------
| Before training | valid loss  5.68
------------------------------------------------------------
| epoch   1 |  1000/51003 batches | loss  4.34
| epoch   1 |  2000/51003 batches | loss  3.61
| epoch   1 |  3000/51003 batches | loss  3.52
| epoch   1 |  4000/51003 batches | loss  3.41
| epoch   1 |  5000/51003 batches | loss  3.38
| epoch   1 |  6000/51003 batches | loss  3.33
| epoch   1 |  7000/51003 batches | loss  3.32
| epoch   1 |  8000/51003 batches | loss  3.26
| epoch   1 |  9000/51003 batches | loss  3.20
| epoch   1 | 10000/51003 batches | loss  3.18
| epoch   1 | 11000/51003 batches | loss  3.13
| epoch   1 | 12000/51003 batches | loss  3.09
| epoch   1 | 13000/51003 batches | loss  3.08
| epoch   1 | 14000/51003 batches | loss  3.03
| epoch   1 | 15000/51003 batches | loss  3.01
| epoch   1 | 16000/51003 batches | loss  2.97
| epoch   1 | 17000/51003 batches | loss  2.96
| epoch   1 | 18000/51003 batches | loss  

In [None]:
# let's re-load the model and continue training.
# it timed out mid-epoch so we're going to save more regularly

checkpoint = load(SAVED_MODEL_PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
current_epoch = checkpoint['epoch']
best_val_loss = checkpoint['best_loss']
model.to(device)

for epoch in range(current_epoch + 1, EPOCHS + 1):

    train(epoch, best_val_loss)
    val_loss = evaluate(model, valid_loader)
    print_line()
    print('| end of epoch {:3d} | valid loss {:5.2f}'.format(
        epoch,
        val_loss)
    )
    print_line()

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        save(
            epoch, 
             model.state_dict(), 
             optimizer.state_dict(), 
             best_val_loss
        )

| epoch   2 |  5000/51003 batches | loss  2.48
| Model saved.
------------------------------------------------------------
| epoch   2 | 10000/51003 batches | loss  2.47
| Model saved.
------------------------------------------------------------


In [None]:
#let's re-load the model and continue training AGAIN.

checkpoint = load()
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
current_epoch = checkpoint['epoch']
best_val_loss = checkpoint['best_loss']
model.to(device)

additional_epochs = EPOCHS + 5

for epoch in range(current_epoch + 1, additional_epochs + 1):

    train()
    val_loss = evaluate(model, valid_loader)
    print_line()
    print('| end of epoch {:3d} | valid loss {:5.2f}'.format(
        epoch,
        val_loss)
    )
    print_line()

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        save(
            epoch, 
             model.state_dict(), 
             optimizer.state_dict(), 
             best_val_loss
        )

In [None]:
#let's re-load the model and continue training AGAIN AGAIN

checkpoint = load()
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
current_epoch = checkpoint['epoch']
best_val_loss = checkpoint['best_loss']
model.to(device)

additional_epochs = EPOCHS + 5

for epoch in range(current_epoch + 1, additional_epochs + 1):

    train()
    val_loss = evaluate(model, valid_loader)
    print_line()
    print('| end of epoch {:3d} | valid loss {:5.2f}'.format(
        epoch,
        val_loss)
    )
    print_line()

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        save(
            epoch, 
             model.state_dict(), 
             optimizer.state_dict(), 
             best_val_loss
        )

In [None]:
#let's re-load the model and continue training AGAIN AGAIN AGAIN (last time I promise)

checkpoint = load()
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
current_epoch = checkpoint['epoch']
best_val_loss = checkpoint['best_loss']
model.to(device)

additional_epochs = EPOCHS + 10

for epoch in range(current_epoch + 1, additional_epochs + 1):

    train()
    val_loss = evaluate(model, valid_loader)
    print_line()
    print('| end of epoch {:3d} | valid loss {:5.2f}'.format(
        epoch,
        val_loss)
    )
    print_line()

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        save(
            epoch, 
             model.state_dict(), 
             optimizer.state_dict(), 
             best_val_loss
        )

In [None]:
checkpoint = load()
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)

model.eval()
total_loss = 0.
with torch.no_grad():
    i = 0
    for batch_index, batch in enumerate(valid_loader):
        data, target = batch
        masked_labels = mask_label_padding(target['input_ids'])
        output = model.generate(input_ids=data['input_ids'])
        print(tokenizer.decode(output[0]))
        print(tokenizer.decode(data['input_ids'][0]))
        i += 1
        if i >= 10:
            break

In [None]:
checkpoint = load()
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)

inputs = [
      '<answer> The US virus death toll has surpassed 50,000 <context> The US virus death toll has surpassed 50,000, according to data from Johns Hopkins University.',
      '<answer> in the last 24 hours <context> More than 3,000 deaths came in the last 24 hours, and there are now over 870,000 confirmed cases nationwide.',
      '<answer> The US <context> The US still has a lower mortality rate than most European nations based on current case counts.'    
]

for input_text in inputs:
    encoded_input = tokenizer(input_text, return_tensors='pt').to(device)
    output = model.generate(input_ids=encoded_input['input_ids'])
    print(tokenizer.decode(output[0]))