In [None]:
!export VLLM_NO_USAGE_STATS=1
!export DO_NOT_TRACK=1
!mkdir -p ~/.config/vllm && touch ~/.config/vllm/do_not_track

!pip install -U vllm
!pip install pyngrok https://download.pytorch.org/whl/nightly/pytorch_triton-3.1.0%2B5fe38ffd73-cp310-cp310-linux_x86_64.whl # Needed for Kaggle
!ngrok authtoken YourToken # Needed for Kaggle

from fastapi import FastAPI, HTTPException
from vllm import LLM, SamplingParams
from vllm.sampling_params import GuidedDecodingParams
from pyngrok import ngrok
import uvicorn
import nest_asyncio
import sys
from typing import Optional, Dict, Any
from pydantic import BaseModel

class CompletionRequest(BaseModel):
    prompt: str
    output_schema: Optional[Dict[str, Any]] = None
    max_tokens: Optional[int] = 2048
    temperature: Optional[float] = 0.0
    top_p: Optional[float] = 0.9
    min_p: Optional[float] = 0.0
    top_k: Optional[int] = 0
    typical_p: Optional[float] = 1.0
    tfs: Optional[float] = 1.0
    top_a: Optional[float] = 0.0
    repetition_penalty: Optional[float] = 1.0
    no_repeat_ngram_size: Optional[int] = 0
    num_beams: Optional[int] = 1
    seed: Optional[int] = 0
    add_bos_token: Optional[bool] = True
    truncation_length: Optional[int] = 8192
    ban_eos_token: Optional[bool] = False
    skip_special_tokens: Optional[bool] = True

app = FastAPI()

llm = LLM(
    model="casperhansen/llama-3-8b-instruct-awq",
    dtype="half",
    trust_remote_code=True,
    quantization="AWQ",
    speculative_config={
        "method": "ngram",
        "num_speculative_tokens": 5,
        "prompt_lookup_max": 4,
    }
)

@app.post("/v1/completions")
async def generate_completion(request: CompletionRequest):
    try:
        # Create base sampling parameters
        sampling_params_kwargs = {
            "temperature": request.temperature,
            "max_tokens": request.max_tokens,
            "repetition_penalty": request.repetition_penalty,
            "seed": request.seed if request.seed != 0 else None
        }

        # Add guided decoding if schema is provided
        if request.output_schema:
            # Create guided decoding parameters
            guided_decoding_params = GuidedDecodingParams(json=request.output_schema, backend="lm-format-enforcer")
            sampling_params_kwargs["guided_decoding"] = guided_decoding_params

        print(f"Request: {request}")
        # Create SamplingParams with all parameters
        sampling_params = SamplingParams(**sampling_params_kwargs)

        # Generate output
        outputs = llm.generate([request.prompt], sampling_params)
        generated_text = outputs[0].outputs[0].text
        print(f"Generated text: {generated_text}")

        return {
            "choices": [
                {
                    "text": generated_text,
                    "index": 0,
                    "finish_reason": "length" if len(generated_text) >= request.max_tokens else "stop"
                }
            ]
        }
    except Exception as e:
        import traceback
        print(f"Error: {str(e)}")
        print(f"Traceback: {traceback.format_exc()}")
        raise HTTPException(status_code=500, detail=str(e))

# Setup ngrok
ngrok_tunnel = ngrok.connect(8000)
print('Public URL:', ngrok_tunnel.public_url, flush=True)
sys.stdout.flush()

# Enable nested asyncio for Jupyter notebooks
nest_asyncio.apply()

# Run the API with host set to 0.0.0.0 to accept external connections
uvicorn.run(app, host="0.0.0.0", port=8000)