In [1]:
from rich.pretty import pprint

from sagemaker.ai_registry.air_constants import REWARD_FUNCTION, REWARD_PROMPT
from sagemaker.ai_registry.dataset import DataSet, CustomizationTechnique
from sagemaker.ai_registry.evaluator import Evaluator

## DataSets

#### Create
- DataSet input format depends on Customization technique
- If no customization technique is provide, client side validation would be skipped
- Provide a source (it can be local file path or S3 URL)

In [2]:

# 1. S3 Data source
dataset = DataSet.create(
            name="sdkv3-gen-ds2",
            source="s3://sdk-air-test-bucket/datasets/training-data/jamjee-sft-ds1.jsonl",
            customization_technique=CustomizationTechnique.SFT
        )

Output()

In [3]:
# 2. local dataset file source
# ------------------------------------
# To remove this line post testing/dogfooding : Sample source https://quip-amazon.com/hXbKA1U0aKTL/Model-Customisation-Bug-Bash#temp:s:temp:C:bYf1df6d6a2346e4fea8eb89d6c9;temp:C:bYf4ecae019198f4eb8940daf7f8
# Download dataset from above link locally and provide data_location as local path.
# Or, Upload the file to an accessible S3 location and provide S3 URI below as data_location.

dataset = DataSet.create(
            name="my-rlvr-ds1",
            source="/Volumes/workplace/sagemaker-python-sdk-staging/recipes-data/rlvr/train_256.jsonl",
            customization_technique=CustomizationTechnique.RLVR
        )

Output()

In [4]:
# Refreshes status from hub
dataset.refresh()
pprint(dataset)

In [5]:
versions = dataset.get_versions()
pprint(versions)

In [6]:
# delete specific version
dataset.delete(version="0.0.4")
#dataset.delete(version="use a version from versions")
#pprint(versions)
# specified deleted version should not be part of output

True

In [None]:
# deletes all versions of this dataset by default
dataset.delete()

#### List DataSet

In [8]:
#Optional max_results argument for pagination or else use default config
datasets = DataSet.get_all(max_results=2)
for dataset in datasets:
    pprint(dataset)

#### Use an existing DataSet

In [None]:
# Use a dataset from iterator
dataset = next(DataSet.get_all(max_results=2))
for dataset in datasets:
    pprint(dataset)

In [4]:
# Use a dataset by name
dataset = DataSet.get(name="sdkv3-gen-ds2")
pprint(dataset)

# We can do CRUD operation on this DataSet
# e.g. dataset.delete()

In [None]:
#Create a new version of this dataset
dataset.create_version(source="s3://sdk-air-test-bucket/datasets/test_ds")

In [None]:
versions = dataset.get_versions()
pprint(versions)

## Evaluator

In [None]:
# Method : Lambda
evaluator = Evaluator.create(
    name = "sdk-new-rf11",
    source="arn:aws:lambda:us-west-2:052150106756:function:sm-eval-vinayshm-rlvr-llama-321b-instruct-v1-1762713051528",
    type=REWARD_FUNCTION

)

In [None]:
# Method : BYOC

evaluator = Evaluator.create(
    name = "eval-lambda-test",
    source="/Volumes/workplace/sagemaker-python-sdk-staging/recipes-data/eval_lambda_1.py",
    type = REWARD_FUNCTION
)

In [9]:
# Reward Prompt
# ------------------------------------
# To remove this line post testing/dogfooding : Sample source https://quip-amazon.com/hXbKA1U0aKTL/Model-Customisation-Bug-Bash#temp:s:temp:C:bYf5c2e9e77efea4868b0420892a;temp:C:bYf4ecae019198f4eb8940daf7f8
# Download prompt from above link locally and provide prompt_source as local path.
# Or, Upload the file to a accessible S3 location and provide S3 URI below as prompt_source.

evaluator = Evaluator.create(
    name = "jamj-rp2",
    source="/Users/jamjee/workplace/hubpuller/prompt/custom_prompt.jinja",
    type = REWARD_PROMPT
)

Output()

In [12]:
# Optional wait, by default we have wait = True during create call.
evaluator.wait()

Output()

In [11]:
evaluator.refresh()
pprint(evaluator)

In [10]:
# Optional max_results for pagination
evaluators = Evaluator.get_all(max_results=2)
for evaluator in evaluators:
    pprint(evaluator)

In [None]:
# Get evaluators by type
evaluators = Evaluator.get_all(type='RewardPrompt', max_results=2)
for evaluator in evaluators:
    pprint(evaluator)

In [13]:
# Get an evaluator by name
evaluator = Evaluator.get(name="sdk-new-rf11")
pprint(evaluator)

In [14]:
evaluator.create_version(source=evaluator.reference)

Output()

True

In [15]:
versions = evaluator.get_versions()
pprint(versions)

In [None]:
# delete evaluator, option version argument or delete all versions.
evaluator.delete()