In [87]:
# Import required libraries
import requests

# Base URL for the Llama Stack client API
# The client connects to the remote TRL service for training
base_url = "http://127.0.0.1:8321"

# Remote TRL service runs on http://localhost:8080
# Client forwards training requests to the remote service
remote_service_url = "http://localhost:8080"

# Headers for GET requests (retrieving data)
headers_get = {
    "accept": "application/json"
}

# Headers for POST requests (sending data)
headers_post = {
    "Content-Type": "application/json"
}

In [88]:
# Get the list of available providers
# This will show us what services are available (remote TRL for post-training, localfs for datasets, etc.)

url_providers = f"{base_url}/v1/providers"
response_providers = requests.get(url_providers, headers=headers_get)

# Display the providers and their configurations
# You should see 'remote::trl' provider for post-training and 'localfs' for dataset storage
print(response_providers.json())

{'data': [{'api': 'post_training', 'provider_id': 'trl_remote', 'provider_type': 'remote::trl', 'config': {'base_url': 'http://localhost:8080', 'timeout': 3600, 'connect_timeout': 30, 'max_retries': 3, 'retry_delay': 5, 'training_config': {'device': 'cuda', 'dpo_beta': 0.1, 'use_reference_model': True, 'max_seq_length': 2048, 'gradient_checkpointing': False, 'logging_steps': 10, 'warmup_ratio': 0.1, 'weight_decay': 0.01}}, 'health': {'status': 'Not Implemented', 'message': 'Provider does not implement health check'}}, {'api': 'datasetio', 'provider_id': 'localfs', 'provider_type': 'inline::localfs', 'config': {'kvstore': {'type': 'sqlite', 'db_path': '/tmp/llama_stack_provider_trl_remote/datasetio.db'}}, 'health': {'status': 'Not Implemented', 'message': 'Provider does not implement health check'}}]}


In [89]:
# List all available datasets in the system
# This will show existing datasets that can be used for training

url_datasets = f"{base_url}/v1/datasets"
response_datasets = requests.get(url_datasets, headers=headers_get)

# Display the datasets - each dataset should have a purpose (e.g., 'post-training/messages')
# and a source containing the training data
print(response_datasets.json())

