In [None]:
# nb71 | WebSocket 串流輸出實作
# Goal: Real-time streaming LLM responses via WebSocket

# Cell1:  Shared Cache Bootstrap
import os, pathlib, torch
import sys
from datetime import datetime

# Shared cache configuration (複製到每本 notebook)
AI_CACHE_ROOT = os.getenv("AI_CACHE_ROOT", "../ai_warehouse/cache")

for k, v in {
    "HF_HOME": f"{AI_CACHE_ROOT}/hf",
    "TRANSFORMERS_CACHE": f"{AI_CACHE_ROOT}/hf/transformers",
    "HF_DATASETS_CACHE": f"{AI_CACHE_ROOT}/hf/datasets",
    "HUGGINGFACE_HUB_CACHE": f"{AI_CACHE_ROOT}/hf/hub",
    "TORCH_HOME": f"{AI_CACHE_ROOT}/torch",
}.items():
    os.environ[k] = v
    pathlib.Path(v).mkdir(parents=True, exist_ok=True)
print("[Cache]", AI_CACHE_ROOT, "| GPU:", torch.cuda.is_available())

In [None]:
# Cell 2: Dependencies & Imports
import json
import asyncio
import uuid
from typing import Dict, Set, Optional
from datetime import datetime

from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.responses import HTMLResponse
import uvicorn

# For demo purposes - in production, import from shared_utils
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
from threading import Thread
import time

In [None]:
# Cell 3: WebSocket Protocol Design
"""
WebSocket Message Protocol:

Client → Server:
{
  "type": "generate",
  "id": "unique-request-id",
  "data": {
    "messages": [{"role": "user", "content": "Hello"}],
    "max_new_tokens": 256,
    "temperature": 0.7
  }
}

{
  "type": "cancel",
  "id": "request-id-to-cancel"
}

Server → Client:
{
  "type": "token",
  "id": "request-id",
  "data": {"token": "Hello", "is_final": false}
}

{
  "type": "done",
  "id": "request-id",
  "data": {"total_tokens": 127, "duration_ms": 5420}
}

{
  "type": "error",
  "id": "request-id",
  "data": {"error": "Generation failed", "code": "MODEL_ERROR"}
}
"""


class WSMessage:
    @staticmethod
    def token(request_id: str, token: str, is_final: bool = False):
        return {
            "type": "token",
            "id": request_id,
            "data": {"token": token, "is_final": is_final},
        }

    @staticmethod
    def done(request_id: str, total_tokens: int, duration_ms: float):
        return {
            "type": "done",
            "id": request_id,
            "data": {"total_tokens": total_tokens, "duration_ms": duration_ms},
        }

    @staticmethod
    def error(request_id: str, error: str, code: str = "UNKNOWN"):
        return {
            "type": "error",
            "id": request_id,
            "data": {"error": error, "code": code},
        }

In [None]:
# Cell 4: LLM Adapter Integration (Streaming-enabled)
class StreamingLLMAdapter:
    def __init__(self, model_id: str = "Qwen/Qwen2.5-7B-Instruct"):
        print(f"Loading {model_id}...")
        self.tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
        self.model = AutoModelForCausalLM.from_pretrained(
            model_id,
            device_map="auto",
            torch_dtype="auto",
            load_in_4bit=True,  # Low VRAM option
        )

        # Pad token for batch processing
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

        print(f"✓ Model loaded on {self.model.device}")

    def format_messages(self, messages):
        # Simple chat template - adapt to model's format
        formatted = ""
        for msg in messages:
            role = msg.get("role", "user")
            content = msg.get("content", "")
            if role == "system":
                formatted += f"System: {content}\n"
            elif role == "user":
                formatted += f"User: {content}\n"
            elif role == "assistant":
                formatted += f"Assistant: {content}\n"
        formatted += "Assistant: "
        return formatted

    async def generate_stream(
        self, messages, max_new_tokens=256, temperature=0.7, stop_event=None
    ):
        """Generate tokens with async streaming"""
        try:
            prompt = self.format_messages(messages)
            inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)

            # TextIteratorStreamer for token-by-token output
            streamer = TextIteratorStreamer(
                self.tokenizer, skip_prompt=True, skip_special_tokens=True
            )

            generation_kwargs = {
                **inputs,
                "max_new_tokens": max_new_tokens,
                "temperature": temperature,
                "do_sample": True,
                "streamer": streamer,
                "pad_token_id": self.tokenizer.eos_token_id,
            }

            # Run generation in separate thread
            def generate():
                self.model.generate(**generation_kwargs)

            thread = Thread(target=generate)
            thread.start()

            # Stream tokens
            for token in streamer:
                if stop_event and stop_event.is_set():
                    print("Generation cancelled by stop_event")
                    break
                yield token

            thread.join()

        except Exception as e:
            print(f"Generation error: {e}")
            raise e

