In [None]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import wandb
import os
import time
import nltk
import gensim
from datetime import datetime
from tqdm import tqdm, trange
from nltk.tokenize import word_tokenize
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pack_padded_sequence, pad_sequence
from datasets import load_dataset
from huggingface_hub import hf_hub_download
from collections import Counter
from matplotlib_venn import venn2

In [None]:
SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
# Download tokenizer resources
nltk.download('punkt')

# Download the FastText model from Hugging Face
model_path = hf_hub_download(repo_id="facebook/fasttext-en-vectors", filename="model.bin")

# Load the model
ft_model = gensim.models.fasttext.load_facebook_model(model_path)
wv = ft_model.wv
print(f"Word embeddings loaded: {len(wv.key_to_index)} words, {wv.vector_size} dimensions")

In [None]:
# Load dataset splits
train = load_dataset("tau/commonsense_qa", split="train[:-1000]")
valid = load_dataset("tau/commonsense_qa", split="train[-1000:]")  # Use last 1000 examples as validation
test = load_dataset("tau/commonsense_qa", split="validation")  # Use original validation as test

print(f"Dataset loaded - Train: {len(train)}, Validation: {len(valid)}, Test: {len(test)}")

In [None]:
# Convert to DataFrames for analysis
train_df = pd.DataFrame(train)
valid_df = pd.DataFrame(valid)
test_df = pd.DataFrame(test)

# Sample data points
print("\nSample question:")
print(train_df['question'].iloc[0])
print("\nSample choices:")
print(train_df['choices'].iloc[0])
print("\nSample answer key:")
print(train_df['answerKey'].iloc[0])

In [None]:
# Add question length to dataframes
train_df['question_length'] = train_df['question'].apply(lambda x: len(word_tokenize(x)))
valid_df['question_length'] = valid_df['question'].apply(lambda x: len(word_tokenize(x)))
test_df['question_length'] = test_df['question'].apply(lambda x: len(word_tokenize(x)))

# Plot question length distribution
plt.figure(figsize=(12, 6))
sns.histplot(data=train_df, x='question_length', kde=True, label='Train', alpha=0.6)
sns.histplot(data=valid_df, x='question_length', kde=True, label='Validation', alpha=0.6)
sns.histplot(data=test_df, x='question_length', kde=True, label='Test', alpha=0.6)
plt.title('Distribution of Question Lengths')
plt.xlabel('Number of tokens in question')
plt.ylabel('Count')
plt.legend()
plt.show()

In [None]:
plt.figure(figsize=(15, 5))

# Plot answer key distributions for each split
datasets = {"Training": train_df, "Validation": valid_df, "Test": test_df}

for i, (split, df) in enumerate(datasets.items(), 1):
    answer_counts = Counter(df['answerKey'])
    plt.subplot(1, 3, i)
    plt.bar(answer_counts.keys(), answer_counts.values())
    plt.title(f"{split} Answer Key Distribution")
    plt.xlabel("Answer Labels")
    plt.ylabel("Frequency")
    
    # Add count labels
    for label, count in answer_counts.items():
        plt.text(label, count + 5, str(count), ha='center')

plt.tight_layout()
plt.show()

In [None]:
def get_question_type(question):
    """Extract the question type based on first word or common question words"""
    question = question.lower().strip()
    question_words = ['what', 'which', 'who', 'how', 'why', 'when', 'where']
    
    for word in question_words:
        if question.startswith(word) or f" {word} " in question:
            return word
    
    return 'other'

# Add question type to dataframes
train_df['question_type'] = train_df['question'].apply(get_question_type)
valid_df['question_type'] = valid_df['question'].apply(get_question_type)
test_df['question_type'] = test_df['question'].apply(get_question_type)

# Plot question type distribution
plt.figure(figsize=(15, 5))

