# Task 1: Build a Text Classification API

## Scenario
Build a production-ready FastAPI service for sentiment classification that:
1. Loads a model efficiently using lifespan context
2. Validates input/output with Pydantic models
3. Provides single and batch classification endpoints
4. Includes proper error handling

## Your Tasks:
1. **Setup lifespan**: Create FastAPI app with model loading in lifespan
2. **Define Pydantic models**: Request/response validation
3. **Single classification**: Endpoint for one text at a time
4. **Batch classification**: Endpoint for multiple texts
5. **Error handling**: Handle edge cases gracefully

## Setup (provided)

In [None]:
import json
import sys
from contextlib import asynccontextmanager
from typing import List, Dict, Any

from fastapi import FastAPI, HTTPException, status
from fastapi.testclient import TestClient
from pydantic import BaseModel, Field, validator
from sentence_transformers import SentenceTransformer
import numpy as np
from sklearn.linear_model import LogisticRegression

print("Imports successful!")

### Helper: Train a simple classifier

We'll train a simple sentiment classifier for testing.

In [None]:
# Load sample data
with open('../fixtures/input/sample_texts.json') as f:
    sample_data = json.load(f)

print(f"Loaded {len(sample_data)} samples")
print(f"\nSample: {sample_data[0]}")

In [None]:
# Train a simple classifier (for demo purposes)
def train_simple_classifier():
    """Train a simple sentiment classifier."""
    texts = [item['text'] for item in sample_data]
    labels = [item['expected_label'] for item in sample_data]
    
    # Create embeddings
    model = SentenceTransformer('all-MiniLM-L6-v2')
    embeddings = model.encode(texts)
    
    # Train classifier
    clf = LogisticRegression(max_iter=1000)
    clf.fit(embeddings, labels)
    
    return model, clf

# We'll use this in our API
print("Classifier training function ready")

---
## Task 1: Create Lifespan Context and App

Create:
- Global dictionary `ml_models` to store models
- `lifespan` async context manager that:
  - Loads models on startup
  - Stores them in `ml_models["encoder"]` and `ml_models["classifier"]`
  - Cleans up on shutdown
- FastAPI app with the lifespan

In [None]:
# YOUR CODE HERE
# 1. Create ml_models dict
# 2. Create lifespan async context manager
# 3. Create FastAPI app with lifespan



In [None]:
# TEST - Do not modify
assert 'ml_models' in dir(), "ml_models dict not found"
assert 'lifespan' in dir(), "lifespan function not found"
assert 'app' in dir(), "app not found"

# Test that models load (context manager triggers lifespan startup)
with TestClient(app) as client:
    assert 'encoder' in ml_models, "encoder not loaded in ml_models"
    assert 'classifier' in ml_models, "classifier not loaded in ml_models"
    print("✓ Task 1 PASSED!")
    print(f"  Models loaded: {list(ml_models.keys())}")

# Re-load models for subsequent cells (lifespan cleanup clears them)
encoder, classifier = train_simple_classifier()
ml_models["encoder"] = encoder
ml_models["classifier"] = classifier

---
## Task 2: Define Pydantic Models

Create Pydantic models:

1. **`TextInput`**: 
   - `text`: str (min length 1, max length 5000)

2. **`ClassificationResult`**:
   - `text`: str
   - `label`: str
   - `confidence`: float (between 0 and 1)

3. **`BatchTextInput`**:
   - `texts`: List[str] (min 1 item, max 100 items)

4. **`BatchClassificationResult`**:
   - `results`: List[ClassificationResult]
   - `count`: int

In [None]:
# YOUR CODE HERE
# Define all four Pydantic models



In [None]:
# TEST - Do not modify
assert 'TextInput' in dir(), "TextInput not found"
assert 'ClassificationResult' in dir(), "ClassificationResult not found"
assert 'BatchTextInput' in dir(), "BatchTextInput not found"
assert 'BatchClassificationResult' in dir(), "BatchClassificationResult not found"

# Test validation
valid_input = TextInput(text="This is a test")
assert valid_input.text == "This is a test"

try:
    TextInput(text="")  # Should fail
    assert False, "Empty text should be rejected"
except:
    pass

try:
    TextInput(text="a" * 6000)  # Should fail
    assert False, "Text too long should be rejected"
except:
    pass

result = ClassificationResult(text="test", label="positive", confidence=0.95)
assert 0 <= result.confidence <= 1

batch_input = BatchTextInput(texts=["text1", "text2"])
assert len(batch_input.texts) == 2

print("✓ Task 2 PASSED!")
print("  All Pydantic models defined with proper validation")

---
## Task 3: Implement Single Classification Endpoint

Create a POST endpoint `/classify` that:
- Takes `TextInput`
- Returns `ClassificationResult`
- Uses the loaded models from `ml_models`
- Handles errors with appropriate HTTP status codes

Implementation steps:
1. Encode the input text
2. Predict with classifier
3. Get probability for confidence
4. Return result

In [None]:
# YOUR CODE HERE
# Implement the /classify endpoint



