# Attention LSTM with BERT Embeddings (Train & Eval)


## Setup

### Mount to Drive

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [6]:
%cd /content/drive/MyDrive/Colab\ Notebooks/group_project

/content/drive/MyDrive/Colab Notebooks/group_project


### Necessary Package Installation

In [None]:
!pip install accelerate

### Necessary Package Import

In [12]:
from typing import Dict
from transformers import PreTrainedTokenizer, BertTokenizer, BertModel, TrainingArguments, Trainer
import torch
from torch.utils.data import Dataset
from pandas import read_csv
from torch import nn
from sklearn.metrics import accuracy_score, precision_recall_fscore_support

## Dataset

### Deffinition

In [5]:
# premise,hypothesis,label
class NILDataset(Dataset):
    def __init__(self, root: str, *, tokenizer: PreTrainedTokenizer, max_length: int, is_testing: bool = False) -> None:
        self._df = read_csv(root).fillna("")
        self._tokenizer = tokenizer
        self._max_length = max_length
        self._is_test = is_testing

    def __len__(self) -> int:
        return len(self._df)

    def __getitem__(self, index) -> Dict:
        row = self._df.loc[index]

        premise = row["premise"]
        hypothesis = row["hypothesis"]
        if not self._is_test:
            label = row["label"]

        encoding = self._tokenizer.encode_plus(
            premise, hypothesis,
            add_special_tokens=True,
            max_length=self._max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )

        target = {
            "input_ids": encoding['input_ids'].squeeze(0),  # Remove the batch dimension
            "attention_mask": encoding['attention_mask'].squeeze(0)  # Remove the batch dimension
        }

        if not self._is_test:
            target["label"] = torch.tensor(label, dtype=torch.float32)

        return target

### Initialisation

In [8]:
bert_model = "bert-base-uncased"
tokenizer = BertTokenizer.from_pretrained(bert_model)
max_length = 256

train_dataset = NILDataset("./data/NLI/train.csv", tokenizer=tokenizer, max_length=max_length)
evaluation_dataset = NILDataset("./data/NLI/dev.csv", tokenizer=tokenizer, max_length=max_length)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

## Model

### Architecure

In [20]:
class AttentionLayer(nn.Module):
    def __init__(self, hidden_dim):
        super(AttentionLayer, self).__init__()
        self.attention = nn.Linear(hidden_dim, 1)

    def forward(self, lstm_output):
        # lstm_output shape: (batch_size, seq_len, hidden_dim)
        attention_scores = self.attention(lstm_output).squeeze(2)  # (batch_size, seq_len)
        attention_weights = torch.softmax(attention_scores, dim=1).unsqueeze(2)  # (batch_size, seq_len, 1)
        weighted_output = lstm_output * attention_weights  # (batch_size, seq_len, hidden_dim)
        context_vector = weighted_output.sum(1)  # Sum over the sequence dimension (batch_size, hidden_dim)

        return context_vector

class NLIModelDL(nn.Module):
    def __init__(self, *, bert_model:str, lstm_hidden_dim: int):
        super(NLIModelDL, self).__init__()
        self._bert = BertModel.from_pretrained(bert_model)
        self._lstm = nn.LSTM(self._bert.config.hidden_size, lstm_hidden_dim, batch_first=True)
        self._attention = AttentionLayer(lstm_hidden_dim)
        self._fc = nn.Linear(lstm_hidden_dim, 2)

    def forward(self, input_ids, attention_mask, labels = None):
        input_ids = input_ids.to(self._bert.device)
        attention_mask = attention_mask.to(self._bert.device)
        if labels is not None:
            labels = labels.to(self._bert.device).long()

        with torch.no_grad():  # Freeze BERT during training
            encoded_layers = self._bert(input_ids, attention_mask=attention_mask)

        lstm_out, _ = self._lstm(encoded_layers.last_hidden_state)
        lstm_out = lstm_out[:, -1, :]
        logits = self._fc(lstm_out)

        loss = None
        if labels is not None:
            loss_fct = torch.nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, 2), labels.view(-1))

        return {'loss': loss, 'logits': logits} if loss is not None else logits

### Initialisation

In [21]:
lstm_hidden_dim = 2
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = NLIModelDL(bert_model=bert_model, lstm_hidden_dim=lstm_hidden_dim)
model = model.to(device)

## Training and Evaluation

### Metrics

In [22]:
def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average="binary")
    acc = accuracy_score(labels, preds)
    return {
        "accuracy": acc,
        "f1": f1,
        "precision": precision,
        "recall": recall
    }

### Setup

In [23]:
# Training arguments
# Output directory and logging directory can be modified as necessary
training_args = TrainingArguments(
    output_dir='./out',  # Change as necessary when running
    num_train_epochs=5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    warmup_steps=500,
    weight_decay=0.01,
    logging_dir='./logs',  # Change as necessary when running
    logging_strategy="epoch",
    evaluation_strategy="epoch",
    learning_rate=0.01,
    max_grad_norm=1.0,
    save_strategy="epoch",
    save_total_limit=2,
    lr_scheduler_type="linear"
)

# Initialize the Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=evaluation_dataset,
    compute_metrics=compute_metrics
)

### Experiment

In [24]:
import warnings
warnings.filterwarnings("ignore")

In [None]:
# Train the model on tran.csv
trainer.train()

# Evaluate the model on dev.csv
trainer.evaluate()

Epoch,Training Loss,Validation Loss,Accuracy,F1,Precision,Recall
1,0.6474,0.615466,0.6635,0.67386,0.674345,0.673376
2,0.599,0.596135,0.684726,0.728736,0.655561,0.820299
3,0.5765,0.578266,0.693335,0.716287,0.685594,0.749856
4,0.5626,0.584596,0.697937,0.741061,0.664688,0.837263
5,0.5509,0.569843,0.707882,0.726667,0.702848,0.752156


Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pai

TrainOutput(global_step=8420, training_loss=0.5872786786947001, metrics={'train_runtime': 828.5858, 'train_samples_per_second': 162.59, 'train_steps_per_second': 10.162, 'total_flos': 0.0, 'train_loss': 0.5872786786947001, 'epoch': 5.0})

In [None]:
# Save the model
torch.save(model, "./models/attention_lstm__with_bert_embeddings.pth")  # The model saving path can be modified if necessary