for i, (split, df) in enumerate(datasets.items(), 1):
    question_type_counts = Counter(df['question_type'])
    sorted_types = sorted(question_type_counts.keys())
    
    plt.subplot(1, 3, i)
    plt.bar(sorted_types, [question_type_counts[t] for t in sorted_types])
    plt.title(f"{split} Question Type Distribution")
    plt.xlabel("Question Type")
    plt.ylabel("Frequency")
    plt.xticks(rotation=45)

plt.tight_layout()
plt.show()

In [None]:
def get_vocabulary(texts):
    """Extract unique vocabulary from a list of texts"""
    vocab = set()
    for text in texts:
        tokens = word_tokenize(text.lower())
        vocab.update(tokens)
    return vocab

# Get vocabulary from train and test sets
train_vocab = get_vocabulary(train_df['question'].tolist())
test_vocab = get_vocabulary(test_df['question'].tolist())

# Calculate vocabulary overlap
overlap = len(train_vocab.intersection(test_vocab))
train_only = len(train_vocab - test_vocab)
test_only = len(test_vocab - train_vocab)

# Plot Venn diagram
plt.figure(figsize=(8, 6))
venn2(subsets=(train_only, test_only, overlap), 
      set_labels=('Train Vocabulary', 'Test Vocabulary'))
plt.title('Vocabulary Overlap Between Train and Test Sets')
plt.show()

print(f"Train vocabulary size: {len(train_vocab)}")
print(f"Test vocabulary size: {len(test_vocab)}")
print(f"Vocabulary overlap: {overlap} words ({overlap/len(train_vocab)*100:.2f}% of train vocabulary)")

In [None]:
def preprocess_text(text):
    """
    Preprocess text: tokenize only, preserving case, punctuation, and all words
    
    Args:
        text: Input text string
        
    Returns:
        List of tokens
    """
    if not isinstance(text, str):
        raise TypeError(f"Input must be a string, got {type(text).__name__} instead")
    
    if not text or text.isspace():
        return []

    # Simple tokenization using NLTK
    tokens = word_tokenize(text)
    return tokens

In [None]:
class QAEmbeddingDataset(Dataset):
    """Dataset for CommonsenseQA with averaged word embeddings"""
    def __init__(self, hf_dataset, word_vectors, cache_path=None):
        self.data = hf_dataset
        self.wv = word_vectors
        self.embedding_dim = word_vectors.vector_size
        self.cache = {}
        
        # Load cache if provided and exists
        if cache_path and os.path.exists(cache_path):
            try:
                with open(cache_path, 'rb') as f:
                    self.cache = pickle.load(f)
                print(f"Loaded {len(self.cache)} cached embeddings")
            except Exception as e:
                print(f"Failed to load cache: {e}")
                self.cache = {}
        
        self.cache_path = cache_path
    
    def __len__(self):
        return len(self.data)
    
    def get_embedding(self, text):
        """Get averaged word embedding for text with caching"""
        # Return from cache if available
        if text in self.cache:
            return self.cache[text]
        
        # Process text and compute embedding
        tokens = preprocess_text(text)
        
        # Filter tokens to only those in vocabulary
        valid_tokens = [word for word in tokens if word in self.wv]
        
        if not valid_tokens:
            # Return zeros if no valid tokens
            embedding = np.zeros(self.embedding_dim)
        else:
            # Simple averaging of word vectors
            word_vectors = [self.wv[word] for word in valid_tokens]
            embedding = np.mean(word_vectors, axis=0)
        
        # Store in cache
        self.cache[text] = embedding
        return embedding
    
    def __getitem__(self, idx):
        example = self.data[idx]
        
        # Get question embedding
        question_embedding = self.get_embedding(example["question"])
        
        # Get choice embeddings
        choice_embeddings = [self.get_embedding(choice) for choice in example["choices"]["text"]]
        
        # Convert answer key to index (A->0, B->1, etc.)
        answer_index = ord(example["answerKey"]) - ord("A")
        
        # Convert to PyTorch tensors
        question_tensor = torch.tensor(question_embedding).float()
        choices_tensor = torch.tensor(choice_embeddings).float()
        answer_tensor = torch.tensor(answer_index).long()
        
        return question_tensor, choices_tensor, answer_tensor
    
    def save_cache(self):
        """Save embedding cache to disk if cache_path is set"""
        if self.cache_path:
            with open(self.cache_path, 'wb') as f:
                pickle.dump(self.cache, f)
            print(f"Saved {len(self.cache)} embeddings to cache")

