In [1]:
# Imports
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import pickle
from tqdm import tqdm
import os
import json
import time

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [2]:
# Path
MODEL_DIR = '../../models/qa'

# Hyperparameters
BATCH_SIZE = 32
LEARNING_RATE = 0.0008
NUM_EPOCHS = 8
EMBED_DIM = 300
HIDDEN_DIM = 128
DROPOUT = 0.2
MAX_GRAD_NORM = 5.0
WARMUP_PROPORTION = 0.1
PATIENCE = 3

In [3]:
# Load processed data and vocabulary
with open('../../data/q&a/vocab.pkl', 'rb') as f:
    vocab_data = pickle.load(f)
    word2idx = vocab_data['word2idx']
    idx2word = vocab_data['idx2word']

with open('../../data/q&a/train_processed.pkl', 'rb') as f:
    train_data = pickle.load(f)

with open('../../data/q&a/dev_processed.pkl', 'rb') as f:
    dev_data = pickle.load(f)

print(f"Loaded {len(train_data)} training examples")
print(f"Loaded {len(dev_data)} development examples")
print(f"Vocabulary size: {len(word2idx)}")

Loaded 87599 training examples
Loaded 10570 development examples
Vocabulary size: 50000


In [4]:
# Dataset class
class QADataset(Dataset):
    def __init__(self, data, max_context_len=400, max_question_len=50):
        # Filter out invalid examples
        self.data = [ex for ex in data if ex['start_position'] >= 0 and ex['end_position'] >= 0]
        self.max_context_len = max_context_len
        self.max_question_len = max_question_len
        print(f"Filtered to {len(self.data)} valid examples")
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        example = self.data[idx]
        
        # Ensure positions are within bounds
        start_pos = min(example['start_position'], self.max_context_len - 1)
        end_pos = min(example['end_position'], self.max_context_len - 1)
        
        return {
            'context_ids': torch.tensor(example['context_ids'][:self.max_context_len], dtype=torch.long),
            'question_ids': torch.tensor(example['question_ids'][:self.max_question_len], dtype=torch.long),
            'start_position': torch.tensor(start_pos, dtype=torch.long),
            'end_position': torch.tensor(end_pos, dtype=torch.long),
            'id': example['id']
        }

