# Task 1: Build a Text Classification API - SOLUTION

## Scenario
Build a production-ready FastAPI service for sentiment classification that:
1. Loads a model efficiently using lifespan context
2. Validates input/output with Pydantic models
3. Provides single and batch classification endpoints
4. Includes proper error handling

## Setup

In [None]:
import json
import sys
from contextlib import asynccontextmanager
from typing import List, Dict, Any

from fastapi import FastAPI, HTTPException, status
from fastapi.testclient import TestClient
from pydantic import BaseModel, Field, validator
from sentence_transformers import SentenceTransformer
import numpy as np
from sklearn.linear_model import LogisticRegression

print("Imports successful!")

### Helper: Train a simple classifier

In [None]:
# Load sample data
with open('../fixtures/input/sample_texts.json') as f:
    sample_data = json.load(f)

print(f"Loaded {len(sample_data)} samples")
print(f"\nSample: {sample_data[0]}")

In [None]:
# Train a simple classifier (for demo purposes)
def train_simple_classifier():
    """Train a simple sentiment classifier."""
    texts = [item['text'] for item in sample_data]
    labels = [item['expected_label'] for item in sample_data]
    
    # Create embeddings
    model = SentenceTransformer('all-MiniLM-L6-v2')
    embeddings = model.encode(texts)
    
    # Train classifier
    clf = LogisticRegression(max_iter=1000)
    clf.fit(embeddings, labels)
    
    return model, clf

print("Classifier training function ready")

---
## Task 1: Create Lifespan Context and App - SOLUTION

In [None]:
# SOLUTION

# Global dict to store models
ml_models = {}

@asynccontextmanager
async def lifespan(app: FastAPI):
    """Lifespan context manager for loading/unloading models."""
    # Startup: Load models
    print("Loading models...")
    encoder, classifier = train_simple_classifier()
    ml_models["encoder"] = encoder
    ml_models["classifier"] = classifier
    print(f"Models loaded: {list(ml_models.keys())}")
    
    yield  # App runs here
    
    # Shutdown: Cleanup
    print("Cleaning up models...")
    ml_models.clear()

# Create FastAPI app with lifespan
app = FastAPI(
    title="Text Classification API",
    description="Sentiment classification service with ML models",
    version="1.0.0",
    lifespan=lifespan
)

print("FastAPI app created with lifespan")

In [None]:
# TEST
assert 'ml_models' in dir(), "ml_models dict not found"
assert 'lifespan' in dir(), "lifespan function not found"
assert 'app' in dir(), "app not found"

# Test that models load (context manager triggers lifespan startup)
with TestClient(app) as client:
    assert 'encoder' in ml_models, "encoder not loaded in ml_models"
    assert 'classifier' in ml_models, "classifier not loaded in ml_models"
    print("✓ Task 1 PASSED!")
    print(f"  Models loaded: {list(ml_models.keys())}")

# Re-load models for subsequent cells (lifespan cleanup clears them)
encoder, classifier = train_simple_classifier()
ml_models["encoder"] = encoder
ml_models["classifier"] = classifier

---
## Task 2: Define Pydantic Models - SOLUTION

In [None]:
# SOLUTION

class TextInput(BaseModel):
    """Input model for single text classification."""
    text: str = Field(..., min_length=1, max_length=5000, description="Text to classify")
    
    class Config:
        json_schema_extra = {
            "example": {
                "text": "This product is amazing!"
            }
        }

class ClassificationResult(BaseModel):
    """Output model for classification result."""
    text: str = Field(..., description="Original text")
    label: str = Field(..., description="Predicted label")
    confidence: float = Field(..., ge=0.0, le=1.0, description="Prediction confidence")
    
    class Config:
        json_schema_extra = {
            "example": {
                "text": "This product is amazing!",
                "label": "positive",
                "confidence": 0.95
            }
        }

class BatchTextInput(BaseModel):
    """Input model for batch text classification."""
    texts: List[str] = Field(..., min_length=1, max_length=100, description="List of texts to classify")
    
    @validator('texts')
    def validate_texts(cls, v):
        """Validate each text in the list."""
        for text in v:
            if not text or len(text) == 0:
                raise ValueError("Each text must be non-empty")
            if len(text) > 5000:
                raise ValueError("Each text must be <= 5000 characters")
        return v
    
    class Config:
        json_schema_extra = {
            "example": {
                "texts": [
                    "This is great!",
                    "This is terrible.",
                    "It's okay."
                ]
            }
        }

class BatchClassificationResult(BaseModel):
    """Output model for batch classification."""
    results: List[ClassificationResult] = Field(..., description="List of classification results")
    count: int = Field(..., description="Number of texts classified")
    
    class Config:
        json_schema_extra = {
            "example": {
                "results": [
                    {"text": "Great!", "label": "positive", "confidence": 0.95},
                    {"text": "Terrible.", "label": "negative", "confidence": 0.92}
                ],
                "count": 2
            }
        }

