Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

add llama3 example #51

Closed
wants to merge 6 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
202 changes: 202 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,208 @@ Clients are expected to auth with the same API key set in the `X-API-Key` HTTP h
</details>
&nbsp;


# Examples

Learn and use the following examples to quickly get started to serve the model of your choice.

<details>
<summary><b>Serve Llama 3</b></summary>

&nbsp;

You can serve Llama 3 and stream chat response to client. This example is based on LitGPT which can be installed
from [here](https://github.com/Lightning-AI/litgpt?tab=readme-ov-file#install-litgpt).

```python
from typing import Generator, List
import json
from pathlib import Path
from typing import Any, Optional
from litgpt.utils import check_valid_checkpoint_dir

import lightning as L
import torch
from litserve import LitAPI, LitServer

from litgpt.model import GPT
from litgpt.config import Config
from litgpt.tokenizer import Tokenizer
from litgpt.generate.base import generate, next_token
from litgpt.prompts import load_prompt_style, has_prompt_style, PromptStyle
from litgpt.utils import load_checkpoint, CLI, get_default_supported_precision
from pydantic import BaseModel


class PromptRequest(BaseModel):
prompt: str
max_new_tokens: int = 50
temperature: float = 0.8
top_k: int = 50


class LlamaAPI(LitAPI):
def __init__(
self,
checkpoint_dir: Path,
precision: Optional[str] = None,
) -> None:
super().__init__()
self.checkpoint_dir = checkpoint_dir
self.precision = precision

def setup(self, device: str) -> None:
# Set up the model, so it can be called in `predict`.
config = Config.from_file(self.checkpoint_dir / "model_config.yaml")
device = torch.device(device)
torch.set_float32_matmul_precision("high")

precision = self.precision or get_default_supported_precision(training=False)

fabric = L.Fabric(
accelerator=device.type,
devices=1
if device.type == "cpu"
else [device.index], # TODO: Update once LitServe supports "auto"
precision=precision,
)
checkpoint_path = self.checkpoint_dir / "lit_model.pth"
self.tokenizer = Tokenizer(self.checkpoint_dir)
self.prompt_style = (
load_prompt_style(self.checkpoint_dir)
if has_prompt_style(self.checkpoint_dir)
else PromptStyle.from_config(config)
)
with fabric.init_module(empty_init=True):
model = GPT(config)
with fabric.init_tensor():
# enable the kv cache
model.set_kv_cache(batch_size=1)
model.eval()

self.model = fabric.setup_module(model)
load_checkpoint(fabric, self.model, checkpoint_path)
self.device = fabric.device

def decode_request(self, request: PromptRequest) -> Any:
# Convert the request payload to your model input.
prompt = request.prompt
prompt = self.prompt_style.apply(prompt)
encoded = self.tokenizer.encode(prompt, device=self.device)
return encoded, request

@torch.inference_mode()
def generate_iter(
self,
model: GPT,
prompt: torch.Tensor,
max_returned_tokens: int,
*,
temperature: float = 1.0,
top_k: Optional[int] = None,
eos_id: Optional[int] = None,
) -> torch.Tensor:
T = prompt.size(0)
assert max_returned_tokens > T
if model.max_seq_length < max_returned_tokens - 1:
raise NotImplementedError(
f"max_seq_length {model.max_seq_length} needs to be >= {max_returned_tokens - 1}"
)

device = prompt.device
tokens = [prompt]
input_pos = torch.tensor([T], device=device)
token = next_token(
model,
torch.arange(0, T, device=device),
prompt.view(1, -1),
temperature=temperature,
top_k=top_k,
).clone()
tokens.append(token)
for _ in range(2, max_returned_tokens - T + 1):
token = next_token(
model,
input_pos,
token.view(1, -1),
temperature=temperature,
top_k=top_k,
).clone()
if token == eos_id:
break
input_pos = input_pos.add_(1)
yield token

def predict(self, x: List) -> Generator:
# Run the model on the input and return the output.
inputs, request = x
prompt_length = inputs.size(0)
max_returned_tokens = prompt_length + request.max_new_tokens

y_iter = self.generate_iter(
self.model,
inputs,
max_returned_tokens,
temperature=request.temperature,
top_k=request.top_k,
eos_id=self.tokenizer.eos_id,
)
for token in y_iter:
yield token

for block in self.model.transformer.h:
block.attn.kv_cache.reset_parameters()

def encode_response(self, outputs: Generator) -> Generator:
# Convert the model output to a response payload.
for output in outputs:
decoded_output = self.tokenizer.decode(output)
yield json.dumps({"output": decoded_output})


if __name__ == "__main__":
# 1. Download Llama 3:
# litgpt download --repo_id meta-llama/Meta-Llama-3-8B-Instruct

# 2. Run server
checkpoint_dir: Path = Path("checkpoints/meta-llama/Meta-Llama-3-8B-Instruct")
check_valid_checkpoint_dir(checkpoint_dir, model_filename="lit_model.pth")

server = LitServer(
LlamaAPI(checkpoint_dir=checkpoint_dir),
accelerator="cuda",
devices=1,
stream=True,
)

server.run(port=8000)
```

You can stream response with a Python client as follows:

```python
import requests
import json

url = "http://127.0.0.1:8000/stream-predict"
prompt = "Write a Python code to sort a linkedlist in reverse order."
resp = requests.post(
url,
json={
"prompt": prompt,
"max_new_tokens": 200,
},
stream=True,
)
for chunk in resp.iter_content(chunk_size=4000):
if chunk:
msg = json.loads(chunk.decode("utf-8"))["output"]
print(msg, end="")
```

</details>


## License

litserve is released under the [Apache 2.0](https://www.apache.org/licenses/LICENSE-2.0) license.
Expand Down