In [5]:
# BiDAF Model
class BiDAFModel(nn.Module):
    def __init__(self, vocab_size, embed_dim=300, hidden_dim=128, dropout=0.2):
        super().__init__()
        
        self.embed_dim = embed_dim
        self.hidden_dim = hidden_dim
        
        # Word embeddings
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=word2idx['<PAD>'])
        self.embedding_dropout = nn.Dropout(dropout)
        
        # Highway network for embeddings
        self.highway = Highway(embed_dim, num_layers=2)
        
        # Contextual encoding layers
        self.context_lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True, 
                                   bidirectional=True, dropout=dropout if dropout > 0 else 0)
        self.question_lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True, 
                                    bidirectional=True, dropout=dropout if dropout > 0 else 0)
        
        # Attention weights
        self.att_weight_c = nn.Linear(2 * hidden_dim, 1, bias=False)
        self.att_weight_q = nn.Linear(2 * hidden_dim, 1, bias=False)  
        self.att_weight_cq = nn.Linear(2 * hidden_dim, 1, bias=False)
        
        # Modeling layer
        self.modeling_lstm1 = nn.LSTM(8 * hidden_dim, hidden_dim, batch_first=True,
                                     bidirectional=True, dropout=dropout if dropout > 0 else 0)
        self.modeling_lstm2 = nn.LSTM(2 * hidden_dim, hidden_dim, batch_first=True,
                                     bidirectional=True, dropout=dropout if dropout > 0 else 0)
        
        # Output projections
        self.start_linear = nn.Linear(10 * hidden_dim, 1)
        self.end_linear = nn.Linear(10 * hidden_dim, 1)
        
        self.dropout = nn.Dropout(dropout)
        
        # Initialize weights
        self._init_weights()
        
    def _init_weights(self):
        """Initialize model weights"""
        for name, param in self.named_parameters():
            if 'weight' in name and len(param.shape) > 1:
                nn.init.xavier_uniform_(param)
            elif 'bias' in name:
                nn.init.constant_(param, 0)
    
    def forward(self, context, question):
        batch_size = context.size(0)
        context_len = context.size(1)
        question_len = question.size(1)
        
        # Masks
        context_mask = (context != word2idx['<PAD>']).float()
        question_mask = (question != word2idx['<PAD>']).float()
        
        # Embeddings with highway network
        context_emb = self.embedding(context)
        question_emb = self.embedding(question)
        
        context_emb = self.highway(context_emb)
        question_emb = self.highway(question_emb)
        
        context_emb = self.embedding_dropout(context_emb)
        question_emb = self.embedding_dropout(question_emb)
        
        # Contextual encoding
        context_enc, _ = self.context_lstm(context_emb)  # (batch, context_len, 2*hidden)
        question_enc, _ = self.question_lstm(question_emb)  # (batch, question_len, 2*hidden)
        
        # Attention Flow Layer
        # Similarity matrix computation
        similarity = self._compute_similarity(context_enc, question_enc)  # (batch, context_len, question_len)
        
        # Mask similarity scores
        question_mask_expanded = question_mask.unsqueeze(1).expand(-1, context_len, -1)
        similarity = similarity.masked_fill(question_mask_expanded == 0, -1e9)
        
        # Context-to-Question Attention
        c2q_att = F.softmax(similarity, dim=2)  # (batch, context_len, question_len)
        c2q = torch.bmm(c2q_att, question_enc)  # (batch, context_len, 2*hidden)
        
        # Question-to-Context Attention
        max_similarity = torch.max(similarity, dim=2)[0]  # (batch, context_len)
        q2c_att = F.softmax(max_similarity, dim=1)  # (batch, context_len)
        q2c = torch.bmm(q2c_att.unsqueeze(1), context_enc)  # (batch, 1, 2*hidden)
        q2c = q2c.expand(-1, context_len, -1)  # (batch, context_len, 2*hidden)
        
        # Query-aware context representation
        G = torch.cat([
            context_enc,
            c2q, 
            context_enc * c2q,
            context_enc * q2c
        ], dim=2)  # (batch, context_len, 8*hidden)
        
        G = self.dropout(G)
        
        # Modeling Layer
        M1, _ = self.modeling_lstm1(G)  # (batch, context_len, 2*hidden)
        M2, _ = self.modeling_lstm2(M1)  # (batch, context_len, 2*hidden)
        
        # Output Layer
        start_input = torch.cat([G, M1], dim=2)  # (batch, context_len, 10*hidden)
        end_input = torch.cat([G, M2], dim=2)    # (batch, context_len, 10*hidden)
        
        start_logits = self.start_linear(start_input).squeeze(-1)  # (batch, context_len)
        end_logits = self.end_linear(end_input).squeeze(-1)       # (batch, context_len)
        
        # Apply context mask
        start_logits = start_logits.masked_fill(context_mask == 0, -1e9)
        end_logits = end_logits.masked_fill(context_mask == 0, -1e9)
        
        return start_logits, end_logits
    
    def _compute_similarity(self, context_enc, question_enc):
        """Compute similarity matrix between context and question"""
        batch_size, context_len, hidden_size = context_enc.size()
        question_len = question_enc.size(1)
        
        # Expand tensors for element-wise operations
        context_expanded = context_enc.unsqueeze(2).expand(-1, -1, question_len, -1)
        question_expanded = question_enc.unsqueeze(1).expand(-1, context_len, -1, -1)
        
        # Element-wise product
        elementwise_prod = context_expanded * question_expanded
        
        # Compute attention weights
        alpha = (self.att_weight_c(context_expanded) + 
                self.att_weight_q(question_expanded) + 
                self.att_weight_cq(elementwise_prod))  # (batch, context_len, question_len, 1)
        
        return alpha.squeeze(-1)  # (batch, context_len, question_len)

In [6]:
# Highway Network for better gradient flow
class Highway(nn.Module):
    def __init__(self, size, num_layers=1):
        super().__init__()
        self.num_layers = num_layers
        self.nonlinear = nn.ModuleList([nn.Linear(size, size) for _ in range(num_layers)])
        self.linear = nn.ModuleList([nn.Linear(size, size) for _ in range(num_layers)])
        self.gate = nn.ModuleList([nn.Linear(size, size) for _ in range(num_layers)])
        
    def forward(self, x):
        for layer in range(self.num_layers):
            gate = torch.sigmoid(self.gate[layer](x))
            nonlinear = F.relu(self.nonlinear[layer](x))
            linear = self.linear[layer](x)
            x = gate * nonlinear + (1 - gate) * linear
        return x

