# EpiHelix AI Services (Kaggle GPU)

**GPU-accelerated AI microservices for EpiHelix backend**

## Setup Requirements

1. ‚úÖ **Enable GPU** in Kaggle notebook settings (P100 recommended)
2. ‚úÖ **Enable Internet** in settings
3. ‚úÖ **Get ngrok auth token** from https://dashboard.ngrok.com/get-started/your-authtoken
4. ‚úÖ **Add token to Kaggle Secrets**:
   - Go to notebook settings ‚Üí Add-ons ‚Üí Secrets
   - Click "Add a new secret"
   - Name: `NGROK_AUTH_TOKEN`
   - Value: Your token from ngrok dashboard
   
Optional: Add `BACKEND_API_KEY` secret for authentication

In [None]:
# Install dependencies (takes ~2 min first time)
!pip install -q fastapi uvicorn pyngrok sentence-transformers transformers torch accelerate

In [None]:
# Set ngrok auth token from Kaggle Secrets
from kaggle_secrets import UserSecretsClient

try:
    ngrok_token = UserSecretsClient().get_secret("NGROK_AUTH_TOKEN")
    
    # Configure ngrok
    from pyngrok import ngrok, conf
    conf.get_default().auth_token = ngrok_token
    
    print("‚úÖ ngrok authentication configured successfully")
except Exception as e:
    print(f"‚ö†Ô∏è  Warning: Could not get NGROK_AUTH_TOKEN from secrets: {e}")
    print("   Please add your ngrok token to Kaggle Secrets:")
    print("   1. Go to https://dashboard.ngrok.com/get-started/your-authtoken")
    print("   2. Copy your token")
    print("   3. Add to Kaggle: Settings ‚Üí Add-ons ‚Üí Secrets ‚Üí Add 'NGROK_AUTH_TOKEN'")
    print("\n   Without this, ngrok tunnels may be rate-limited or fail.")

## 0. Setup ngrok Authentication

Get your ngrok auth token and add it to Kaggle Secrets as `NGROK_AUTH_TOKEN`

## 1. Load All Models on GPU

Load all 3 models into GPU memory once at startup:
- **Reranker**: cross-encoder/ms-marco-MiniLM-L-6-v2
- **Embedder**: sentence-transformers/all-MiniLM-L6-v2
- **LLM (Shared)**: Qwen/Qwen2.5-3B-Instruct (used for both summarization AND chat)

In [None]:
import torch
from sentence_transformers import SentenceTransformer, CrossEncoder
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import gc

print("üî• CUDA available:", torch.cuda.is_available())
print("üìä GPU:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU")

# Force GPU device
device = "cuda" if torch.cuda.is_available() else "cpu"

print("\n" + "="*60)
print("üîÑ Loading Reranker Model (ms-marco-MiniLM-L-6-v2)...")
print("="*60)
reranker_model = CrossEncoder(
    'cross-encoder/ms-marco-MiniLM-L-6-v2',
    max_length=512,
    device=device
)
print("‚úÖ Reranker loaded")

print("\n" + "="*60)
print("üîç Loading Embedding Model (all-MiniLM-L6-v2)...")
print("="*60)
embedding_model = SentenceTransformer(
    'sentence-transformers/all-MiniLM-L6-v2',
    device=device
)
print("‚úÖ Embedder loaded")

print("\n" + "="*60)
print("üí¨ Loading LLM (Qwen2.5-3B-Instruct - Shared for Summarization & Chat)...")
print("="*60)
llm_tokenizer = AutoTokenizer.from_pretrained(
    "Qwen/Qwen2.5-3B-Instruct",
    trust_remote_code=True
)
llm_model = AutoModelForCausalLM.from_pretrained(
    "Qwen/Qwen2.5-3B-Instruct",
    torch_dtype=torch.float16,  # Use FP16 to save memory
    device_map="auto",
    trust_remote_code=True
)
llm_pipeline = pipeline(
    "text-generation",
    model=llm_model,
    tokenizer=llm_tokenizer,
    max_new_tokens=512,
    do_sample=True,
    temperature=0.7,
    top_p=0.9,
    device=device
)
print("‚úÖ LLM loaded (serves both /summarize and /chat endpoints)")

print("\n" + "="*60)
print("üéâ All models loaded successfully!")
print("="*60)

# Check GPU memory
if torch.cuda.is_available():
    print(f"\nüìä GPU Memory: {torch.cuda.memory_allocated(0) / 1e9:.2f}GB / {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f}GB")

