# Day 33: Implementing Streaming API - Part 2

In this notebook, we'll implement a streaming API for our LLM server. Streaming is essential for responsive user interfaces, allowing tokens to be displayed as they're generated rather than waiting for the entire response.

## Overview
1. Understanding streaming generation
2. Implementing a streaming endpoint with FastAPI
3. Testing the streaming API

## 1. Understanding Streaming Generation

Streaming generation allows tokens to be sent to the client as they're produced, rather than waiting for the entire generation to complete. This provides a more responsive user experience, especially for longer responses.

Key benefits of streaming:
- Improved perceived latency (time to first token)
- Better user experience for chat applications
- Ability to cancel generation early if needed
- More efficient use of client resources

In [None]:
# Install required packages
!pip3 install -q vllm fastapi uvicorn sse-starlette pydantic

In [None]:
import os
import time
import json
import asyncio
from typing import List, Dict, Any, Optional

# Check if we're running in a notebook
try:
    get_ipython
    is_notebook = True
except NameError:
    is_notebook = False

print(f"Running in {'notebook' if is_notebook else 'script'} mode")

## 2. Implementing a Streaming Endpoint with FastAPI

We'll create a FastAPI server with a streaming endpoint that uses Server-Sent Events (SSE) to stream tokens to the client.

In [None]:
# Create a file for our streaming API server
streaming_server_code = """
import os
import time
import json
import asyncio
from typing import List, Dict, Any, Optional, AsyncGenerator

from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse, JSONResponse
from sse_starlette.sse import EventSourceResponse
from pydantic import BaseModel, Field

# Import vLLM for model serving
try:
    from vllm import LLM, SamplingParams
    VLLM_AVAILABLE = True
except ImportError:
    print("vLLM not available. Using mock implementation for demonstration.")
    VLLM_AVAILABLE = False

# Define the model name
MODEL_NAME = "facebook/opt-350m"  # Using a smaller model for demonstration

# Initialize FastAPI app
app = FastAPI(title="LLM Streaming API Server")

# Add CORS middleware
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Initialize the model
if VLLM_AVAILABLE:
    print(f"Loading model: {MODEL_NAME}")
    try:
        llm = LLM(model=MODEL_NAME)
        print("Model loaded successfully")
    except Exception as e:
        print(f"Error loading model: {e}")
        VLLM_AVAILABLE = False
else:
    llm = None

# Define request and response models
class StreamingRequest(BaseModel):
    prompt: str
    max_tokens: int = Field(default=100, ge=1, le=1024)
    temperature: float = Field(default=0.7, ge=0.0, le=2.0)
    top_p: float = Field(default=0.95, ge=0.0, le=1.0)
    stream: bool = Field(default=True)

class TokenResponse(BaseModel):
    token: str
    index: int
    finished: bool = False

# Mock token generator for demonstration
async def mock_token_generator(prompt: str, max_tokens: int) -> AsyncGenerator[str, None]:
    """Generate tokens one by one for demonstration purposes."""
    words = [" AI", " is", " transforming", " the", " world", " through", " innovative", " solutions", 
             " that", " enhance", " productivity", " and", " creativity", "."]
    
    for i in range(min(max_tokens, len(words))):
        await asyncio.sleep(0.3)  # Simulate generation time
        yield words[i]

# vLLM streaming implementation
async def vllm_stream_tokens(prompt: str, sampling_params: SamplingParams) -> AsyncGenerator[str, None]:
    """Stream tokens from vLLM."""
    # Start the generation
    outputs = llm.generate([prompt], sampling_params, use_tqdm=False)
    generated_text = outputs[0].outputs[0].text
    
    # Simulate streaming by yielding one token at a time
    # Note: vLLM doesn't natively support token-by-token streaming,
    # so we're simulating it here for demonstration purposes
    tokens = generated_text.split()
    for token in tokens:
        await asyncio.sleep(0.1)  # Simulate generation time
        yield f" {token}"

# Define API endpoints
@app.get("/")
async def root():
    return {"message": "LLM Streaming API Server is running", "model": MODEL_NAME}

@app.post("/generate")
async def generate(request: StreamingRequest):
    """Generate text with or without streaming."""
    try:
        if request.stream:
            return await stream_generate(request)
        else:
            return await complete_generate(request)
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

async def complete_generate(request: StreamingRequest):
    """Generate complete text without streaming."""
    start_time = time.time()
    
    if VLLM_AVAILABLE:
        # Set up sampling parameters
        sampling_params = SamplingParams(
            temperature=request.temperature,
            top_p=request.top_p,
            max_tokens=request.max_tokens
        )
        
        # Generate text
        outputs = llm.generate([request.prompt], sampling_params)
        generated_text = outputs[0].outputs[0].text
    else:
        # Mock generation
        await asyncio.sleep(1)  # Simulate processing time
        generated_text = " AI is transforming the world through innovative solutions that enhance productivity and creativity."
    
    # Calculate token counts (approximate)
    prompt_tokens = len(request.prompt.split())
    completion_tokens = len(generated_text.split())
    total_tokens = prompt_tokens + completion_tokens
    
    # Calculate request time
    request_time = time.time() - start_time
    
    return {
        "text": generated_text,
        "usage": {
            "prompt_tokens": prompt_tokens,
            "completion_tokens": completion_tokens,
            "total_tokens": total_tokens
        },
        "request_time": request_time
    }

async def stream_generate(request: StreamingRequest):
    """Stream generated tokens as Server-Sent Events."""
    async def event_generator():
        token_index = 0
        
        # Choose the appropriate token generator
        if VLLM_AVAILABLE:
            sampling_params = SamplingParams(
                temperature=request.temperature,
                top_p=request.top_p,
                max_tokens=request.max_tokens
            )
            token_stream = vllm_stream_tokens(request.prompt, sampling_params)
        else:
            token_stream = mock_token_generator(request.prompt, request.max_tokens)
        
        # Stream each token
        async for token in token_stream:
            response = TokenResponse(
                token=token,
                index=token_index,
                finished=False
            )
            token_index += 1
            yield json.dumps(response.dict())
        
        # Send final message
        final_response = TokenResponse(
            token="",
            index=token_index,
            finished=True
        )
        yield json.dumps(final_response.dict())
    
    return EventSourceResponse(event_generator())

@app.get("/health")
async def health():
    return {"status": "healthy"}

# Run the server
if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)
"""

