In [7]:
# Import required libraries
import requests

# Base URL for the Llama Stack client API
# The client connects to the remote TRL service for training
base_url = "http://127.0.0.1:8321"

# Remote TRL service runs on http://localhost:8080
# Client forwards training requests to the remote service
remote_service_url = "http://localhost:8080"

# Headers for GET requests (retrieving data)
headers_get = {
    "accept": "application/json"
}

# Headers for POST requests (sending data)
headers_post = {
    "Content-Type": "application/json"
}

In [8]:
# Get the list of available providers
# This will show us what services are available (remote TRL for post-training, localfs for datasets, etc.)

url_providers = f"{base_url}/v1/providers"
response_providers = requests.get(url_providers, headers=headers_get)

# Display the providers and their configurations
# You should see 'remote::trl' provider for post-training and 'localfs' for dataset storage
print(response_providers.json())

{'data': [{'api': 'post_training', 'provider_id': 'trl_remote', 'provider_type': 'remote::trl', 'config': {'base_url': 'http://localhost:8080', 'timeout': 3600, 'connect_timeout': 30, 'max_retries': 3, 'retry_delay': 5, 'training_config': {'device': 'cuda', 'dpo_beta': 0.1, 'dpo_loss_type': 'sigmoid', 'use_reference_model': True, 'max_seq_length': 2048, 'gradient_checkpointing': False, 'logging_steps': 10, 'warmup_ratio': 0.1, 'weight_decay': 0.01}}, 'health': {'status': 'Not Implemented', 'message': 'Provider does not implement health check'}}, {'api': 'datasetio', 'provider_id': 'localfs', 'provider_type': 'inline::localfs', 'config': {'kvstore': {'type': 'sqlite', 'db_path': '/tmp/llama_stack_provider_trl_remote/datasetio.db'}}, 'health': {'status': 'Not Implemented', 'message': 'Provider does not implement health check'}}]}


In [10]:
# Upload a MINIMAL DPO dataset for fast remote training
# This creates a tiny preference dataset for rapid testing

url_upload_dataset = f"{base_url}/v1/datasets"

# Define the minimal dataset payload - just 2 examples for speed!
dataset_payload = {
    "dataset_id": "test-dpo-dataset-remote",
    "purpose": "post-training/messages",             
    "dataset_type": "preference",                    
    "source": {
        "type": "rows",                              
        "rows": [
            {
                "prompt": "What is 2+2?",
                "chosen": "2+2 equals 4. This is basic arithmetic.",
                "rejected": "I don't know math."
            },
            {
                "prompt": "What color is the sky?",
                "chosen": "The sky is blue during clear weather.",
                "rejected": "No idea about colors."
            }
        ]
    },
    "metadata": {
        "provider_id": "localfs", # Use local filesystem storage
        "description": "Minimal DPO dataset for fast testing"
    }
}

# Send the POST request to upload the dataset
response_dataset = requests.post(url_upload_dataset, headers=headers_post, json=dataset_payload)
print("Dataset Upload Status:", response_dataset.status_code)
print("Dataset Upload Response:", response_dataset.json())

Dataset Upload Status: 200
Dataset Upload Response: {'identifier': 'test-dpo-dataset-remote', 'provider_resource_id': 'test-dpo-dataset-remote', 'provider_id': 'localfs', 'type': 'dataset', 'owner': {'principal': '', 'attributes': {}}, 'purpose': 'post-training/messages', 'source': {'type': 'rows', 'rows': [{'prompt': 'What is 2+2?', 'chosen': '2+2 equals 4. This is basic arithmetic.', 'rejected': "I don't know math."}, {'prompt': 'What color is the sky?', 'chosen': 'The sky is blue during clear weather.', 'rejected': 'No idea about colors.'}]}, 'metadata': {'provider_id': 'localfs', 'description': 'Minimal DPO dataset for fast testing'}}


In [4]:
# Verify that our dataset was successfully uploaded
# This should now show our "test-dpo-dataset-inline-large" dataset

url_datasets = f"{base_url}/v1/datasets"
response_datasets = requests.get(url_datasets, headers=headers_get)