## 2. FastAPI Service with 4 Endpoints

Create REST API endpoints for each AI service:
- `/rerank` - Cross-encoder reranking
- `/embed` - Sentence embeddings
- `/summarize` - Text summarization (Qwen Instruct)
- `/chat` - Chatbot conversations (Same Qwen Instruct model)

In [None]:
from fastapi import FastAPI, HTTPException, Header
from pydantic import BaseModel, Field
from typing import Optional, List
import asyncio
from functools import partial

app = FastAPI(
    title="EpiHelix AI Services",
    description="GPU-accelerated reranking, embedding, and LLM services",
    version="1.0.0"
)

# Optional API key authentication (if you set BACKEND_API_KEY secret in Kaggle)
API_KEY = None  # Set to your secret if needed
# from kaggle_secrets import UserSecretsClient
# API_KEY = UserSecretsClient().get_secret("BACKEND_API_KEY")

def verify_api_key(authorization: Optional[str] = Header(None)):
    """Verify API key if configured."""
    if API_KEY and authorization != f"Bearer {API_KEY}":
        raise HTTPException(status_code=401, detail="Invalid API key")


# ===== REQUEST/RESPONSE MODELS =====

class RerankRequest(BaseModel):
    query: str = Field(..., description="Search query")
    documents: List[str] = Field(..., description="List of documents to rerank")
    top_k: int = Field(20, description="Number of top results to return")

class RerankResult(BaseModel):
    index: int
    score: float

class RerankResponse(BaseModel):
    results: List[RerankResult]
    model: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"

class EmbedRequest(BaseModel):
    texts: List[str] = Field(..., description="Texts to embed")
    normalize: bool = Field(True, description="Normalize embeddings to unit length")

class EmbedResponse(BaseModel):
    embeddings: List[List[float]]
    dimension: int
    model: str = "sentence-transformers/all-MiniLM-L6-v2"

class SummarizeRequest(BaseModel):
    text: str = Field(..., description="Text to summarize")
    max_length: int = Field(150, description="Maximum summary length in tokens")
    temperature: float = Field(0.7, ge=0.0, le=2.0)

class SummarizeResponse(BaseModel):
    summary: str
    model: str = "Qwen/Qwen2.5-3B-Instruct"

class ChatRequest(BaseModel):
    messages: List[dict] = Field(..., description="Chat messages (role + content)")
    temperature: float = Field(0.7, ge=0.0, le=2.0)
    max_tokens: int = Field(512, ge=1, le=2048)

class ChatResponse(BaseModel):
    response: str
    model: str = "Qwen/Qwen2.5-3B-Instruct"


# ===== ENDPOINTS =====

@app.get("/")
async def root():
    """Health check endpoint."""
    return {
        "status": "healthy",
        "services": ["rerank", "embed", "summarize", "chat"],
        "gpu": torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU",
        "models": {
            "reranker": "cross-encoder/ms-marco-MiniLM-L-6-v2",
            "embedder": "sentence-transformers/all-MiniLM-L6-v2",
            "llm": "Qwen/Qwen2.5-3B-Instruct (shared for summarize & chat)"
        }
    }

@app.get("/health")
async def health():
    """Detailed health check."""
    return {
        "status": "healthy",
        "gpu_available": torch.cuda.is_available(),
        "gpu_memory_used_gb": torch.cuda.memory_allocated(0) / 1e9 if torch.cuda.is_available() else 0,
        "models_loaded": True
    }

@app.post("/rerank", response_model=RerankResponse)
async def rerank(request: RerankRequest):
    """
    Rerank documents using cross-encoder model.
    
    Returns documents sorted by relevance score (highest first).
    """
    try:
        # Create query-document pairs
        pairs = [[request.query, doc] for doc in request.documents]
        
        # Run reranking in thread pool (blocking operation)
        loop = asyncio.get_event_loop()
        scores = await loop.run_in_executor(
            None,
            partial(reranker_model.predict, pairs)
        )
        
        # Sort by score descending
        ranked_indices = sorted(
            enumerate(scores),
            key=lambda x: x[1],
            reverse=True
        )[:request.top_k]
        
        results = [
            RerankResult(index=idx, score=float(score))
            for idx, score in ranked_indices
        ]
        
        return RerankResponse(results=results)
    
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Reranking failed: {str(e)}")

