# Recurrent Neural Networks for BoolQ Reading Comprehension

## 1. Introduction

- **Objective**: Develop a reading comprehension model using a 2-layer LSTM and a 2-layer classifier. The model will be trained end-to-end on the BoolQ dataset.
- **Task**: The BoolQ dataset involves answering yes/no questions given a passage. The goal is to predict the correct label for each question.
- **Approach**: Utilize PyTorch for building the model, and Hugging Face's datasets library to manage data.


## 2. Setup
- **Libraries**: 
  - `torch`: For building the neural network.
  - `datasets`: For loading the BoolQ dataset.
  - `transformers`: For using a pre-trained BPE tokenizer.
  - `fasttext`: To load and use FastText embeddings.
  - `numpy`, `pandas`, `matplotlib`, `seaborn`: For data manipulation and visualization.
  - `gensim`: For loading the pre-trained word embedding model.
  - `sklearn`: For metrics.
  - `wandb`: For experiment tracking

- **Planned Correctness Tests**:
  - Use `assert` statements to check tensor dimensions, and confirm the expected shapes of inputs and outputs throughout the data pipeline.
  - Print sample outputs at different stages to validate transformations.

- **Experiment Tracking**:
  - Use `wandb` for logging experiments, including hyperparameters, metrics, and visualizations.


In [2]:
# TODO: make the pip install for used libraries and packages !!!

# %pip install torch datasets transformers fasttext numpy pandas matplotlib gensim scikit-learn wandb re

In [3]:
# TODO: remove all chatGPT comments and add my own documentation in MDF

In [4]:
import wandb
from datasets import load_dataset
import re
import nltk
import fasttext
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt

Define the hyperparameters (keeping it on top)

In [5]:
max_seq_len = 512               # You already have this defined
embedding_dim: int = 300        # fastText model embedding dimension
padding_type: str = 'zeros'     # Choose between padding w/ 'zeros' or the 'avg' of the embedding

optimizer_choice: str = 'Adam'  # Optimizer ['Adam', 'AdamW', 'SGD']

hidden_dim: int  = 128          # Hidden size for LSTM
output_dim: int = 1             # Binary classification (yes/no)
n_layers: int  = 2              # Two LSTM layers
dropout_rate: float = 0.3       # Dropout rate for regularization
learning_rate: float = 1e-3     # Optimizer learning rate
weight_decay: float = 0.01      # For AdamW or L2 in SGD
train_batch_size: int = 32      # Training Batch size
val_batch_size: int = 32        # Validation and Testing Batch size
n_epochs: int  = 10             # Number of epochs
patience: int = 3               # Early stopping patience


In [6]:
run_number: int = 1 # TODO: Don't forget to change this!!!

# defining the WandB project and run name
project_name: str = 'nlp-rnn_lstm_pt'
run_name: str = f"run_{run_number}-default-run"

wandb.init(
    project=project_name,
    name=run_name,
    config={
        "max_seq_len": max_seq_len,
        "padding_type": padding_type,
        "embedding_dim": embedding_dim,
        "hidden_dim": hidden_dim,
        "output_dim": output_dim,
        "n_layers": n_layers,
        "dropout_rate": dropout_rate,
        "learning_rate": learning_rate,
        "weight_decay": weight_decay,
        "train_batch_size": train_batch_size,
        "val_batch_size": val_batch_size,
        "n_epochs": n_epochs,
        "patience": patience,
        "optimizer_choice": optimizer_choice
    }
)

