In [None]:
import sagemaker 

sess = sagemaker.Session()
# sagemaker session bucket -> used for uploading data, models and logs
# sagemaker will automatically create this bucket if it not exists
sagemaker_session_bucket=None
if sagemaker_session_bucket is None and sess is not None:
    # set to default bucket if a bucket name is not given
    sagemaker_session_bucket = sess.default_bucket()

role = sagemaker.get_execution_role()

sess = sagemaker.Session(default_bucket=sagemaker_session_bucket)

print(f"sagemaker role arn: {role}")
print(f"sagemaker bucket: {sess.default_bucket()}")
print(f"sagemaker session region: {sess.boto_region_name}")

In [None]:
hyperparameters = {}
hyperparameters["access_token"] =  "hf_xxxx" # update the access token from hf
hyperparameters["model_name"] = "meta-llama/Llama-2-70b-chat-hf"
hyperparameters["tp_size"] = 8
hyperparameters["pp_size"] = 8

In [None]:
checkpoint_s3_uri = "s3://" + sagemaker_session_bucket + "/neuronx_llama_experiment"
# we will use the sagemaker s3 checkpoints mechanism since we need read/write access to the paths.
hyperparameters["output_dir"] = "/opt/ml/checkpoints/llama70b_weights"
hyperparameters["checkpoint-dir"] = '/opt/ml/checkpoints'
hyperparameters["convert_from_full_model"] = ""
hyperparameters["n_layers"] = 80

In [None]:
docker_image = "763104351884.dkr.ecr.us-east-2.amazonaws.com/pytorch-training-neuronx:1.13.1-neuronx-py310-sdk2.17.0-ubuntu20.04"

In [None]:
from sagemaker.pytorch import PyTorch
# Need to check if this works on multinode with torchrun.
estimator = PyTorch(
    base_job_name="neuronx-llama-download-model-weights",
    source_dir="./scripts",
    entry_point="convert_checkpoints.py",
    role=role,
    image_uri=docker_image,
    instance_count=1,
    instance_type="ml.trn1.32xlarge",
    sagemaker_session=sess,
    volume_size=1024,
    hyperparameters=hyperparameters,
    debugger_hook_config=False,
    checkpoint_s3_uri=checkpoint_s3_uri,
    checkpoint_local_path=hyperparameters["checkpoint-dir"],
    disable_output_compression=True,
    keep_alive_period_in_seconds=600
)

In [None]:
estimator.fit()