# 04 — Testing ML Pipelines & APIs

This notebook covers:
- Testing ML pipelines with reproducible seeds
- FastAPI TestClient (with lifespan!)
- Mocking ML model predictions
- Coverage reports

## Setup

In [None]:
import subprocess, sys, textwrap, tempfile, pathlib, os

FIXTURES = os.path.join(os.path.dirname(os.path.abspath(".")),
                        "11_testing_code_quality", "fixtures", "input")
# Fallback for running from within the module dir
if not os.path.exists(FIXTURES):
    FIXTURES = os.path.join(os.path.abspath("."), "fixtures", "input")
if not os.path.exists(FIXTURES):
    FIXTURES = os.path.join(os.path.abspath(".."), "fixtures", "input")

print(f"Fixtures dir: {FIXTURES}")
assert os.path.exists(FIXTURES), f"Fixtures not found at {FIXTURES}"

sys.path.insert(0, FIXTURES)

## 1. Testing the Sample Module

In [None]:
from sample_module import clean_text, tokenize, count_words, extract_emails, is_palindrome

# Direct testing — no pytest needed for exploration
assert clean_text("  hello   world  ") == "hello world"
assert tokenize("Hello World") == ["hello", "world"]
assert tokenize("") == []
assert count_words("the cat sat on the mat") == {"the": 2, "cat": 1, "sat": 1, "on": 1, "mat": 1}
assert extract_emails("email me at user@example.com or admin@test.org") == ["user@example.com", "admin@test.org"]
assert is_palindrome("A man, a plan, a canal: Panama") == True
assert is_palindrome("hello") == False

print("All sample_module functions work correctly!")

## 2. Testing ML Pipeline — Reproducibility

In [None]:
import numpy as np
from sample_ml_pipeline import (
    create_pipeline, train_pipeline, predict, evaluate,
    SAMPLE_TEXTS, SAMPLE_LABELS, preprocess_texts
)

# Train with fixed seed — results should be reproducible
np.random.seed(42)
pipe1 = create_pipeline(max_features=100)
train_pipeline(pipe1, SAMPLE_TEXTS, SAMPLE_LABELS)
preds1 = predict(pipe1, SAMPLE_TEXTS)

np.random.seed(42)
pipe2 = create_pipeline(max_features=100)
train_pipeline(pipe2, SAMPLE_TEXTS, SAMPLE_LABELS)
preds2 = predict(pipe2, SAMPLE_TEXTS)

assert np.array_equal(preds1, preds2), "Same seed should give same predictions"
print(f"Predictions (run 1): {preds1}")
print(f"Predictions (run 2): {preds2}")
print("Reproducibility confirmed!")

In [None]:
# Evaluate metrics
metrics = evaluate(SAMPLE_LABELS, preds1.tolist())
print(f"Metrics: {metrics}")
assert 0 <= metrics["accuracy"] <= 1
assert 0 <= metrics["f1_macro"] <= 1
print("Metrics are in valid range.")

## 3. Testing FastAPI with TestClient

**Important**: Starlette 0.50+ requires context manager for lifespan events.

In [None]:
from fastapi.testclient import TestClient
from sample_fastapi_app import app

# CORRECT: context manager triggers lifespan (model loads)
with TestClient(app) as client:
    # Health check
    resp = client.get("/health")
    assert resp.status_code == 200
    data = resp.json()
    assert data["status"] == "ok"
    assert data["model_loaded"] is True
    print(f"Health: {data}")

    # Predict
    resp = client.post("/predict", json={"text": "This is great and amazing"})
    assert resp.status_code == 200
    pred = resp.json()
    assert pred["label"] in ("positive", "negative", "neutral")
    assert 0 <= pred["score"] <= 1
    print(f"Predict: {pred}")

    # Batch predict
    resp = client.post("/predict/batch", json={"texts": ["good", "bad", "ok"]})
    assert resp.status_code == 200
    batch = resp.json()
    assert len(batch["predictions"]) == 3
    print(f"Batch: {len(batch['predictions'])} predictions")

print("\nAll API tests passed!")

## 4. Testing Validation Errors

In [None]:
with TestClient(app) as client:
    # Empty text — should fail validation (min_length=1)
    resp = client.post("/predict", json={"text": ""})
    assert resp.status_code == 422
    print(f"Empty text -> {resp.status_code} (validation error)")

    # Missing field
    resp = client.post("/predict", json={})
    assert resp.status_code == 422
    print(f"Missing field -> {resp.status_code}")

    # Wrong type
    resp = client.post("/predict", json={"text": 123})
    assert resp.status_code == 422
    print(f"Wrong type -> {resp.status_code}")

print("\nValidation error handling works correctly!")

## 5. Mocking the ML Model in API Tests

In [None]:
from unittest.mock import MagicMock
from sample_fastapi_app import ml_models

with TestClient(app) as client:
    # Replace the real model with a mock
    mock_model = MagicMock()
    mock_model.predict.return_value = ("positive", 0.99)
    ml_models["sentiment"] = mock_model

    resp = client.post("/predict", json={"text": "anything"})
    assert resp.status_code == 200
    assert resp.json()["label"] == "positive"
    assert resp.json()["score"] == 0.99

    # Verify the mock was called
    mock_model.predict.assert_called_once_with("anything")

print("Mock model works — we control predictions in tests!")

## 6. Running pytest with Coverage

In [None]:
# Write a proper test file and run with coverage
test_content = textwrap.dedent(f'''
import sys
sys.path.insert(0, "{FIXTURES}")
from sample_module import clean_text, tokenize, extract_emails

def test_clean():
    assert clean_text("  hi  there  ") == "hi there"

def test_tokenize():
    assert tokenize("Hello World") == ["hello", "world"]

def test_tokenize_empty():
    assert tokenize("") == []

def test_emails():
    assert extract_emails("a@b.com") == ["a@b.com"]
    assert extract_emails("no email") == []
''')

with tempfile.TemporaryDirectory() as td:
    p = pathlib.Path(td) / "test_cov.py"
    p.write_text(test_content)
    cmd = [
        sys.executable, "-m", "pytest", str(p), "-v",
        f"--cov={FIXTURES}/sample_module.py",  # not a package, but works for demo
        "--cov-report=term-missing",
        "--tb=short", "--no-header",
    ]
    result = subprocess.run(cmd, capture_output=True, text=True, timeout=30)
    print(result.stdout)
    if result.stderr:
        print(result.stderr)

## Key Takeaways

1. **Reproducibility**: always set `np.random.seed()` for ML tests
2. **TestClient context manager**: required for lifespan events (Starlette 0.50+)
3. **Mock ML models** to control predictions and speed up tests
4. **Test validation errors** (422) — they're part of your API contract
5. **Coverage**: `--cov-report=term-missing` shows exactly which lines need tests