In [1]:
# Imports
import torch
import torch.nn as nn
import torch.nn.functional as F
import pickle
import json
import re
import numpy as np
from collections import Counter
import string
from tqdm import tqdm
import os

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [2]:
# Load model checkpoint
checkpoint_path = '../../models/qa/best_qa_model.pt'
checkpoint = torch.load(checkpoint_path, map_location=device)
word2idx = checkpoint['vocab']['word2idx']
idx2word = checkpoint['vocab']['idx2word']
config = checkpoint['config']
epoch_trained = checkpoint.get('epoch', None)
dev_exact_acc = checkpoint.get('dev_exact_acc', None)

print(f"Loaded model trained for {epoch_trained + 1 if epoch_trained is not None else 'unknown'} epochs")

if dev_exact_acc is not None:
    print(f"Best Dev Exact Acc: {dev_exact_acc:.4f}")
print(f"Vocabulary size: {config['vocab_size']}")

Loaded model trained for 5 epochs
Vocabulary size: 50000


In [3]:
# Recreate model architecture (copy from training notebook)
class Highway(nn.Module):
    def __init__(self, size, num_layers=1):
        super().__init__()
        self.num_layers = num_layers
        self.nonlinear = nn.ModuleList([nn.Linear(size, size) for _ in range(num_layers)])
        self.linear = nn.ModuleList([nn.Linear(size, size) for _ in range(num_layers)])
        self.gate = nn.ModuleList([nn.Linear(size, size) for _ in range(num_layers)])

    def forward(self, x):
        for layer in range(self.num_layers):
            gate = torch.sigmoid(self.gate[layer](x))
            nonlinear = F.relu(self.nonlinear[layer](x))
            linear = self.linear[layer](x)
            x = gate * nonlinear + (1 - gate) * linear
        return x

class BiDAFModel(nn.Module):
    def __init__(self, vocab_size, embed_dim=300, hidden_dim=128, dropout=0.2):
        super().__init__()
        self.embed_dim = embed_dim
        self.hidden_dim = hidden_dim
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=word2idx['<PAD>'])
        self.embedding_dropout = nn.Dropout(dropout)
        self.highway = Highway(embed_dim, num_layers=2)
        self.context_lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True, bidirectional=True, dropout=dropout)
        self.question_lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True, bidirectional=True, dropout=dropout)
        self.att_weight_c = nn.Linear(2 * hidden_dim, 1, bias=False)
        self.att_weight_q = nn.Linear(2 * hidden_dim, 1, bias=False)
        self.att_weight_cq = nn.Linear(2 * hidden_dim, 1, bias=False)
        self.modeling_lstm1 = nn.LSTM(8 * hidden_dim, hidden_dim, batch_first=True, bidirectional=True, dropout=dropout)
        self.modeling_lstm2 = nn.LSTM(2 * hidden_dim, hidden_dim, batch_first=True, bidirectional=True, dropout=dropout)
        self.start_linear = nn.Linear(10 * hidden_dim, 1)
        self.end_linear = nn.Linear(10 * hidden_dim, 1)
        self.dropout = nn.Dropout(dropout)
        self._init_weights()
        
    def _init_weights(self):
        for name, param in self.named_parameters():
            if 'weight' in name and len(param.shape) > 1:
                nn.init.xavier_uniform_(param)
            elif 'bias' in name:
                nn.init.constant_(param, 0)
                
    def forward(self, context, question):
        batch_size = context.size(0)
        context_len = context.size(1)
        question_len = question.size(1)
        context_mask = (context != word2idx['<PAD>']).float()
        question_mask = (question != word2idx['<PAD>']).float()
        context_emb = self.embedding(context)
        question_emb = self.embedding(question)
        context_emb = self.highway(context_emb)
        question_emb = self.highway(question_emb)
        context_emb = self.embedding_dropout(context_emb)
        question_emb = self.embedding_dropout(question_emb)
        context_enc, _ = self.context_lstm(context_emb)
        question_enc, _ = self.question_lstm(question_emb)
        similarity = self._compute_similarity(context_enc, question_enc)
        question_mask_expanded = question_mask.unsqueeze(1).expand(-1, context_len, -1)
        similarity = similarity.masked_fill(question_mask_expanded == 0, -1e9)
        c2q_att = F.softmax(similarity, dim=2)
        c2q = torch.bmm(c2q_att, question_enc)
        max_similarity = torch.max(similarity, dim=2)[0]
        q2c_att = F.softmax(max_similarity, dim=1)
        q2c = torch.bmm(q2c_att.unsqueeze(1), context_enc)
        q2c = q2c.expand(-1, context_len, -1)
        G = torch.cat([context_enc, c2q, context_enc * c2q, context_enc * q2c], dim=2)
        G = self.dropout(G)
        M1, _ = self.modeling_lstm1(G)
        M2, _ = self.modeling_lstm2(M1)
        start_input = torch.cat([G, M1], dim=2)
        end_input = torch.cat([G, M2], dim=2)
        start_logits = self.start_linear(start_input).squeeze(-1)
        end_logits = self.end_linear(end_input).squeeze(-1)
        start_logits = start_logits.masked_fill(context_mask == 0, -1e9)
        end_logits = end_logits.masked_fill(context_mask == 0, -1e9)
        return start_logits, end_logits
        
    def _compute_similarity(self, context_enc, question_enc):
        batch_size, context_len, hidden_size = context_enc.size()
        question_len = question_enc.size(1)
        context_expanded = context_enc.unsqueeze(2).expand(-1, -1, question_len, -1)
        question_expanded = question_enc.unsqueeze(1).expand(-1, context_len, -1, -1)
        elementwise_prod = context_expanded * question_expanded
        alpha = (self.att_weight_c(context_expanded) +
                 self.att_weight_q(question_expanded) +
                 self.att_weight_cq(elementwise_prod))
        return alpha.squeeze(-1)