In [None]:
# Cell 5: WebSocket Connection Manager
class ConnectionManager:
    def __init__(self):
        self.active_connections: Dict[str, WebSocket] = {}
        self.active_generations: Dict[str, asyncio.Event] = {}

    async def connect(self, websocket: WebSocket, client_id: str):
        await websocket.accept()
        self.active_connections[client_id] = websocket
        print(f"Client {client_id} connected. Total: {len(self.active_connections)}")

    def disconnect(self, client_id: str):
        # Cancel any active generations
        if client_id in self.active_generations:
            self.active_generations[client_id].set()
            del self.active_generations[client_id]

        if client_id in self.active_connections:
            del self.active_connections[client_id]
        print(f"Client {client_id} disconnected. Total: {len(self.active_connections)}")

    async def send_message(self, client_id: str, message: dict):
        if client_id in self.active_connections:
            try:
                await self.active_connections[client_id].send_text(json.dumps(message))
            except Exception as e:
                print(f"Failed to send to {client_id}: {e}")
                self.disconnect(client_id)

    def cancel_generation(self, client_id: str, request_id: str):
        if client_id in self.active_generations:
            self.active_generations[client_id].set()
            print(f"Cancelled generation {request_id} for client {client_id}")

In [None]:
# Cell 6: FastAPI App with WebSocket Endpoint
app = FastAPI(title="Streaming LLM API", version="1.0.0")

# Global instances
llm_adapter = None  # Will be initialized when needed
connection_manager = ConnectionManager()


def get_llm_adapter():
    global llm_adapter
    if llm_adapter is None:
        llm_adapter = StreamingLLMAdapter()
    return llm_adapter


@app.websocket("/ws/{client_id}")
async def websocket_endpoint(websocket: WebSocket, client_id: str):
    await connection_manager.connect(websocket, client_id)

    try:
        while True:
            # Receive message from client
            data = await websocket.receive_text()
            message = json.loads(data)

            message_type = message.get("type")
            request_id = message.get("id", str(uuid.uuid4()))

            if message_type == "generate":
                await handle_generate(client_id, request_id, message.get("data", {}))

            elif message_type == "cancel":
                connection_manager.cancel_generation(client_id, request_id)
                await connection_manager.send_message(
                    client_id, {"type": "cancelled", "id": request_id}
                )

            else:
                await connection_manager.send_message(
                    client_id,
                    WSMessage.error(request_id, "Unknown message type", "INVALID_TYPE"),
                )

    except WebSocketDisconnect:
        connection_manager.disconnect(client_id)
    except Exception as e:
        print(f"WebSocket error for {client_id}: {e}")
        connection_manager.disconnect(client_id)