[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33maintnoair[0m. Use [1m`wandb login --relogin`[0m to force relogin


Downloading the required BoolQ dataset and splitting it like required from the project presentation

In [8]:
train_data = load_dataset('google/boolq', split='train[:-1000]')
validation_data = load_dataset('google/boolq', split='train[-1000:]')
test_data = load_dataset('google/boolq', split='validation')

In [7]:
wandb.finish()

VBox(children=(Label(value='0.007 MB of 0.007 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

Have a look at the data and labels

In [9]:
test_question = train_data[5]['question']
test_passage = train_data[5]['passage']
print(train_data[5])
print(f"Number of training samples: {len(train_data)}")
print(f"Number of validation samples: {len(validation_data)}")
print(f"Number of validation samples: {len(test_data)}")

train_yes_count = sum(1 for label in train_data['answer'] if label == 1)
train_no_count = sum(1 for label in train_data['answer'] if label == 0)
train_total = train_yes_count + train_no_count

validation_yes_count = sum(1 for label in validation_data['answer'] if label == 1)
validation_no_count = sum(1 for label in validation_data['answer'] if label == 0)
validation_total = validation_yes_count + validation_no_count

test_yes_count = sum(1 for label in test_data['answer'] if label == 1)
test_no_count = sum(1 for label in test_data['answer'] if label == 0)
test_total = test_yes_count + test_no_count

# Print the counts and ratios
print(f"Train set - Yes: {train_yes_count}, No: {train_no_count}, Ratio (y/n): {round(train_yes_count / train_no_count, 2)}, Percent Yes: {round(train_yes_count / train_total * 100, 2)}%")

print(f"Validation set - Yes: {validation_yes_count}, No: {validation_no_count}, Ratio (y/n): {round(validation_yes_count / validation_no_count, 2)}, Percent Yes: {round(validation_yes_count / validation_total * 100, 2)}%")

print(f"Test set - Yes: {test_yes_count}, No: {test_no_count}, Ratio (y/n): {round(test_yes_count / test_no_count, 2)}, Percent Yes: {round(test_yes_count / test_total * 100, 2)}%")

{'question': 'can you use oyster card at epsom station', 'answer': False, 'passage': "Epsom railway station serves the town of Epsom in Surrey. It is located off Waterloo Road and is less than two minutes' walk from the High Street. It is not in the London Oyster card zone unlike Epsom Downs or Tattenham Corner stations. The station building was replaced in 2012/2013 with a new building with apartments above the station (see end of article)."}
Number of training samples: 8427
Number of validation samples: 1000
Number of validation samples: 3270
Train set - Yes: 5279, No: 3148, Ratio (y/n): 1.68, Percent Yes: 62.64%
Validation set - Yes: 595, No: 405, Ratio (y/n): 1.47, Percent Yes: 59.5%
Test set - Yes: 2033, No: 1237, Ratio (y/n): 1.64, Percent Yes: 62.17%


Download the fastText model if not already in directory

In [10]:
# Load the fastText model
model_bin = Path('cc.en.300.bin')

if not model_bin.exists():
    fasttext.util.download_model('en', if_exists='ignore') # download if not already in dir

ft = fasttext.load_model(str(model_bin))



Ensure the nltk tokenizer is downloaded

In [11]:
nltk.download('punkt')

[nltk_data] Downloading package punkt to /Users/blackbook/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

## 3. Preprocessing

- **Text Cleaning**:
  - **Operations**:
    - Convert text to lowercase for consistency.
    - Remove special characters and URLs while keeping necessary hyphens.
    - Remove extra whitespace between words as well as before or after a sequence.
  - **Reasoning**: These basic cleaning steps standardize the input without over-complicating the preprocessing and removing as little sentiment as possible from the sentences. I chose to not remove stopwords and not do stemming or lemmatizing for the same reason.


In [12]:
test_question = train_data[5]['question']
test_passage = train_data[5]['passage']

In [13]:
def to_lowercase(text: str) -> str:
    lowered_text = text.lower()
    assert lowered_text.islower(), "Text is not fully lowercase"
    return lowered_text

print(to_lowercase(test_question))
print(to_lowercase(test_passage))

can you use oyster card at epsom station
epsom railway station serves the town of epsom in surrey. it is located off waterloo road and is less than two minutes' walk from the high street. it is not in the london oyster card zone unlike epsom downs or tattenham corner stations. the station building was replaced in 2012/2013 with a new building with apartments above the station (see end of article).


In [14]:
def remove_special_characters_and_urls(text: str) -> str:
    # Remove URLs
    text = re.sub(r'http[s]?://\S+|www\.\S+', '', text)
    # Replace slashes with spaces first
    text = text.replace('/', ' ')
    # Remove special characters except for alphanumeric characters and spaces
    cleaned_text = re.sub(r'[^a-zA-Z0-9\s\'\-]', '', text)
    assert not re.search(r'http[s]?://|www\.', cleaned_text), "URLs were not fully removed"
    return cleaned_text

test_question_w_url = "Visit us at https://example.com for more info!"
test_passage_w_url = "Some RANDOM text With VARIETY. Check this out: www.example.org and the year is 2012/2013."

print(remove_special_characters_and_urls(test_question))
print(remove_special_characters_and_urls(test_passage))
print(remove_special_characters_and_urls(test_question_w_url))
print(remove_special_characters_and_urls(test_passage_w_url))

can you use oyster card at epsom station
Epsom railway station serves the town of Epsom in Surrey It is located off Waterloo Road and is less than two minutes' walk from the High Street It is not in the London Oyster card zone unlike Epsom Downs or Tattenham Corner stations The station building was replaced in 2012 2013 with a new building with apartments above the station see end of article
Visit us at  for more info
Some RANDOM text With VARIETY Check this out  and the year is 2012 2013


In [15]:
def remove_extra_whitespace(text: str) -> str:
    cleaned_text = re.sub(r'\s+', ' ', text).strip()
    assert '  ' not in cleaned_text, "There are still multiple spaces"
    return cleaned_text

print(remove_extra_whitespace(test_question))
print(remove_extra_whitespace(test_passage))

can you use oyster card at epsom station
Epsom railway station serves the town of Epsom in Surrey. It is located off Waterloo Road and is less than two minutes' walk from the High Street. It is not in the London Oyster card zone unlike Epsom Downs or Tattenham Corner stations. The station building was replaced in 2012/2013 with a new building with apartments above the station (see end of article).


*Initial Plan:*
- ***Tokenization**:*
  - ***Decision**: Use a pre-trained Byte-Pair Encoding (BPE) tokenizer from the `transformers` library.*
  - ***Reasoning**:*
    - *Using a pre-trained tokenizer simplifies the preprocessing pipeline, as the tokenizer has already been trained on a large and diverse corpus, which increases its generalization capability.*
    - *Pre-trained tokenizers from `transformers` are well-optimized and widely used in various NLP tasks.*
    - *BPE helps handle out-of-vocabulary (OOV) words by breaking them into known subword units, allowing for more robust word representations.*

Adjusted plan after feedback on project 1 as well as stage 1 of project 2:
- **Tokenization**
  - **Decision**: Use NLTK's word-level tokenizer.
  - **Reasoning**:
    - Ensures compatibility with FastText, which expects whole words.
    - Reduces complexity by relying on FastText's internal OOV handling.
    - Aligns with the word-based tokenization used in FastText’s training.

In [16]:
# Tokenization using NLTK
def tokenize(text: str) -> list:
    tokens = nltk.word_tokenize(text)
    assert isinstance(tokens, list) and len(tokens) > 0, "Tokenization failed or empty token list"
    return tokens


In [17]:
lower_passage = to_lowercase(test_passage)
lower_question = to_lowercase(test_question)

cleaned_passage = remove_extra_whitespace(remove_special_characters_and_urls(lower_passage))
cleaned_question = remove_extra_whitespace(remove_special_characters_and_urls(lower_question))

print(tokenize(cleaned_passage))
print(tokenize(cleaned_question))

['epsom', 'railway', 'station', 'serves', 'the', 'town', 'of', 'epsom', 'in', 'surrey', 'it', 'is', 'located', 'off', 'waterloo', 'road', 'and', 'is', 'less', 'than', 'two', 'minutes', "'", 'walk', 'from', 'the', 'high', 'street', 'it', 'is', 'not', 'in', 'the', 'london', 'oyster', 'card', 'zone', 'unlike', 'epsom', 'downs', 'or', 'tattenham', 'corner', 'stations', 'the', 'station', 'building', 'was', 'replaced', 'in', '2012', '2013', 'with', 'a', 'new', 'building', 'with', 'apartments', 'above', 'the', 'station', 'see', 'end', 'of', 'article']
['can', 'you', 'use', 'oyster', 'card', 'at', 'epsom', 'station']


- **Word Embedding Lookups**:
  - **Decision**: Use the FastText API directly to obtain embeddings for tokenized words.
  - **Reasoning**:
    - The FastText API considers subword information when generating word embeddings, providing robust handling of OOV words.
    - This approach prevents the issue of having to map subword tokens directly to embeddings, which is not feasible with traditional embedding lookup methods.
  - **OOV Word Handling**:
    - Rely on FastText's built-in subword handling to generate embeddings for unknown words.

In [18]:
# Function to get FastText embeddings
def get_fasttext_embeddings(tokens: list) -> np.ndarray:
    embeddings = []
    for token in tokens:
        embedding = ft.get_word_vector(token)
        assert embedding.shape == (embedding_dim,), f"Embedding shape mismatch for token: {token}"
        embeddings.append(embedding)
    return np.array(embeddings)

- **Sequence Truncation and Padding**:
  - **Truncating**: Truncate sequences to a fixed length of 512 tokens.
  - **Padding**: Apply padding to make all sequences in a batch have the same length.
  - **Reasoning**:
    - Limiting the sequence length to 512 tokens balances computational efficiency and context retention. This choice ensures that the input size remains manageable while still covering most of the content in the passages. It is also a popular sequence length for nlp applications, that's why I chose it. I will play with the maximum sequence length for tuning.

In [19]:
# Padding/truncation function
def pad_or_truncate(sequence: np.ndarray, max_length: int = max_seq_len, padding_type: str = 'zeros') -> np.ndarray:
    current_length = len(sequence)
    
    # Truncate if the sequence is longer than max_length
    if current_length > max_length:
        return sequence[:max_length]
    
    # If the sequence is shorter than max_length, pad
    elif current_length < max_length:
        if padding_type == 'zeros':
            # Pad with zeros
            padding = np.zeros((max_length - current_length, sequence.shape[1]))
        elif padding_type == 'avg':
            # Pad with the average embedding
            avg_embedding = np.mean(sequence, axis=0)
            padding = np.tile(avg_embedding, (max_length - current_length, 1))
        else:
            raise ValueError("Invalid padding_type. Use 'zeros' or 'avg'.")
        
        padded_sequence = np.vstack((sequence, padding))
        assert padded_sequence.shape == (max_length, sequence.shape[1]), "Padding failed"
        return padded_sequence
    
    else:
        return sequence

Combine all preprocessing steps into a single preprocessing pipeline for easy data preparation.

In [20]:
# Full preprocessing pipeline
def preprocessing_pipeline(text: str) -> np.ndarray:
    # Step 1: Clean the text
    lowercase_text = to_lowercase(text)
    cleaned_text = remove_special_characters_and_urls(lowercase_text)
    prepared_text = remove_extra_whitespace(cleaned_text)
    
    # Step 2: Tokenize the text
    tokens = tokenize(prepared_text)
    
    # Step 3: Get FastText embeddings
    embeddings = get_fasttext_embeddings(tokens)
    
    # Step 4: Pad or truncate the embeddings to a fixed length
    padded_embeddings = pad_or_truncate(embeddings)
    
    # Ensure the final output shape is correct
    assert padded_embeddings.shape == (max_seq_len, embedding_dim), f"Final embedding shape is {padded_embeddings.shape}, expected (512, 300)"
    
    return padded_embeddings


Print some test results to see if the preprocessing worked like expected.

In [193]:
# Function to print some sample embeddings for inspection
def print_sample_embeddings(tokens: list, embeddings: np.ndarray, num_samples: int = 5):
    print(f"Displaying first {num_samples} tokens and their embeddings:")
    for i in range(min(num_samples, len(tokens))):
        print(f"Token: {tokens[i]} - Embedding: {embeddings[i][:5]}...")  # Show the first 5 dimensions of each embedding for brevity

# Test the pipeline
train_data = load_dataset('google/boolq', split='train[:-1000]')
test_question = train_data[5]['question']

print(f"Original Question: {test_question}")
processed_embeddings = preprocessing_pipeline(test_question)

# Run a separate tokenization to print tokens (as they are not returned from the pipeline in the current design)
tokens = tokenize(remove_extra_whitespace(remove_special_characters_and_urls(to_lowercase(test_question))))
print_sample_embeddings(tokens, processed_embeddings)

Original Question: can you use oyster card at epsom station
Displaying first 5 tokens and their embeddings:
Token: can - Embedding: [ 0.04433588  0.09070976  0.05368941  0.18836261 -0.1909022 ]...
Token: you - Embedding: [ 0.10789153 -0.04412316  0.13406183  0.0902128  -0.15022287]...
Token: use - Embedding: [-0.0271657   0.06062786 -0.06484954  0.04023458  0.0097399 ]...
Token: oyster - Embedding: [ 0.02858743 -0.06040148  0.00882877  0.10480718  0.07261764]...
Token: card - Embedding: [0.07530363 0.10068872 0.0854513  0.08289765 0.00129725]...


In [194]:
# Function to log confusion matrix
def log_confusion_matrix(labels, preds):
    # Generate confusion matrix
    cm = confusion_matrix(labels, preds)

    # Plot confusion matrix using Seaborn heatmap
    plt.figure(figsize=(10, 7))
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=[0, 1], yticklabels=[0, 1])
    plt.xlabel('Prediction')
    plt.ylabel('Actual')
    plt.title(f'Confusion Matrix on Validation set')

    # Log confusion matrix plot as an image to WandB
    wandb.log({"confusion_matrix": wandb.Image(plt)})
    plt.close()






- **Input Preparation**:
  - Each input is a concatenation of the question and passage of total length 1024 (512 * 2). This sequence will be tokenized and converted into a sequence of FastText word embeddings (each of dimension 300).
  - The resulting input will have the required shape of `(max_sequence_length * 2, batch_size, embedding_dim)`— for example, `(1024, 32, 300)` for a batch size of 32.



In [195]:
class BoolQDataset(Dataset):
    def __init__(self, data: np.ndarray, max_seq_length: int =512) -> None:
        self.data = data
        self.max_seq_length = max_seq_length
        
    def __len__(self) -> int:
        return len(self.data)
    
    def __getitem__(self, idx: int) -> torch.tensor:
        # Retrieve passage and question
        passage = self.data[idx]['passage']
        question = self.data[idx]['question']
        
        # Preprocess both passage and question using the preprocessing pipeline
        passage_embeddings = preprocessing_pipeline(passage)
        question_embeddings = preprocessing_pipeline(question)
        
        # Concatenate passage and question embeddings (passage first, then question)
        combined_embeddings = np.concatenate((passage_embeddings, question_embeddings), axis=0)
        
        # If the combined length exceeds the max_seq_length, truncate it
        combined_embeddings = pad_or_truncate(combined_embeddings, max_length=self.max_seq_length)
        
        # Convert the label to tensor (1 for 'yes', 0 for 'no')
        label = 1 if self.data[idx]['answer'] else 0
        
        # Return the combined embeddings and the label as tensors
        return torch.tensor(combined_embeddings, dtype=torch.float32), torch.tensor(label, dtype=torch.long)


In [196]:
# Define the dataset and dataloader objects with dynamically calculated max lengths
train_dataset = BoolQDataset(train_data)
validation_dataset = BoolQDataset(validation_data)
test_dataset = BoolQDataset(test_data)

train_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)
validation_loader = DataLoader(validation_dataset, batch_size=val_batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=val_batch_size, shuffle=False)

## 4. Model Architecture
- **RNN Type**:
  - **Decision**: Use LSTM for the RNN layers.
  - **Rationale**: LSTM cells help maintain long-term dependencies through gating mechanisms, which is beneficial for reading comprehension tasks where context from the entire passage can be important for answering questions.

- **Model Configuration**:
  - **Embedding Layer**: Input dimension of 300 using FastText embeddings.
  - **RNN Layers**: Two LSTM layers with a hidden size of 128.
  - **Dropout**: Apply dropout with a rate of 0.3 between the LSTM layers for regularization.
  - **Classifier**: A two-layer fully connected network (hidden layer of size 64) with ReLU activation.

- **Loss and Optimizer**:
  - **Loss Function**: Use Binary Cross-Entropy Loss for the binary classification task.
  - **Optimizer**: Use the Adam optimizer with an initial learning rate of 0.001.
  - **Rationale**:
    - Adam is chosen for its adaptive learning rate, which can improve training stability and convergence.

- **Regularization**:
  - **Dropout**: Applied to reduce overfitting.
  - **Early Stopping**: Monitor validation loss and stop training if it does not improve for 3 consecutive epochs.


In [197]:
# Optimizer selection function
def get_optimizer(optimizer_choice, model, learning_rate, weight_decay=0):
    if optimizer_choice == 'Adam':
        return optim.Adam(model.parameters(), lr=learning_rate)
    elif optimizer_choice == 'AdamW':
        return optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    elif optimizer_choice == 'SGD':
        return optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=weight_decay)
    else:
        raise ValueError(f"Unsupported optimizer: {optimizer_choice}")


In [198]:
# Model definition (same as before)
class LSTMModel(nn.Module):
    def __init__(self, embedding_dim, hidden_dim, output_dim, n_layers, dropout_rate):
        super(LSTMModel, self).__init__()
        
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers=n_layers, batch_first=True, dropout=dropout_rate)
        self.fc1 = nn.Linear(hidden_dim, 64)
        self.fc2 = nn.Linear(64, output_dim)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout_rate)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        lstm_out, (hn, cn) = self.lstm(x)
        last_hidden_state = hn[-1]
        x = self.relu(self.fc1(last_hidden_state))
        x = self.dropout(x)
        x = self.fc2(x)
        return self.sigmoid(x)

In [199]:
# Instantiate the model
model = LSTMModel(embedding_dim, hidden_dim, output_dim, n_layers, dropout_rate)

# Get optimizer based on choice
optimizer = get_optimizer(optimizer_choice, model, learning_rate, weight_decay)

# Define loss function
criterion = nn.BCELoss()

In [21]:
# Move model and criterion to the correct device (GPU/CPU)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Type: {device.type}, Device: {device}")

Type: cpu, Device: cpu


In [None]:
model = model.to(device)
criterion = criterion.to(device)

## 5. Training
- **Number of Epochs**: Train for up to 20 epochs with early stopping.
- **Checkpointing**: Save the model with the best validation accuracy to avoid overfitting.

- **Hyperparameter Experimentation**:
  - **Learning Rate**: Test various learning rates (e.g., 0.001, 0.0005, 0.0001) to find an optimal balance between convergence speed and training stability.
  - **Batch Size**: Experiment with different batch sizes (e.g., 16, 32, 64) to optimize memory usage and training time.
  - **Dropout Rate**: Adjust dropout rates (e.g., 0.2, 0.3, 0.5) to find the optimal level of regularization.
  - **Hidden Layer Size**: Try varying the number of hidden units in the RNN and classifier layers (e.g., 64, 128, 256) to assess their impact on model capacity.


In [201]:
def train_model(model, train_loader, validation_loader, n_epochs, patience):
    best_val_loss = float('inf')
    early_stop_count = 0
    
    for epoch in range(n_epochs):
        model.train()
        running_loss = 0.0
        all_labels = []
        all_preds = []
        
        # Training loop
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs).squeeze()
            loss = criterion(outputs, labels.float())
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            all_labels.extend(labels.cpu().numpy())
            all_preds.extend((outputs.cpu().detach().numpy() > 0.5).astype(int))

        avg_train_loss = running_loss / len(train_loader)
        train_acc = accuracy_score(all_labels, all_preds)
        
        # Validation loop
        model.eval()
        val_loss = 0.0
        val_labels = []
        val_preds = []
        
        with torch.no_grad():
            for inputs, labels in validation_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs).squeeze()
                loss = criterion(outputs, labels.float())
                
                val_loss += loss.item()
                val_labels.extend(labels.cpu().numpy())
                val_preds.extend((outputs.cpu().detach().numpy() > 0.5).astype(int))

        avg_val_loss = val_loss / len(validation_loader)
        val_acc = accuracy_score(val_labels, val_preds)
        val_precision = precision_score(val_labels, val_preds)
        val_recall = recall_score(val_labels, val_preds)
        val_f1 = f1_score(val_labels, val_preds)

        # Log metrics to WandB
        wandb.log({
            "epoch": epoch + 1,
            "train_loss": avg_train_loss,
            "train_accuracy": train_acc,
            "val_loss": avg_val_loss,
            "val_accuracy": val_acc,
            "val_precision": val_precision,
            "val_recall": val_recall,
            "val_f1": val_f1
        })
        
        print(f"Epoch {epoch+1}/{n_epochs} - Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}, Val Acc: {val_acc:.4f}")
        
        # Log confusion matrix without step parameter
        log_confusion_matrix(val_labels, val_preds)
        
        # Early stopping logic
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            early_stop_count = 0
        else:
            early_stop_count += 1
            if early_stop_count >= patience:
                print("Early stopping triggered.")
                break

    # Log final metrics at the end of training
    wandb.log({"best_val_loss": best_val_loss, "final_val_accuracy": val_acc, "final_val_f1": val_f1})
    
    # Finish WandB run
    wandb.finish()

