In [4]:
# Imports
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import pickle
from tqdm import tqdm
import os
import json

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [5]:
# Load processed data and vocabulary
with open('../../data/q&a/vocab.pkl', 'rb') as f:
    vocab_data = pickle.load(f)
    word2idx = vocab_data['word2idx']
    idx2word = vocab_data['idx2word']

with open('../../data/q&a/train_processed.pkl', 'rb') as f:
    train_data = pickle.load(f)

with open('../../data/q&a/dev_processed.pkl', 'rb') as f:
    dev_data = pickle.load(f)

print(f"Loaded {len(train_data)} training examples")
print(f"Loaded {len(dev_data)} development examples")
print(f"Vocabulary size: {len(word2idx)}")

Loaded 87599 training examples
Loaded 10570 development examples
Vocabulary size: 50000


In [6]:
# Dataset class
class QADataset(Dataset):
    def __init__(self, data, max_context_len=400, max_question_len=50):
        # Filter out invalid examples
        self.data = [ex for ex in data if ex['start_position'] >= 0 and ex['end_position'] >= 0]
        self.max_context_len = max_context_len
        self.max_question_len = max_question_len
        print(f"Filtered to {len(self.data)} valid examples")
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        example = self.data[idx]
        
        # Ensure positions are within bounds
        start_pos = min(example['start_position'], self.max_context_len - 1)
        end_pos = min(example['end_position'], self.max_context_len - 1)
        
        return {
            'context_ids': torch.tensor(example['context_ids'][:self.max_context_len], dtype=torch.long),
            'question_ids': torch.tensor(example['question_ids'][:self.max_question_len], dtype=torch.long),
            'start_position': torch.tensor(start_pos, dtype=torch.long),
            'end_position': torch.tensor(end_pos, dtype=torch.long),
            'id': example['id']
        }