In [4]:
# Load and prepare model
model = BiDAFModel(
    vocab_size=config['vocab_size'],
    embed_dim=config['embed_dim'],
    hidden_dim=config['hidden_dim'],
    dropout=0.0
).to(device)

model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
print("Model loaded successfully!")

Model loaded successfully!


In [5]:
# Tokenization and input encoding (match training pipeline)
def clean_text(text):
    text = re.sub(r'\s+', ' ', text)
    return text.strip()

def simple_tokenize(text):
    text = re.sub(r"([.!?])", r" \1 ", text)
    text = re.sub(r"[^a-zA-Z0-9.!?]+", r" ", text)
    return text.split()

def encode_text(text, word2idx, max_len):
    tokens = simple_tokenize(clean_text(text.lower()))
    ids = [word2idx.get(token, word2idx['<PAD>']) for token in tokens[:max_len]]
    ids += [word2idx['<PAD>']] * (max_len - len(ids))
    return ids, tokens[:max_len]

def get_best_span(start_logits, end_logits, max_answer_length=30):
    start_probs = F.softmax(start_logits, dim=0)
    end_probs = F.softmax(end_logits, dim=0)
    best_score = 0
    best_start = 0
    best_end = 0
    for start_idx in range(len(start_probs)):
        for end_idx in range(start_idx, min(start_idx+max_answer_length, len(end_probs))):
            score = start_probs[start_idx] * end_probs[end_idx]
            if score > best_score:
                best_score = score
                best_start = start_idx
                best_end = end_idx
    return best_start, best_end, float(best_score)

# Main prediction function
def predict_answer(context, question, model, word2idx, max_context_len=400, max_question_len=50, max_answer_length=30):
    context_ids, context_tokens = encode_text(context, word2idx, max_context_len)
    question_ids, question_tokens = encode_text(question, word2idx, max_question_len)
    context_tensor = torch.tensor([context_ids], dtype=torch.long).to(device)
    question_tensor = torch.tensor([question_ids], dtype=torch.long).to(device)
    with torch.no_grad():
        start_logits, end_logits = model(context_tensor, question_tensor)
        start_idx, end_idx, confidence = get_best_span(start_logits[0], end_logits[0], max_answer_length)
        if start_idx < len(context_tokens) and end_idx < len(context_tokens):
            answer_tokens = context_tokens[start_idx:end_idx + 1]
            answer_text = ' '.join(answer_tokens)
        else:
            answer_text = ""
            confidence = 0.0
    return {
        'answer': answer_text,
        'start_idx': start_idx,
        'end_idx': end_idx,
        'confidence': confidence,
        'context_tokens': context_tokens,
        'question_tokens': question_tokens
    }

