# Creating assets for model customization

In [None]:
from rich.pretty import pprint

from sagemaker.ai_registry.air_constants import REWARD_FUNCTION, REWARD_PROMPT
from sagemaker.ai_registry.dataset import DataSet, CustomizationTechnique
from sagemaker.ai_registry.evaluator import Evaluator

In [None]:
# Configure AWS credentials and region
#! ada credentials update --provider=isengard --account=<> --role=Admin --profile=default --once
#! aws configure set region us-west-2

## DataSets

#### Create
- DataSet input format depends on Customization technique
- If no customization technique is provide, client side validation would be skipped
- Provide a source (it can be local file path or S3 URL)

In [None]:

# 1. S3 Data source
dataset = DataSet.create(
            name="sdkv3-gen-ds2",
            source="s3://sdk-air-test-bucket/datasets/training-data/jamjee-sft-ds1.jsonl",
                # or use local filepath as source.
            # customization_technique=CustomizationTechnique.SFT
        )

In [None]:
# Refreshes status from hub
dataset.refresh()
pprint(dataset.__dict__)

In [None]:
versions = dataset.get_versions()
pprint(versions.__dict__)

In [None]:
# delete specific version
dataset.delete(version="0.0.4")
#dataset.delete(version="use a version from versions")
#pprint(versions)
# specified deleted version should not be part of output

In [None]:
# deletes all versions of this dataset by default
dataset.delete()

#### List DataSet

In [None]:
#Optional max_results argument for pagination or else use default config
datasets = DataSet.get_all(max_results=2)
for dataset in datasets:
    pprint(dataset)

#### Use an existing DataSet

In [None]:
# Use a dataset from iterator
dataset = next(DataSet.get_all(max_results=2))
for dataset in datasets:
    pprint(dataset.__dict__)

In [None]:
# Use a dataset by name
dataset = DataSet.get(name="sdkv3-gen-ds2")
pprint(dataset)

# We can do CRUD operation on this DataSet
# e.g. dataset.delete()

In [None]:
#Create a new version of this dataset
dataset.create_version(source="s3://<bucket>/datasets/test_ds")

In [None]:
versions = dataset.get_versions()
pprint(versions)

## Evaluator

In [None]:
# Method : Lambda
evaluator = Evaluator.create(
    name = "sdk-new-rf11",
    source="arn:aws:lambda:us-west-2:<>:function:<function-name>8",
    type=REWARD_FUNCTION

)

In [None]:
# Method : BYOC

evaluator = Evaluator.create(
    name = "eval-lambda-test",
    source="/path_to_local/eval_lambda_1.py",
    type = REWARD_FUNCTION
)

In [None]:
# Reward Prompt
evaluator = Evaluator.create(
    name = "jamj-rp2",
    source="/path_to_local/custom_prompt.jinja",
    type = REWARD_PROMPT
)

In [None]:
# Optional wait, by default we have wait = True during create call.
evaluator.wait()

In [None]:
evaluator.refresh()
pprint(evaluator)

In [None]:
# Optional max_results for pagination
evaluators = Evaluator.get_all(max_results=2)
for evaluator in evaluators:
    pprint(evaluator)

In [None]:
# Get evaluators by type
evaluators = Evaluator.get_all(type='RewardPrompt', max_results=2)
for evaluator in evaluators:
    pprint(evaluator)

In [None]:
# Get an evaluator by name
evaluator = Evaluator.get(name="sdk-new-rf11")
pprint(evaluator)

In [None]:
evaluator.create_version(source=evaluator.reference)

In [None]:
versions = evaluator.get_versions()
pprint(versions)

In [None]:
# delete evaluator, option version argument or delete all versions.
evaluator.delete()