In [None]:
# Create datasets
train_embedding_dataset = QAEmbeddingDataset(
    train, 
    wv,
    cache_path='train_embeddings.pkl'
)

valid_embedding_dataset = QAEmbeddingDataset(
    valid, 
    wv,
    cache_path='valid_embeddings.pkl'
)

# Create DataLoaders
train_loader = DataLoader(
    train_embedding_dataset,
    batch_size=128,
    shuffle=True,
    num_workers=4,
    pin_memory=True
)

valid_loader = DataLoader(
    valid_embedding_dataset,
    batch_size=256,
    shuffle=False,
    num_workers=4,
    pin_memory=True
)

print(f"Created train loader with {len(train_loader)} batches")
print(f"Created validation loader with {len(valid_loader)} batches")

In [None]:
class QAEmbeddingModel(nn.Module):
    """
    Question-answering model using word embeddings
    
    Architecture:
    1. Separate projection layers for question and choices
    2. Question-choice interaction through concatenation
    3. Two fully-connected layers for scoring
    """
    def __init__(self, embedding_dim, hidden_dim=128, dropout_rate=0.2):
        super(QAEmbeddingModel, self).__init__()
        
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        
        # Projection layers
        self.question_projection = nn.Sequential(
            nn.Linear(embedding_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout_rate)
        )
        
        self.choice_projection = nn.Sequential(
            nn.Linear(embedding_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout_rate)
        )
        
        # Classification layers
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(hidden_dim, 1)
        )
    
    def forward(self, question, choices):
        """
        Args:
            question: [batch_size, embedding_dim] - Question embeddings
            choices: [batch_size, num_choices, embedding_dim] - Choice embeddings
            
        Returns:
            [batch_size, num_choices] - Logits for each choice
        """
        batch_size, num_choices, _ = choices.size()
        
        # Project question
        question_hidden = self.question_projection(question)  # [batch_size, hidden_dim]
        
        # Project all choices (flatten batch and choices dimensions first)
        choices_flat = choices.view(batch_size * num_choices, -1)
        choices_hidden = self.choice_projection(choices_flat)
        choices_hidden = choices_hidden.view(batch_size, num_choices, -1)
        
        # Expand question to match choices dimension
        question_expanded = question_hidden.unsqueeze(1).expand(-1, num_choices, -1)
        
        # Concatenate question and choices
        combined = torch.cat((question_expanded, choices_hidden), dim=2)
        
        # Flatten for classifier
        combined_flat = combined.view(batch_size * num_choices, -1)
        
        # Get scores for each choice
        scores_flat = self.classifier(combined_flat)
        scores = scores_flat.view(batch_size, num_choices)
        
        return scores

