In [None]:
# Imports
import torch
import torch.nn as nn
import torch.nn.functional as F
import pickle
import json
import re
import numpy as np
from collections import Counter
import string
from tqdm import tqdm

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
# Load model checkpoint
checkpoint = torch.load('../../models/qa/best_qa_model.pt', map_location=device)
word2idx = checkpoint['vocab']['word2idx']
idx2word = checkpoint['vocab']['idx2word']
config = checkpoint['config']

print(f"Loaded model trained for {checkpoint['epoch']+1} epochs")
print(f"Best Dev Exact Acc: {checkpoint['dev_exact_acc']:.4f}")
print(f"Vocabulary size: {config['vocab_size']}")

In [None]:
# Recreate model architecture (copy from training notebook)
class Highway(nn.Module):
    def __init__(self, size, num_layers=1):
        super().__init__()
        self.num_layers = num_layers
        self.nonlinear = nn.ModuleList([nn.Linear(size, size) for _ in range(num_layers)])
        self.linear = nn.ModuleList([nn.Linear(size, size) for _ in range(num_layers)])
        self.gate = nn.ModuleList([nn.Linear(size, size) for _ in range(num_layers)])
        
    def forward(self, x):
        for layer in range(self.num_layers):
            gate = torch.sigmoid(self.gate[layer](x))
            nonlinear = F.relu(self.nonlinear[layer](x))
            linear = self.linear[layer](x)
            x = gate * nonlinear + (1 - gate) * linear
        return x

class BiDAFModel(nn.Module):
    def __init__(self, vocab_size, embed_dim=300, hidden_dim=128, dropout=0.2):
        super().__init__()
        
        self.embed_dim = embed_dim
        self.hidden_dim = hidden_dim
        
        # Word embeddings
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=word2idx['<PAD>'])
        self.embedding_dropout = nn.Dropout(dropout)
        
        # Highway network
        self.highway = Highway(embed_dim, num_layers=2)
        
        # Contextual encoding layers
        self.context_lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True, 
                                   bidirectional=True, dropout=dropout if dropout > 0 else 0)
        self.question_lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True, 
                                    bidirectional=True, dropout=dropout if dropout > 0 else 0)
        
        # Attention weights
        self.att_weight_c = nn.Linear(2 * hidden_dim, 1, bias=False)
        self.att_weight_q = nn.Linear(2 * hidden_dim, 1, bias=False)
        self.att_weight_cq = nn.Linear(2 * hidden_dim, 1, bias=False)
        
        # Modeling layer
        self.modeling_lstm1 = nn.LSTM(8 * hidden_dim, hidden_dim, batch_first=True,
                                     bidirectional=True, dropout=dropout if dropout > 0 else 0)
        self.modeling_lstm2 = nn.LSTM(2 * hidden_dim, hidden_dim, batch_first=True,
                                     bidirectional=True, dropout=dropout if dropout > 0 else 0)
        
        # Output projections
        self.start_linear = nn.Linear(10 * hidden_dim, 1)
        self.end_linear = nn.Linear(10 * hidden_dim, 1)
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, context, question):
        batch_size = context.size(0)
        context_len = context.size(1)
        question_len = question.size(1)
        
        # Masks
        context_mask = (context != word2idx['<PAD>']).float()
        question_mask = (question != word2idx['<PAD>']).float()
        
        # Embeddings with highway network
        context_emb = self.embedding(context)
        question_emb = self.embedding(question)
        
        context_emb = self.highway(context_emb)
        question_emb = self.highway(question_emb)
        
        context_emb = self.embedding_dropout(context_emb)
        question_emb = self.embedding_dropout(question_emb)
        
        # Contextual encoding
        context_enc, _ = self.context_lstm(context_emb)
        question_enc, _ = self.question_lstm(question_emb)
        
        # Attention Flow Layer
        similarity = self._compute_similarity(context_enc, question_enc)
        
        # Mask similarity scores
        question_mask_expanded = question_mask.unsqueeze(1).expand(-1, context_len, -1)
        similarity = similarity.masked_fill(question_mask_expanded == 0, -1e9)
        
        # Context-to-Question Attention
        c2q_att = F.softmax(similarity, dim=2)
        c2q = torch.bmm(c2q_att, question_enc)
        
        # Question-to-Context Attention
        max_similarity = torch.max(similarity, dim=2)[0]
        q2c_att = F.softmax(max_similarity, dim=1)
        q2c = torch.bmm(q2c_att.unsqueeze(1), context_enc)
        q2c = q2c.expand(-1, context_len, -1)
        
        # Query-aware context representation
        G = torch.cat([
            context_enc,
            c2q,
            context_enc * c2q,
            context_enc * q2c
        ], dim=2)
        
        G = self.dropout(G)
        
        # Modeling Layer
        M1, _ = self.modeling_lstm1(G)
        M2, _ = self.modeling_lstm2(M1)
        
        # Output Layer
        start_input = torch.cat([G, M1], dim=2)
        end_input = torch.cat([G, M2], dim=2)
        
        start_logits = self.start_linear(start_input).squeeze(-1)
        end_logits = self.end_linear(end_input).squeeze(-1)
        
        # Apply context mask
        start_logits = start_logits.masked_fill(context_mask == 0, -1e9)
        end_logits = end_logits.masked_fill(context_mask == 0, -1e9)
        
        return start_logits, end_logits
    
    def _compute_similarity(self, context_enc, question_enc):
        batch_size, context_len, hidden_size = context_enc.size()
        question_len = question_enc.size(1)
        
        context_expanded = context_enc.unsqueeze(2).expand(-1, -1, question_len, -1)
        question_expanded = question_enc.unsqueeze(1).expand(-1, context_len, -1, -1)
        
        elementwise_prod = context_expanded * question_expanded
        
        alpha = (self.att_weight_c(context_expanded) + 
                self.att_weight_q(question_expanded) + 
                self.att_weight_cq(elementwise_prod))
        
        return alpha.squeeze(-1)