# SQuAD-style normalization & metrics
def normalize_answer(s):
    def remove_articles(text):
        regex = re.compile(r'\b(a|an|the)\b', re.IGNORECASE)
        return re.sub(regex, ' ', text)
    def white_space_fix(text):
        return ' '.join(text.split())
    def remove_punc(text):
        exclude = set(string.punctuation)
        return ''.join(ch for ch in text if ch not in exclude)
    def lower(text):
        return text.lower()
    return white_space_fix(remove_articles(remove_punc(lower(s))))

def get_tokens(s):
    if not s:
        return []
    return normalize_answer(s).split()

def compute_exact(a_gold, a_pred):
    return int(normalize_answer(a_gold) == normalize_answer(a_pred))

def compute_f1(a_gold, a_pred):
    gold_toks = get_tokens(a_gold)
    pred_toks = get_tokens(a_pred)
    common = Counter(gold_toks) & Counter(pred_toks)
    num_same = sum(common.values())
    if len(gold_toks) == 0 or len(pred_toks) == 0:
        return int(gold_toks == pred_toks)
    if num_same == 0:
        return 0
    precision = 1.0 * num_same / len(pred_toks)
    recall = 1.0 * num_same / len(gold_toks)
    return (2 * precision * recall) / (precision + recall)

In [6]:
# Load dev data for evaluation
with open('../../data/q&a/dev_processed.pkl', 'rb') as f:
    dev_data = pickle.load(f)
    
print(f"Loaded {len(dev_data)} development examples for evaluation")

Loaded 10570 development examples for evaluation


In [7]:
# Evaluation loop with progress bar and live metrics
def evaluate_on_squad(model, dev_data, word2idx, num_examples=1000):
    exact_scores, f1_scores = [], []
    predictions = {}
    eval_data = dev_data[:num_examples]
    print(f"Evaluating on {len(eval_data)} examples...")
    pbar = tqdm(eval_data, dynamic_ncols=True)
    running_exact = 0
    running_f1 = 0
    for idx, example in enumerate(pbar):
        # Reconstruct context and question from tokens if needed
        context_text = ' '.join(example['context_tokens']) if 'context_tokens' in example else example['context']
        question_text = ' '.join(example['question_tokens']) if 'question_tokens' in example else example['question']
        ground_truth = example['answer_text']
        result = predict_answer(context_text, question_text, model, word2idx)
        predicted_answer = result['answer']
        predictions[example['id']] = predicted_answer
        exact = compute_exact(ground_truth, predicted_answer)
        f1 = compute_f1(ground_truth, predicted_answer)
        exact_scores.append(exact)
        f1_scores.append(f1)
        running_exact += exact
        running_f1 += f1
        avg_exact = running_exact / (idx + 1)
        avg_f1 = running_f1 / (idx + 1)
        pbar.set_postfix({
            'EM': f'{avg_exact*100:.1f}%',
            'F1': f'{avg_f1*100:.1f}%',
            'Ans': (predicted_answer[:40] + '...' if len(predicted_answer) > 40 else predicted_answer)
        })
        
    return {
        'exact_match': np.mean(exact_scores) * 100,
        'f1': np.mean(f1_scores) * 100,
        'predictions': predictions,
        'total_examples': len(eval_data)
    }

In [8]:
# Run comprehensive evaluation
results = evaluate_on_squad(model, dev_data, word2idx, num_examples=500)

print(f"\nResults on {results['total_examples']} examples:")
print(f"Exact Match: {results['exact_match']:.2f}%")
print(f"F1 Score: {results['f1']:.2f}%")

Evaluating on 500 examples...


100%|████████████████████████████████████████████████| 500/500 [18:40<00:00,  2.24s/it, EM=40.4%, F1=48.9%, Ans=tuesday]


Results on 500 examples:
Exact Match: 40.40%
F1 Score: 48.92%