In [None]:
# TEST - Do not modify
with TestClient(app) as client:
    # Test positive sentiment
    response = client.post(
        "/classify",
        json={"text": "This is amazing! I love it!"}
    )
    assert response.status_code == 200, f"Expected 200, got {response.status_code}"
    data = response.json()
    assert 'label' in data, "Response missing 'label'"
    assert 'confidence' in data, "Response missing 'confidence'"
    assert 'text' in data, "Response missing 'text'"
    assert 0 <= data['confidence'] <= 1, "Confidence should be between 0 and 1"
    assert data['label'] == 'positive', f"Expected 'positive', got {data['label']}"

    # Test negative sentiment
    response = client.post(
        "/classify",
        json={"text": "This is terrible and awful!"}
    )
    assert response.status_code == 200
    data = response.json()
    assert data['label'] == 'negative', f"Expected 'negative', got {data['label']}"

    # Test validation error
    response = client.post(
        "/classify",
        json={"text": ""}  # Empty text
    )
    assert response.status_code == 422, "Empty text should return 422"

print("✓ Task 3 PASSED!")
print("  /classify endpoint working correctly")

---
## Task 4: Implement Batch Classification Endpoint

Create a POST endpoint `/classify/batch` that:
- Takes `BatchTextInput`
- Returns `BatchClassificationResult`
- Processes all texts efficiently (batch encoding)
- Handles errors gracefully

Implementation steps:
1. Encode all texts in batch
2. Predict for all texts
3. Get probabilities
4. Build list of results
5. Return with count

In [None]:
# YOUR CODE HERE
# Implement the /classify/batch endpoint



In [None]:
# TEST - Do not modify
with TestClient(app) as client:
    # Test batch classification
    response = client.post(
        "/classify/batch",
        json={
            "texts": [
                "This is amazing!",
                "Terrible quality. Complete waste of money.",
                "The package arrived today."
            ]
        }
    )
    assert response.status_code == 200, f"Expected 200, got {response.status_code}"
    data = response.json()
    assert 'results' in data, "Response missing 'results'"
    assert 'count' in data, "Response missing 'count'"
    assert data['count'] == 3, f"Expected count=3, got {data['count']}"
    assert len(data['results']) == 3, f"Expected 3 results, got {len(data['results'])}"

    # Check first result is positive
    assert data['results'][0]['label'] == 'positive', "First text should be positive"

    # Check second result is negative
    assert data['results'][1]['label'] == 'negative', "Second text should be negative"

    # Test validation error
    response = client.post(
        "/classify/batch",
        json={"texts": []}  # Empty list
    )
    assert response.status_code == 422, "Empty list should return 422"

print("✓ Task 4 PASSED!")
print("  /classify/batch endpoint working correctly")

---
## Task 5: Add Health Check and Error Handling

Add:
1. **GET `/health`**: Returns status and model info
2. **GET `/`**: Root endpoint with API info
3. **Error handling**: Try-except blocks with proper HTTP exceptions

In [None]:
# YOUR CODE HERE
# Add health check and root endpoints



In [None]:
# TEST - Do not modify
with TestClient(app) as client:
    # Test root endpoint
    response = client.get("/")
    assert response.status_code == 200
    data = response.json()
    assert 'message' in data or 'name' in data, "Root should return info"

    # Test health endpoint
    response = client.get("/health")
    assert response.status_code == 200
    data = response.json()
    assert 'status' in data, "Health check missing 'status'"
    assert data['status'] == 'healthy', f"Expected healthy, got {data['status']}"

print("✓ Task 5 PASSED!")
print("  Health check and root endpoints working")

---
## Bonus: Test Full API Flow

In [None]:
# Test complete API flow
with TestClient(app) as client:
    print("=== Full API Test ===")
    print()

    # 1. Health check
    response = client.get("/health")
    print("1. Health Check:")
    print(f"   {response.json()}")
    print()

    # 2. Single classification
    text = "This product exceeded all my expectations!"
    response = client.post("/classify", json={"text": text})
    result = response.json()
    print("2. Single Classification:")
    print(f"   Text: {text}")
    print(f"   Label: {result['label']}")
    print(f"   Confidence: {result['confidence']:.4f}")
    print()

    # 3. Batch classification
    texts = [
        "Absolutely love it!",
        "Worst purchase ever.",
        "It's okay, nothing special."
    ]
    response = client.post("/classify/batch", json={"texts": texts})
    result = response.json()
    print("3. Batch Classification:")
    for i, r in enumerate(result['results']):
        print(f"   {i+1}. {r['label']} ({r['confidence']:.4f}): {r['text']}")
    print()

print("✓ All tests passed!")

---
## Expected Results

After completing all tasks:
- **Task 1**: FastAPI app with lifespan loading models
- **Task 2**: Four Pydantic models with validation
- **Task 3**: `/classify` endpoint returning single result
- **Task 4**: `/classify/batch` endpoint processing multiple texts
- **Task 5**: Health check and root endpoints

## Key Concepts

1. **Lifespan**: Load heavy resources (models) once at startup
2. **Pydantic**: Automatic validation and serialization
3. **Error handling**: Proper HTTP status codes
4. **Batch processing**: More efficient than multiple single requests
5. **Type hints**: Better IDE support and documentation