In [None]:
%%capture
!pip install -U sagemaker

In [None]:
import torch
import sagemaker
from sagemaker.pytorch import PyTorch


In [None]:
role = sagemaker.get_execution_role()
sagemaker_session = sagemaker.Session()
bucket = sagemaker_session.default_bucket()


In [None]:
estimator = PyTorch(
    entry_point="model.py",
    source_dir="model",
    role=role,
    py_version="py38",
    framework_version="1.10.0",
    instance_count=2,
    use_spot_instances=True,
    input_mode="FastFile",  # Amazon SageMaker streams data from S3 on demand instead of downloading the entire dataset before training begins.
    instance_type="ml.g4dn.12xlarge",
    max_run=60
    * 60
    * 60,  # Timeout in seconds for training. After this amount of time Amazon SageMaker terminates the job regardless of its current status.
    max_wait=60
    * 2
    * 60
    * 60,  # Timeout in seconds waiting for spot training job. After this amount of time Amazon SageMaker will stop waiting for managed spot training job to complete.
    volume_size=900,  # Size in GB of the EBS volume to use for storing input data during training (default: 30).
    hyperparameters={
        "epochs": 1,
        # https://pytorch.org/docs/stable/distributed.html
        # TODO: nccl
        "backend": "gloo",  # Use the Gloo backend for distributed CPU training. Use the NCCL backend for distributed GPU training. If you encounter any problem with NCCL, use Gloo as the fallback option.
        "batch-size": 512,
    },  # Hyperparameters to initialize this estimator with.
)


In [None]:
inputs = "s3://sagemaker-eu-west-1-oasprocessed/tokenized_data/"


In [None]:
estimator.fit({"training": inputs})