In [7]:
# BiDAF Model with attention Q&A Model
class BiDAFModel(nn.Module):
    def __init__(self, vocab_size, embed_dim=300, hidden_dim=128, dropout=0.2):
        super().__init__()
        
        self.embed_dim = embed_dim
        self.hidden_dim = hidden_dim
        
        # Word embeddings
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=word2idx['<PAD>'])
        self.embedding_dropout = nn.Dropout(dropout)
        
        # Highway network for embeddings
        self.highway = Highway(embed_dim, num_layers=2)
        
        # Contextual encoding layers
        self.context_lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True, 
                                   bidirectional=True, dropout=dropout if dropout > 0 else 0)
        self.question_lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True, 
                                    bidirectional=True, dropout=dropout if dropout > 0 else 0)
        
        # Attention weights
        self.att_weight_c = nn.Linear(2 * hidden_dim, 1, bias=False)
        self.att_weight_q = nn.Linear(2 * hidden_dim, 1, bias=False)  
        self.att_weight_cq = nn.Linear(2 * hidden_dim, 1, bias=False)
        
        # Modeling layer
        self.modeling_lstm1 = nn.LSTM(8 * hidden_dim, hidden_dim, batch_first=True,
                                     bidirectional=True, dropout=dropout if dropout > 0 else 0)
        self.modeling_lstm2 = nn.LSTM(2 * hidden_dim, hidden_dim, batch_first=True,
                                     bidirectional=True, dropout=dropout if dropout > 0 else 0)
        
        # Output projections
        self.start_linear = nn.Linear(10 * hidden_dim, 1)
        self.end_linear = nn.Linear(10 * hidden_dim, 1)
        
        self.dropout = nn.Dropout(dropout)
        
        # Initialize weights
        self._init_weights()
        
    def _init_weights(self):
        """Initialize model weights"""
        for name, param in self.named_parameters():
            if 'weight' in name and len(param.shape) > 1:
                nn.init.xavier_uniform_(param)
            elif 'bias' in name:
                nn.init.constant_(param, 0)
    
    def forward(self, context, question):
        batch_size = context.size(0)
        context_len = context.size(1)
        question_len = question.size(1)
        
        # Masks
        context_mask = (context != word2idx['<PAD>']).float()
        question_mask = (question != word2idx['<PAD>']).float()
        
        # Embeddings with highway network
        context_emb = self.embedding(context)
        question_emb = self.embedding(question)
        
        context_emb = self.highway(context_emb)
        question_emb = self.highway(question_emb)
        
        context_emb = self.embedding_dropout(context_emb)
        question_emb = self.embedding_dropout(question_emb)
        
        # Contextual encoding
        context_enc, _ = self.context_lstm(context_emb)  # (batch, context_len, 2*hidden)
        question_enc, _ = self.question_lstm(question_emb)  # (batch, question_len, 2*hidden)
        
        # Attention Flow Layer
        # Similarity matrix computation
        similarity = self._compute_similarity(context_enc, question_enc)  # (batch, context_len, question_len)
        
        # Mask similarity scores
        question_mask_expanded = question_mask.unsqueeze(1).expand(-1, context_len, -1)
        similarity = similarity.masked_fill(question_mask_expanded == 0, -1e9)
        
        # Context-to-Question Attention
        c2q_att = F.softmax(similarity, dim=2)  # (batch, context_len, question_len)
        c2q = torch.bmm(c2q_att, question_enc)  # (batch, context_len, 2*hidden)
        
        # Question-to-Context Attention
        max_similarity = torch.max(similarity, dim=2)[0]  # (batch, context_len)
        q2c_att = F.softmax(max_similarity, dim=1)  # (batch, context_len)
        q2c = torch.bmm(q2c_att.unsqueeze(1), context_enc)  # (batch, 1, 2*hidden)
        q2c = q2c.expand(-1, context_len, -1)  # (batch, context_len, 2*hidden)
        
        # Query-aware context representation
        G = torch.cat([
            context_enc,
            c2q, 
            context_enc * c2q,
            context_enc * q2c
        ], dim=2)  # (batch, context_len, 8*hidden)
        
        G = self.dropout(G)
        
        # Modeling Layer
        M1, _ = self.modeling_lstm1(G)  # (batch, context_len, 2*hidden)
        M2, _ = self.modeling_lstm2(M1)  # (batch, context_len, 2*hidden)
        
        # Output Layer
        start_input = torch.cat([G, M1], dim=2)  # (batch, context_len, 10*hidden)
        end_input = torch.cat([G, M2], dim=2)    # (batch, context_len, 10*hidden)
        
        start_logits = self.start_linear(start_input).squeeze(-1)  # (batch, context_len)
        end_logits = self.end_linear(end_input).squeeze(-1)       # (batch, context_len)
        
        # Apply context mask
        start_logits = start_logits.masked_fill(context_mask == 0, -1e9)
        end_logits = end_logits.masked_fill(context_mask == 0, -1e9)
        
        return start_logits, end_logits
    
    def _compute_similarity(self, context_enc, question_enc):
        """Compute similarity matrix between context and question"""
        batch_size, context_len, hidden_size = context_enc.size()
        question_len = question_enc.size(1)
        
        # Expand tensors for element-wise operations
        context_expanded = context_enc.unsqueeze(2).expand(-1, -1, question_len, -1)
        question_expanded = question_enc.unsqueeze(1).expand(-1, context_len, -1, -1)
        
        # Element-wise product
        elementwise_prod = context_expanded * question_expanded
        
        # Compute attention weights
        alpha = (self.att_weight_c(context_expanded) + 
                self.att_weight_q(question_expanded) + 
                self.att_weight_cq(elementwise_prod))  # (batch, context_len, question_len, 1)
        
        return alpha.squeeze(-1)  # (batch, context_len, question_len)

In [8]:
# Highway Network for better gradient flow
class Highway(nn.Module):
    def __init__(self, size, num_layers=1):
        super().__init__()
        self.num_layers = num_layers
        self.nonlinear = nn.ModuleList([nn.Linear(size, size) for _ in range(num_layers)])
        self.linear = nn.ModuleList([nn.Linear(size, size) for _ in range(num_layers)])
        self.gate = nn.ModuleList([nn.Linear(size, size) for _ in range(num_layers)])
        
    def forward(self, x):
        for layer in range(self.num_layers):
            gate = torch.sigmoid(self.gate[layer](x))
            nonlinear = F.relu(self.nonlinear[layer](x))
            linear = self.linear[layer](x)
            x = gate * nonlinear + (1 - gate) * linear
        return x

