## Loading models

In [None]:
import torch

from modeling.pretrained_bert import PretrainedBertModule
from modeling.lstm import LSTMModule
from trainer import load_model

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# load in the model. Note that load_model returns a tuple (<model>, <checkpoint>).
# The checkpoint just contains information and cofigurations for training, and isn't very relevant in our experiments.
bert_model_path = "mqnli_models/bert-easy-best.pt"
lstm_model_path = "mqnli_models/lstm-easy-best.pt"

bert_model, _ = load_model(PretrainedBertModule, bert_model_path, device=device)
lstm_model, _ = load_model(LSTMModule, lstm_model_path, device=device)

bert_model = bert_model.to(device)
bert_model.eval() # disable dropout
lstm_model = lstm_model.to(device)
lstm_model.eval() # disable dropout

## Loading data (Bert)

In [None]:
import torch
# pytorch checkpoint containing preprocessed data
bert_easy_data_path = "mqnli_models/bert-preprocessed-data.pt"

# load the data
bert_data = torch.load(bert_easy_data_path)
print(type(bert_data)) # datasets.mqnli.MQNLIBertData

# An MQNLIBertData object contains the train, dev, and test sets, and other tools and data structures for tokenization.

bert_train_set = bert_data.train
bert_dev_set = bert_data.dev
bert_test_set = bert_data.test

print(type(bert_dev_set)) # datasets.mqnli.MQNLIBertDataset

# The MQNLIBertDataset (not MQNLIBertData) class inherits the torch.utils.data.Dataset class and uses pytorch's Dataset interface.


## Iterating through examples (Bert)

In [None]:
# accessing first ten examples in the dev set
from torch.utils.data import DataLoader

dataloader = DataLoader(bert_dev_set, batch_size=16, shuffle=False)


# Iterating through the dataloader produces batches of examples.
# Each batch of examples is packaged into a tuple/list that is fed into the model.
# The tuple has the form:
# (
#   input_ids,       # numeric token ids for the sentence itself. shape=(batch_size, 27)
#   token_type_ids,  # For bert: A 0/1 tag for each token indicating if it is the premise (0) or hypothesis(1). shape=(batch_size, 27)
#   attention_masks, # For bert: A 0./1. float value to indicate if its a padding token that shouldn't be attended to (0). shape=(batch_size, 27)
#   original_input,  # numeric token ids for the input, but without the [CLS] and [SEP] tokens and
#                      without breaking up doesnot, notevery into two words. shape=(batch_size,18)
#   label,           # gold label, 0 (neutral), 1 (entailment), 2 (contradiction). shape=(batchsize,)
# )
 # the function that generates this tuple is datasets.mqnli.MQNLIBertDataset.__getitem__()

with torch.no_grad():
    for i, input_tuple in enumerate(dataloader):
        if i == 10: break
        input_tuple = [x.to(device) for x in input_tuple]
        labels = input_tuple[-1]

        logits = bert_model(input_tuple) # call the forward function of the Bert model
        pred = torch.argmax(logits, dim=1) # get label predictions
        print(pred == labels)

# Loading Data (LSTM)

This part is pretty much the same as Bert.

In [None]:
# pytorch checkpoint containing preprocessed data
lstm_easy_data_path = "mqnli_models/lstm-preprocessed-data.pt"

# load the data
lstm_data = torch.load(lstm_easy_data_path)
print(type(lstm_data)) # datasets.mqnli.MQNLIData

# An MQNLIData object contains the train, dev, and test sets, and other tools and data structures for tokenization. This is all same as

lstm_train_set = lstm_data.train
lstm_dev_set = lstm_data.dev
lstm_test_set = lstm_data.test

print(type(lstm_dev_set)) # datasets.mqnli.MQNLIDataset
# up to now this is same as loading Bert Data

# Iterating through examples (LSTM)

This part is slightly different. The LSTM model only accepts inputs where the first dimension (dim=0) is sentence length, and the second dimension (dim=1) is the batch size. To do this we need a special `collate_fn` that does this transposition, as the `DataLoader` outputs batch_first by default. Other than that the rest remains the same.

In [None]:
# accessing first ten examples in the dev set
from torch.utils.data import DataLoader
from datasets.mqnli import get_collate_fxn

collate_fn = get_collate_fxn(lstm_dev_set, batch_first=False) # get the collate function automatically
dataloader = DataLoader(lstm_dev_set, batch_size=16, shuffle=False, collate_fn=collate_fn)


# Iterating through the dataloader produces batches of examples.
# Each batch of examples is packaged into a tuple/list that is fed into the model.
# The tuple has the form:
# (
#   input_ids,       # numeric token ids for the sentence itself. ***shape=(19, batch_size)*** where 19 is sentence length, and contains a [SEP] token.
#   label,           # gold label, 0 (neutral), 1 (entailment), 2 (contradiction). shape=(batchsize,)
# )
# the function that generates this tuple is datasets.mqnli.MQNLIDataset.__getitem__()

# this part is same as before
with torch.no_grad():
    for i, input_tuple in enumerate(dataloader):
        if i == 10: break
        input_tuple = [x.to(device) for x in input_tuple]
        print(len(input_tuple))
        labels = input_tuple[-1]

        logits = lstm_model(input_tuple) # call the forward function of the LSTM model
        pred = torch.argmax(logits, dim=1) # get label predictions
        print(pred == labels)

In [None]:
input_tuple[0].shape