In [1]:
import random
import time
import numpy as np
import torch
import torch.autograd
from torch.autograd import Variable
import pandas as pd
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

from IPython import get_ipython
get_ipython().run_line_magic("load_ext", "autoreload")
get_ipython().run_line_magic("autoreload", "2")

from datasets import load_dataset
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler

In [2]:
dataset = load_dataset("imdb")

Reusing dataset imdb (/tmp/xdg-cache/huggingface/datasets/imdb/plain_text/1.0.0/90099cb476936b753383ba2ae6ab2eae419b2e87f71cd5189cb9c8e5814d12a3)


In [3]:
train_batch = dataset['train']['text']
train_labels = torch.tensor(dataset['train']['label'])

In [4]:
val_batch = dataset['test']['text']
val_labels = torch.tensor(dataset['test']['label'])

In [5]:
val_batch = val_batch[:10000]
val_labels = torch.tensor(val_labels[:10000])

  


In [6]:
from transformers import  BartTokenizer

In [7]:
# Load the BART tokenizer.
print('Loading BART tokenizer...')
tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')

Loading BART tokenizer...


In [8]:
vocab_size = tokenizer.vocab_size

In [None]:
seq_length = 150
input_encoding = tokenizer(train_batch, return_tensors='pt', padding=True, truncation = True, max_length=seq_length)

In [None]:
val_encoding = tokenizer(val_batch, return_tensors='pt', padding=True, truncation = True, max_length=seq_length)

In [None]:
input_ids = input_encoding['input_ids']
input_mask = input_encoding['attention_mask']

In [None]:
val_ids = val_encoding['input_ids']
val_mask = val_encoding['attention_mask']

In [None]:
input_ids.shape, input_mask.shape, val_ids.shape, val_mask.shape

In [None]:
# Creating DataLoaders

# TRAINNG DATALOADER
batch_size = 80

train_data = TensorDataset(input_ids, input_mask, train_labels)
train_sampler = RandomSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size)

In [None]:
# VALIDATION DATALOADER

val_data = TensorDataset(val_ids, val_mask, val_labels)
val_sampler = RandomSampler(val_data)
val_dataloader = DataLoader(val_data, sampler=val_sampler, batch_size=1)

In [None]:
b = next(iter(train_dataloader))

In [None]:
b[0].shape, b[1].shape, b[2].shape

In [None]:
from ClassifierModel import ClassifierModel

In [None]:
clf = ClassifierModel(vocab_size, 64, 2, 2, 512, batch_size = batch_size, seq_length=seq_length)
clf

In [None]:
loss_fn = torch.nn.CrossEntropyLoss()

In [None]:
def set_seed(seed_value=42):
    """Set seed for reproducibility.
    """
    random.seed(seed_value)
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)
    torch.cuda.manual_seed_all(seed_value)

In [None]:
clf.to(0)

In [None]:
def train(model, train_dataloader, val_dataloader=None, epochs=4, evaluation=False):
    print("Start training...\n")
    train_loss = []
    train_acc = []
    for epoch_i in range(epochs):
        # =======================================
        #               Training
        # =======================================
        print(f"{'Epoch':^7} | {'Batch':^7} | {'Train Loss':^12} | {'Val Loss':^10} | {'Val Acc':^9} | {'Elapsed':^9}")
        print("-"*70)

        t0_epoch, t0_batch = time.time(), time.time()

        total_loss, batch_loss, batch_counts = 0, 0, 0
#         model.to(device)
        model.train()
        optimizer = torch.optim.Adam(model.parameters())
        
        for step, batch in enumerate(train_dataloader):
            batch_counts +=1
            
            b_input_ids, b_labels = batch[0].to(0), batch[2].to(0)
            model.zero_grad()
            b_input_ids = torch.nn.functional.one_hot(b_input_ids, num_classes=vocab_size)
            if not b_input_ids.shape[0] == batch_size:
                continue

            logits = model(b_input_ids)
#             print(logits.shape, b_labels.shape)
            loss = loss_fn(logits, b_labels)
            batch_loss += loss.item()
            total_loss += loss.item()
            
            train_loss.append(loss.item())
            
            loss.backward()

            optimizer.step()

            if (step % 20 == 0 and step != 0) or (step == len(train_dataloader) - 1):
                time_elapsed = time.time() - t0_batch
                
                print(f"{epoch_i + 1:^7} | {step:^7} | {batch_loss / batch_counts:^12.6f} | {'-':^10} | {'-':^9} | {time_elapsed:^9.2f}")

                batch_loss, batch_counts = 0, 0
                t0_batch = time.time()

        # Calculate the average loss over the entire training data
        avg_train_loss = total_loss / len(train_dataloader)

        print("-"*70)
        # =======================================
        #               Evaluation
        # =======================================
        if evaluation == True:
            # After the completion of each training epoch, measure the model's performance
            # on our validation set.
            val_loss, val_accuracy = evaluate(model, val_dataloader)

            # Print performance over the entire training data
            time_elapsed = time.time() - t0_epoch
            
            print(f"{epoch_i + 1:^7} | {'-':^7} | {avg_train_loss:^12.6f} | {val_loss:^10.6f} | {val_accuracy:^9.2f} | {time_elapsed:^9.2f}")
            print("-"*70)
        print("\n")
    
    print("Training complete!")
    torch.save(model, 'classifier_model.pt')
    return train_loss

In [None]:
train(clf, train_dataloader, epochs=50)

In [None]:
# def evaluate(model, val_dataloader):
#     """After the completion of each training epoch, measure the model's performance
#     on our validation set.
#     """
#     model.eval()

#     val_accuracy = []
#     val_loss = []

#     for batch in val_dataloader:
        
#         b_input_ids, b_labels = batch[0].to(0), batch[2].to(0)
#         b_input_ids = torch.nn.functional.one_hot(b_input_ids, num_classes=vocab_size)
# #         if not b_input_ids.shape[0] == 1:
# #             continue

#         # Compute logits
#         with torch.no_grad():
#             logits = model(b_input_ids)

#         # Compute loss
#         loss = loss_fn(logits, b_labels)
#         val_loss.append(loss.item())
#         print(loss)
#         # Get the predictions
#         preds = torch.argmax(logits, dim=1).flatten()

#         # Calculate the accuracy rate
#         accuracy = (preds == b_labels).cpu().numpy().mean() * 100
#         val_accuracy.append(accuracy)

#     # Compute the average accuracy and loss over the validation set.
#     val_loss = np.mean(val_loss)
#     val_accuracy = np.mean(val_accuracy)

#     return val_loss, val_accuracy

In [None]:
# evaluate(clf, val_dataloader)

# ---------- Testing differentiability ----------------

b[0].shape

b_dash = torch.randn((4,47, vocab_siz

b_dash = torch.randn((4,47, vocab_size))

a = torch.nn.Parameter(torch.zeros_like(b_dash).float() +1 )

c = 2*a

c.shape

idx =  torch.argmax(c, dim=-1, keepdims=  True)

mask = torch.zeros_like(c).scatter_(-1, idx, 1.).float().detach() + c - c.detach()

o1 = clf(mask)

mse = torch.nn.MSELoss()

z = mse(o1, torch.zeros_like(o1))

torch.autograd.grad(z, a)

# ---------------------- Fin ----------------------------