## Prepare Data
In this example, we use the Daring-Anteater dataset. For improved accuracy, please refer to the [Data Synthesis Section](https://github.com/NVIDIA/TensorRT-Model-Optimizer/tree/main/examples/speculative_decoding#optional-data-synthesis) in the README.

In [None]:
!git clone https://huggingface.co/datasets/nvidia/Daring-Anteater /tmp/Daring-Anteater

## Convert Model for Speculative Decoding
Here, we'll adapt our base model for speculative decoding by attaching a smaller EAGLE Head. The upcoming code first loads meta-llama/Llama-3.2-1B as our base model and then configures the new draft head. To ensure compatibility, the draft head's dimensions must match the target model. Finally, the modelopt toolkit attaches this new, untrained head, leaving us with a combined model that is ready for the training phase later.

In [None]:
import transformers

import modelopt.torch.opt as mto
import modelopt.torch.speculative as mtsp
from modelopt.torch.speculative.config import EAGLE3_DEFAULT_CFG

mto.enable_huggingface_checkpointing()

# Load original HF model
base_model = "meta-llama/Llama-3.2-1B"
model = transformers.AutoModelForCausalLM.from_pretrained(
    base_model, torch_dtype="auto", device_map="cuda"
)

# Read Default Config for EAGLE3
config = EAGLE3_DEFAULT_CFG["config"]

# Hidden size and vocab size must match base model
config["eagle_architecture_config"].update(
    {
        "hidden_size": model.config.hidden_size,
        "vocab_size": model.config.vocab_size,
        "draft_vocab_size": model.config.vocab_size,
        "max_position_embeddings": model.config.max_position_embeddings,
    }
)

# Convert Model for eagle speculative decoding
mtsp.convert(model, [("eagle", config)])

# Prepare Tokenizer
tokenizer = transformers.AutoTokenizer.from_pretrained(base_model, model_max_length=1024)
tokenizer.pad_token_id = tokenizer.eos_token_id
if tokenizer.chat_template is None:
    tokenizer.chat_template = (
        "{%- for message in messages %}"
        "{{- '<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n' }}"
        "{%- endfor %}"
    )

## Train Draft Head On Daring-Anteater
We will fine-tune the draft head on the Daring-Anteater dataset using the standard Hugging Face Trainer. Note that only the draft model's weights are updated during this process; the original target model remains unchanged. After training, our speculative decoding model will be ready for export and deployment. Note that the time to train will be significantly dependent on the epochs (default=4) and the hardware being used.

In [None]:
import json
from dataclasses import dataclass, field

from eagle_utils import DataCollatorWithPadding, LazySupervisedDataset
from transformers import Trainer

with open("/tmp/Daring-Anteater/train.jsonl") as f:
    data_json = [json.loads(line) for line in f]
train_dataset = LazySupervisedDataset(data_json[: int(len(data_json) * 0.95)], tokenizer=tokenizer)
eval_dataset = LazySupervisedDataset(data_json[int(len(data_json) * 0.95) :], tokenizer=tokenizer)


@dataclass
class TrainingArguments(transformers.TrainingArguments):
    dataloader_drop_last: bool = field(default=True)
    bf16: bool = field(default=True)


training_args = TrainingArguments(
    output_dir="/tmp/eagle_bf16",
    num_train_epochs=4,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
)
trainer = Trainer(
    model=model,
    tokenizer=tokenizer,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=DataCollatorWithPadding(),
)
trainer._move_model_to_device(model, trainer.args.device)

# Make sure label_smoother is None
assert trainer.label_smoother is None, "label_smoother is not supported in speculative decoding!"

trainer.train()
trainer.save_state()
trainer.save_model(training_args.output_dir)
tokenizer.save_pretrained(training_args.output_dir)

# Export Model Checkpoint
To deploy this model, we need to first export it to a Unified checkpoint.

In [None]:
from modelopt.torch.export import export_hf_checkpoint

model.eval()
export_hf_checkpoint(
    model,
    export_dir="/tmp/hf_ckpt",
)

## Deploying on TensorRT-LLM

Here we show an example to deploy on TRT-LLM with `trtllm-serve` and [TRT-LLM container](https://catalog.ngc.nvidia.com/orgs/nvidia/teams/tensorrt-llm/containers/release). See [Deployment](https://github.com/NVIDIA/TensorRT-Model-Optimizer/tree/main/examples/speculative_decoding#deployment) section for more info.  

First, we dump the `trtllm-serve` command and config file we need to `/tmp` folder.

In [None]:
trtllm_serve_script = f"""trtllm-serve {base_model} \\
    --host 0.0.0.0 \\
    --port 8000 \\
    --backend pytorch \\
    --max_batch_size 32 \\
    --max_num_tokens 8192 \\
    --max_seq_len 8192 \\
    --extra_llm_api_options /tmp/extra-llm-api-config.yml
"""

extra_llm_api_config = """enable_attention_dp: false
disable_overlap_scheduler: true
enable_autotuner: false

cuda_graph_config:
    max_batch_size: 1

speculative_config:
    decoding_type: Eagle
    max_draft_len: 3
    speculative_model_dir: /tmp/hf_ckpt

kv_cache_config:
    enable_block_reuse: false
"""

# Dump the two scripts into /tmp
with open("/tmp/trtllm_serve.sh", "w") as f:
    f.write(trtllm_serve_script)

with open("/tmp/extra-llm-api-config.yml", "w") as f:
    f.write(extra_llm_api_config)

Next, we start a TRT-LLM container in the background and run `trtllm-serve` inside it, using our exported checkpoint and the configuration scripts we just created:

In [None]:
import subprocess
import threading

# Generate a unique container name so we can stop/remove it later
container_name = "trtllm_serve_spec"

docker_cmd = [
    "docker",
    "run",
    "--rm",
    "--net",
    "host",
    "--shm-size=2g",
    "--ulimit",
    "memlock=-1",
    "--ulimit",
    "stack=67108864",
    "--gpus",
    "all",
    "-v",
    "/tmp:/tmp",
    "--name",
    container_name,
    "nvcr.io/nvidia/tensorrt-llm/release:1.1.0rc2",
    "bash",
    "-c",
    "bash /tmp/trtllm_serve.sh",
]

# print docker outputs
proc = subprocess.Popen(
    docker_cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, bufsize=1
)

def stream_output(pipe):
    for line in iter(pipe.readline, ""):
        print(line, end="")

# Use thread to print outputs
thread = threading.Thread(target=stream_output, args=(proc.stdout,))
thread.daemon = True
thread.start()

print(
    f"Starting trtllm-serve in Docker (PID: {proc.pid}, container name: {container_name}) in the background:"
)

Please wait for the service to fully start inside the container.   
Once you see the message `INFO:     Application startup complete.`, you can proceed to send requests to the service:

In [None]:
import json
import requests

payload = {
    "model": base_model,
    "messages": [
        {"role": "system", "content": "You are a helpful assistant."},
        {"role": "user", "content": "Tell me about speculative decoding."},
    ],
    "max_tokens": 512,
    "temperature": 0,
    "chat_template": tokenizer.chat_template,
}
headers = {"Content-Type": "application/json", "Accept": "application/json"}

response = requests.post(
    "http://localhost:8000/v1/chat/completions", headers=headers, data=json.dumps(payload)
)
output = response.json()

print(output)

Finally, we clean up the container we created.

In [None]:
!docker rm -f trtllm_serve_spec

## Deploying on SGLang
Here, we deploy our trained model using SGLang. The following code defines the command needed to run the SGLang server with our specific configuration for speculative decoding.

In [None]:
#SGLang server launch command shell script
sglang_serve_script = f"""python3 -m sglang.launch_server \\
    --model {base_model} \\
    --host 0.0.0.0 \\
    --port 30000 \\
    --speculative-algorithm EAGLE3 \\
    --speculative-draft-model-path /tmp/hf_ckpt \\
    --speculative-num-draft-tokens 3 \\
    --dtype float16
"""

with open("/tmp/sglang_serve.sh", "w") as f:
    f.write(sglang_serve_script)

Launch the SGLang server inside a Docker container as a background process.

In [None]:
import subprocess
import threading
import os

container_name = "sglang_serve_spec"
home_dir = os.path.expanduser("~")
hf_cache_dir = os.path.join(home_dir, ".cache", "huggingface")

#Ensure the Hugging Face cache directory exists. This directory should exist as ~/.cache/huggingface, when the model files for meta-llama/Llama-3.2-1B were downloaded earlier.
os.makedirs(hf_cache_dir, exist_ok=True)

docker_cmd = [
    "docker", "run",
    "--rm",
    "--net", "host",
    "--shm-size=32g",
    "--gpus", "all",
    "-v", f"{hf_cache_dir}:/root/.cache/huggingface",
    "-v", "/tmp:/tmp",
    "--ipc=host",
    "--name", container_name,
    "lmsysorg/sglang:latest",
    "bash", "-c", "bash /tmp/sglang_serve.sh",
]

#Launch the Docker container
proc = subprocess.Popen(
    docker_cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, bufsize=1
)

#Stream the process output
def stream_output(pipe):
    for line in iter(pipe.readline, ""):
        print(line, end="")

# Use a thread to stream the output in without blocking the notebook
thread = threading.Thread(target=stream_output, args=(proc.stdout,))
thread.daemon = True
thread.start()

print(
    f"Starting SGLang server in Docker (PID: {proc.pid}, container name: {container_name}) in the background:"
)

As with TRT-LLM, please wait for the service to fully start inside the container.   
Once you see the message `INFO:     Application startup complete.`, you can proceed to send requests to the service:

In [None]:
import json
import requests

payload = {
    "model": base_model,
    "messages": [
        {"role": "system", "content": "You are a helpful assistant."},
        {"role": "user", "content": "Tell me about speculative decoding."},
    ],
    "max_tokens": 512,
    "temperature": 0,
}
headers = {"Content-Type": "application/json", "Accept": "application/json"}

#Send request to the SGLang server
response = requests.post(
    "http://localhost:30000/v1/chat/completions", headers=headers, data=json.dumps(payload)
)
output = response.json()

print(output)

Clean up the container

In [None]:
!docker rm -f sglang_serve_spec

## Deploying on vLLM (Coming Soon)

While vLLM is another extremely popular, high-performance inference server, direct support for speculative decoding with this demo notebook is still under active development. This notebook will be updated once deployment is possible.