{'data': [{'identifier': 'dataset-181c39ce-d135-48dc-86d8-158cfe7d231b', 'provider_resource_id': 'dataset-181c39ce-d135-48dc-86d8-158cfe7d231b', 'provider_id': 'localfs', 'type': 'dataset', 'purpose': 'post-training/messages', 'source': {'type': 'rows', 'rows': [{'prompt': 'What is 2+2?', 'chosen': '2+2 equals 4. This is basic arithmetic.', 'rejected': 'I dont know math.'}, {'prompt': 'What is the capital of France?', 'chosen': 'Paris is the capital city of France.', 'rejected': 'Dunno.'}, {'prompt': 'What is artificial intelligence?', 'chosen': 'AI is the simulation of human intelligence by machines.', 'rejected': 'No idea what that is.'}, {'prompt': 'What color is the sky?', 'chosen': 'The sky appears blue during clear weather.', 'rejected': 'I dont know colors.'}, {'prompt': 'What is the largest planet?', 'chosen': 'Jupiter is the largest planet in our solar system.', 'rejected': 'Not sure about planets.'}, {'prompt': 'Who wrote Hamlet?', 'chosen': 'William Shakespeare wrote Hamlet.

In [96]:
# Upload a MINIMAL DPO dataset for fast remote training
# This creates a tiny preference dataset for rapid testing

url_upload_dataset = f"{base_url}/v1/datasets"

# Define the minimal dataset payload - just 2 examples for speed!
dataset_payload = {
    "dataset_id": "test-dpo-dataset-remote-final",
    "purpose": "post-training/messages",             
    "dataset_type": "preference",                    
    "source": {
        "type": "rows",                              
        "rows": [
            {
                "prompt": "What is 2+2?",
                "chosen": "2+2 equals 4. This is basic arithmetic.",
                "rejected": "I don't know math."
            },
            {
                "prompt": "What color is the sky?",
                "chosen": "The sky is blue during clear weather.",
                "rejected": "No idea about colors."
            }
        ]
    },
    "metadata": {
        "provider_id": "localfs",                    # Use local filesystem storage
        "description": "Minimal DPO dataset for fast testing"
    }
}

# Send the POST request to upload the dataset
response_dataset = requests.post(url_upload_dataset, headers=headers_post, json=dataset_payload)
print("Dataset Upload Status:", response_dataset.status_code)
print("Dataset Upload Response:", response_dataset.json())

Dataset Upload Status: 200
Dataset Upload Response: {'identifier': 'test-dpo-dataset-remote-final', 'provider_resource_id': 'test-dpo-dataset-remote-final', 'provider_id': 'localfs', 'type': 'dataset', 'owner': {'principal': '', 'attributes': {}}, 'purpose': 'post-training/messages', 'source': {'type': 'rows', 'rows': [{'prompt': 'What is 2+2?', 'chosen': '2+2 equals 4. This is basic arithmetic.', 'rejected': "I don't know math."}, {'prompt': 'What color is the sky?', 'chosen': 'The sky is blue during clear weather.', 'rejected': 'No idea about colors.'}]}, 'metadata': {'provider_id': 'localfs', 'description': 'Minimal DPO dataset for fast testing'}}


In [91]:
# Verify that our dataset was successfully uploaded
# This should now show our "test-dpo-dataset-inline-large" dataset

url_datasets = f"{base_url}/v1/datasets"
response_datasets = requests.get(url_datasets, headers=headers_get)

# The response should include our uploaded dataset with all the preference pairs
print(response_datasets.json())

{'data': [{'identifier': 'dataset-181c39ce-d135-48dc-86d8-158cfe7d231b', 'provider_resource_id': 'dataset-181c39ce-d135-48dc-86d8-158cfe7d231b', 'provider_id': 'localfs', 'type': 'dataset', 'purpose': 'post-training/messages', 'source': {'type': 'rows', 'rows': [{'prompt': 'What is 2+2?', 'chosen': '2+2 equals 4. This is basic arithmetic.', 'rejected': 'I dont know math.'}, {'prompt': 'What is the capital of France?', 'chosen': 'Paris is the capital city of France.', 'rejected': 'Dunno.'}, {'prompt': 'What is artificial intelligence?', 'chosen': 'AI is the simulation of human intelligence by machines.', 'rejected': 'No idea what that is.'}, {'prompt': 'What color is the sky?', 'chosen': 'The sky appears blue during clear weather.', 'rejected': 'I dont know colors.'}, {'prompt': 'What is the largest planet?', 'chosen': 'Jupiter is the largest planet in our solar system.', 'rejected': 'Not sure about planets.'}, {'prompt': 'Who wrote Hamlet?', 'chosen': 'William Shakespeare wrote Hamlet.

In [99]:
# Submit FAST DPO training job to remote TRL service 🚀
# MINIMAL CONFIG: Completes in ~30 seconds with 8-GPU FSDP!

url_train_model = f"{base_url}/v1/post-training/preference-optimize"

train_model_data = {
    "job_uuid": "remote-dpo-fast-test-granite-final-final",  # Unique job ID
    "model": "ibm-granite/granite-3.3-2b-base",  # Small 82M parameter model (fast to load)
    "finetuned_model": "dpo-granite-3.3-2b-base-fast-final-final",
    "checkpoint_dir": "./checkpoints",  # Save to absolute /checkpoints directory
    # NOTE: Client requires LoRA format but remote service converts to DPO
    "algorithm_config": {
        "type": "LoRA",
        "lora_attn_modules": ["attn"],
        "apply_lora_to_mlp": False,
        "apply_lora_to_output": False,
        "rank": 8,   # Smaller rank = faster training
        "alpha": 16  # Smaller alpha
    },
    "training_config": {    
        "n_epochs": 1,                    # Just 1 epoch
        "max_steps_per_epoch": 2,         # Only 2 steps total! 
        "gradient_accumulation_steps": 1, # No accumulation
        "data_config": {
            "dataset_id": "test-dpo-dataset-remote-final",
            "batch_size": 1,              # Smallest batch size 
            "shuffle": False,             # Skip shuffling for speed
            "data_format": "instruct"
        },
        "optimizer_config": {
            "optimizer_type": "adamw",
            "lr": 1e-4,                   # Higher LR for faster convergence
            "lr_scheduler_type": "constant", # No scheduling overhead
            "warmup_steps": 0,            # No warmup steps
            "weight_decay": 0.0,          # No weight decay
            "num_warmup_steps": 0
        }
    },
    "hyperparam_search_config": {},
    "logger_config": {}
}

# This should complete in ~30 seconds: Client -> Llama Stack -> Remote TRL Service
response_train_model = requests.post(url_train_model, headers=headers_post, json=train_model_data)
print("Training Status:", response_train_model.status_code)
print("Training Response:", response_train_model.json())

Training Status: 200
Training Response: {'job_uuid': 'remote-dpo-fast-test-granite-final-final'}


In [93]:
# Get a list of all post-training jobs
# This will show all training jobs that have been submitted to the system

url_post_training_jobs = f"{base_url}/v1/post-training/jobs"
response_post_training_jobs = requests.get(url_post_training_jobs, headers=headers_get)

# Display all jobs with their current status and metadata
print(response_post_training_jobs.json())

{'data': [{'job_uuid': 'remote-dpo-fast-test-granite-hacky'}]}


In [94]:
# Check the status of a specific training job
# Replace the job_uuid with the actual UUID from your training job

job_uuid = "remote-dpo-fast-test-granite-hacky"  # The job UUID from the remote training request
url_job_status = f"{base_url}/v1/post-training/job/status?job_uuid={job_uuid}"

response_job_status = requests.get(url_job_status, headers=headers_get)

print("Job Status:", response_job_status.status_code)
# The response will include: status, scheduled_at, started_at, completed_at, checkpoints
print("Job Status Response:", response_job_status.json())

Job Status: 200
Job Status Response: {'job_uuid': 'remote-dpo-fast-test-granite-hacky', 'status': 'in_progress', 'scheduled_at': '2025-06-23T03:53:19.283760Z', 'started_at': '2025-06-23T03:53:19.285018Z', 'completed_at': None, 'resources_allocated': None, 'checkpoints': []}


In [95]:
# Retrieve artifacts (checkpoints, metrics) from a completed training job
# This will show available model checkpoints and their metadata

url_job_artifacts = f"{base_url}/v1/post-training/job/artifacts?job_uuid={job_uuid}"
response_job_artifacts = requests.get(url_job_artifacts, headers=headers_get)

print("Job Artifacts Status:", response_job_artifacts.status_code)
# The response will include checkpoint information: identifier, path, epoch, training_metrics
print("Job Artifacts Response:", response_job_artifacts.json())

Job Artifacts Status: 200
Job Artifacts Response: {'job_uuid': 'remote-dpo-fast-test-granite-hacky', 'checkpoints': []}