@app.post("/embed", response_model=EmbedResponse)
async def embed(request: EmbedRequest):
    """
    Generate embeddings for texts.
    
    Returns 384-dimensional vectors.
    """
    try:
        # Run embedding in thread pool
        loop = asyncio.get_event_loop()
        embeddings = await loop.run_in_executor(
            None,
            partial(
                embedding_model.encode,
                request.texts,
                normalize_embeddings=request.normalize,
                convert_to_numpy=True
            )
        )
        
        return EmbedResponse(
            embeddings=embeddings.tolist(),
            dimension=embeddings.shape[1]
        )
    
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Embedding failed: {str(e)}")

@app.post("/summarize", response_model=SummarizeResponse)
async def summarize(request: SummarizeRequest):
    """
    Summarize text using Qwen Instruct model.
    
    Optimized for concise summaries.
    """
    try:
        # Create summarization prompt
        prompt = f"""Summarize the following text concisely in {request.max_length} tokens or less:

{request.text}

Summary:"""
        
        # Generate summary in thread pool (using shared LLM)
        loop = asyncio.get_event_loop()
        result = await loop.run_in_executor(
            None,
            partial(
                llm_pipeline,
                prompt,
                max_new_tokens=request.max_length,
                temperature=request.temperature,
                return_full_text=False
            )
        )
        
        summary_text = result[0]["generated_text"].strip()
        
        return SummarizeResponse(summary=summary_text)
    
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Summarization failed: {str(e)}")

@app.post("/chat", response_model=ChatResponse)
async def chat(request: ChatRequest):
    """
    Generate chat response using Qwen Instruct model (same as summarizer).
    
    Expects messages in format: [{"role": "user", "content": "..."}]
    """
    try:
        # Format messages into prompt (Qwen chat template)
        prompt = ""
        for msg in request.messages:
            role = msg.get("role", "user")
            content = msg.get("content", "")
            if role == "system":
                prompt += f"<|im_start|>system\n{content}<|im_end|>\n"
            elif role == "user":
                prompt += f"<|im_start|>user\n{content}<|im_end|>\n"
            elif role == "assistant":
                prompt += f"<|im_start|>assistant\n{content}<|im_end|>\n"
        
        prompt += "<|im_start|>assistant\n"
        
        # Generate response in thread pool (using shared LLM)
        loop = asyncio.get_event_loop()
        result = await loop.run_in_executor(
            None,
            partial(
                llm_pipeline,
                prompt,
                max_new_tokens=request.max_tokens,
                temperature=request.temperature,
                return_full_text=False
            )
        )
        
        response_text = result[0]["generated_text"].strip()
        
        return ChatResponse(response=response_text)
    
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Chat generation failed: {str(e)}")


print("‚úÖ FastAPI app created with 4 endpoints:")
print("   POST /rerank    - Cross-encoder reranking")
print("   POST /embed     - Sentence embeddings")
print("   POST /summarize - Text summarization (Qwen Instruct)")
print("   POST /chat      - Chatbot (Qwen Instruct)")

## 3. Start Server with ngrok Tunnel

Expose the API publicly via ngrok.

In [None]:
from pyngrok import ngrok
import uvicorn
from threading import Thread
import time

# Start ngrok tunnel
print("\n" + "="*60)
print("üåê Starting ngrok tunnel...")
print("="*60)

tunnel = ngrok.connect(8000)
public_url = tunnel.public_url  # Extract the URL string from NgrokTunnel object

print(f"\n‚úÖ PUBLIC URL: {public_url}")
print(f"üìù Add this to your backend .env file:")
print(f"\n   KAGGLE_AI_ENDPOINT={public_url}\n")
print("="*60)

# Start FastAPI server in background thread
def run_server():
    uvicorn.run(app, host="0.0.0.0", port=8000, log_level="info")

server_thread = Thread(target=run_server, daemon=True)
server_thread.start()

print("\n‚è≥ Waiting for server to start...")
time.sleep(5)  # Increased wait time for server startup

print("\nüéâ Server is running!")
print(f"   Health check: {public_url}/health")
print(f"   API docs: {public_url}/docs")
print(f"\nüí° Test the server by running the next cell")
print("‚ö†Ô∏è  Keep this notebook running to maintain the service")
print("   (Kaggle sessions last ~12 hours, then you need to restart)")

## 4. Test Endpoints (Optional)

Quick tests to verify all services work.

In [None]:
import requests
import json
import time

# Wait a bit more to ensure server is ready
print("‚è≥ Waiting for server to be fully ready...")
time.sleep(3)