async def handle_generate(client_id: str, request_id: str, data: dict):
    """Handle generation request with streaming"""
    try:
        # Validate input
        messages = data.get("messages", [])
        if not messages:
            await connection_manager.send_message(
                client_id,
                WSMessage.error(request_id, "No messages provided", "MISSING_MESSAGES"),
            )
            return

        max_new_tokens = min(data.get("max_new_tokens", 256), 512)  # Cap for safety
        temperature = data.get("temperature", 0.7)

        # Create stop event for this generation
        stop_event = asyncio.Event()
        connection_manager.active_generations[client_id] = stop_event

        # Get LLM adapter
        adapter = get_llm_adapter()

        # Start generation with timing
        start_time = time.time()
        token_count = 0

        async for token in adapter.generate_stream(
            messages,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            stop_event=stop_event,
        ):
            if stop_event.is_set():
                break

            # Send token to client
            await connection_manager.send_message(
                client_id, WSMessage.token(request_id, token)
            )
            token_count += 1

        # Send completion message
        duration_ms = (time.time() - start_time) * 1000
        await connection_manager.send_message(
            client_id, WSMessage.done(request_id, token_count, duration_ms)
        )

        # Cleanup
        if client_id in connection_manager.active_generations:
            del connection_manager.active_generations[client_id]

    except Exception as e:
        await connection_manager.send_message(
            client_id, WSMessage.error(request_id, str(e), "GENERATION_ERROR")
        )

        # Cleanup on error
        if client_id in connection_manager.active_generations:
            del connection_manager.active_generations[client_id]

In [None]:
# Cell 7: Client-side HTML/JavaScript Demo
@app.get("/")
async def get_demo_page():
    html_content = """
<!DOCTYPE html>
<html>
<head>
    <title>WebSocket Streaming Demo</title>
    <style>
        body { font-family: Arial, sans-serif; margin: 20px; }
        #output { border: 1px solid #ccc; padding: 10px; height: 300px; overflow-y: auto; white-space: pre-wrap; }
        button { margin: 5px; padding: 10px; }
        input, textarea { width: 100%; margin: 5px 0; }
    </style>
</head>
<body>
    <h1>WebSocket Streaming LLM Demo</h1>

    <div>
        <label>Message:</label>
        <textarea id="messageInput" rows="3" placeholder="Type your message here...">請介紹什麼是 RAG？</textarea>
    </div>

    <div>
        <label>Max Tokens:</label>
        <input type="number" id="maxTokens" value="128" min="1" max="512">

        <label>Temperature:</label>
        <input type="number" id="temperature" value="0.7" min="0" max="2" step="0.1">
    </div>

    <div>
        <button onclick="connect()">Connect</button>
        <button onclick="sendMessage()">Send Message</button>
        <button onclick="cancelGeneration()">Cancel</button>
        <button onclick="disconnect()">Disconnect</button>
        <button onclick="clearOutput()">Clear</button>
    </div>

    <div>
        <label>Status: <span id="status">Disconnected</span></label>
    </div>

    <div id="output"></div>

    <script>
        let ws = null;
        let clientId = 'client-' + Math.random().toString(36).substr(2, 9);
        let currentRequestId = null;

        function connect() {
            if (ws) {
                ws.close();
            }

            ws = new WebSocket(`ws://localhost:8000/ws/${clientId}`);

            ws.onopen = function() {
                document.getElementById('status').textContent = 'Connected';
                appendOutput('[SYSTEM] Connected to WebSocket\\n');
            };

            ws.onmessage = function(event) {
                const message = JSON.parse(event.data);
                handleMessage(message);
            };

            ws.onclose = function() {
                document.getElementById('status').textContent = 'Disconnected';
                appendOutput('[SYSTEM] Disconnected\\n');
            };

            ws.onerror = function(error) {
                appendOutput(`[ERROR] ${error}\\n`);
            };
        }

        function handleMessage(message) {
            const { type, id, data } = message;

            switch(type) {
                case 'token':
                    appendOutput(data.token);
                    break;

                case 'done':
                    appendOutput(`\\n[DONE] ${data.total_tokens} tokens in ${data.duration_ms.toFixed(1)}ms\\n\\n`);
                    currentRequestId = null;
                    break;

                case 'error':
                    appendOutput(`\\n[ERROR] ${data.error} (${data.code})\\n\\n`);
                    currentRequestId = null;
                    break;

                case 'cancelled':
                    appendOutput(`\\n[CANCELLED] Generation stopped\\n\\n`);
                    currentRequestId = null;
                    break;
            }
        }

        function sendMessage() {
            if (!ws || ws.readyState !== WebSocket.OPEN) {
                alert('Please connect first');
                return;
            }

            const message = document.getElementById('messageInput').value.trim();
            if (!message) {
                alert('Please enter a message');
                return;
            }

            currentRequestId = 'req-' + Date.now();

            const request = {
                type: 'generate',
                id: currentRequestId,
                data: {
                    messages: [{ role: 'user', content: message }],
                    max_new_tokens: parseInt(document.getElementById('maxTokens').value),
                    temperature: parseFloat(document.getElementById('temperature').value)
                }
            };

            appendOutput(`[USER] ${message}\\n[ASSISTANT] `);
            ws.send(JSON.stringify(request));
        }

        function cancelGeneration() {
            if (currentRequestId && ws && ws.readyState === WebSocket.OPEN) {
                ws.send(JSON.stringify({
                    type: 'cancel',
                    id: currentRequestId
                }));
            }
        }

        function disconnect() {
            if (ws) {
                ws.close();
                ws = null;
            }
        }

        function appendOutput(text) {
            const output = document.getElementById('output');
            output.textContent += text;
            output.scrollTop = output.scrollHeight;
        }

        function clearOutput() {
            document.getElementById('output').textContent = '';
        }

        // Auto-connect on page load
        window.onload = function() {
            // Don't auto-connect in demo
            // connect();
        };
    </script>
</body>
</html>
    """
    return HTMLResponse(content=html_content)