In [9]:
# Training hyperparameters
BATCH_SIZE = 32
LEARNING_RATE = 0.0008
NUM_EPOCHS = 3
EMBED_DIM = 300
HIDDEN_DIM = 128
DROPOUT = 0.2
MAX_GRAD_NORM = 5.0
WARMUP_PROPORTION = 0.1

In [10]:
# Create data loaders
train_dataset = QADataset(train_data, max_context_len=400, max_question_len=50)
dev_dataset = QADataset(dev_data, max_context_len=400, max_question_len=50)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)
dev_loader = DataLoader(dev_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True)

print(f"Training batches: {len(train_loader)}")
print(f"Development batches: {len(dev_loader)}")

Filtered to 87599 valid examples
Filtered to 10570 valid examples
Training batches: 2738
Development batches: 331


In [12]:
# Initialize model
model = BiDAFModel(
    vocab_size=len(word2idx),
    embed_dim=EMBED_DIM,
    hidden_dim=HIDDEN_DIM,
    dropout=DROPOUT
).to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")



Total parameters: 18,002,730
Trainable parameters: 18,002,730


In [13]:
# Optimizer with weight decay
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.01)

# Learning rate scheduler with warmup
total_steps = len(train_loader) * NUM_EPOCHS
warmup_steps = int(total_steps * WARMUP_PROPORTION)

In [14]:
def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps):
    def lr_lambda(current_step):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        return max(0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)))
    
    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

In [15]:
scheduler = get_linear_schedule_with_warmup(optimizer, warmup_steps, total_steps)

# Loss function
criterion = nn.CrossEntropyLoss()

In [16]:
# Training function
def train_epoch(model, train_loader, optimizer, scheduler):
    model.train()
    total_loss = 0
    correct_starts = 0
    correct_ends = 0
    total_examples = 0
    
    progress_bar = tqdm(train_loader, desc="Training")
    
    for batch in progress_bar:
        context_ids = batch['context_ids'].to(device)
        question_ids = batch['question_ids'].to(device)
        start_positions = batch['start_position'].to(device)
        end_positions = batch['end_position'].to(device)
        
        optimizer.zero_grad()
        
        # Forward pass
        start_logits, end_logits = model(context_ids, question_ids)
        
        # Calculate loss
        start_loss = criterion(start_logits, start_positions)
        end_loss = criterion(end_logits, end_positions)
        loss = start_loss + end_loss
        
        # Backward pass
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_NORM)
        optimizer.step()
        scheduler.step()
        
        # Calculate accuracy
        start_preds = start_logits.argmax(dim=1)
        end_preds = end_logits.argmax(dim=1)
        
        correct_starts += (start_preds == start_positions).sum().item()
        correct_ends += (end_preds == end_positions).sum().item()
        total_examples += start_positions.size(0)
        
        total_loss += loss.item()
        
        # Update progress bar
        progress_bar.set_postfix({
            'Loss': f'{loss.item():.4f}',
            'Start Acc': f'{correct_starts/total_examples:.3f}',
            'End Acc': f'{correct_ends/total_examples:.3f}',
            'LR': f'{scheduler.get_last_lr()[0]:.6f}'
        })
    
    avg_loss = total_loss / len(train_loader)
    start_acc = correct_starts / total_examples
    end_acc = correct_ends / total_examples
    
    return avg_loss, start_acc, end_acc

In [17]:
# Evaluation function  
def evaluate_model(model, dev_loader):
    model.eval()
    total_loss = 0
    correct_starts = 0
    correct_ends = 0
    exact_matches = 0
    total_examples = 0
    
    with torch.no_grad():
        for batch in tqdm(dev_loader, desc="Evaluating"):
            context_ids = batch['context_ids'].to(device)
            question_ids = batch['question_ids'].to(device)
            start_positions = batch['start_position'].to(device)
            end_positions = batch['end_position'].to(device)
            
            # Forward pass
            start_logits, end_logits = model(context_ids, question_ids)
            
            # Calculate loss
            start_loss = criterion(start_logits, start_positions)
            end_loss = criterion(end_logits, end_positions)
            loss = start_loss + end_loss
            total_loss += loss.item()
            
            # Calculate accuracy
            start_preds = start_logits.argmax(dim=1)
            end_preds = end_logits.argmax(dim=1)
            
            correct_starts += (start_preds == start_positions).sum().item()
            correct_ends += (end_preds == end_positions).sum().item()
            exact_matches += ((start_preds == start_positions) & (end_preds == end_positions)).sum().item()
            total_examples += start_positions.size(0)
    
    avg_loss = total_loss / len(dev_loader)
    start_acc = correct_starts / total_examples
    end_acc = correct_ends / total_examples
    exact_acc = exact_matches / total_examples
    
    return avg_loss, start_acc, end_acc, exact_acc

