In [3]:
import torch
import torch.nn as nn
from torch import optim
from torch.optim.lr_scheduler import ExponentialLR
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import transformers
from transformers import BertTokenizer, BertModel
from sklearn.model_selection import train_test_split
from ignite.engine import Engine, Events
from ignite.metrics import Accuracy, Loss
from ignite.contrib.handlers.param_scheduler import create_lr_scheduler_with_warmup, LRScheduler
from ignite.handlers import EarlyStopping, Checkpoint, DiskSaver, global_step_from_engine
from ignite.engine.events import EventEnum
from tqdm.notebook import tqdm
import pandas as pd
import matplotlib.pyplot as plt

from models.mgru import mGRU

In [3]:
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")
torch.backends.cudnn.benchmark = True

## Load Bert

In [4]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained("bert-base-uncased")
for param in model.base_model.parameters():
    param.requires_grad = False

Downloading:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/466k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/570 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/440M [00:00<?, ?B/s]

In [5]:
model = model.to(device)

## Load data

In [6]:
df = pd.read_csv('/kaggle/input/stanford-natural-language-inference-corpus/snli_1.0_test.csv', 
                 usecols=["gold_label", "sentence1", "sentence2"])
df = df.loc[df["gold_label"] != "-"]
df.loc[:, ["sentence1", "sentence2"]] = df.loc[:, ["sentence1", "sentence2"]].apply(lambda x: x.str.lower())

In [7]:
class Dataset_from_encoding(Dataset):
    """Create dataset from encoding matrix
    """
    def __init__(self, p_encodings, h_encodings, labels):
        self.p_encodings = p_encodings
        self.h_encodings = h_encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {}
        item["p"] = {key: val[idx].clone().detach() for key, val in self.p_encodings.items()}
        item["h"] = {key: val[idx].clone().detach() for key, val in self.h_encodings.items()}
        item["labels"] = torch.tensor(self._get_label(self.labels[idx]))
        return item

    def __len__(self):
        return len(self.labels)
    
    def _get_label(self, x):
        label = {'contradiction': 0,
                 'neutral': 1,
                 'entailment': 2,}

        return label[x]

In [8]:
def get_train_test(df, test_size=0.2):
    """Train/test split & create Dataset
    """
    train, test = train_test_split(df, test_size=test_size, shuffle=True)
    train_p_encodings = tokenizer(train.sentence1.tolist(), 
                                return_tensors="pt",
                                max_length=64,
                                truncation=True,
                                padding=True).to(device)
    train_h_encodings = tokenizer(train.sentence2.tolist(), 
                                return_tensors="pt",
                                max_length=64,
                                truncation=True,
                                padding=True).to(device)
    test_p_encodings = tokenizer(test.sentence1.tolist(), 
                                return_tensors="pt",
                                max_length=64,
                                truncation=True,
                                padding=True).to(device)
    test_h_encodings = tokenizer(test.sentence2.tolist(), 
                               return_tensors="pt",
                               max_length=64,
                               truncation=True,
                               padding=True).to(device)

    train_ds = Dataset_from_encoding(train_p_encodings, train_h_encodings, train["gold_label"].tolist())
    test_ds = Dataset_from_encoding(test_p_encodings, test_h_encodings, test["gold_label"].tolist())

    return train_ds, test_ds

In [9]:
train_ds, test_ds = get_train_test(df)

In [10]:
train_dl = DataLoader(train_ds, 32, shuffle=True)
test_dl = DataLoader(test_ds, 32, shuffle=False)

## Model

In [2]:
options = {"EMBED_DIM": 300, "HIDDEN_DIM": 150, "CLASSES": 3, "DROPOUT": 0.2, "LAST_NON_LINEAR": True, "CUDA": use_cuda}
mgru = mGRU(options)

In [86]:
criterion = torch.nn.CrossEntropyLoss().cuda()
optim = transformers.AdamW(mgru.parameters(), lr=1e-4, eps=1e-6)