In [9]:
# Example predictions
for i, example in enumerate(dev_data[:5]):
    context_text = ' '.join(example['context_tokens']) if 'context_tokens' in example else example['context']
    question_text = ' '.join(example['question_tokens']) if 'question_tokens' in example else example['question']
    ground_truth = example['answer_text']
    result = predict_answer(context_text, question_text, model, word2idx)
    predicted_answer = result['answer']
    exact_score = compute_exact(ground_truth, predicted_answer)
    f1_score = compute_f1(ground_truth, predicted_answer)
    print(f"\nExample {i+1}:")
    print(f"Context: {context_text[:200]}...")
    print(f"Question: {question_text}")
    print(f"Ground Truth: '{ground_truth}'")
    print(f"Predicted: '{predicted_answer}'")
    print(f"Exact Match: {exact_score}, F1: {f1_score:.3f}")

print("\nEvaluation complete!")


Example 1:
Context: super bowl 50 was an american football game to determine the champion of the national football league nfl for the 2015 season . the american football conference afc champion denver broncos defeated th...
Question: which nfl team represented the afc at super bowl 50 ?
Ground Truth: 'Denver Broncos'
Predicted: 'american football conference'
Exact Match: 0, F1: 0.000

Example 2:
Context: super bowl 50 was an american football game to determine the champion of the national football league nfl for the 2015 season . the american football conference afc champion denver broncos defeated th...
Question: which nfl team represented the nfc at super bowl 50 ?
Ground Truth: 'Carolina Panthers'
Predicted: 'panthers'
Exact Match: 0, F1: 0.667

Example 3:
Context: super bowl 50 was an american football game to determine the champion of the national football league nfl for the 2015 season . the american football conference afc champion denver broncos defeated th...
Question: where 

In [11]:
import json

# Interactive Q&A function
def interactive_qa():
    print("\n" + "="*50)
    print("INTERACTIVE Q&A")
    print("="*50)
    print("Enter 'quit' to exit")
    
    while True:
        print("\n" + "-"*30)
        context = input("Enter context paragraph: ").strip()
        if context.lower() == 'quit':
            print("Exiting interactive Q&A.")
            break
            
        question = input("Enter question: ").strip()
        if question.lower() == 'quit':
            print("Exiting interactive Q&A.")
            break
        
        # Predict answer using your previously defined predict_answer()
        result = predict_answer(context, question, model, word2idx)
        
        print(f"\nAnswer: '{result['answer']}'")
        print(f"Confidence: {result['confidence']:.3f}")
        print(f"Position: tokens {result['start_idx']}-{result['end_idx']}")
        
        if result['confidence'] < 0.05:
            print("Low confidence - answer might be unreliable")

# Save evaluation results
eval_results_path = '../../models/qa/detailed_evaluation.json'
with open(eval_results_path, 'w') as f:
    json.dump({
        'overall_results': {
            'exact_match': results['exact_match'],
            'f1': results['f1'],
            'total_examples': results['total_examples']
        },
        'model_config': config,
        'training_info': {
            'epoch': checkpoint.get('epoch', None),
            'dev_exact_acc': checkpoint.get('dev_exact_acc', None)
        }
    }, f, indent=2)

print(f"\nEvaluation complete! Results saved to {eval_results_path}")
print(f"Overall Performance: EM={results['exact_match']:.1f}%, F1={results['f1']:.1f}%")


Evaluation complete! Results saved to ../../models/qa/detailed_evaluation.json
Overall Performance: EM=40.4%, F1=48.9%


In [None]:
# Interactive Q&A
interactive_qa()


INTERACTIVE Q&A
Enter 'quit' to exit

------------------------------


Enter context paragraph:  The old lighthouse keeper, Silas, had lived alone on the craggy isle for fifty years. His only companions were the gulls and the rhythmic crash of waves against the rocks. Every evening, he'd climb the winding staircase, his boots echoing in the stone tower, to light the lamp and send its beam sweeping across the dark sea. One stormy night, a small fishing boat was tossed by the tempest. Silas, with his keen eyes, saw the tiny vessel battling the waves. He knew they were in desperate need of guidance. He worked tirelessly, adjusting the lamp, making sure its light was strong and true. The boat, guided by the steady beam, found its way to safety. The next morning, the grateful fishermen came to thank Silas. He smiled, a rare and gentle smile, and said, "The light is my duty. It's all I have to offer." And in that moment, Silas realized that his solitary life, once filled with quiet isolation, had found purpose in the simple act of shining.
Enter question:  who 


Answer: 'old lighthouse'
Confidence: 0.245
Position: tokens 1-2

------------------------------
