# LLM Serving from the First Principles

You can code along the presenter or just run the cells
* To move to the next step run `jj next` in the terminal
* To move back to the previous step, run `jj prev`

# Step 8: self work

These are the tasks for Part 2:

- Play with it:
  - Try to feed the entire codebase to the model and see what it recommends to improve next
  - Send different requests manually `curl localhost:8000/triton/{MODEL_NAME} --json '{ ... }'` and watch what happens in Grafana
- Refactor:
  - Add a new model that does all the steps (tokenization, inference, decoding) and see how well it behaves
  - Add a `torch.compile` stanza to `gemma_torch` model and see how it impacts latency and GPU usage
- New features:
  - Remove the prompt prefix from the model output
    - easy mode: tensorrt parameter🙂
    - hard mode: add another output to the `preprocessing` model called `"num_tokens"`. Forward that output to the `postprocessing` model and do the indexing there
    - optionally, find a way to skip special tokens
  - Multi-turn conversations
    - you can make `preprocessing` model stateful by introducing a session key and then just appending conversations to a dictionary (keyed by the session key)
  
- ... and everything else you come up with!

In [None]:
import torch

DEVICE = torch.device('cuda:0')

torch.inference_mode()

In [None]:
from transformers import GemmaTokenizer, Gemma3ForCausalLM

checkpoint = 'google/gemma-3-1b-it'
TOKENIZER = GemmaTokenizer.from_pretrained(f'checkpoints/{checkpoint}')
MODEL = Gemma3ForCausalLM.from_pretrained(f'checkpoints/{checkpoint}', torch_dtype=torch.bfloat16, device_map=DEVICE).eval()
MODEL.generation_config.max_new_tokens = 256

In [None]:
MODEL.generation_config.cache_implementation = 'static'
MODEL.forward = torch.compile(MODEL.forward, fullgraph=True)

In [None]:
from fastapi import FastAPI
from pydantic import BaseModel

class Query(BaseModel):
    prompt: str

app = FastAPI()

In [None]:
import time

@app.post("/infer")
def inference(query: Query):
    since = time.time()
    tokenized = TOKENIZER.apply_chat_template(
        [{"role": "user", "content": query.prompt}],
        add_generation_prompt=True,
        return_dict=True,
        return_tensors='pt').to(DEVICE)
    model_outputs = MODEL.generate(**tokenized)
    decoded_text = TOKENIZER.decode(model_outputs[0])
    latency = time.time() - since
    return {"response": decoded_text, "metrics": {"latency": latency, "tokens_count": model_outputs.shape[1], "tokens_per_second": model_outputs.shape[1]/latency}}

In [None]:
import uvicorn

config = uvicorn.Config(app, host="0.0.0.0")
server = uvicorn.Server(config)
await server.serve()