In [87]:
scheduler = transformers.get_linear_schedule_with_warmup(optim, 10000, len(train_dl)*10)

## Training

In [None]:
EPOCHS = 10

In [88]:
def train_step(engine, batch):
    mgru.train()
    optim.zero_grad()
    y = batch["labels"].to(device)
    p_encode = model(**batch["p"])["last_hidden_state"].permute(1, 0, 2)
    h_encode = model(**batch["h"])["last_hidden_state"].permute(1, 0, 2)
    y_pred = mgru(p_encode, h_encode, training=True)
    loss = criterion(y_pred, y)
    loss.backward()
    # engine.fire_event(BackpropEvents.BACKWARD_COMPLETED)
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)
    optim.step()
    # scheduler.step()
    return loss.item()


def validation_step(engine, batch):
    model.eval()
    y = batch["labels"].to(device)
    p_encode = model(**batch["p"])["last_hidden_state"].permute(1, 0, 2)
    h_encode = model(**batch["h"])["last_hidden_state"].permute(1, 0, 2)
    y_pred = mgru(p_encode, h_encode, training=False)
    return y_pred, y
    

def score_function(engine):
    return engine.state.metrics['accuracy']

In [None]:
log_interval = 10
pbar = tqdm(initial=0, leave=False, total=len(train_dl), desc=f"ITERATION - loss: {0:.2f}")

trainer = Engine(train_step)

val_metrics = {
    "accuracy": Accuracy(),
    "loss": Loss(criterion)
}
evaluator = Engine(validation_step)
for name, metric in val_metrics.items():
    metric.attach(evaluator, name)

handler = EarlyStopping(patience=5, score_function=score_function, trainer=trainer)
evaluator.add_event_handler(Events.COMPLETED, handler)

@trainer.on(Events.ITERATION_COMPLETED(every=log_interval))
def log_training_loss(engine):
    # print("Epoch[{}] Loss: {:.2f}".format(trainer.state.epoch, trainer.state.output))
    pbar.desc = f"ITERATION - loss: {engine.state.output:.2f}"
    pbar.update(log_interval)

@trainer.on(Events.EPOCH_COMPLETED)
def log_validation_results(engine):
    evaluator.run(test_dl)
    metrics = evaluator.state.metrics
    tqdm.write("Validation Results - Epoch: {}  Avg accuracy: {:.2f} Avg loss: {:.2f}"
          .format(trainer.state.epoch, metrics["accuracy"], metrics["loss"]))

    pbar.n = pbar.last_print_n = 0

# @evaluator.on(Events.EPOCH_COMPLETED)
# def reduct_step(engine):
#     scheduler.step()

@trainer.on(Events.EPOCH_COMPLETED | Events.COMPLETED)
def log_time(engine):
    tqdm.write(f"{trainer.last_event_name.name} took {trainer.state.times[trainer.last_event_name.name]} seconds")

trainer.run(train_dl, EPOCHS)
pbar.close()

## Plot gradient

In [None]:
def plot_grad_flow():
    '''Plots the gradients flowing through different layers in the net during training.
    Can be used for checking for possible gradient vanishing / exploding problems.'''
    ave_grads = []
    max_grads= []
    layers = []
    for n, p in mgru.named_parameters():
        if(p.requires_grad) and ("bias" not in n):
            layers.append(n)
            ave_grads.append(p.grad.mean())
            max_grads.append(p.grad.max())
    plt.bar(np.arange(len(max_grads)), max_grads, alpha=0.1, lw=1, color="c")
    plt.bar(np.arange(len(max_grads)), ave_grads, alpha=0.1, lw=1, color="b")
    plt.hlines(0, 0, len(ave_grads)+1, lw=2, color="k" )
    plt.xticks(range(0,len(ave_grads), 1), layers, rotation="vertical")
    plt.xlim(left=0, right=len(ave_grads))
    plt.ylim(bottom = -0.001, top=0.02)
    plt.xlabel("Layers")
    plt.ylabel("average gradient")
    plt.show()

plot_grad_flow()