In [202]:
# Assuming train_loader and validation_loader are already defined
train_model(model, train_loader, validation_loader, n_epochs, patience)

Epoch 1/10 - Train Loss: 0.6638, Val Loss: 0.6749, Val Acc: 0.5950
Epoch 2/10 - Train Loss: 0.6613, Val Loss: 0.6755, Val Acc: 0.5950
Epoch 3/10 - Train Loss: 0.6620, Val Loss: 0.6748, Val Acc: 0.5950
Epoch 4/10 - Train Loss: 0.6623, Val Loss: 0.6747, Val Acc: 0.5950
Epoch 5/10 - Train Loss: 0.6623, Val Loss: 0.6763, Val Acc: 0.5950
Epoch 6/10 - Train Loss: 0.6614, Val Loss: 0.6752, Val Acc: 0.5950
Epoch 7/10 - Train Loss: 0.6611, Val Loss: 0.6787, Val Acc: 0.5950
Early stopping triggered.


VBox(children=(Label(value='0.131 MB of 0.131 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
best_val_loss,▁
epoch,▁▂▃▅▆▇█
final_val_accuracy,▁
final_val_f1,▁
train_accuracy,▁▆▆▆▇█▇
train_loss,█▁▃▄▄▂▁
val_accuracy,▁▁▁▁▁▁▁
val_f1,▁▁▁▁▁▁▁
val_loss,▁▂▁▁▄▂█
val_precision,▁▁▁▁▁▁▁

0,1
best_val_loss,0.67472
epoch,7.0
final_val_accuracy,0.595
final_val_f1,0.74608
train_accuracy,0.62656
train_loss,0.66108
val_accuracy,0.595
val_f1,0.74608
val_loss,0.67868
val_precision,0.595


## 6. Evaluation
- **Primary Metric**:
  - **Accuracy**: Chosen as the main evaluation metric since it reflects the overall model performance in binary classification.
- **Baseline Comparison**:
  - Compare the model's accuracy against a majority class baseline (e.g., always predicting "yes") to understand the model's relative performance.
- **Error Analysis**:
  - Analyze the confusion matrix to identify patterns in misclassifications and judge the types of errors the model makes.


## 7. Interpretation
- **Performance Expectations**:
  - Learning from the results of Project 1 I am setting my expectations a bit lower (more realistic) this time. I'm expecting the LSTM to achieve an accuracy of 65 - 70%. Hopefully beating the baseline of always predicting "yes" (accuracy of 61-63%)
