### Prerequisites

### Imports 

In [None]:
from sagemaker import get_execution_role, Session
from sagemaker.huggingface import HuggingFace
import sagemaker 
import logging

In [None]:
logger = logging.getLogger('sagemaker')
logger.setLevel(logging.DEBUG)
logger.addHandler(logging.StreamHandler())

In [None]:
logger.info(f'[Using SageMaker = {sagemaker.__version__}]')

In [None]:
session = Session()

ROLE = get_execution_role()
BUCKET = session.default_bucket()
logger.info(f'Default bucket = {BUCKET}')

TRANSFORMERS_VERSION = '4.17.0'
PYTORCH_VERSION = '1.10.2'
PYTHON_VERSION = 'py38'

INSTANCE_TYPE = 'ml.p3.16xlarge'
INSTANCE_COUNT = 6

VOLUME_SIZE = 1024  # in GB

In [None]:
DATA = {'train': f's3://{BUCKET}/data'}
HYPERPARAMETERS = {'model_s3_save_path': f's3://{BUCKET}/model'}
DISTRIBUTION_STRATEGY = {'smdistributed':{'dataparallel':{ 'enabled': True }}}

In [None]:
huggingface_estimator = HuggingFace(entry_point='fine_tune_clf.py', 
                                    source_dir='./scripts', 
                                    role=ROLE, 
                                    instance_type=INSTANCE_TYPE,
                                    instance_count=INSTANCE_COUNT,
                                    volume_size=VOLUME_SIZE,
                                    transformers_version=TRANSFORMERS_VERSION, 
                                    pytorch_version=PYTORCH_VERSION, 
                                    py_version=PYTHON_VERSION,
                                    hyperparameters=HYPERPARAMETERS,
                                    distribution=DISTRIBUTION_STRATEGY, 
                                    disable_profiler=True,
                                    debugger_hook_config=False)

In [None]:
####

In [None]:
%%time

huggingface_estimator.fit(DATA)