# The response should include our uploaded dataset with all the preference pairs
print(response_datasets.json())

{'data': [{'identifier': 'test-dpo-dataset-remote', 'provider_resource_id': 'test-dpo-dataset-remote', 'provider_id': 'localfs', 'type': 'dataset', 'purpose': 'post-training/messages', 'source': {'type': 'rows', 'rows': [{'prompt': 'What is 2+2?', 'chosen': '2+2 equals 4. This is basic arithmetic.', 'rejected': "I don't know math."}, {'prompt': 'What color is the sky?', 'chosen': 'The sky is blue during clear weather.', 'rejected': 'No idea about colors.'}]}, 'metadata': {'provider_id': 'localfs', 'description': 'Minimal DPO dataset for fast testing'}}]}


In [None]:
url_train_model = f"{base_url}/v1/post-training/preference-optimize"

train_model_data = {
    "job_uuid": "2",
    "finetuned_model": "ibm-granite/granite-3.3-2b-instruct", #smaller - distilgpt2, larger - ibm-granite/granite-3.3-2b-instruct
    
    #not used for DPO, used for PPO
    "algorithm_config": {
        "reward_scale": 0.0,            
        "reward_clip": 0.0,            
        "epsilon": 0.0,                 
        "gamma": 0.0,                 
        
        #used for DPO (already in run.yaml config)
        #"beta": 0.1, # controls the strength of the DPO loss                   
        #"loss_type": "sigmoid", # specides the loss function used for DPO       
    },
    
    "training_config": {    
        "n_epochs": 3, # number of epochs to train for
        "max_steps_per_epoch": 50, # maximum number of steps per epoch
        "gradient_accumulation_steps": 1, # number of gradient accumulation steps
        
        "optimizer_config": {
            "optimizer_type": "adamw", # adaptive learning with weight decay
            "weight_decay": 0.01, # prevents overfitting
            "num_warmup_steps": 0,
            "lr": 5e-5, # controls how big of a step optimzer is going to take
            "warmup_ratio": 0.1, # controls how long the optimzer will warmup for
        },
        
        "data_config": {
            "data_format": "instruct", # instruct format is used for my preference dataset
            "dataset_id": "test-dpo-dataset-remote",  # my preference dataset
            "batch_size": 2, # batch size
            "train_split_percentage": 0.9, # split percentage for training and validation
            "shuffle": True, # shuffle the dataset
        }
    },
    
    "hyperparam_search_config": {},
    "logger_config": {}
}

# Make the training request
response_train_model = requests.post(url_train_model, headers=headers_post, json=train_model_data)
print("Train Model Status:", response_train_model.status_code)
print("Train Model Response:", response_train_model.json())

Train Model Status: 200
Train Model Response: {'job_uuid': '1'}


In [None]:
# Get a list of all post-training jobs
# This will show all training jobs that have been submitted to the system

url_post_training_jobs = f"{base_url}/v1/post-training/jobs"
response_post_training_jobs = requests.get(url_post_training_jobs, headers=headers_get)

# Display all jobs with their current status and metadata
print(response_post_training_jobs.json())

In [None]:
# Check the status of a specific training job
# Replace the job_uuid with the actual UUID from your training job

job_uuid = "dpo-model-demo-remote-live"  # The job UUID from the remote training request
url_job_status = f"{base_url}/v1/post-training/job/status?job_uuid={job_uuid}"

response_job_status = requests.get(url_job_status, headers=headers_get)

print("Job Status:", response_job_status.status_code)
# The response will include: status, scheduled_at, started_at, completed_at, checkpoints
print("Job Status Response:", response_job_status.json())

In [None]:
# Retrieve artifacts (checkpoints, metrics) from a completed training job
# This will show available model checkpoints and their metadata

url_job_artifacts = f"{base_url}/v1/post-training/job/artifacts?job_uuid={job_uuid}"
response_job_artifacts = requests.get(url_job_artifacts, headers=headers_get)

print("Job Artifacts Status:", response_job_artifacts.status_code)
# The response will include checkpoint information: identifier, path, epoch, training_metrics
print("Job Artifacts Response:", response_job_artifacts.json())