In [None]:
def train_embedding_model(model, train_loader, valid_loader, epochs=30, lr=1e-3,
                          weight_decay=1e-5, log_wandb=True):
    """
    Train the embedding-based QA model
    
    Args:
        model: The model to train
        train_loader: Training data loader
        valid_loader: Validation data loader
        epochs: Number of training epochs
        lr: Learning rate
        weight_decay: Weight decay coefficient
        log_wandb: Whether to log metrics to W&B
    
    Returns:
        Trained model and best validation accuracy
    """
    # Initialize optimizer and loss function
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    criterion = nn.CrossEntropyLoss()
    
    # Learning rate scheduler
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='max', factor=0.5, patience=3, verbose=True
    )
    
    # Setup WandB if requested
    if log_wandb:
        wandb.init(
            project="commonsense-qa-embeddings",
            name=f"embedding-model-{datetime.now().strftime('%Y-%m-%dT%H:%M:%S')}",
            config={
                "model": "qa_embedding_model",
                "embedding_dim": model.embedding_dim,
                "hidden_dim": model.hidden_dim,
                "epochs": epochs,
                "learning_rate": lr,
                "weight_decay": weight_decay,
            }
        )
        wandb.watch(model)
    
    # Create checkpoint directory
    os.makedirs("checkpoints", exist_ok=True)
    
    # Training loop
    best_val_accuracy = 0.0
    start_time = time.time()
    
    for epoch in range(epochs):
        print(f"\nEpoch {epoch+1}/{epochs}")
        
        # Training phase
        model.train()
        train_loss = 0.0
        train_correct = 0
        train_total = 0
        
        for question_batch, choices_batch, answer_batch in tqdm(train_loader, desc="Training"):
            # Move data to device
            question_batch = question_batch.to(device)
            choices_batch = choices_batch.to(device)
            answer_batch = answer_batch.to(device)
            
            # Forward pass
            optimizer.zero_grad()
            outputs = model(question_batch, choices_batch)
            loss = criterion(outputs, answer_batch)
            
            # Backward pass
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)  # Gradient clipping
            optimizer.step()
            
            # Statistics
            train_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            train_total += answer_batch.size(0)
            train_correct += (predicted == answer_batch).sum().item()
        
        train_accuracy = train_correct / train_total
        avg_train_loss = train_loss / len(train_loader)
        
        # Validation phase
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0
        
        with torch.no_grad():
            for question_batch, choices_batch, answer_batch in tqdm(valid_loader, desc="Validation"):
                # Move data to device
                question_batch = question_batch.to(device)
                choices_batch = choices_batch.to(device)
                answer_batch = answer_batch.to(device)
                
                # Forward pass
                outputs = model(question_batch, choices_batch)
                loss = criterion(outputs, answer_batch)
                
                # Statistics
                val_loss += loss.item()
                _, predicted = torch.max(outputs, 1)
                val_total += answer_batch.size(0)
                val_correct += (predicted == answer_batch).sum().item()
        
        val_accuracy = val_correct / val_total
        avg_val_loss = val_loss / len(valid_loader)
        
        # Update learning rate
        scheduler.step(val_accuracy)
        
        # Save best model
        if val_accuracy > best_val_accuracy:
            print(f"Validation accuracy improved from {best_val_accuracy:.4f} to {val_accuracy:.4f}")
            best_val_accuracy = val_accuracy
            torch.save({
                "epoch": epoch,
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "val_accuracy": val_accuracy,
            }, "checkpoints/best_embedding_model.pt")
        
        # Print metrics
        print(f"Train Loss: {avg_train_loss:.4f}, Train Acc: {train_accuracy:.4f}")
        print(f"Val Loss: {avg_val_loss:.4f}, Val Acc: {val_accuracy:.4f}")
        
        # Log to WandB
        if log_wandb:
            wandb.log({
                "epoch": epoch,
                "train_loss": avg_train_loss,
                "train_accuracy": train_accuracy,
                "val_loss": avg_val_loss,
                "val_accuracy": val_accuracy,
                "learning_rate": optimizer.param_groups[0]['lr']
            })
    
    training_time = (time.time() - start_time) / 60
    print(f"Training completed in {training_time:.2f} minutes")
    print(f"Best validation accuracy: {best_val_accuracy:.4f}")
    
    if log_wandb:
        wandb.run.summary["best_val_accuracy"] = best_val_accuracy
        wandb.finish()
    
    return model, best_val_accuracy

In [None]:
# Initialize the model
embedding_dim = wv.vector_size
hidden_dim = 128
dropout_rate = 0.2

embedding_model = QAEmbeddingModel(
    embedding_dim=embedding_dim,
    hidden_dim=hidden_dim,
    dropout_rate=dropout_rate
)
embedding_model = embedding_model.to(device)

# Train the model
print("Starting embedding model training...")
trained_embedding_model, best_embedding_accuracy = train_embedding_model(
    embedding_model,
    train_loader,
    valid_loader,
    epochs=30,
    lr=1e-3,
    weight_decay=1e-5
)