In [None]:
# Cell 8: Health Check & Info Endpoints
@app.get("/health")
async def health_check():
    return {
        "status": "healthy",
        "active_connections": len(connection_manager.active_connections),
        "active_generations": len(connection_manager.active_generations),
        "timestamp": datetime.now().isoformat(),
    }


@app.get("/info")
async def api_info():
    return {
        "name": "Streaming LLM WebSocket API",
        "version": "1.0.0",
        "endpoints": {"websocket": "/ws/{client_id}", "demo": "/", "health": "/health"},
        "message_types": ["generate", "cancel"],
        "response_types": ["token", "done", "error", "cancelled"],
    }

In [None]:
# Cell 9: Smoke Test Function
def run_server(host="127.0.0.1", port=8000):
    """Start the WebSocket server"""
    print(f"🚀 Starting WebSocket streaming server...")
    print(f"📡 WebSocket endpoint: ws://{host}:{port}/ws/{{client_id}}")
    print(f"🌐 Demo page: http://{host}:{port}/")
    print(f"💚 Health check: http://{host}:{port}/health")
    print("\nPress Ctrl+C to stop")

    uvicorn.run(
        app,
        host=host,
        port=port,
        reload=False,  # Disable in production
        access_log=True,
    )

In [None]:
# Cell 10: Smoke Test - Start Server
if __name__ == "__main__":
    # For notebook testing - comment out to avoid blocking
    print("✅ WebSocket streaming server ready!")
    print("Uncomment the line below to start the server:")
    print("# run_server()")

    # Uncomment to start server:
    # run_server()

"""
=== Smoke Test Instructions ===

1. Uncomment `run_server()` above and run this cell
2. Open browser to http://localhost:8000/
3. Click "Connect" button
4. Type a message and click "Send Message"
5. Watch real-time streaming tokens appear
6. Test "Cancel" button during generation
7. Check http://localhost:8000/health for metrics

Expected output:
- WebSocket connection established
- Streaming tokens appear one by one
- Final stats: token count + duration
- Cancellation works mid-generation
"""

In [None]:
# 測試用最小客戶端
import asyncio
import websockets
import json


async def test_websocket():
    uri = "ws://localhost:8000/ws/test-client"
    async with websockets.connect(uri) as websocket:
        # Send generation request
        request = {
            "type": "generate",
            "id": "test-001",
            "data": {
                "messages": [{"role": "user", "content": "Hello, how are you?"}],
                "max_new_tokens": 50,
                "temperature": 0.7,
            },
        }

        await websocket.send(json.dumps(request))

        # Receive streaming response
        tokens = []
        async for message in websocket:
            data = json.loads(message)
            print(f"Received: {data}")

            if data["type"] == "token":
                tokens.append(data["data"]["token"])
            elif data["type"] == "done":
                print(f"✅ Generation complete: {''.join(tokens)}")
                break
            elif data["type"] == "error":
                print(f"❌ Error: {data['data']['error']}")
                break


# asyncio.run(test_websocket())  # Uncomment to test