In [None]:
import sagemaker
from sagemaker.pytorch import PyTorch
import boto3
import json

In [None]:
sagemaker_session = sagemaker.Session()

bucket = 'slip-ml'

role = 'arn:aws:iam::438465160412:role/Sagemaker'

In [None]:
secret_name = "huggingface"
region_name = "us-east-1"
session = boto3.session.Session()
secretsmanager = session.client(service_name='secretsmanager', region_name=region_name)
get_secret_value_response = secretsmanager.get_secret_value(SecretId=secret_name)
secret = get_secret_value_response['SecretString']
api_key = json.loads(secret)["API_KEY"]

In [None]:
instance_type = 'ml.g5.2xlarge'

In [None]:
image_uri = sagemaker.image_uris.retrieve(framework='pytorch',
                             region=sagemaker_session.boto_region_name,
                             instance_type=instance_type,
                             image_scope='training')
print(image_uri)

In [None]:
estimator = PyTorch(
    entry_point="finetune_llama.py",
    role=role,
    instance_count=1,
    instance_type=instance_type,
    image_uri=image_uri,
    py_version="py310",
    source_dir="source",
    hyperparameters={
        "batch-size": 2,
        "epochs": 7,
        "lr": 3e-4,
        "project-name": "vallr-phoneme-llama",
        'bucket': f'{bucket}',
    },
    sagemaker_session=sagemaker_session,
    volume_size=100,
    environment={"HF_TOKEN": "" + api_key,}
)

In [None]:
estimator.fit({'training': f's3://{bucket}/data/vallr/train/text/',
               'test': f's3://{bucket}/data/vallr/test/text/'})