# Write the server code to a file
with open("streaming_server.py", "w") as f:
    f.write(streaming_server_code)

print("Streaming server code written to streaming_server.py")

## 3. Testing the Streaming API

Now let's create a simple client to test our streaming API. We'll use the `requests` library for non-streaming requests and a custom function for streaming.

In [None]:
# Create a simple HTML client for testing the streaming API
html_client = """
<!DOCTYPE html>
<html>
<head>
    <title>LLM Streaming Client</title>
    <style>
        body { font-family: Arial, sans-serif; max-width: 800px; margin: 0 auto; padding: 20px; }
        textarea { width: 100%; height: 100px; margin-bottom: 10px; }
        #output { border: 1px solid #ccc; padding: 10px; min-height: 200px; white-space: pre-wrap; }
        button { padding: 10px; margin-right: 10px; }
        .controls { margin: 20px 0; }
    </style>
</head>
<body>
    <h1>LLM Streaming Client</h1>
    
    <div>
        <textarea id="prompt" placeholder="Enter your prompt here...">Artificial intelligence will transform the future by</textarea>
    </div>
    
    <div class="controls">
        <button id="streamBtn">Stream Generation</button>
        <button id="completeBtn">Complete Generation</button>
        <button id="clearBtn">Clear Output</button>
    </div>
    
    <div>
        <h3>Output:</h3>
        <div id="output"></div>
    </div>
    
    <script>
        const API_URL = 'http://localhost:8000/generate';
        const promptInput = document.getElementById('prompt');
        const outputDiv = document.getElementById('output');
        const streamBtn = document.getElementById('streamBtn');
        const completeBtn = document.getElementById('completeBtn');
        const clearBtn = document.getElementById('clearBtn');
        
        // Stream generation
        streamBtn.addEventListener('click', async () => {
            const prompt = promptInput.value.trim();
            if (!prompt) return;
            
            outputDiv.textContent = prompt;
            streamBtn.disabled = true;
            completeBtn.disabled = true;
            
            try {
                const response = await fetch(API_URL, {
                    method: 'POST',
                    headers: {
                        'Content-Type': 'application/json',
                    },
                    body: JSON.stringify({
                        prompt,
                        max_tokens: 100,
                        temperature: 0.7,
                        top_p: 0.95,
                        stream: true
                    })
                });
                
                const reader = response.body.getReader();
                const decoder = new TextDecoder();
                
                while (true) {
                    const { value, done } = await reader.read();
                    if (done) break;
                    
                    const chunk = decoder.decode(value);
                    const lines = chunk.split('\n');
                    
                    for (const line of lines) {
                        if (line.startsWith('data:')) {
                            const data = line.slice(5).trim();
                            if (data) {
                                try {
                                    const tokenData = JSON.parse(data);
                                    if (!tokenData.finished) {
                                        outputDiv.textContent += tokenData.token;
                                    }
                                } catch (e) {
                                    console.error('Error parsing JSON:', e);
                                }
                            }
                        }
                    }
                }
            } catch (error) {
                console.error('Error:', error);
                outputDiv.textContent += '\n\nError: ' + error.message;
            } finally {
                streamBtn.disabled = false;
                completeBtn.disabled = false;
            }
        });
        
        // Complete generation
        completeBtn.addEventListener('click', async () => {
            const prompt = promptInput.value.trim();
            if (!prompt) return;
            
            outputDiv.textContent = 'Generating...';
            streamBtn.disabled = true;
            completeBtn.disabled = true;
            
            try {
                const response = await fetch(API_URL, {
                    method: 'POST',
                    headers: {
                        'Content-Type': 'application/json',
                    },
                    body: JSON.stringify({
                        prompt,
                        max_tokens: 100,
                        temperature: 0.7,
                        top_p: 0.95,
                        stream: false
                    })
                });
                
                const data = await response.json();
                outputDiv.textContent = prompt + data.text;
            } catch (error) {
                console.error('Error:', error);
                outputDiv.textContent = 'Error: ' + error.message;
            } finally {
                streamBtn.disabled = false;
                completeBtn.disabled = false;
            }
        });
        
        // Clear output
        clearBtn.addEventListener('click', () => {
            outputDiv.textContent = '';
        });
    </script>
</body>
</html>
"""

