# Initial setup

In [None]:
import sagemaker
from sagemaker import get_execution_role

sagemaker_session = sagemaker.Session()
role = "arn:aws:iam::941656036254:role/service-role/AmazonSageMaker-ExecutionRole-20210904T193230"

bucket = sagemaker_session.default_bucket()
prefix = 'sagemaker/sm-dataparallel-distribution-options'
print('Bucket:\n{}'.format(bucket))

In [None]:
# Data preparation was done in lab2 of this chapter.
# If you skipped it, then run following code below
# preparing dataset
! wget https://download.pytorch.org/tutorial/hymenoptera_data.zip
! unzip hymenoptera_data.zip
data_url = sagemaker_session.upload_data(path="./hymenoptera_data", key_prefix="hymenoptera_data")

# Remote Mode

In [None]:
from sagemaker.pytorch import PyTorch

instance_type = 'ml.p3.2xlarge'
instance_count = 2

distribution = { 
    "smdistributed": { 
        "dataparallel": {
            "enabled": True, 
            "custom_mpi_options": "-verbose -x NCCL_DEBUG=VERSION"
        }
    }
}

sm_dp_estimator = PyTorch(
          entry_point="train.py", # Pick your train script
          source_dir='3_sources',
          role=role,
          instance_type=instance_type,
          sagemaker_session=sagemaker_session,
          framework_version='1.6.0',
          py_version='py36',
          instance_count=1,
          hyperparameters={
              "batch-size":64,
              "epochs":20,
              "model-name":"squeezenet",
              "num-classes": 2,
              "feature-extract":True,
              "sync-s3-path":f"s3://{bucket}/distributed-training/output"
          },
          disable_profiler=True,
          debugger_hook_config=False,
          distribution=distribution,
          base_job_name="SM-DP",
      )

In [None]:
sm_dp_estimator.fit(inputs={"train":f"{data_url}/train", "val":f"{data_url}/val"})