print(f"Embedding model training complete! Best validation accuracy: {best_embedding_accuracy:.4f}")

# Save cache for future use
train_embedding_dataset.save_cache()
valid_embedding_dataset.save_cache()

In [None]:
class QARNNDataset(Dataset):
    """Dataset for CommonsenseQA with sequence processing for RNN"""
    def __init__(self, hf_dataset, word_vectors, max_seq_length=50, cache_path=None):
        self.data = hf_dataset
        self.wv = word_vectors
        self.embedding_dim = word_vectors.vector_size
        self.max_seq_length = max_seq_length
        self.cache = {}
        
        # Special tokens
        self.PAD_TOKEN = "<PAD>"
        self.UNK_TOKEN = "<UNK>"
        self.SEP_TOKEN = "<SEP>"
        
        # Build vocabulary mapping
        self.word_to_idx = {
            self.PAD_TOKEN: 0,
            self.UNK_TOKEN: 1,
            self.SEP_TOKEN: 2
        }
        
        # Load cache if provided and exists
        if cache_path and os.path.exists(cache_path):
            try:
                with open(cache_path, 'rb') as f:
                    cache_data = pickle.load(f)
                    self.cache = cache_data.get('sequences', {})
                    self.word_to_idx = cache_data.get('vocab', self.word_to_idx)
                print(f"Loaded {len(self.cache)} cached sequences")
            except Exception as e:
                print(f"Failed to load cache: {e}")
                self.cache = {}
        
        # Add common words to vocab if not already loaded from cache
        if len(self.word_to_idx) <= 3:
            self._build_vocab()
        
        self.cache_path = cache_path
        print(f"Vocabulary size: {len(self.word_to_idx)}")
    
    def _build_vocab(self):
        """Build vocabulary from dataset"""
        print("Building vocabulary from dataset...")
        idx = 3  # Start after special tokens
        
        # Process a sample of the dataset to build vocabulary
        for example_idx in trange(min(1000, len(self.data))):
            example = self.data[example_idx]
            
            # Process question
            for token in preprocess_text(example["question"]):
                if token not in self.word_to_idx and token in self.wv:
                    self.word_to_idx[token] = idx
                    idx += 1
            
            # Process choices
            for choice in example["choices"]["text"]:
                for token in preprocess_text(choice):
                    if token not in self.word_to_idx and token in self.wv:
                        self.word_to_idx[token] = idx
                        idx += 1
        
        print(f"Built vocabulary with {len(self.word_to_idx)} words")
    
    def _get_word_idx(self, word):
        """Get vocabulary index for word, using UNK for unknown words"""
        return self.word_to_idx.get(word, self.word_to_idx[self.UNK_TOKEN])
    
    def prepare_sequence(self, question, choice):
        """Tokenize and prepare a question-choice sequence"""
        # Create cache key
        cache_key = f"{question}:{choice}"
        
        # Check if already in cache
        if cache_key in self.cache:
            return self.cache[cache_key]
        
        # Tokenize
        question_tokens = preprocess_text(question)
        choice_tokens = preprocess_text(choice)
        
        # Combine with separator token
        combined_tokens = question_tokens + [self.SEP_TOKEN] + choice_tokens
        
        # Truncate if too long
        if len(combined_tokens) > self.max_seq_length:
            combined_tokens = combined_tokens[:self.max_seq_length]
        
        # Convert to indices
        token_ids = [self._get_word_idx(token) for token in combined_tokens]
        
        # Save in cache
        self.cache[cache_key] = token_ids
        
        return token_ids
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        example = self.data[idx]
        
        # Get question and choices
        question = example["question"]
        choices = example["choices"]["text"]
        
        # Prepare sequences for each question-choice pair
        sequences = []
        sequence_lengths = []
        
        for choice in choices:
            token_ids = self.prepare_sequence(question, choice)
            sequences.append(torch.tensor(token_ids, dtype=torch.long))
            sequence_lengths.append(len(token_ids))
        
        # Convert answer key to index
        answer_index = ord(example["answerKey"]) - ord("A")
        answer = torch.tensor(answer_index, dtype=torch.long)
        
        return sequences, sequence_lengths, answer
    
    def get_embedding_matrix(self):
        """Create embedding matrix for initializing the embedding layer"""
        vocab_size = len(self.word_to_idx)
        embedding_matrix = torch.zeros(vocab_size, self.embedding_dim)
        
        # Fill embedding matrix with pre-trained vectors where available
        for word, idx in self.word_to_idx.items():
            if word == self.PAD_TOKEN:
                # Zero vector for padding
                pass
            elif word == self.UNK_TOKEN or word == self.SEP_TOKEN:
                # Random initialization for special tokens
                embedding_matrix[idx] = torch.randn(self.embedding_dim) * 0.1
            else:
                # Use pre-trained vectors for known words
                try:
                    embedding_matrix[idx] = torch.tensor(self.wv[word])
                except (KeyError, ValueError):
                    embedding_matrix[idx] = torch.randn(self.embedding_dim) * 0.1
        
        return embedding_matrix
    
    def save_cache(self):
        """Save sequences and vocabulary to cache"""
        if self.cache_path:
            cache_data = {
                'sequences': self.cache,
                'vocab': self.word_to_idx
            }
            with open(self.cache_path, 'wb') as f:
                pickle.dump(cache_data, f)
            print(f"Saved {len(self.cache)} sequences and {len(self.word_to_idx)} vocabulary items to cache")