In [None]:
# Initialize and load model
model = BiDAFModel(
    vocab_size=config['vocab_size'],
    embed_dim=config['embed_dim'],
    hidden_dim=config['hidden_dim'],
    dropout=0.0
).to(device)

model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

print("Model loaded successfully!")

In [None]:
# Text preprocessing functions
def clean_text(text):
    """Basic text cleaning"""
    text = re.sub(r'\s+', ' ', text)
    text = text.strip()
    return text

def simple_tokenize(text):
    """Simple word tokenization"""
    text = re.sub(r"([.!?])", r" \1 ", text)
    text = re.sub(r"[^a-zA-Z0-9.!?]+", r" ", text)
    return text.split()

def encode_text(text, word2idx, max_len):
    """Encode text to token IDs"""
    tokens = simple_tokenize(clean_text(text.lower()))
    ids = [word2idx.get(token, word2idx['<UNK>']) for token in tokens[:max_len]]
    ids += [word2idx['<PAD>']] * (max_len - len(ids))
    return ids, tokens[:max_len]

# Advanced span selection
def get_best_span(start_logits, end_logits, max_answer_length=30):
    """Get the best answer span using dynamic programming"""
    start_probs = F.softmax(start_logits, dim=0)
    end_probs = F.softmax(end_logits, dim=0)
    
    best_score = 0
    best_start = 0
    best_end = 0
    
    for start_idx in range(len(start_probs)):
        for end_idx in range(start_idx, min(start_idx + max_answer_length, len(end_probs))):
            score = start_probs[start_idx] * end_probs[end_idx]
            if score > best_score:
                best_score = score
                best_start = start_idx
                best_end = end_idx
    
    return best_start, best_end, float(best_score)

