In [None]:
import requests
import time
import asyncio
import aiohttp
from concurrent.futures import ThreadPoolExecutor, as_completed
import nest_asyncio
import json

nest_asyncio.apply() # For Jupyter

API_BASE_URL = "http://localhost:8000" # Ensure container is running and port-mapped

print(f"API Base URL: {API_BASE_URL}\n")

# --- Helper function to make requests ---
def make_request(method, endpoint, payload=None, expected_status=200):
    url = f"{API_BASE_URL}{endpoint}"
    try:
        if method.upper() == "GET":
            response = requests.get(url, timeout=60)
        elif method.upper() == "POST":
            response = requests.post(url, json=payload, timeout=120) # Longer timeout for predict
        else:
            raise ValueError(f"Unsupported method: {method}")
        
        print(f"--- Request to {method.upper()} {endpoint} ---")
        if payload:
            print(f"Payload (first 50 chars if long): {str(payload)[:150]}")
        
        if response.status_code == expected_status:
            print(f"Status: {response.status_code} OK")
            try:
                res_json = response.json()
                print(f"Response (sample): {str(res_json)[:300]}...")
                return res_json
            except requests.exceptions.JSONDecodeError:
                print(f"Response (not JSON): {response.text[:300]}...")
                return response.text
        else:
            print(f"Status: {response.status_code} - Error: {response.text[:300]}...")
            return {"error": response.text, "status_code": response.status_code}
    except requests.exceptions.RequestException as e:
        print(f"Request failed for {endpoint}: {e}")
        return {"error": str(e), "status_code": "N/A"}

In [None]:
# --- 1. Check Basic API Endpoints ---
print("\n--- Checking Basic API Endpoints ---")
make_request("GET", "/")
make_request("GET", "/healthz")
make_request("GET", "/readiness") # This might fail if default model preload fails
make_request("GET", "/metrics") # Check Prometheus metrics endpoint
make_request("GET", "/cache_info") # Check cache on one worker

In [None]:
# --- 2. Test Single Prediction (Default Model) ---
print("\n--- Test Single Prediction (Default Model) ---")
single_payload_sentiment = {
    "model_name": "distilbert-base-uncased-finetuned-sst-2-english",
    "task": "sentiment-analysis",
    "inputs": "This is a fantastic product, I highly recommend it!"
}
result_single = make_request("POST", "/predict", single_payload_sentiment)

In [None]:
# --- 3. Test Batch Prediction (Default Model) ---
print("\n--- Test Batch Prediction (Default Model) ---")
batch_payload_sentiment = {
    "model_name": "distilbert-base-uncased-finetuned-sst-2-english",
    "task": "sentiment-analysis",
    "inputs": [
        "I am incredibly happy with the service.",
        "This is the worst thing I have ever bought.",
        "It's an okay movie, neither good nor bad."
    ]
}
result_batch = make_request("POST", "/predict", batch_payload_sentiment)

In [None]:
# --- 4. Test a Different Model and Task (e.g., Text Generation) ---
# This will cause a new model to be downloaded and cached if not used before.
print("\n--- Test Different Model/Task (Text Generation with GPT-2) ---")
# Using a smaller variant like 'gpt2' for quicker demo.
# For actual generation, you might want 'gpt2-medium' or larger.
generation_payload = {
    "model_name": "gpt2", # A common, relatively small text generation model
    "task": "text-generation",
    "inputs": "Once upon a time, in a land far away",
    "pipeline_kwargs": {"max_new_tokens": 20, "num_return_sequences": 1} # Arguments for the pipeline
}
# Text generation can take longer, especially for the first load.
result_generation = make_request("POST", "/predict", generation_payload)
# Verify cache info again (might be a different worker, but lru_cache for gpt2 should be populated in at least one)
make_request("GET", "/cache_info")


In [None]:
# --- 5. Demonstrate Parallel Requests (ThreadPoolExecutor) ---
print("\n--- Demonstrate Parallel Requests (ThreadPoolExecutor) ---")
parallel_texts_sentiment = [
    "The weather today is beautiful and sunny.",
    "I'm feeling a bit down after hearing the news.",
    "This new software update is incredibly buggy.",
    "The concert was an unforgettable experience!",
    "Customer support was surprisingly helpful and efficient.",
    "I am neutral about this new policy change.",
    "This book is a masterpiece of modern literature.",
    "The food at that restaurant was utterly disappointing."
] * 2 # 16 requests

def send_predict_request(text_input, model_name, task):
    payload = {"model_name": model_name, "task": task, "inputs": text_input}
    url = f"{API_BASE_URL}/predict"
    try:
        response = requests.post(url, json=payload, timeout=120)
        response.raise_for_status()
        return response.json()
    except requests.exceptions.RequestException as e:
        return {"error": str(e), "input_text": text_input, "status_code": response.status_code if 'response' in locals() else "N/A"}

start_time_parallel = time.time()
parallel_results_sentiment = []
# Aggressively use workers to demonstrate server parallelism
with ThreadPoolExecutor(max_workers=len(parallel_texts_sentiment)) as executor:
    futures = [
        executor.submit(send_predict_request, text, "distilbert-base-uncased-finetuned-sst-2-english", "sentiment-analysis")
        for text in parallel_texts_sentiment
    ]
    for i, future in enumerate(as_completed(futures)):
        try:
            data = future.result()
            parallel_results_sentiment.append(data)
            print(f"Parallel Req {i+1}/{len(parallel_texts_sentiment)}: PID {data.get('worker_pid', 'N/A')}, Model '{data.get('model_name', 'N/A')}', Input '{str(data.get('inputs', 'N/A'))[:30]}...'")
        except Exception as exc:
            print(f"Parallel Req {i+1}/{len(parallel_texts_sentiment)} generated an exception: {exc}")
            parallel_results_sentiment.append({"error": str(exc)})

end_time_parallel = time.time()
print(f"\nThreadPoolExecutor: Completed {len(parallel_results_sentiment)} sentiment requests in {end_time_parallel - start_time_parallel:.2f} seconds.")

worker_pids_sentiment = set()
successful_sentiment_requests = 0
for res in parallel_results_sentiment:
    if isinstance(res, dict) and "worker_pid" in res:
        worker_pids_sentiment.add(res['worker_pid'])
        successful_sentiment_requests +=1
print(f"Sentiment requests handled by PIDs: {worker_pids_sentiment} ({successful_sentiment_requests} successful)")

In [None]:
# --- 6. Check Metrics Again After Load ---
print("\n--- Final check of /metrics endpoint ---")
# Metrics should reflect the requests made
metrics_after_load = make_request("GET", "/metrics")
if isinstance(metrics_after_load, str): # if it's raw text from prometheus
    # Look for specific metrics
    if "hf_requests_total" in metrics_after_load and "gpt2" in metrics_after_load:
        print("Metrics endpoint seems to contain data for requests made.")
    else:
        print("Metrics endpoint content doesn't explicitly show expected counters, but was fetched.")

print("\n--- Demo Complete ---")
print("Remember to check Docker container logs to see PIDs and model loading/caching messages from Gunicorn workers.")