In [None]:
def collate_rnn_batch(batch):
    """Collate function for RNN batches"""
    # Unpack batch
    all_sequences = []
    all_lengths = []
    all_answers = []
    
    for sequences, lengths, answer in batch:
        all_sequences.append(sequences)
        all_lengths.append(lengths)
        all_answers.append(answer)
    
    # Stack answers
    answers_tensor = torch.stack(all_answers)
    
    return all_sequences, all_lengths, answers_tensor

In [None]:
# Create RNN datasets
train_rnn_dataset = QARNNDataset(
    train, 
    wv,
    max_seq_length=50,
    cache_path='train_rnn_cache.pkl'
)

valid_rnn_dataset = QARNNDataset(
    valid, 
    wv,
    max_seq_length=50,
    cache_path='valid_rnn_cache.pkl'
)

# Create RNN data loaders
train_rnn_loader = DataLoader(
    train_rnn_dataset,
    batch_size=64,
    shuffle=True,
    collate_fn=collate_rnn_batch,
    num_workers=4,
    pin_memory=True
)

valid_rnn_loader = DataLoader(
    valid_rnn_dataset,
    batch_size=64,
    shuffle=False,
    collate_fn=collate_rnn_batch,
    num_workers=4,
    pin_memory=True
)

print(f"Created RNN train loader with {len(train_rnn_loader)} batches")
print(f"Created RNN validation loader with {len(valid_rnn_loader)} batches")