In [7]:
# Create data loaders
train_dataset = QADataset(train_data, max_context_len=400, max_question_len=50)
dev_dataset = QADataset(dev_data, max_context_len=400, max_question_len=50)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)
dev_loader = DataLoader(dev_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True)

print(f"Training batches: {len(train_loader)}")
print(f"Development batches: {len(dev_loader)}")

Filtered to 87599 valid examples
Filtered to 10570 valid examples
Training batches: 2738
Development batches: 331


In [8]:
# Initialize model
model = BiDAFModel(
    vocab_size=len(word2idx),
    embed_dim=EMBED_DIM,
    hidden_dim=HIDDEN_DIM,
    dropout=DROPOUT
).to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")



Total parameters: 18,002,730
Trainable parameters: 18,002,730


In [9]:
def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps):
    def lr_lambda(current_step):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        return max(0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)))
    
    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

In [10]:
# Optimizer, Scheduler and Loss function
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.01)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.8)
criterion = nn.CrossEntropyLoss()

In [11]:
# Training and Evaluation function
def train_epoch(model, loader, optimizer, epoch):
    model.train()
    total_loss = 0
    total_start_correct = 0
    total_end_correct = 0
    total_samples = 0
    
    # Use tqdm with less frequent updates
    pbar = tqdm(loader, desc=f"Epoch {epoch+1} [Train]", 
                dynamic_ncols=True, leave=False)
    
    start_time = time.time()
    
    for batch_idx, batch in enumerate(pbar):
        # Move to device
        context = batch['context_ids'].to(device, non_blocking=True)
        question = batch['question_ids'].to(device, non_blocking=True)
        start_pos = batch['start_position'].to(device, non_blocking=True)
        end_pos = batch['end_position'].to(device, non_blocking=True)
        
        optimizer.zero_grad()
        
        # Forward pass
        start_logits, end_logits = model(context, question)
        
        # Loss computation
        start_loss = criterion(start_logits, start_pos)
        end_loss = criterion(end_logits, end_pos)
        loss = start_loss + end_loss
        
        # Backward pass
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_NORM)
        optimizer.step()
        
        # Update metrics (less frequently)
        batch_size = start_pos.size(0)
        total_loss += loss.item()
        total_start_correct += (start_logits.argmax(1) == start_pos).sum().item()
        total_end_correct += (end_logits.argmax(1) == end_pos).sum().item()
        total_samples += batch_size
        
        # Update progress bar every 50 batches
        if batch_idx % 50 == 0:
            elapsed = time.time() - start_time
            pbar.set_postfix({
                'Loss': f'{total_loss/(batch_idx+1):.3f}',
                'Start': f'{total_start_correct/total_samples:.3f}',
                'End': f'{total_end_correct/total_samples:.3f}',
                'Time': f'{elapsed:.0f}s'
            })
    
    return (total_loss / len(loader), 
            total_start_correct / total_samples, 
            total_end_correct / total_samples)

def eval_epoch(model, loader, epoch):
    model.eval()
    total_loss = 0
    total_start_correct = 0
    total_end_correct = 0
    total_exact_correct = 0
    total_samples = 0
    
    pbar = tqdm(loader, desc=f"Epoch {epoch+1} [Val]", 
                dynamic_ncols=True, leave=False)
    
    with torch.no_grad():
        for batch in pbar:
            context = batch['context_ids'].to(device, non_blocking=True)
            question = batch['question_ids'].to(device, non_blocking=True)
            start_pos = batch['start_position'].to(device, non_blocking=True)
            end_pos = batch['end_position'].to(device, non_blocking=True)
            
            start_logits, end_logits = model(context, question)
            
            start_loss = criterion(start_logits, start_pos)
            end_loss = criterion(end_logits, end_pos)
            loss = start_loss + end_loss
            
            # Predictions
            start_pred = start_logits.argmax(1)
            end_pred = end_logits.argmax(1)
            
            # Update metrics
            batch_size = start_pos.size(0)
            total_loss += loss.item()
            total_start_correct += (start_pred == start_pos).sum().item()
            total_end_correct += (end_pred == end_pos).sum().item()
            total_exact_correct += ((start_pred == start_pos) & (end_pred == end_pos)).sum().item()
            total_samples += batch_size
    
    return (total_loss / len(loader),
            total_start_correct / total_samples,
            total_end_correct / total_samples,
            total_exact_correct / total_samples)