In [None]:
# Main prediction function
def predict_answer(context, question, model, word2idx, max_context_len=400, max_question_len=50, max_answer_length=30):
    """Predict answer for given context and question"""
    
    # Encode inputs
    context_ids, context_tokens = encode_text(context, word2idx, max_context_len)
    question_ids, question_tokens = encode_text(question, word2idx, max_question_len)
    
    # Convert to tensors
    context_tensor = torch.tensor([context_ids], dtype=torch.long).to(device)
    question_tensor = torch.tensor([question_ids], dtype=torch.long).to(device)
    
    # Model inference
    with torch.no_grad():
        start_logits, end_logits = model(context_tensor, question_tensor)
        
        # Get best span
        start_idx, end_idx, confidence = get_best_span(
            start_logits[0], end_logits[0], max_answer_length
        )
        
        # Extract answer from context tokens
        if start_idx < len(context_tokens) and end_idx < len(context_tokens):
            answer_tokens = context_tokens[start_idx:end_idx + 1]
            answer_text = ' '.join(answer_tokens)
        else:
            answer_text = ""
            confidence = 0.0
    
    return {
        'answer': answer_text,
        'start_idx': start_idx,
        'end_idx': end_idx,
        'confidence': confidence,
        'context_tokens': context_tokens,
        'question_tokens': question_tokens
    }

In [None]:
# Evaluation metrics (SQuAD style)
def normalize_answer(s):
    """Lower text and remove punctuation, articles and extra whitespace."""
    def remove_articles(text):
        regex = re.compile(r'\b(a|an|the)\b', re.IGNORECASE)
        return re.sub(regex, ' ', text)
    
    def white_space_fix(text):
        return ' '.join(text.split())
    
    def remove_punc(text):
        exclude = set(string.punctuation)
        return ''.join(ch for ch in text if ch not in exclude)
    
    def lower(text):
        return text.lower()
    
    return white_space_fix(remove_articles(remove_punc(lower(s))))

def get_tokens(s):
    if not s:
        return []
    return normalize_answer(s).split()

def compute_exact(a_gold, a_pred):
    return int(normalize_answer(a_gold) == normalize_answer(a_pred))

def compute_f1(a_gold, a_pred):
    gold_toks = get_tokens(a_gold)
    pred_toks = get_tokens(a_pred)
    
    common = Counter(gold_toks) & Counter(pred_toks)
    num_same = sum(common.values())
    
    if len(gold_toks) == 0 or len(pred_toks) == 0:
        return int(gold_toks == pred_toks)
    
    if num_same == 0:
        return 0
    
    precision = 1.0 * num_same / len(pred_toks)
    recall = 1.0 * num_same / len(gold_toks)
    f1 = (2 * precision * recall) / (precision + recall)
    
    return f1

In [None]:
# Load development data for comprehensive evaluation
with open('../../data/q&a/dev_processed.pkl', 'rb') as f:
    dev_data = pickle.load(f)

print(f"Loaded {len(dev_data)} development examples for evaluation")

In [None]:
# Comprehensive evaluation function
def evaluate_on_squad(model, dev_data, num_examples=1000):
    """Evaluate model on SQuAD-style data"""
    
    exact_scores = []
    f1_scores = []
    predictions = {}
    
    # Take a subset for evaluation
    eval_data = dev_data[:num_examples]
    
    print(f"Evaluating on {len(eval_data)} examples...")
    
    for example in tqdm(eval_data):
        # Reconstruct context and question from tokens
        context_text = ' '.join(example['context_tokens'])
        question_text = ' '.join(example['question_tokens'])
        ground_truth = example['answer_text']
        
        # Predict answer
        result = predict_answer(context_text, question_text, model, word2idx)
        predicted_answer = result['answer']
        
        # Store prediction
        predictions[example['id']] = predicted_answer
        
        # Calculate metrics
        exact_score = compute_exact(ground_truth, predicted_answer)
        f1_score = compute_f1(ground_truth, predicted_answer)
        
        exact_scores.append(exact_score)
        f1_scores.append(f1_score)
    
    # Calculate averages
    avg_exact = np.mean(exact_scores) * 100
    avg_f1 = np.mean(f1_scores) * 100
    
    return {
        'exact_match': avg_exact,
        'f1': avg_f1,
        'predictions': predictions,
        'total_examples': len(eval_data)
    }

In [None]:
# Run comprehensive evaluation
print("\n" + "="*50)
print("COMPREHENSIVE EVALUATION")
print("="*50)

results = evaluate_on_squad(model, dev_data, num_examples=1000)

print(f"\nResults on {results['total_examples']} examples:")
print(f"Exact Match: {results['exact_match']:.2f}%")
print(f"F1 Score: {results['f1']:.2f}%")

