In [1]:
!git clone https://github.com/timurtripp/CSCI5832-Shared-Task.git repo

Cloning into 'repo'...
remote: Enumerating objects: 1074, done.[K
remote: Counting objects: 100% (1074/1074), done.[K
remote: Compressing objects: 100% (1061/1061), done.[K
remote: Total 1074 (delta 23), reused 1047 (delta 9), pack-reused 0[K
Receiving objects: 100% (1074/1074), 3.35 MiB | 11.65 MiB/s, done.
Resolving deltas: 100% (23/23), done.


In [2]:
! pip install transformers tqdm



In [3]:
# Imports for most of the notebook
import torch
from transformers import BertModel
from transformers import AutoTokenizer
from typing import Dict, List
import random
from tqdm.autonotebook import tqdm

In [4]:
# from repo.dataset import write_split_dataset
from repo.dataset import read_split_dataset
from repo.dataset import label_distribution

# write_split_dataset('./repo/training_data/train.json', './repo/training_data/dev.json', './repo/split-dataset.json')
train_sentences, train_labels, dev_sentences, dev_labels, validation_sentences, validation_labels = read_split_dataset('./repo/split-dataset.json')

print('Train Label Distribution:\n')
label_distribution(train_labels)

print('\n\nDev Label Distribution:\n')
label_distribution(dev_labels)

print('\n\nValidation Label Distribution:\n')
label_distribution(validation_labels)

Train Label Distribution:

--- Entailment Count ---
765

--- Contradiction Count ---
765


Dev Label Distribution:

--- Entailment Count ---
85

--- Contradiction Count ---
85


Validation Label Distribution:

--- Entailment Count ---
100

--- Contradiction Count ---
100


In [5]:
print(torch.cuda.is_available())
# device = torch.device("cpu")
# ✓ Uncomment the below line if you see True in the print statement
device = torch.device("cuda:0")

True


In [6]:
class BatchTokenizer:
    """Tokenizes and pads a batch of input sentences."""

    def __init__(self, model_name):
        """Initializes the tokenizer

        Args:
            pad_symbol (Optional[str], optional): The symbol for a pad. Defaults to "<P>".
        """
        self.hf_tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model_name = model_name

    def get_sep_token(self,):
        return self.hf_tokenizer.sep_token

    def __call__(self, prem_batch: List[str], hyp_batch: List[str]) -> List[List[str]]:
        """Uses the huggingface tokenizer to tokenize and pad a batch.

        We return a dictionary of tensors per the huggingface model specification.

        Args:
            batch (List[str]): A List of sentence strings

        Returns:
            Dict: The dictionary of token specifications provided by HuggingFace
        """
        # The HF tokenizer will PAD for us, and additionally combine
        # The two sentences deimited by the [SEP] token.
        enc = self.hf_tokenizer(
            prem_batch,
            hyp_batch,
            padding=True,
            return_token_type_ids=False,
            return_tensors='pt'
        )

        return enc

In [7]:
def chunk(lst, n):
    """Yield successive n-sized chunks from lst."""
    for i in range(0, len(lst), n):
        yield lst[i:i + n]

def chunk_multi(lst1, lst2, n):
    for i in range(0, len(lst1), n):
        yield lst1[i: i + n], lst2[i: i + n]

batch_size = 10
tokenizer = BatchTokenizer('prajjwal1/bert-small')
# Notice that since we use huggingface, we tokenize and
# encode in all at once!
train_input_batches = [b for b in chunk_multi([entry['statement'] for entry in train_sentences], [entry['pid'] if entry['sid'] is None else entry['pid'] + ', ' + entry['sid'] for entry in train_sentences], batch_size)]
# Tokenize + encode
train_input_batches = [tokenizer(*batch) for batch in train_input_batches]

config.json:   0%|          | 0.00/286 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

In [8]:
def encode_labels(labels: List[int]) -> torch.FloatTensor:
    """Turns the batch of labels into a tensor

    Args:
        labels (List[int]): List of all labels in the batch

    Returns:
        torch.FloatTensor: Tensor of all labels in the batch
    """
    return torch.LongTensor([int(l) for l in labels])


train_label_batches = [b for b in chunk(train_labels, batch_size)]
train_label_batches = [encode_labels(batch) for batch in train_label_batches]

In [9]:
# For making predictions at test time
def predict(model: torch.nn.Module, sents: torch.Tensor) -> List:
    logits = model(sents.to(device))
    return list(torch.Tensor.cpu(torch.argmax(logits, axis=2).squeeze()).numpy())

In [10]:
class NLIClassifier(torch.nn.Module):
    def __init__(self, output_size: int, hidden_size: int, model_name='prajjwal1/bert-small'):
        super().__init__()
        self.output_size = output_size
        self.hidden_size = hidden_size

        # Initialize BERT, which we use instead of a single embedding layer.
        self.bert = BertModel.from_pretrained(model_name)

        # TODO [OPTIONAL]: Updating all BERT parameters can be slow and memory intensive.
        # Freeze them if training is too slow. Notice that the learning
        # rate should probably be smaller in this case.
        # Uncommenting out the below 2 lines means only our classification layer will be updated.

        # for param in self.bert.parameters():
        #    param.requires_grad = False

        self.bert_hidden_dimension = self.bert.config.hidden_size

        # ✓ Add an extra hidden layer in the classifier, projecting
        #      from the BERT hidden dimension to hidden size. Hint: torch.nn.Linear()

        self.hidden_layer = torch.nn.Linear(self.bert_hidden_dimension, self.hidden_size)

        # ✓ Add a relu nonlinearity to be used in the forward method
        #      https://pytorch.org/docs/stable/generated/torch.nn.ReLU.html

        self.relu = torch.nn.ReLU()

        self.classifier = torch.nn.Linear(self.hidden_size, self.output_size)
        self.log_softmax = torch.nn.LogSoftmax(dim=2)

    def encode_text(
        self,
        symbols: Dict
    ) -> torch.Tensor:
        """Encode the (batch of) sequence(s) of token symbols BERT.
            Then, get CLS represenation.

        Args:
            symbols (Dict): The Dict of token specifications provided by the HuggingFace tokenizer

        Returns:
            torch.Tensor: CLS token embedding
        """
        # First we get the contextualized embedding for each input symbol
        # We no longer need an LSTM, since BERT encodes context and
        # gives us a single vector describing the sequence in the form of the [CLS] token.
        encoded_sequence = self.bert(**symbols)
        # ✓ Get the [CLS] token
        #      The BertModel output. See here: https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertModel
        #      and check the returns for the forward method.
        # We want to return a tensor of the form batch_size x 1 x bert_hidden_dimension
        # print(encoded_sequence.last_hidden_state.shape)
        # Return only the first token's embedding from the last_hidden_state. Hint: using list slices
        # print(encoded_sequence.last_hidden_state.shape)
        return encoded_sequence.last_hidden_state[:, 0:1, :]

    def forward(
        self,
        symbols: Dict,
    ) -> torch.Tensor:
        """_summary_

        Args:
            symbols (Dict): The Dict of token specifications provided by the HuggingFace tokenizer

        Returns:
            torch.Tensor: _description_
        """
        encoded_sents = self.encode_text(symbols)
        output = self.hidden_layer(encoded_sents)
        output = self.relu(output)
        output = self.classifier(output)
        return self.log_softmax(output)

In [11]:
def training_loop(
    num_epochs,
    train_features,
    train_labels,
    dev_sents,
    dev_labels,
    optimizer,
    model,
    label_map=None
):
    print("Training...")
    loss_func = torch.nn.NLLLoss()
    batches = list(zip(train_features, train_labels))
    random.shuffle(batches)
    for i in range(num_epochs):
        losses = []
        for features, labels in tqdm(batches):
            # Empty the dynamic computation graph
            optimizer.zero_grad()
            preds = model(features.to(device)).squeeze(1)
            loss = loss_func(preds, labels.to(device))
            # Backpropogate the loss through our model
            loss.backward()
            optimizer.step()
            losses.append(loss.item())

        print(f"epoch {i}, loss: {sum(losses)/len(losses)}")
        # Estimate the f1 score for the development set
        print("Evaluating dev...")
        all_preds = []
        all_labels = []
        for sents, labels in tqdm(zip(dev_sents, dev_labels), total=len(dev_sents)):
            pred = predict(model, sents)
            all_preds.extend(pred)
            all_labels.extend(list(labels.cpu().numpy()))

        dev_f1 = macro_f1(all_preds, all_labels, possible_labels=set(all_labels), label_map = label_map)
        print(f"Dev F1 {dev_f1}")

    # Return the trained model
    return model

In [12]:
def class_predict(model, sents):

    with torch.inference_mode():

        logits = model(**sents.to(device)).logits
        predictions = torch.argmax(logits, axis=1)

    return predictions

def prediction_loop(model, dev_sents, dev_labels, label_map=None):
    print("Evaluating...")
    all_preds = []
    all_labels = []
    for sents, labels in tqdm(zip(dev_sents, dev_labels), total=len(dev_sents)):
        pred = class_predict(model, sents).cpu()
        all_preds.extend(pred)
        all_labels.extend(list(labels.cpu().numpy()))

    dev_f1 = macro_f1(all_preds, all_labels, possible_labels=set(all_labels), label_map = label_map)
    print(f"F1 {dev_f1}")

def class_training_loop(
    num_epochs,
    train_features,
    train_labels,
    dev_sents,
    dev_labels,
    optimizer,
    model,
    label_map=None
):
    print("Training...")
    loss_func = torch.nn.CrossEntropyLoss()
    batches = list(zip(train_features, train_labels))
    random.shuffle(batches)
    for i in range(num_epochs):
        losses = []
        for features, labels in tqdm(batches):
            # Empty the dynamic computation graph
            optimizer.zero_grad()
            logits = model(**features.to(device)).logits
            # preds = torch.argmax(logits, axis=1)
            loss = loss_func(logits, labels.to(device))
            # Backpropogate the loss through our model
            loss.backward()
            optimizer.step()
            losses.append(loss.item())

        print(f"epoch {i}, loss: {sum(losses)/len(losses)}")
        # Estimate the f1 score for the development set
        print("Evaluating dev...")
        all_preds = []
        all_labels = []
        for sents, labels in tqdm(zip(dev_sents, dev_labels), total=len(dev_sents)):
            pred = class_predict(model, sents).cpu()
            all_preds.extend(pred)
            all_labels.extend(list(labels.cpu().numpy()))

        dev_f1 = macro_f1(all_preds, all_labels, possible_labels=set(all_labels), label_map = label_map)
        print(f"Dev F1 {dev_f1}")

    # Return the trained model
    return model

In [13]:
import numpy as np
from numpy import sum as t_sum
from numpy import logical_and


def precision(predicted_labels, true_labels, which_label=1):
    """
    Precision is True Positives / All Positives Predictions
    """
    pred_which = np.array([pred == which_label for pred in predicted_labels])
    true_which = np.array([lab == which_label for lab in true_labels])
    denominator = t_sum(pred_which)
    if denominator:
        return t_sum(logical_and(pred_which, true_which))/denominator
    else:
        return 0.


def recall(predicted_labels, true_labels, which_label=1):
    """
    Recall is True Positives / All Positive Labels
    """
    pred_which = np.array([pred == which_label for pred in predicted_labels])
    true_which = np.array([lab == which_label for lab in true_labels])
    denominator = t_sum(true_which)
    if denominator:
        return t_sum(logical_and(pred_which, true_which))/denominator
    else:
        return 0.


def f1_score(
    predicted_labels: List[int],
    true_labels: List[int],
    which_label: int
):
    """
    F1 score is the harmonic mean of precision and recall
    """
    P = precision(predicted_labels, true_labels, which_label=which_label)
    R = recall(predicted_labels, true_labels, which_label=which_label)

    if P and R:
        return 2*P*R/(P+R)
    else:
        return 0.


def macro_f1(
    predicted_labels: List[int],
    true_labels: List[int],
    possible_labels: List[int],
    label_map=None
):
    converted_prediction = [label_map[int(x)] for x in predicted_labels] if label_map else predicted_labels
    scores = [f1_score(converted_prediction, true_labels, l) for l in possible_labels]
    # Macro, so we take the uniform avg.
    return sum(scores) / len(scores)

In [14]:
# You can increase epochs if need be
epochs = 10

# Find a good learning rate
LR = 0.0000001

possible_labels = set(train_labels)

dev_input_batches = [b for b in chunk_multi([entry['statement'] for entry in dev_sentences], [entry['pid'] if entry['sid'] is None else entry['pid'] + ', ' + entry['sid'] for entry in dev_sentences], batch_size)]

# Tokenize + encode
dev_input_batches = [tokenizer(*batch) for batch in dev_input_batches]
dev_batch_labels = [b for b in chunk(dev_labels, batch_size)]
dev_batch_labels = [encode_labels(batch) for batch in dev_batch_labels]

model = NLIClassifier(len(set(dev_labels)), 10)
model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), LR)

model = training_loop(
    epochs,
    train_input_batches,
    train_label_batches,
    dev_input_batches,
    dev_batch_labels,
    optimizer,
    model
)

pytorch_model.bin:   0%|          | 0.00/116M [00:00<?, ?B/s]

Training...


  0%|          | 0/153 [00:00<?, ?it/s]

epoch 0, loss: 0.7079135226268395
Evaluating dev...


  0%|          | 0/17 [00:00<?, ?it/s]

Dev F1 0.37916219119226635


  0%|          | 0/153 [00:00<?, ?it/s]

epoch 1, loss: 0.7049532353488448
Evaluating dev...


  0%|          | 0/17 [00:00<?, ?it/s]

Dev F1 0.3837535014005602


  0%|          | 0/153 [00:00<?, ?it/s]

epoch 2, loss: 0.702763591716492
Evaluating dev...


  0%|          | 0/17 [00:00<?, ?it/s]

Dev F1 0.4126057428809723


  0%|          | 0/153 [00:00<?, ?it/s]

epoch 3, loss: 0.7010523614540599
Evaluating dev...


  0%|          | 0/17 [00:00<?, ?it/s]

Dev F1 0.41934482223257263


  0%|          | 0/153 [00:00<?, ?it/s]

epoch 4, loss: 0.6997412454848196
Evaluating dev...


  0%|          | 0/17 [00:00<?, ?it/s]

Dev F1 0.41594136509390744


  0%|          | 0/153 [00:00<?, ?it/s]

epoch 5, loss: 0.6987081345389871
Evaluating dev...


  0%|          | 0/17 [00:00<?, ?it/s]

Dev F1 0.4414575866188769


  0%|          | 0/153 [00:00<?, ?it/s]

epoch 6, loss: 0.6978953278142642
Evaluating dev...


  0%|          | 0/17 [00:00<?, ?it/s]

Dev F1 0.48268366903347343


  0%|          | 0/153 [00:00<?, ?it/s]

epoch 7, loss: 0.6972492683946697
Evaluating dev...


  0%|          | 0/17 [00:00<?, ?it/s]

Dev F1 0.46819603753910327


  0%|          | 0/153 [00:00<?, ?it/s]

epoch 8, loss: 0.6966989468904882
Evaluating dev...


  0%|          | 0/17 [00:00<?, ?it/s]

Dev F1 0.46489832007073384


  0%|          | 0/153 [00:00<?, ?it/s]

epoch 9, loss: 0.6962099016881457
Evaluating dev...


  0%|          | 0/17 [00:00<?, ?it/s]

Dev F1 0.4698482777952976


In [15]:
validation_input_batches = [b for b in chunk_multi([entry['statement'] for entry in validation_sentences], [entry['pid'] if entry['sid'] is None else entry['pid'] + ', ' + entry['sid'] for entry in validation_sentences], batch_size)]
validation_input_batches = [tokenizer(*batch) for batch in validation_input_batches]
validation_batch_labels = [b for b in chunk(validation_labels, batch_size)]
validation_batch_labels = [encode_labels(batch) for batch in validation_batch_labels]

# prediction_loop(model, validation_input_batches, validation_batch_labels)

all_preds = []
all_labels = []
for sents, labels in tqdm(zip(validation_input_batches, validation_batch_labels), total=len(validation_input_batches)):
    pred = predict(model, sents)
    all_preds.extend(pred)
    all_labels.extend(list(labels.cpu().numpy()))
dev_f1 = macro_f1(all_preds, all_labels, possible_labels=set(all_labels), label_map = None)
print(f"Validation F1 {dev_f1}")

  0%|          | 0/20 [00:00<?, ?it/s]

Validation F1 0.4463492063492063
