# Speculative Decoding

---

## Pre-requisite

In [None]:
%pip install huggingface_hub

In [None]:
from azure.identity import DefaultAzureCredential, InteractiveBrowserCredential


try:
    credential = DefaultAzureCredential()
    # Check if given credential can get token successfully.
    credential.get_token("https://management.azure.com/.default")
except Exception as ex:
    # Fall back to InteractiveBrowserCredential in case DefaultAzureCredential not work
    credential = InteractiveBrowserCredential()

In [None]:
from azure.ai.ml import MLClient


ml_client = MLClient.from_config(credential=credential)

## Create draft model

In [None]:
registry_name = "test_centralus"
comp_name = "eagle3_chat_completion_pipeline"

In [None]:
registry_ml_client = MLClient(credential=credential, registry_name=registry_name)
eagle3_comp = registry_ml_client.components.get(name=comp_name, label="latest")
eagle3_comp

In [None]:
draft_model_config = {
  "architectures": [
    "LlamaForCausalLMEagle3"
  ],
  "bos_token_id": 128000,
  "eos_token_id": 128001,
  "hidden_act": "silu",
  "hidden_size": 4096,
  "initializer_range": 0.02,
  "intermediate_size": 14336,
  "max_position_embeddings": 2048,
  "model_type": "llama",
  "num_attention_heads": 32,
  "num_key_value_heads": 8,
  "num_hidden_layers": 1,
  "pad_token_id": 0,
  "rms_norm_eps": 1e-05,
  "tie_word_embeddings": False,
  "torch_dtype": "float16",
  "transformers_version": "4.28.1",
  "use_cache": True,
  "vocab_size": 128256,
  "draft_vocab_size": 32000
}

In [None]:
draft_config_path = "./data/config/draft_model_config.json"
input_data_path = "./data/train/sharegpt_train_small.jsonl"

In [None]:
import json


with open(draft_config_path, "w") as f:
    json.dump(draft_model_config, f, indent=4)

In [None]:
from azure.ai.ml.dsl import pipeline
from azure.ai.ml.entities._inputs_outputs import Input
from azure.ai.ml.constants._common import AssetTypes


@pipeline
def speculative_decoding_pipeline():
    node = eagle3_comp(
        mlflow_model_path=Input(type=AssetTypes.MLFLOW_MODEL, path="azureml://registries/azureml-meta/models/Meta-Llama-3-8B-Instruct/versions/9"),
        dataset_train_split=Input(type=AssetTypes.URI_FILE, path=input_data_path),
        dataset_validation_split=Input(type=AssetTypes.URI_FILE, path=input_data_path),
        draft_model_config=Input(type=AssetTypes.URI_FILE, path=draft_config_path),
        # resume_from_checkpoint=None,
    )
    return {
        "output_model": node.outputs.output_model_path
    }


spec_dec_job = speculative_decoding_pipeline()

In [None]:
spec_dec_job = ml_client.jobs.create_or_update(
    spec_dec_job, experiment_name="speculative-decoding-exp"
)
spec_dec_job

## Download models

In [None]:
base_model_name = "nvidia/Llama-3.1-8B-Instruct-FP8"
# draft_model_name = "lmsys/sglang-EAGLE3-LLaMA3.1-Instruct-8B"

In [None]:
from huggingface_hub import snapshot_download


base_model_dir = "./models/base"
draft_model_dir = "./models/draft"

snapshot_download(repo_id=base_model_name, local_dir=base_model_dir)
# snapshot_download(repo_id=draft_model_name, local_dir=draft_model_dir)

In [None]:
ml_client.jobs.download(name=spec_dec_job.name, output_name="output_model", download_path=draft_model_dir, all=True)

## Change config

In [None]:
import json


draft_config = json.load(open(f"{draft_model_dir}/config.json"))

draft_config = {
    **draft_config,
    "max_position_embeddings": 131072,
    "rope_scaling": {
        "factor": 8,
        "high_freq_factor": 4,
        "low_freq_factor": 1,
        "original_max_position_embeddings": 8192,
        "rope_type": "llama3"
    }
}

with open(f"{draft_model_dir}/config.json", "w") as f:
    json.dump(draft_config, f, indent=4)

## Upload model as a whole

In [None]:
from azure.ai.ml.entities import Model


model = Model(
    path="./models", # Path to your model files
    name="llama-3-1-speculative",
    version="1"
)
ml_client.models.create_or_update(model)

## Create environment

In [None]:
from azure.ai.ml.entities import Environment, BuildContext


env = Environment(
    build=BuildContext(path="./environment"),
    name="speculative-env",
    description="Environment for speculative decoding inference using sglang.",
)

ml_client.environments.create_or_update(env)

## Create an online endpoint

In [None]:
endpoint_name = "llama-3-1-speculative-endpoint"

In [None]:
from azure.ai.ml.entities import ManagedOnlineEndpoint


endpoint = ManagedOnlineEndpoint(
   name=endpoint_name,
   auth_mode="key" # Use "aml_token" for token-based authentication
)

ml_client.online_endpoints.begin_create_or_update(endpoint).wait()

In [None]:
deployment_name = "llama-3-1-speculative-deployment"

In [None]:
from azure.ai.ml.entities import ManagedOnlineDeployment


deployment = ManagedOnlineDeployment(
   name=deployment_name,
   endpoint_name=endpoint_name,
   model=model,
   # instance_type="Standard_NC40ads_H100_v5",
   instance_type="STANDARD_ND96ISRF_H100_V5",
   instance_count=1,
   environment=env
)
ml_client.online_deployments.begin_create_or_update(deployment).wait()

## Invoke endpoint

In [None]:
sample_request_file = "./sample_request.json"

In [None]:
sample_payload = {
    "messages": [
        {"role": "system", "content": "You are a helpful assistant."},
        {"role": "user", "content": "What's the weather like today?"}
    ],
    "temperature": 0.7,
    "max_tokens": 300
}

with open(sample_request_file, "w") as f:
    json.dump(sample_payload, f, indent=4)

In [None]:
response = ml_client.online_endpoints.invoke(
   endpoint_name=endpoint_name,
   deployment_name=deployment_name,
   request_file=sample_request_file
)
print(response)