In [None]:
from src.utils import *

In [5]:
def test_multi_head_attention():
    batch_size, seq_len, embed_dim = 2, 10, 512
    num_heads = 8
    
    # Create model and input
    model = MultiHeadAttention(embed_dim, num_heads)
    x = torch.rand(batch_size, seq_len, embed_dim)
    
    # Run forward pass
    with torch.no_grad():
        output, weights = model(x, x, x)
    
    # Test shapes
    assert output.shape == (batch_size, seq_len, embed_dim)
    assert weights.shape == (batch_size, num_heads, seq_len, seq_len)
    
    # Test attention properties with relaxed tolerances
    weight_sums = weights.sum(dim=-1)
    assert torch.allclose(
        weight_sums,
        torch.ones_like(weight_sums),
        rtol=1e-2,
        atol=1e-2
    )

In [None]:
# Example usage
qa_pipeline = MedicalQAPipeline()

# Yes/No question
question1 = "Is high blood pressure a risk factor for heart disease?"
context1 = "High blood pressure is one of the main risk factors for heart disease and stroke."
answer1 = qa_pipeline.answer_question(question1, context1)
print(f"Q: {question1}\nA: {answer1['answer']} (Confidence: {answer1['confidence']:.2f})")

# Span question
question2 = "What is the most common symptom of a heart attack?"
context2 = "While heart attack symptoms vary, chest pain is the most common symptom, often described as pressure or tightness."
answer2 = qa_pipeline.answer_question(question2, context2)
print(f"Q: {question2}\nA: {answer2['answer']} (Confidence: {answer2['confidence']:.2f})")