print("Pydantic models defined")

In [None]:
# TEST
assert 'TextInput' in dir(), "TextInput not found"
assert 'ClassificationResult' in dir(), "ClassificationResult not found"
assert 'BatchTextInput' in dir(), "BatchTextInput not found"
assert 'BatchClassificationResult' in dir(), "BatchClassificationResult not found"

# Test validation
valid_input = TextInput(text="This is a test")
assert valid_input.text == "This is a test"

try:
    TextInput(text="")  # Should fail
    assert False, "Empty text should be rejected"
except:
    pass

try:
    TextInput(text="a" * 6000)  # Should fail
    assert False, "Text too long should be rejected"
except:
    pass

result = ClassificationResult(text="test", label="positive", confidence=0.95)
assert 0 <= result.confidence <= 1

batch_input = BatchTextInput(texts=["text1", "text2"])
assert len(batch_input.texts) == 2

print("✓ Task 2 PASSED!")
print("  All Pydantic models defined with proper validation")

---
## Task 3: Implement Single Classification Endpoint - SOLUTION

In [None]:
# SOLUTION

@app.post("/classify", response_model=ClassificationResult, status_code=status.HTTP_200_OK)
async def classify_text(input_data: TextInput):
    """
    Classify a single text.
    
    Args:
        input_data: TextInput with text to classify
        
    Returns:
        ClassificationResult with label and confidence
    """
    try:
        # Get models
        encoder = ml_models.get("encoder")
        classifier = ml_models.get("classifier")
        
        if not encoder or not classifier:
            raise HTTPException(
                status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
                detail="Models not loaded"
            )
        
        # Encode text
        embedding = encoder.encode([input_data.text])
        
        # Predict
        label = classifier.predict(embedding)[0]
        
        # Get confidence (max probability)
        probabilities = classifier.predict_proba(embedding)[0]
        confidence = float(np.max(probabilities))
        
        return ClassificationResult(
            text=input_data.text,
            label=label,
            confidence=confidence
        )
        
    except HTTPException:
        raise
    except Exception as e:
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=f"Classification failed: {str(e)}"
        )

print("/classify endpoint implemented")

In [None]:
# TEST
with TestClient(app) as client:
    # Test positive sentiment
    response = client.post(
        "/classify",
        json={"text": "This is amazing! I love it!"}
    )
    assert response.status_code == 200, f"Expected 200, got {response.status_code}"
    data = response.json()
    assert 'label' in data, "Response missing 'label'"
    assert 'confidence' in data, "Response missing 'confidence'"
    assert 'text' in data, "Response missing 'text'"
    assert 0 <= data['confidence'] <= 1, "Confidence should be between 0 and 1"
    assert data['label'] == 'positive', f"Expected 'positive', got {data['label']}"

    # Test negative sentiment
    response = client.post(
        "/classify",
        json={"text": "This is terrible and awful!"}
    )
    assert response.status_code == 200
    data = response.json()
    assert data['label'] == 'negative', f"Expected 'negative', got {data['label']}"

    # Test validation error
    response = client.post(
        "/classify",
        json={"text": ""}  # Empty text
    )
    assert response.status_code == 422, "Empty text should return 422"

print("✓ Task 3 PASSED!")
print("  /classify endpoint working correctly")

---
## Task 4: Implement Batch Classification Endpoint - SOLUTION

In [None]:
# SOLUTION

@app.post("/classify/batch", response_model=BatchClassificationResult, status_code=status.HTTP_200_OK)
async def classify_batch(input_data: BatchTextInput):
    """
    Classify multiple texts in batch.
    
    Args:
        input_data: BatchTextInput with list of texts
        
    Returns:
        BatchClassificationResult with all results
    """
    try:
        # Get models
        encoder = ml_models.get("encoder")
        classifier = ml_models.get("classifier")
        
        if not encoder or not classifier:
            raise HTTPException(
                status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
                detail="Models not loaded"
            )
        
        # Encode all texts in batch (more efficient)
        embeddings = encoder.encode(input_data.texts)
        
        # Predict for all texts
        labels = classifier.predict(embeddings)
        
        # Get probabilities for confidence
        probabilities = classifier.predict_proba(embeddings)
        confidences = np.max(probabilities, axis=1)
        
        # Build results list
        results = [
            ClassificationResult(
                text=text,
                label=label,
                confidence=float(confidence)
            )
            for text, label, confidence in zip(input_data.texts, labels, confidences)
        ]
        
        return BatchClassificationResult(
            results=results,
            count=len(results)
        )
        
    except HTTPException:
        raise
    except Exception as e:
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=f"Batch classification failed: {str(e)}"
        )

