# LLM Serving from the First Principles

You can code along the presenter or just run the cells
* To move to the next step run `jj next` in the terminal
* To move back to the previous step, run `jj prev`

# Step 8: self work

We're done with a minimal LLM server! Now it's your turn:

- Play with it:
  - Try various prompts and see how it influences model output length _and quality_
    - https://www.promptingguide.ai/
  - Try to feed the entire codebase to the model and see what it recommends to improve next
- Refactor:
  - Rewrite this whole thing using a [`transformers.pipeline`](https://huggingface.co/docs/transformers/en/main_classes/pipelines#transformers.TextGenerationPipeline)
  - Use [FastAPI lifespans](https://fastapi.tiangolo.com/advanced/events/#lifespan-events) to load the model
    - You could also run some warmup requests, too
  - Dockerize this entire app
    - Don't forget to `COPY` the model checkpoints into the container image
- New features:
  - Remove the prompt prefix from the model output
    - easy mode: python string indexing: `decoded_output[index:]`
    - hard mode: tensor indexing: `model_output[index:]`
    - optionally, find a way to skip special tokens
  - Multi-turn conversations
    - create new session keys, and store the user prompts and model outputs into a list under the session key
  - more metrics
    - The metrics were intentionally kept simple. You can also measure tokenization time, decoding time, proportion of all the stages to the total time, number of input tokens, number of tokens that were generated, generated tokens per second, and a myriad of other statistics
    - bonus points for figuring out how to read gpu utilization metrics
- ... and everything else you come up with!

In [None]:
import torch

DEVICE = torch.device('cuda:0')

torch.inference_mode()

In [None]:
from transformers import GemmaTokenizer, Gemma3ForCausalLM

checkpoint = 'google/gemma-3-1b-it'
TOKENIZER = GemmaTokenizer.from_pretrained(f'checkpoints/{checkpoint}')
MODEL = Gemma3ForCausalLM.from_pretrained(f'checkpoints/{checkpoint}', torch_dtype=torch.bfloat16, device_map=DEVICE).eval()
MODEL.generation_config.max_new_tokens = 256

In [None]:
MODEL.generation_config.cache_implementation = 'static'
MODEL.forward = torch.compile(MODEL.forward, fullgraph=True)

In [None]:
from fastapi import FastAPI
from pydantic import BaseModel

class Query(BaseModel):
    prompt: str

app = FastAPI()

In [None]:
import time

@app.post("/infer")
def inference(query: Query):
    since = time.time()
    tokenized = TOKENIZER.apply_chat_template(
        [{"role": "user", "content": query.prompt}],
        add_generation_prompt=True,
        return_dict=True,
        return_tensors='pt').to(DEVICE)
    model_outputs = MODEL.generate(**tokenized)
    decoded_text = TOKENIZER.decode(model_outputs[0])
    latency = time.time() - since
    return {"response": decoded_text, "metrics": {"latency": latency, "tokens_count": model_outputs.shape[1], "tokens_per_second": model_outputs.shape[1]/latency}}

In [None]:
import uvicorn

config = uvicorn.Config(app, host="0.0.0.0")
server = uvicorn.Server(config)
await server.serve()