# üöÄ NL2SQL Backend Server with BART Model

This notebook runs a FastAPI backend server that translates natural language to SQL using the BART model.

## üìã What This Does

1. ‚úÖ Installs all required dependencies
2. ‚úÖ Loads the BART NL2SQL model (SwastikM/bart-large-nl2sql)
3. ‚úÖ Starts a FastAPI server
4. ‚úÖ Exposes it via ngrok tunnel (publicly accessible)
5. ‚úÖ Provides API endpoint for your frontend to use

## ‚ö° Quick Start

1. **Runtime ‚Üí Change runtime type** ‚Üí Select **GPU (T4)**
2. **Run all cells** (Runtime ‚Üí Run all)
3. **Copy the ngrok URL** from the output
4. **Use it in your frontend** to connect to the backend

---

## üì¶ Step 1: Install Dependencies

In [None]:
!pip install -q fastapi uvicorn transformers torch pyngrok pydantic

## ü§ñ Step 2: Load BART Model

This will download the model (~1.6GB) - takes about 1-2 minutes on first run.

In [None]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch

print("üîÑ Loading BART NL2SQL model...")
model_name = "SwastikM/bart-large-nl2sql"

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

# Move to GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)

print(f"‚úÖ Model loaded successfully on {device.upper()}!")
print(f"   GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'N/A'}")

## üåê Step 3: Create FastAPI Server

In [None]:
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import List
import uvicorn
from pyngrok import ngrok
import nest_asyncio

# Allow nested event loops (required for Colab)
nest_asyncio.apply()

# Create FastAPI app
app = FastAPI(
    title="NL2SQL BART API",
    description="Natural Language to SQL translation using BART",
    version="1.0.0"
)

# Enable CORS for frontend
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # Allow all origins (adjust for production)
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Request/Response models
class TranslateRequest(BaseModel):
    natural_language: str
    database: str = "default"
    schema_context: str = ""

class SQLCandidate(BaseModel):
    sql: str
    confidence: float
    reasoning: str

class TranslateResponse(BaseModel):
    candidates: List[SQLCandidate]
    database: str

def translate_to_sql(natural_language: str, schema_context: str = "") -> List[SQLCandidate]:
    """Translate natural language to SQL using BART"""
    try:
        # Format prompt
        if schema_context:
            prompt = f"sql_prompt: {natural_language}\nsql_context: {schema_context}"
        else:
            prompt = f"sql_prompt: {natural_language}"
        
        # Tokenize
        inputs = tokenizer(
            prompt,
            return_tensors="pt",
            max_length=512,
            truncation=True
        ).input_ids.to(device)
        
        # Generate SQL with beam search
        with torch.no_grad():
            outputs = model.generate(
                inputs,
                max_new_tokens=200,
                num_beams=3,
                num_return_sequences=3,
                do_sample=False,
                early_stopping=True,
                return_dict_in_generate=True,
                output_scores=True,
            )
        
        # Decode results
        candidates = []
        sequences = outputs.sequences
        scores = outputs.sequences_scores if hasattr(outputs, 'sequences_scores') else None
        
        for idx, sequence in enumerate(sequences):
            sql = tokenizer.decode(sequence, skip_special_tokens=True)
            
            # Calculate confidence
            if scores is not None:
                confidence = float(torch.exp(scores[idx]).cpu())
            else:
                confidence = max(0.5, 1.0 - (idx * 0.15))
            
            candidates.append(
                SQLCandidate(
                    sql=sql.strip(),
                    confidence=round(confidence, 3),
                    reasoning=f"Generated by BART model (beam {idx + 1})"
                )
            )
        
        # Remove duplicates
        unique_candidates = []
        seen_sql = set()
        for candidate in candidates:
            if candidate.sql not in seen_sql:
                unique_candidates.append(candidate)
                seen_sql.add(candidate.sql)
        
        return unique_candidates
        
    except Exception as e:
        print(f"Error in translation: {e}")
        return []

# API Endpoints
@app.get("/")
async def root():
    return {
        "message": "NL2SQL BART API",
        "status": "running",
        "model": "SwastikM/bart-large-nl2sql",
        "device": device,
        "endpoints": {
            "translate": "/api/translate/",
            "health": "/api/health",
            "docs": "/docs"
        }
    }

@app.get("/api/health")
async def health():
    return {"status": "healthy", "model_loaded": True, "device": device}

@app.post("/api/translate/", response_model=TranslateResponse)
async def translate(request: TranslateRequest):
    """Translate natural language to SQL"""
    if not request.natural_language or not request.natural_language.strip():
        raise HTTPException(status_code=400, detail="Query cannot be empty")
    
    candidates = translate_to_sql(request.natural_language, request.schema_context)
    
    if not candidates:
        raise HTTPException(status_code=500, detail="Translation failed")
    
    return TranslateResponse(
        candidates=candidates,
        database=request.database
    )

print("‚úÖ FastAPI server created!")

## üöÄ Step 4: Start Server with ngrok Tunnel

**IMPORTANT:** Copy the ngrok URL from the output below!

In [None]:
import threading

# Set your ngrok auth token (get free token from https://ngrok.com)
# Optional but recommended to avoid session limits
NGROK_AUTH_TOKEN = ""  # Paste your token here (optional)

if NGROK_AUTH_TOKEN:
    ngrok.set_auth_token(NGROK_AUTH_TOKEN)

# Start server in background thread
def run_server():
    uvicorn.run(app, host="0.0.0.0", port=8000, log_level="info")

thread = threading.Thread(target=run_server, daemon=True)
thread.start()

# Create ngrok tunnel
import time
time.sleep(2)  # Wait for server to start

public_url = ngrok.connect(8000)

print("\n" + "="*70)
print("üéâ BACKEND SERVER IS RUNNING!")
print("="*70)
print(f"\nüì° Public URL: {public_url}")
print(f"\nüîó API Endpoints:")
print(f"   - Translation: {public_url}/api/translate/")
print(f"   - Health: {public_url}/api/health")
print(f"   - Docs: {public_url}/docs")
print(f"\nüí° Usage in Frontend:")
print(f"   Update your frontend to use: {public_url}")
print(f"\n‚ö†Ô∏è  Keep this notebook running to keep the server alive!")
print("="*70 + "\n")

# Keep the cell running
import time
try:
    while True:
        time.sleep(60)
except KeyboardInterrupt:
    print("\nüõë Server stopped")

## üß™ Step 5: Test the API (Optional)

Run this cell to test if the translation works:

In [None]:
import requests
import json

# Get the public URL from ngrok
tunnels = ngrok.get_tunnels()
if tunnels:
    api_url = str(tunnels[0].public_url)
    
    # Test translation
    test_query = "Show all students with marks above 80"
    
    print(f"üß™ Testing translation: '{test_query}'\n")
    
    response = requests.post(
        f"{api_url}/api/translate/",
        json={
            "natural_language": test_query,
            "database": "students_db",
            "schema_context": "CREATE TABLE students (id INT, name VARCHAR(100), marks INT);"
        }
    )
    
    if response.status_code == 200:
        data = response.json()
        print("‚úÖ Translation successful!\n")
        print("SQL Candidates:")
        for i, candidate in enumerate(data['candidates'], 1):
            print(f"\n{i}. {candidate['sql']}")
            print(f"   Confidence: {candidate['confidence']:.2%}")
    else:
        print(f"‚ùå Error: {response.status_code}")
        print(response.text)
else:
    print("‚ùå No tunnels found. Make sure Step 4 is running!")