print("/classify/batch endpoint implemented")

In [None]:
# TEST
with TestClient(app) as client:
    # Test batch classification
    response = client.post(
        "/classify/batch",
        json={
            "texts": [
                "This is amazing!",
                "Terrible quality. Complete waste of money.",
                "The package arrived today."
            ]
        }
    )
    assert response.status_code == 200, f"Expected 200, got {response.status_code}"
    data = response.json()
    assert 'results' in data, "Response missing 'results'"
    assert 'count' in data, "Response missing 'count'"
    assert data['count'] == 3, f"Expected count=3, got {data['count']}"
    assert len(data['results']) == 3, f"Expected 3 results, got {len(data['results'])}"

    # Check first result is positive
    assert data['results'][0]['label'] == 'positive', "First text should be positive"

    # Check second result is negative
    assert data['results'][1]['label'] == 'negative', "Second text should be negative"

    # Test validation error
    response = client.post(
        "/classify/batch",
        json={"texts": []}  # Empty list
    )
    assert response.status_code == 422, "Empty list should return 422"

print("✓ Task 4 PASSED!")
print("  /classify/batch endpoint working correctly")

---
## Task 5: Add Health Check and Error Handling - SOLUTION

In [None]:
# SOLUTION

@app.get("/", status_code=status.HTTP_200_OK)
async def root():
    """Root endpoint with API information."""
    return {
        "name": "Text Classification API",
        "version": "1.0.0",
        "description": "Sentiment classification service",
        "endpoints": {
            "classify": "/classify",
            "batch_classify": "/classify/batch",
            "health": "/health",
            "docs": "/docs"
        }
    }

@app.get("/health", status_code=status.HTTP_200_OK)
async def health_check():
    """Health check endpoint."""
    models_loaded = "encoder" in ml_models and "classifier" in ml_models
    
    if not models_loaded:
        raise HTTPException(
            status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
            detail="Models not loaded"
        )
    
    return {
        "status": "healthy",
        "models_loaded": list(ml_models.keys()),
        "model_info": {
            "encoder": type(ml_models["encoder"]).__name__,
            "classifier": type(ml_models["classifier"]).__name__
        }
    }

print("Health check and root endpoints added")

In [None]:
# TEST
with TestClient(app) as client:
    # Test root endpoint
    response = client.get("/")
    assert response.status_code == 200
    data = response.json()
    assert 'message' in data or 'name' in data, "Root should return info"

    # Test health endpoint
    response = client.get("/health")
    assert response.status_code == 200
    data = response.json()
    assert 'status' in data, "Health check missing 'status'"
    assert data['status'] == 'healthy', f"Expected healthy, got {data['status']}"

print("✓ Task 5 PASSED!")
print("  Health check and root endpoints working")

---
## Bonus: Test Full API Flow

In [None]:
# Test complete API flow
with TestClient(app) as client:
    print("=== Full API Test ===")
    print()

    # 1. Health check
    response = client.get("/health")
    print("1. Health Check:")
    print(f"   {response.json()}")
    print()

    # 2. Single classification
    text = "This product exceeded all my expectations!"
    response = client.post("/classify", json={"text": text})
    result = response.json()
    print("2. Single Classification:")
    print(f"   Text: {text}")
    print(f"   Label: {result['label']}")
    print(f"   Confidence: {result['confidence']:.4f}")
    print()

    # 3. Batch classification
    texts = [
        "Absolutely love it!",
        "Worst purchase ever.",
        "It's okay, nothing special."
    ]
    response = client.post("/classify/batch", json={"texts": texts})
    result = response.json()
    print("3. Batch Classification:")
    for i, r in enumerate(result['results']):
        print(f"   {i+1}. {r['label']} ({r['confidence']:.4f}): {r['text']}")
    print()

print("✓ All tests passed!")

---
## Summary

**Key techniques used:**

1. **Lifespan context manager:**
   - Load models once at startup
   - Store in global dict accessible to all endpoints
   - Clean up on shutdown

2. **Pydantic models:**
   - Automatic validation (min/max length)
   - Type checking and conversion
   - API documentation generation
   - Custom validators with `@validator`

3. **Error handling:**
   - Use `HTTPException` with proper status codes
   - 422 for validation errors (automatic)
   - 503 for service unavailable
   - 500 for internal errors

4. **Batch processing:**
   - Encode all texts at once (more efficient)
   - Single predict call for all inputs
   - Build results list with list comprehension

5. **Type hints:**
   - Better IDE support
   - Automatic OpenAPI documentation
   - Runtime validation

**Common pitfalls avoided:**
- Loading models in each request (slow)
- Not validating input length
- Processing batch items one at a time
- Not handling model loading failures
- Using generic error messages