## Batch tune  script to find max batch_size fit into memory

### Installation

This example notebook requires the **SageMaker Python SDK v2.70.0** and **transformers v4.11.0**.

In [None]:
!pip install --force-reinstall sagemaker==2.70.0

In [None]:
!pip install transformers==4.11.0

In [None]:
import botocore
import boto3
import sagemaker
import transformers
import pandas as pd

print(f"sagemaker: {sagemaker.__version__}")
print(f"transformers: {transformers.__version__}")

### SageMaker environment 

In [None]:
import sagemaker

sess = sagemaker.Session()

# SageMaker session bucket -> used for uploading data, models and logs
# SageMaker will automatically create this bucket if it does not exist
sagemaker_session_bucket = None
if sagemaker_session_bucket is None and sess is not None:
    # set to default bucket if a bucket name is not given
    sagemaker_session_bucket = sess.default_bucket()

role = sagemaker.get_execution_role()
sess = sagemaker.Session(default_bucket=sagemaker_session_bucket)

print(f"sagemaker role arn: {role}")
print(f"sagemaker bucket: {sess.default_bucket()}")
print(f"sagemaker session region: {sess.boto_region_name}")

### Training Setup

This notebook uses HF training script to demonstrate how to find max batch size can fit in memory, if you're using a customized training script, please update `find_max_batch_size.py` script and `hyperparameters` accordingly.

In [None]:
LANGUAGE_MODELING_LOSS = "clm"  

MODEL_NAME = "gpt2"
TOKENIZER_NAME = "gpt2"
MODEL_CONFIG = "model_name_or_path"

INSTANCE_TYPE = "ml.p3.8xlarge"  

### Tune Native PyTorch

In [None]:
from sagemaker.huggingface import HuggingFace

# hyperparameters are passed to the training entrypoint as arguments
hyperparameters = {
    "training_script": f"run_{LANGUAGE_MODELING_LOSS}.py",
    MODEL_CONFIG: MODEL_NAME,
    "tokenizer_name": TOKENIZER_NAME,
    "fp16": True,
    "sequence_len": 512,
    "per_device_train_batch_size_min" : 1,
    "per_device_train_batch_size_max" : 128,
}

# configure the training job
native_estimator = HuggingFace(
    entry_point="find_max_batch_size.py",
    source_dir="./scripts",
    instance_type=INSTANCE_TYPE,
    instance_count=1,
    role=role,
    py_version="py38",
    transformers_version="4.11.0",
    pytorch_version="1.9.0",
    volume_size=100,
    hyperparameters=hyperparameters,
    disable_profiler=True,  # Disabling SageMaker Profiler to avoid overheads during benchmarking
    debugger_hook_config=False,  # Disabling SageMaker Debugger to avoid overheads during benchmarking
)

# start the training job
native_estimator.fit(wait=False)
native_estimator.latest_training_job.name

### Training with Optimized PyTorch

In [None]:
!pygmentize ./scripts/launch_sm_training_compiler.py

In [None]:
from sagemaker.huggingface import HuggingFace, TrainingCompilerConfig

# configure the training job
optimized_estimator = HuggingFace(
    entry_point="find_max_batch_size.py",  # Wrapper around training script that enables multi GPU training
    compiler_config=TrainingCompilerConfig(),  # We are enabling SageMaker Training Compiler here !
    source_dir="./scripts",
    instance_type=INSTANCE_TYPE,
    instance_count=1,
    role=role,
    volume_size=100,
    py_version="py38",
    transformers_version="4.11.0",
    pytorch_version="1.9.0",
    hyperparameters=hyperparameters,
    disable_profiler=True,  # Disabling SageMaker Profiler to avoid overheads during benchmarking
    debugger_hook_config=False,  # Disabling SageMaker Debugger to avoid overheads during benchmarking
)

# start the training job
optimized_estimator.fit(wait=False)
optimized_estimator.latest_training_job.name

### Wait for training jobs to complete

In [None]:
waiter = native_estimator.sagemaker_session.sagemaker_client.get_waiter(
    "training_job_completed_or_stopped"
)
waiter.wait(TrainingJobName=native_estimator.latest_training_job.name)
waiter = optimized_estimator.sagemaker_session.sagemaker_client.get_waiter(
    "training_job_completed_or_stopped"
)
waiter.wait(TrainingJobName=optimized_estimator.latest_training_job.name)

## Analysis

### Load logs for compiler optimized case

In [None]:
%%capture optimized

# access the logs of the optimized training job
optimized_estimator.sagemaker_session.logs_for_job(optimized_estimator.latest_training_job.name)

In [None]:
for line in optimized.stdout.split("\n"):
    if 'result' in line and 'max_batch_size' in line or 'Total max batch' in line:
        print(line)

### Load logs for native case

In [None]:
%%capture native

# access the logs of the native training job
native_estimator.sagemaker_session.logs_for_job(native_estimator.latest_training_job.name)

In [None]:
for line in native.stdout.split("\n"):
    if 'result' in line and 'max_batch_size' in line or 'Total max batch' in line:
        print(line)

## Clean up

Stop all training jobs launched if the jobs are still running.

In [None]:
import boto3

sm = boto3.client("sagemaker")


def stop_training_job(name):
    status = sm.describe_training_job(TrainingJobName=name)["TrainingJobStatus"]
    if status == "InProgress":
        sm.stop_training_job(TrainingJobName=name)


stop_training_job(native_estimator.latest_training_job.name)
stop_training_job(optimized_estimator.latest_training_job.name)

Also, to find instructions on cleaning up resources, see [Clean Up](https://docs.aws.amazon.com/sagemaker/latest/dg/ex1-cleanup.html) in the *Amazon SageMaker Developer Guide*.