# Write the HTML client to a file
with open("streaming_client.html", "w") as f:
    f.write(html_client)

print("HTML client written to streaming_client.html")

## 4. Running the Server and Testing

To run the server and test the streaming API, follow these steps:

1. Run the server in a terminal:
```bash
python streaming_server.py
```

2. Open the HTML client in a web browser:
```bash
# On macOS
open streaming_client.html
# On Linux
xdg-open streaming_client.html
```

3. Enter a prompt and click "Stream Generation" to see tokens appear one by one, or "Complete Generation" to get the full response at once.

In [None]:
# Python client for testing the streaming API
import requests
import json
import sseclient

def test_complete_generation(prompt="Artificial intelligence will transform the future by"):
    """Test the non-streaming API endpoint."""
    try:
        # Define the API endpoint
        url = "http://localhost:8000/generate"
        
        # Define the request payload
        payload = {
            "prompt": prompt,
            "max_tokens": 100,
            "temperature": 0.7,
            "top_p": 0.95,
            "stream": False
        }
        
        # Send the request
        response = requests.post(url, json=payload)
        
        # Check if the request was successful
        if response.status_code == 200:
            result = response.json()
            print(f"Generated text: {prompt}{result['text']}")
            print(f"\nUsage: {result['usage']}")
            print(f"Request time: {result['request_time']:.2f} seconds")
        else:
            print(f"Error: {response.status_code} - {response.text}")
    except Exception as e:
        print(f"Error testing API: {e}")
        print("Make sure the server is running.")

# Note: This function requires the sseclient package
# !pip install sseclient-py
def test_streaming_generation(prompt="Artificial intelligence will transform the future by"):
    """Test the streaming API endpoint."""
    try:
        # Define the API endpoint
        url = "http://localhost:8000/generate"
        
        # Define the request payload
        payload = {
            "prompt": prompt,
            "max_tokens": 100,
            "temperature": 0.7,
            "top_p": 0.95,
            "stream": True
        }
        
        # Send the request
        response = requests.post(url, json=payload, stream=True)
        
        # Check if the request was successful
        if response.status_code == 200:
            client = sseclient.SSEClient(response)
            
            # Print the prompt first
            print(prompt, end="", flush=True)
            
            # Process each event
            for event in client.events():
                try:
                    data = json.loads(event.data)
                    if not data["finished"]:
                        print(data["token"], end="", flush=True)
                    else:
                        print("\n\nGeneration complete.")
                        break
                except json.JSONDecodeError:
                    print(f"\nError parsing event data: {event.data}")
        else:
            print(f"Error: {response.status_code} - {response.text}")
    except Exception as e:
        print(f"Error testing streaming API: {e}")
        print("Make sure the server is running.")

# Uncomment to test the API
# Note: This requires the server to be running
# test_complete_generation()
# test_streaming_generation()

## 5. Streaming Performance Considerations

When implementing streaming in production, consider these factors:

1. **Connection Management**: Ensure your server can handle many concurrent SSE connections
2. **Backpressure Handling**: Implement mechanisms to handle slow clients
3. **Error Handling**: Gracefully handle disconnections and errors
4. **Monitoring**: Track streaming-specific metrics like time-to-first-token
5. **Load Balancing**: Configure load balancers to support long-lived connections

## Conclusion

In this notebook, we've implemented a streaming API for our LLM server using FastAPI and Server-Sent Events. Streaming provides a more responsive user experience by sending tokens to the client as they're generated.

Key takeaways:
- Server-Sent Events (SSE) provide a simple way to implement streaming
- Streaming significantly improves perceived latency
- The implementation can be adapted for different LLM frameworks

In the next notebook, we'll explore Text Generation Inference (TGI), another popular framework for LLM serving.