print("="*60)
print("üß™ Testing all endpoints...")
print("="*60)

# Test 1: Health check
print("\n1Ô∏è‚É£ Health Check:")
try:
    response = requests.get(f"{public_url}/health", timeout=10)
    print(json.dumps(response.json(), indent=2))
except Exception as e:
    print(f"‚ùå Error: {e}")
    print("üí° Make sure the previous cell (server startup) has finished running!")

# Test 2: Reranking
print("\n2Ô∏è‚É£ Reranking Test:")
try:
    rerank_data = {
        "query": "COVID-19 cases in USA",
        "documents": [
            "Total COVID cases in United States reached 100 million",
            "France reported new influenza outbreak",
            "USA vaccination rates increased to 70%"
        ],
        "top_k": 3
    }
    response = requests.post(f"{public_url}/rerank", json=rerank_data, timeout=30)
    print(json.dumps(response.json(), indent=2))
except Exception as e:
    print(f"‚ùå Error: {e}")

# Test 3: Embeddings
print("\n3Ô∏è‚É£ Embedding Test:")
try:
    embed_data = {
        "texts": ["COVID-19 pandemic", "Influenza outbreak"],
        "normalize": True
    }
    response = requests.post(f"{public_url}/embed", json=embed_data, timeout=30)
    result = response.json()
    print(f"Generated {len(result['embeddings'])} embeddings of dimension {result['dimension']}")
    print(f"First embedding (truncated): {result['embeddings'][0][:5]}...")
except Exception as e:
    print(f"‚ùå Error: {e}")

# Test 4: Summarization
print("\n4Ô∏è‚É£ Summarization Test:")
try:
    summarize_data = {
        "text": "The COVID-19 pandemic has affected millions worldwide. Countries implemented lockdowns, mask mandates, and vaccination programs. The virus spread rapidly through communities, overwhelming healthcare systems. Scientists developed multiple vaccines in record time. Global cooperation was essential in fighting the pandemic.",
        "max_length": 50,
        "temperature": 0.7
    }
    response = requests.post(f"{public_url}/summarize", json=summarize_data, timeout=60)
    print(json.dumps(response.json(), indent=2))
except Exception as e:
    print(f"‚ùå Error: {e}")

# Test 5: Chat
print("\n5Ô∏è‚É£ Chat Test:")
try:
    chat_data = {
        "messages": [
            {"role": "user", "content": "What is a pandemic?"}
        ],
        "temperature": 0.7,
        "max_tokens": 100
    }
    response = requests.post(f"{public_url}/chat", json=chat_data, timeout=60)
    print(json.dumps(response.json(), indent=2))
except Exception as e:
    print(f"‚ùå Error: {e}")

print("\n‚úÖ All tests completed!")

## 5. Keep Alive (Run Forever)

Keep the server running. The notebook will stay active as long as this cell runs.

In [None]:
print("üîÑ Server is running. Press interrupt (‚ñ†) to stop.")
print(f"   Public URL: {public_url}")
print(f"   API Docs: {public_url}/docs")
print(f"   Health: {public_url}/health")

# Keep alive loop
try:
    while True:
        time.sleep(60)
        # Optional: print heartbeat every minute
        if torch.cuda.is_available():
            memory_used = torch.cuda.memory_allocated(0) / 1e9
            print(f"üíì Heartbeat - GPU Memory: {memory_used:.2f}GB")
except KeyboardInterrupt:
    print("\nüõë Server stopped")
    ngrok.disconnect(public_url)

## üìö Usage Instructions

### Backend Integration

Add to your `backend/.env`:
```bash
# Kaggle AI Services
KAGGLE_AI_ENDPOINT=https://xxxx-xx-xxx-xxx-xx.ngrok-free.app

# Enable AI features
RERANKER_PROVIDER=kaggle
EMBEDDER_PROVIDER=kaggle
LLM_PROVIDER=kaggle
```

### Python Client Example

```python
import httpx

async def rerank_with_kaggle(query: str, documents: list[str]):
    async with httpx.AsyncClient() as client:
        response = await client.post(
            f"{KAGGLE_ENDPOINT}/rerank",
            json={"query": query, "documents": documents, "top_k": 20}
        )
        return response.json()
```

### Maintenance

- **Session Duration**: ~12 hours max
- **Weekly Quota**: 30 GPU hours
- **Restart**: Just click "Run All" again to get new ngrok URL
- **Monitoring**: Check `/health` endpoint for status