In [None]:
import sagemaker
from sagemaker.pytorch import PyTorch
import boto3
import json

In [None]:
sagemaker_session = sagemaker.Session()
bucket = 'slip-ml'
role = 'arn:aws:iam::438465160412:role/Sagemaker'
project_name = 'vid-embed-image-transformer'

In [None]:
secret_name = "huggingface"
region_name = "us-east-1"
session = boto3.session.Session()
secretsmanager = session.client(service_name='secretsmanager', region_name=region_name)
get_secret_value_response = secretsmanager.get_secret_value(SecretId=secret_name)
secret = get_secret_value_response['SecretString']
api_key = json.loads(secret)["API_KEY"]

In [None]:
training_instances_gpus = {
    "ml.g5.2xlarge": 1,
    "ml.g5.12xlarge": 4,
    "ml.p4d.24xlarge": 8,
    "ml.p5.48xlarge": 8
}
instance_type = 'ml.p4d.24xlarge'

In [None]:
image_uri = sagemaker.image_uris.retrieve(framework='pytorch',
                             region=sagemaker_session.boto_region_name,
                             instance_type=instance_type,
                             image_scope='training')
print(image_uri)

In [None]:
estimator = PyTorch(
    entry_point="vallr.py",
    role=role,
    base_job_name=project_name,
    instance_count=1,
    instance_type=instance_type,
    image_uri=image_uri,
    py_version="py310",
    source_dir="source",
    hyperparameters={
        "batch-size": 2,
        "epochs": 7,
        "lr": 3e-4,
        "project-name": f"{project_name}",
    },
    sagemaker_session=sagemaker_session,
    volume_size=50,
    environment={"HF_TOKEN": "" + api_key},
    output_path=f's3://{bucket}/models/{project_name}',
    code_location=f's3://{bucket}/model-building/{project_name}',
)

In [None]:
estimator.fit({'training': f's3://{bucket}/data/vallr/train/',
               'test': f's3://{bucket}/data/vallr/test/'})

In [None]:
from transformers import ViTModel, ViTConfig        
vit = ViTModel.from_pretrained("google/vit-base-patch16-224")
vit_config = ViTConfig.from_pretrained("google/vit-base-patch16-224")