In [12]:
# Training loop
best_exact_acc = 0
patience_count = 0

print("Starting optimized training...")
overall_start_time = time.time()

for epoch in range(NUM_EPOCHS):
    epoch_start_time = time.time()
    
    # Training
    train_loss, train_start_acc, train_end_acc = train_epoch(model, train_loader, optimizer, epoch)
    
    # Validation
    val_loss, val_start_acc, val_end_acc, val_exact_acc = eval_epoch(model, dev_loader, epoch)
    
    # Scheduler step (once per epoch)
    scheduler.step()
    
    epoch_time = time.time() - epoch_start_time
    
    print(f"\nEpoch {epoch+1}/{NUM_EPOCHS} - Time: {epoch_time:.1f}s")
    print(f"Train - Loss: {train_loss:.4f}, Start: {train_start_acc:.4f}, End: {train_end_acc:.4f}")
    print(f"Val   - Loss: {val_loss:.4f}, Start: {val_start_acc:.4f}, End: {val_end_acc:.4f}, Exact: {val_exact_acc:.4f}")
    print(f"LR: {optimizer.param_groups[0]['lr']:.6f}")
    
    # Save best model
    if val_exact_acc > best_exact_acc:
        best_exact_acc = val_exact_acc
        patience_count = 0
        
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_exact_acc': val_exact_acc,
            'vocab': {'word2idx': word2idx, 'idx2word': idx2word},
            'config': {
                'vocab_size': len(word2idx),
                'embed_dim': EMBED_DIM,
                'hidden_dim': HIDDEN_DIM,
                'dropout': DROPOUT
            }
        }, os.path.join(MODEL_DIR, 'best_qa_model.pt'))
        
        print(f"New best model saved! Exact Acc: {val_exact_acc:.4f}")
    else:
        patience_count += 1
        print(f"No improvement. Patience: {patience_count}/{PATIENCE}")
        if patience_count >= PATIENCE:
            print("Early stopping triggered!")
            break
    
    # Memory cleanup
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

total_time = time.time() - overall_start_time
print(f"\nTraining completed in {total_time:.1f}s ({total_time/60:.1f} minutes)")
print(f"Best validation exact accuracy: {best_exact_acc:.4f}")

Starting optimized training...


                                                                                                                        


Epoch 1/8 - Time: 1090.4s
Train - Loss: 6.9092, Start: 0.1657, End: 0.1789
Val   - Loss: 5.4398, Start: 0.3141, End: 0.3364, Exact: 0.2081
LR: 0.000640
New best model saved! Exact Acc: 0.2081


                                                                                                                        


Epoch 2/8 - Time: 1941.6s
Train - Loss: 4.3080, Start: 0.4215, End: 0.4579
Val   - Loss: 4.3923, Start: 0.4176, End: 0.4504, Exact: 0.3005
LR: 0.000512
New best model saved! Exact Acc: 0.3005


                                                                                                                        


Epoch 3/8 - Time: 2434.8s
Train - Loss: 3.5259, Start: 0.5080, End: 0.5491
Val   - Loss: 4.2393, Start: 0.4383, End: 0.4710, Exact: 0.3263
LR: 0.000410
New best model saved! Exact Acc: 0.3263


                                                                                                                        


Epoch 4/8 - Time: 1831.6s
Train - Loss: 3.0387, Start: 0.5598, End: 0.6063
Val   - Loss: 4.2883, Start: 0.4446, End: 0.4753, Exact: 0.3312
LR: 0.000328
New best model saved! Exact Acc: 0.3312


                                                                                                                        


Epoch 5/8 - Time: 1769.1s
Train - Loss: 2.6235, Start: 0.6076, End: 0.6562
Val   - Loss: 4.3498, Start: 0.4490, End: 0.4783, Exact: 0.3353
LR: 0.000262
New best model saved! Exact Acc: 0.3353


                                                                                                                        


Epoch 6/8 - Time: 1725.2s
Train - Loss: 2.2640, Start: 0.6502, End: 0.7023
Val   - Loss: 4.6125, Start: 0.4368, End: 0.4724, Exact: 0.3272
LR: 0.000210


NameError: name 'PATIENCE' is not defined