# LLM Serving from the First Principles

You can code along the presenter or just run the cells
* To move to the next step run `jj next` in the terminal
* To move back to the previous step, run `jj prev`

## Step 7: using pre-invented wheels

The way we're formatting our prompts right now is suboptimal. Let's use built-in functionality to reduce fragility of the code.

In [None]:
import torch

DEVICE = torch.device('cuda:0')

torch.inference_mode()

In [None]:
from transformers import GemmaTokenizer, Gemma3ForCausalLM

checkpoint = 'google/gemma-3-1b-it'
TOKENIZER = GemmaTokenizer.from_pretrained(f'checkpoints/{checkpoint}')
MODEL = Gemma3ForCausalLM.from_pretrained(f'checkpoints/{checkpoint}', torch_dtype=torch.bfloat16, device_map=DEVICE).eval()
MODEL.generation_config.max_new_tokens = 256
PROMPT_FORMAT = "<start_of_turn>user\n{prompt}<end_of_turn>\n<start_of_turn>model\n"

In [None]:
MODEL.generation_config.cache_implementation = 'static'
MODEL.forward = torch.compile(MODEL.forward, fullgraph=True)

In [None]:
from fastapi import FastAPI
from pydantic import BaseModel

class Query(BaseModel):
    prompt: str

app = FastAPI()

In [None]:
import time

@app.post("/infer")
def inference(query: Query):
    since = time.time()
    token_ids = TOKENIZER.encode(PROMPT_FORMAT.format(prompt=query.prompt))
    token_ids_tensor = torch.tensor([token_ids], device=DEVICE)
    model_outputs = MODEL.generate(token_ids_tensor)
    decoded_text = TOKENIZER.decode(model_outputs[0])
    latency = time.time() - since
    return {"response": decoded_text, "metrics": {"latency": latency, "tokens_count": model_outputs.shape[1], "tokens_per_second": model_outputs.shape[1]/latency}}

In [None]:
import uvicorn

config = uvicorn.Config(app, host="0.0.0.0")
server = uvicorn.Server(config)
await server.serve()