In [None]:
class QARNNModel(nn.Module):
    """RNN model for CommonsenseQA with 2-layer LSTM and 2-layer classifier"""
    def __init__(self, vocab_size, embedding_dim, hidden_dim, dropout_rate=0.3):
        super(QARNNModel, self).__init__()
        
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        
        # Embedding layer
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
        
        # 2-layer bidirectional LSTM
        self.lstm = nn.LSTM(
            input_size=embedding_dim,
            hidden_size=hidden_dim,
            num_layers=2,
            batch_first=True,
            dropout=dropout_rate,
            bidirectional=True
        )
        
        # 2-layer classifier
        lstm_output_dim = hidden_dim * 2  # bidirectional = *2
        self.fc1 = nn.Linear(lstm_output_dim, hidden_dim)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout_rate)
        self.fc2 = nn.Linear(hidden_dim, 1)
    
    def forward(self, batch_sequences, batch_lengths):
        """
        Process batch efficiently using choice-based batching
        
        Args:
            batch_sequences: List of lists of tensor sequences
            batch_lengths: List of lists of sequence lengths
            
        Returns:
            Tensor of logits [batch_size, num_choices]
        """
        batch_size = len(batch_sequences)
        num_choices = len(batch_sequences[0])
        device = next(self.parameters()).device
        
        # Create tensor to store results
        all_logits = torch.zeros((batch_size, num_choices), device=device)
        
        # Process each choice across all examples together
        for choice_idx in range(num_choices):
            # Collect all sequences for this choice
            choice_sequences = []
            
            for batch_idx in range(batch_size):
                seq = batch_sequences[batch_idx][choice_idx].to(device)
                if seq.dtype != torch.long:
                    seq = seq.long()
                choice_sequences.append(seq)
            
            # Get lengths before padding
            lengths = torch.tensor([seq.size(0) for seq in choice_sequences], device=device)
            
            # Pad sequences to same length
            padded_sequences = pad_sequence(choice_sequences, batch_first=True, padding_value=0)
            
            # Sort by length for packed_padded_sequence
            sorted_lengths, sorted_indices = lengths.sort(descending=True)
            sorted_padded = padded_sequences[sorted_indices]
            
            # Embed sequences
            embedded = self.embedding(sorted_padded)
            
            # Pack padded sequences for efficient LSTM
            packed = pack_padded_sequence(
                embedded, 
                sorted_lengths.cpu(), 
                batch_first=True,
                enforce_sorted=True
            )
            
            # Process through LSTM
            _, (hidden, _) = self.lstm(packed)
            
            # Get final hidden states from both directions
            final_hidden = torch.cat([hidden[-2], hidden[-1]], dim=1)
            
            # Process through classifier
            x = self.fc1(final_hidden)
            x = self.relu(x)
            x = self.dropout(x)
            logits = self.fc2(x).squeeze(-1)
            
            # Restore original order
            _, unsorted_indices = sorted_indices.sort(0)
            unsorted_logits = logits[unsorted_indices]
            
            # Store results
            all_logits[:, choice_idx] = unsorted_logits
        
        return all_logits

