In [19]:
!pip install -Uqq sagemaker transformers[torch] datasets

In [29]:
import sagemaker

sagemaker_session = sagemaker.Session()
region = sagemaker_session.boto_region_name

bucket = sagemaker_session.default_bucket()  # "nico-ml-ops-course"
prefix = "sagemaker/food101"

role = sagemaker.get_execution_role()

In [23]:
from functools import partial

from torchvision.datasets import Food101
from transformers import ViTFeatureExtractor

model_name_or_path = "google/vit-base-patch16-224-in21k"
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name_or_path)

preprocessor = partial(feature_extractor, return_tensors="pt")

train_ds = Food101(
    root="food101_dataset", split="train", transform=preprocessor, download=True
)
test_ds = Food101(root="food101_dataset", split="test", transform=preprocessor)

labels = train_ds.classes

In [None]:
inputs = sagemaker_session.upload_data(
    path="./food101_dataset", bucket=bucket, key_prefix=prefix
)
print("input spec (in this case, just an S3 path): {}".format(inputs))

In [36]:
import sagemaker
from sagemaker.pytorch import PyTorch

sagemaker_session = sagemaker.Session()

pytorch_estimator = PyTorch(
    entry_point="train.py",
    instance_type="ml.g4dn.2xlarge",
    instance_count=1,
    framework_version="2.0.0",
    py_version="py310",
    output_path="s3://nico-ml-ops-course/foodformer",
    role=role,
    dependencies=["requirements.txt"],
    # source_dir=".",
    # hyperparameters={},
)

In [37]:
pytorch_estimator.fit({"train": f"s3://{bucket}/{prefix}"}, wait=True)

INFO:sagemaker.image_uris:image_uri is not presented, retrieving image_uri based on instance_type, framework etc.
INFO:sagemaker.image_uris:image_uri is not presented, retrieving image_uri based on instance_type, framework etc.
INFO:sagemaker:Creating training-job with name: pytorch-training-2023-08-31-20-06-40-404


Using provided s3_resource


In [10]:
predictor = pytorch_estimator.deploy(
    initial_instance_count=1, instance_type="ml.g4dn.2xlarge"
)

-------!

In [24]:
sagemaker_session.delete_endpoint(endpoint_name=predictor.endpoint_name)