In [None]:
# Some example predictions
for i, example in enumerate(dev_data[:5]):
    context_text = ' '.join(example['context_tokens'])
    question_text = ' '.join(example['question_tokens'])
    ground_truth = example['answer_text']
    
    result = predict_answer(context_text, question_text, model, word2idx)
    predicted_answer = result['answer']
    
    exact_score = compute_exact(ground_truth, predicted_answer)
    f1_score = compute_f1(ground_truth, predicted_answer)
    
    print(f"\nExample {i+1}:")
    print(f"Context: {context_text[:200]}...")
    print(f"Question: {question_text}")
    print(f"Ground Truth: '{ground_truth}'")
    print(f"Predicted: '{predicted_answer}'")
    print(f"Confidence: {result['confidence']:.3f}")
    print(f"Exact Match: {exact_score}, F1: {f1_score:.3f}")
    print("-" * 50)

In [None]:
# Confidence analysis
def analyze_by_confidence(model, dev_data, num_examples=500):
    """Analyze model performance by confidence levels"""
    
    results = []
    
    for example in tqdm(dev_data[:num_examples], desc="Analyzing confidence"):
        context_text = ' '.join(example['context_tokens'])
        question_text = ' '.join(example['question_tokens'])
        ground_truth = example['answer_text']
        
        result = predict_answer(context_text, question_text, model, word2idx)
        predicted_answer = result['answer']
        
        exact_score = compute_exact(ground_truth, predicted_answer)
        f1_score = compute_f1(ground_truth, predicted_answer)
        
        results.append({
            'confidence': result['confidence'],
            'exact': exact_score,
            'f1': f1_score,
            'answer_length': len(predicted_answer.split())
        })
    
    # Sort by confidence
    results.sort(key=lambda x: x['confidence'], reverse=True)
    
    # Analyze high vs low confidence
    high_conf = [r for r in results if r['confidence'] > 0.1]
    low_conf = [r for r in results if r['confidence'] <= 0.1]
    
    print(f"\nConfidence Analysis:")
    print(f"High confidence predictions ({len(high_conf)}): EM={np.mean([r['exact'] for r in high_conf])*100:.1f}%, F1={np.mean([r['f1'] for r in high_conf])*100:.1f}%")
    print(f"Low confidence predictions ({len(low_conf)}): EM={np.mean([r['exact'] for r in low_conf])*100:.1f}%, F1={np.mean([r['f1'] for r in low_conf])*100:.1f}%")

print("\n" + "="*50)
print("CONFIDENCE ANALYSIS")
print("="*50)
analyze_by_confidence(model, dev_data, num_examples=500)

In [None]:
# Interactive Q&A function
def interactive_qa():
    """Interactive Q&A interface"""
    print("\n" + "="*50)
    print("INTERACTIVE Q&A")
    print("="*50)
    print("Enter 'quit' to exit")
    
    while True:
        print("\n" + "-"*30)
        context = input("Enter context paragraph: ").strip()
        if context.lower() == 'quit':
            break
            
        question = input("Enter question: ").strip()
        if question.lower() == 'quit':
            break
        
        # Predict answer
        result = predict_answer(context, question, model, word2idx)
        
        print(f"\nAnswer: '{result['answer']}'")
        print(f"Confidence: {result['confidence']:.3f}")
        print(f"Position: tokens {result['start_idx']}-{result['end_idx']}")
        
        if result['confidence'] < 0.05:
            print("Low confidence - answer might be unreliable")

# Save evaluation results
with open('../../models/qa/detailed_evaluation.json', 'w') as f:
    json.dump({
        'overall_results': {
            'exact_match': results['exact_match'],
            'f1': results['f1'],
            'total_examples': results['total_examples']
        },
        'model_config': config,
        'training_info': {
            'epoch': checkpoint['epoch'],
            'dev_exact_acc': checkpoint['dev_exact_acc']
        }
    }, f, indent=2)

print(f"\nEvaluation complete! Results saved to detailed_evaluation.json")
print(f"Overall Performance: EM={results['exact_match']:.1f}%, F1={results['f1']:.1f}%")

In [None]:
# Interactive Q&A
interactive_qa()