In [None]:
def train_rnn_model(model, train_loader, valid_loader, epochs=30, lr=1e-3,
                    weight_decay=1e-6, log_wandb=True):
    """
    Train the RNN-based QA model
    
    Args:
        model: The model to train
        train_loader: Training data loader
        valid_loader: Validation data loader
        epochs: Number of training epochs
        lr: Learning rate
        weight_decay: Weight decay coefficient
        log_wandb: Whether to log metrics to W&B
        
    Returns:
        Trained model and best validation accuracy
    """
    # Initialize optimizer and loss function
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    criterion = nn.CrossEntropyLoss()
    
    # One-cycle learning rate scheduler
    total_steps = epochs * len(train_loader)
    scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer,
        max_lr=lr,
        total_steps=total_steps,
        pct_start=0.2,  # 20% warmup
        div_factor=25,
        final_div_factor=1000,
        anneal_strategy='cos'
    )
    
    # Setup WandB if requested
    if log_wandb:
        wandb.init(
            project="commonsense-qa-rnn",
            name=f"rnn-model-{datetime.now().strftime('%Y-%m-%dT%H:%M:%S')}",
            config={
                "model": "qa_rnn_model",
                "embedding_dim": model.embedding_dim,
                "hidden_dim": model.hidden_dim,
                "epochs": epochs,
                "learning_rate": lr,
                "weight_decay": weight_decay,
            }
        )
        wandb.watch(model)
    
    # Create checkpoint directory
    os.makedirs("checkpoints", exist_ok=True)
    
    # Training loop
    best_val_accuracy = 0.0
    start_time = time.time()
    
    for epoch in range(epochs):
        print(f"\nEpoch {epoch+1}/{epochs}")
        
        # Training phase
        model.train()
        train_loss = 0.0
        train_correct = 0
        train_total = 0
        
        with tqdm(train_loader, desc="Training") as progress_bar:
            for batch_sequences, batch_lengths, answers in progress_bar:
                # Move answers to device
                answers = answers.to(device)
                
                # Forward pass
                optimizer.zero_grad()
                outputs = model(batch_sequences, batch_lengths)
                loss = criterion(outputs, answers)
                
                # Backward pass
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)  # Gradient clipping
                optimizer.step()
                scheduler.step()  # Step scheduler every batch for OneCycleLR
                
                # Statistics
                train_loss += loss.item()
                _, predicted = torch.max(outputs, 1)
                train_total += answers.size(0)
                train_correct += (predicted == answers).sum().item()
                
                # Update progress bar
                progress_bar.set_postfix({
                    'loss': f"{loss.item():.4f}",
                    'acc': f"{train_correct/train_total:.4f}"
                })
        
        train_accuracy = train_correct / train_total
        avg_train_loss = train_loss / len(train_loader)
        
        # Validation phase
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0
        
        with torch.no_grad():
            for batch_sequences, batch_lengths, answers in tqdm(valid_loader, desc="Validation"):
                # Move answers to device
                answers = answers.to(device)
                
                # Forward pass
                outputs = model(batch_sequences, batch_lengths)
                loss = criterion(outputs, answers)
                
                # Statistics
                val_loss += loss.item()
                _, predicted = torch.max(outputs, 1)
                val_total += answers.size(0)
                val_correct += (predicted == answers).sum().item()
        
        val_accuracy = val_correct / val_total
        avg_val_loss = val_loss / len(valid_loader)
        
        # Save best model
        if val_accuracy > best_val_accuracy:
            print(f"Validation accuracy improved from {best_val_accuracy:.4f} to {val_accuracy:.4f}")
            best_val_accuracy = val_accuracy
            torch.save({
                "epoch": epoch,
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "val_accuracy": val_accuracy,
            }, "checkpoints/best_rnn_model.pt")
        
        # Print metrics
        print(f"Train Loss: {avg_train_loss:.4f}, Train Acc: {train_accuracy:.4f}")
        print(f"Val Loss: {avg_val_loss:.4f}, Val Acc: {val_accuracy:.4f}")
        
        # Log to WandB
        if log_wandb:
            wandb.log({
                "epoch": epoch,
                "train_loss": avg_train_loss,
                "train_accuracy": train_accuracy,
                "val_loss": avg_val_loss,
                "val_accuracy": val_accuracy,
                "learning_rate": optimizer.param_groups[0]['lr']
            })
    
    training_time = (time.time() - start_time) / 60
    print(f"Training completed in {training_time:.2f} minutes")
    print(f"Best validation accuracy: {best_val_accuracy:.4f}")
    
    if log_wandb:
        wandb.run.summary["best_val_accuracy"] = best_val_accuracy
        wandb.finish()
    
    return model, best_val_accuracy

In [None]:
# Initialize the RNN model
embedding_dim = 300
hidden_dim = 128
dropout_rate = 0.2

rnn_model = QARNNModel(
    vocab_size=len(train_rnn_dataset.word_to_idx),
    embedding_dim=embedding_dim,
    hidden_dim=hidden_dim,
    dropout_rate=dropout_rate
)

# Initialize embedding layer with pre-trained word vectors
embedding_matrix = train_rnn_dataset.get_embedding_matrix()
rnn_model.embedding.weight.data.copy_(embedding_matrix)

# Move model to device
rnn_model = rnn_model.to(device)
print(rnn_model)

# Train the RNN model
print("Starting RNN model training...")
trained_rnn_model, best_rnn_accuracy = train_rnn_model(
    rnn_model,
    train_rnn_loader,
    valid_rnn_loader,
    epochs=30,
    lr=1e-3,
    weight_decay=1e-6
)

print(f"RNN model training complete! Best validation accuracy: {best_rnn_accuracy:.4f}")

# Save cache for future use
train_rnn_dataset.save_cache()
valid_rnn_dataset.save_cache()