# Azure Reinforcement Learning (GRPO) with Speculative Decoding

This notebook demonstrates the complete RL training and speculative decoding workflow at a high level. Most implementation details are deferred to `rl_spec_dec_utils.py`.

## 1. Setup Workspace

In [1]:
from rl_spec_dec_utils import setup_workspace, run_rl_training_pipeline, run_draft_model_pipeline, prepare_combined_model_for_deployment, deploy_speculative_decoding_endpoint, test_deployment, download_and_register_hf_model

ml_client, registry_ml_client = setup_workspace()

Found the config file in: config.json
Class DeploymentTemplateOperations: This is an experimental class, and may change at any time. Please see https://aka.ms/azuremlexperimental for more information.
Overriding of current TracerProvider is not allowed
Overriding of current LoggerProvider is not allowed
Overriding of current MeterProvider is not allowed
Attempting to instrument while already instrumented
Attempting to instrument while already instrumented
Attempting to instrument while already instrumented


Workspace setup complete, connected


## 2. Run RL Training Pipeline (GRPO)

In [None]:
# Run complete RL training pipeline: verify datasets, register data, train model, register model
rl_job, status, registered_model = run_rl_training_pipeline(
    base_model_id="deepseek-ai/DeepSeek-R1-Distill-Qwen-7B",
    compute_cluster="k8s-a100-compute",
    config={
        "num_nodes_finetune": 1,
        "trainer_total_epochs": 0,
    },
)

In [2]:
import os
import time
import uuid
import json
import shutil
import requests
from pathlib import Path
from huggingface_hub import snapshot_download
from azure.ai.ml import MLClient, Input, dsl
from azure.ai.ml.constants import AssetTypes
from azure.identity import DefaultAzureCredential, InteractiveBrowserCredential
from azure.ai.ml.dsl import pipeline
from azure.ai.ml.entities import (
    Model,
    KubernetesOnlineEndpoint,
    KubernetesOnlineDeployment,
    ProbeSettings,
    OnlineRequestSettings,
    Environment,
    BuildContext,
)

In [None]:
hf_model_id = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"
azureml_model_name = "qwen-7b-base"

model = download_and_register_hf_model(
    hf_model_id=hf_model_id,
    azureml_model_name=azureml_model_name,
)

## 4 Deploy base and RL models
We have registered the Qwen-7B base and finetuned models, now we will deploy it as kubernetes endpoints.

In [3]:
ft_model = ml_client.models.get(name="grpo-finqa-3a436a81", version=1)

In [None]:
#--------ENVIRONMENT--------
environment_definition = Environment(
    name="vllm-env",
    build=BuildContext(path="./grpoenv"),
    inference_config={
         "liveness_route": {
            "port": 8000,
            "path": "/health"
        },
        "readiness_route": {
            "port": 8000,
            "path": "/health"
        },
        "scoring_route": {
            "port": 8000,
            "path": "/"
        }
    }

)
environment = ml_client.environments.create_or_update(environment_definition).

[32mUploading grpoenv (0.0 MBs): 100%|██████████| 38/38 [00:00<00:00, 1254.75it/s]
[39m



In [None]:
#--------ENDPOINT--------
endpoint_name = f"{ft_model.name}-endpoint"
endpoint = KubernetesOnlineEndpoint(
    name=endpoint_name,
    auth_mode="key",
    compute="k8s-a100-compute",
)
ml_client.online_endpoints.begin_create_or_update(endpoint).wait()

#--------DEPLOYMENT--------
model = ft_model
deployment_name = f"deployment-2"
environment_vars = {
    "MODEL_PATH": "/model/model_output",
{}
probe_settings = ProbeSettings(
    initial_delay=600,
    period=10,
    timeout=2,
    success_threshold=1,
    failure_threshold=30,
)
deployment = KubernetesOnlineDeployment(
    name=deployment_name,
    endpoint_name=endpoint_name,
    model=model,
    instance_type="monogpu",
    model_mount_path="/model",
    instance_count=1,
    environment=environment, 
    liveness_probe=probe_settings,
    readiness_probe=probe_settings,
    request_settings=OnlineRequestSettings(
        request_timeout_ms=90000,
        max_concurrent_requests_per_instance=4,
    ),
)
ml_client.online_deployments.begin_create_or_update(deployment).wait()
endpoint.traffic = {deployment_name: 100}
ml_client.online_endpoints.begin_create_or_update(endpoint).wait()

Check: endpoint grpo-finqa-3a436a81-endpoint exists


....................................

KeyboardInterrupt: 

..

## 3. Create Draft Model for Speculative Decoding

In [None]:
# Train EAGLE3 draft model for speculative decoding
draft_job, draft_status = run_draft_model_pipeline(
    ml_client=ml_client,
    registry_ml_client=registry_ml_client,
    compute_cluster="shj-a100",
    num_epochs=1,
    monitor=False,  # Set to True to wait for completion
)

## 4. Prepare Combined Model for Deployment

In [None]:
# Download draft model, download base model, combine and register for deployment
combined_model = prepare_combined_model_for_deployment(
    ml_client=ml_client,
    draft_job_name=draft_job.name,
    base_model_hf_id="deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
    model_name="grpo-speculative-decoding",
)

## 5. Deploy Speculative Decoding Endpoint

In [None]:
# Deploy managed online endpoint with speculative decoding
endpoint_name = deploy_speculative_decoding_endpoint(
    ml_client=ml_client,
    combined_model=combined_model,
    instance_type="monogpu",
    compute_name="shj-a100"
)

## 6. Test Deployment

In [None]:
# Test the deployed endpoint with a financial reasoning question
result = test_deployment(ml_client, endpoint_name)

## 7. Cleanup (Optional)

In [None]:
# Uncomment to delete endpoint and free up resources
# ml_client.online_endpoints.begin_delete(name=endpoint_name).wait()
# print(f"✓ Endpoint deleted: {endpoint_name}")

## Summary

This simplified notebook demonstrates the complete workflow:

1. ✅ **Setup**: Connected to Azure ML workspace and registry
2. ✅ **RL Training**: Trained GRPO model on FinQA dataset  
3. ✅ **Draft Model**: Created EAGLE3 draft model for speculative decoding
4. ✅ **Model Preparation**: Combined base and draft models
5. ✅ **Deployment**: Deployed speculative decoding endpoint
6. ✅ **Testing**: Validated 2-3x faster inference
