In [None]:
import torch
import torch.nn as nn
import pandas as pd
import ast
from torch.utils.data import TensorDataset, DataLoader, random_split
from collections import Counter
import re

class CustomPipeline:
    def __init__(self, max_length=128, min_freq=2):
        self.max_length = max_length
        self.min_freq = min_freq
        self.vocab = {"<PAD>": 0, "<UNK>": 1}
        self.word_counts = Counter()
        
    def tokenizer(self, text):
        # Simple whitespace and punctuation split
        text = text.lower()
        # Keep words and basic punctuation
        tokens = re.findall(r"\w+|[^\w\s]", text)
        return tokens

    def build_vocab(self, texts):
        print("Building Vocabulary...")
        for text in texts:
            tokens = self.tokenizer(text)
            self.word_counts.update(tokens)
            
        # Only keep words that appear at least min_freq times
        # This removes rare typos that confuse the model
        for word, count in self.word_counts.items():
            if count >= self.min_freq:
                self.vocab[word] = len(self.vocab)
                
        print(f"Vocabulary built: {len(self.vocab)} unique tokens.")

    def encode(self, text):
        tokens = self.tokenizer(text)
        # Convert to IDs
        ids = [self.vocab.get(t, self.vocab["<UNK>"]) for t in tokens]
        
        # Truncate or Pad
        if len(ids) > self.max_length:
            ids = ids[:self.max_length]
        else:
            ids = ids + [self.vocab["<PAD>"]] * (self.max_length - len(ids))
            
        return ids

    def prepare_dataset(self, csv_path):
        print(f"Reading {csv_path}...")
        df = pd.read_csv(csv_path)
        
        all_texts = []
        all_labels = []
        
        # PASS 1: Collect all text to build vocab
        temp_data = []
        
        for _, row in df.iterrows():
            context = str(row['context'])
            question = str(row['question'])
            label = int(row['label'])
            
            try:
                options = ast.literal_eval(row['answers'])
            except:
                options = row['answers']
            
            # Group the 4 options
            q_group = []
            for option in options:
                # Combined Format: "Context <SEP> Question Option"
                # We use a special string marker " ||| " as a separator
                full_text = f"{context} ||| {question} {option}"
                q_group.append(full_text)
                all_texts.append(full_text)
                
            temp_data.append((q_group, label))
            
        # Build the vocab from the collected text
        self.build_vocab(all_texts)
        
        # PASS 2: Encode Data
        print("Encoding Data...")
        tensor_inputs = []
        tensor_labels = []
        
        for q_group, label in temp_data:
            encoded_group = [self.encode(text) for text in q_group]
            tensor_inputs.append(encoded_group)
            tensor_labels.append(label)
            
        # Convert to Tensors
        # Shape: [Num_Questions, 4, Seq_Len]
        X = torch.tensor(tensor_inputs, dtype=torch.long)
        y = torch.tensor(tensor_labels, dtype=torch.long)
        
        return TensorDataset(X, y), len(self.vocab)

In [None]:
class RobustReasoningModel(nn.Module):
    def __init__(self, vocab_size, embed_dim=100, hidden_dim=100, dropout=0.3):
        super().__init__()
        
        # 1. Embedding
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        self.dropout = nn.Dropout(dropout)
        
        # 2. BiLSTM
        self.lstm = nn.LSTM(
            input_size=embed_dim,
            hidden_size=hidden_dim,
            batch_first=True,
            bidirectional=True
        )
        
        # 3. Classifier
        # Input: Hidden*2 (for BiLSTM) * 2 (for Max+Avg Pooling)
        self.fc = nn.Linear(hidden_dim * 2 * 2, 64) 
        self.classifier = nn.Linear(64, 1) # Final score

    def forward(self, x):
        # x shape: [Batch * 4, Seq_Len]
        
        # Embed
        emb = self.dropout(self.embedding(x))
        
        # LSTM
        # output shape: [Batch, Seq, Hidden*2]
        lstm_out, _ = self.lstm(emb)
        
        # --- POOLING STRATEGY ---
        # Instead of complex Attention, we use Max and Mean pooling
        
        # 1. Max Pool: Take the maximum value across the sequence dimension
        # (Finds the most "active" features)
        max_pool, _ = torch.max(lstm_out, dim=1) 
        
        # 2. Mean Pool: Average across sequence
        mean_pool = torch.mean(lstm_out, dim=1)
        
        # Concatenate: [Batch, Hidden*4]
        combined = torch.cat((max_pool, mean_pool), dim=1)
        
        # Classify
        features = F.relu(self.fc(combined))
        logits = self.classifier(features)
        
        return logits

In [None]:
# --- SETTINGS ---
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
CSV_PATH = r"aml-2025-read-between-the-lines/train.csv"

# --- PREPARE DATA ---
print("--- STARTING PIPELINE ---")
pipeline = CustomPipeline(max_length=150, min_freq=2)
full_dataset, vocab_size = pipeline.prepare_dataset(CSV_PATH)

# --- SPLIT DATA ---
val_size = 800
train_size = len(full_dataset) - val_size
train_subset, val_subset = random_split(
    full_dataset, [train_size, val_size], 
    generator=torch.Generator().manual_seed(42)
)

BATCH_SIZE = 32
train_loader = DataLoader(train_subset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_subset, batch_size=BATCH_SIZE, shuffle=False)

# --- INITIALIZE MODEL ---
print(f"Initializing Model with Vocab Size: {vocab_size}")
model = RobustReasoningModel(
    vocab_size=vocab_size,
    embed_dim=128,    # Increased slightly
    hidden_dim=128,   # Increased slightly
    dropout=0.4       # Increased dropout for regularization
).to(DEVICE)

# --- OPTIMIZER ---
# Adam is usually safer than Adadelta for scratch training
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
criterion = nn.CrossEntropyLoss()

# --- TRAINING LOOP ---
EPOCHS = 15

print("\n--- STARTING TRAINING ---")
for epoch in range(EPOCHS):
    model.train()
    train_loss = 0
    train_correct = 0
    train_total = 0
    
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
        
        # Flatten input: [Batch, 4, Seq] -> [Batch*4, Seq]
        b_size, n_opts, seq_len = inputs.shape
        flat_inputs = inputs.view(-1, seq_len)
        
        optimizer.zero_grad()
        
        # Forward
        logits = model(flat_inputs) # [Batch*4, 1]
        
        # Reshape for Loss: [Batch, 4]
        logits = logits.view(b_size, n_opts)
        
        loss = criterion(logits, labels)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
        optimizer.step()
        
        train_loss += loss.item()
        preds = torch.argmax(logits, dim=1)
        train_correct += (preds == labels).sum().item()
        train_total += labels.size(0)

    # --- VALIDATION ---
    model.eval()
    val_correct = 0
    val_total = 0
    
    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
            b_size, n_opts, seq_len = inputs.shape
            flat_inputs = inputs.view(-1, seq_len)
            
            logits = model(flat_inputs)
            logits = logits.view(b_size, n_opts)
            
            preds = torch.argmax(logits, dim=1)
            val_correct += (preds == labels).sum().item()
            val_total += labels.size(0)
            
    train_acc = train_correct / train_total
    val_acc = val_correct / val_total
    
    print(f"Epoch {epoch+1:02d} | Train Acc: {train_acc*100:.2f}% | Val Acc: {val_acc*100:.2f}%")