In [18]:
# Training loop
best_exact_acc = 0
best_dev_loss = float('inf')
patience = 2
patience_counter = 0

os.makedirs('../../models/qa/', exist_ok=True)

print("Starting training...")
for epoch in range(NUM_EPOCHS):
    print(f"\nEpoch {epoch + 1}/{NUM_EPOCHS}")
    
    # Train
    train_loss, train_start_acc, train_end_acc = train_epoch(model, train_loader, optimizer, scheduler)
    
    # Evaluate
    dev_loss, dev_start_acc, dev_end_acc, dev_exact_acc = evaluate_model(model, dev_loader)
    
    print(f"Train - Loss: {train_loss:.4f}, Start Acc: {train_start_acc:.4f}, End Acc: {train_end_acc:.4f}")
    print(f"Dev - Loss: {dev_loss:.4f}, Start Acc: {dev_start_acc:.4f}, End Acc: {dev_end_acc:.4f}, Exact Acc: {dev_exact_acc:.4f}")
    
    # Save best model based on exact accuracy
    if dev_exact_acc > best_exact_acc:
        best_exact_acc = dev_exact_acc
        best_dev_loss = dev_loss
        patience_counter = 0
        
        # Save model
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'dev_loss': dev_loss,
            'dev_exact_acc': dev_exact_acc,
            'vocab': {'word2idx': word2idx, 'idx2word': idx2word},
            'config': {
                'vocab_size': len(word2idx),
                'embed_dim': EMBED_DIM,
                'hidden_dim': HIDDEN_DIM,
                'dropout': DROPOUT
            }
        }, '../../models/qa/best_qa_model.pt')
        
        print(f"New best model saved! Exact Acc: {dev_exact_acc:.4f}")
    else:
        patience_counter += 1
        if patience_counter >= patience:
            print("Early stopping triggered!")
            break

print("\nTraining completed!")
print(f"Best Dev Exact Accuracy: {best_exact_acc:.4f}")
print(f"Best Dev Loss: {best_dev_loss:.4f}")

Starting training...

Epoch 1/3


Training: 100%|███████████| 2738/2738 [18:05<00:00,  2.52it/s, Loss=6.7644, Start Acc=0.130, End Acc=0.140, LR=0.000593]
Evaluating: 100%|█████████████████████████████████████████████████████████████████████| 331/331 [00:24<00:00, 13.30it/s]


Train - Loss: 7.3202, Start Acc: 0.1304, End Acc: 0.1397
Dev - Loss: 6.3361, Start Acc: 0.2134, End Acc: 0.2268, Exact Acc: 0.1377
New best model saved! Exact Acc: 0.1377

Epoch 2/3


Training: 100%|███████████| 2738/2738 [19:23<00:00,  2.35it/s, Loss=3.7434, Start Acc=0.377, End Acc=0.412, LR=0.000296]
Evaluating: 100%|█████████████████████████████████████████████████████████████████████| 331/331 [00:22<00:00, 14.41it/s]


Train - Loss: 4.6911, Start Acc: 0.3773, End Acc: 0.4116
Dev - Loss: 4.4873, Start Acc: 0.4056, End Acc: 0.4434, Exact Acc: 0.2952
New best model saved! Exact Acc: 0.2952

Epoch 3/3


Training: 100%|███████████| 2738/2738 [28:34<00:00,  1.60it/s, Loss=3.3797, Start Acc=0.509, End Acc=0.553, LR=0.000000]
Evaluating: 100%|█████████████████████████████████████████████████████████████████████| 331/331 [01:12<00:00,  4.56it/s]


Train - Loss: 3.5035, Start Acc: 0.5093, End Acc: 0.5529
Dev - Loss: 4.3978, Start Acc: 0.4221, End Acc: 0.4594, Exact Acc: 0.3114
New best model saved! Exact Acc: 0.3114

Training completed!
Best Dev Exact Accuracy: 0.3114
Best Dev Loss: 4.3978
