## Make all necessary imports

In [None]:
import sagemaker
from sagemaker.tensorflow import TensorFlow
import os
from sagemaker.debugger import TensorBoardOutputConfig
from datetime import datetime

## Setup AWS General ressources

In [None]:
role = 'aws-role'
sagemaker_session = sagemaker.Session()
region = sagemaker_session.boto_session.region_name

## Set up variables

In [None]:
date = datetime.now().strftime("%y%m%d-%H%M%S")
epochs = 10
instance_count = 1
entry_point = 'train.py'
device = 'cpu'
batch_size = 128
learning_rate = 0.001

s3_uri_model = 's3://path-to-s3-bucket/models/'
s3_uri_training_data = 's3://path-to-s3-bucket/datasets/datapath/'
s3_uri = 's3://path-to-s3-bucket/'

instance_type = "ml.m5.4xlarge"

input_channels = { 'train' : s3_uri_training_data} # put the S3 URI for the datasets to be loaded here

job_name = 'Non-dist-33000-artType-{}i-{}b-{}-{}e-{}'.format(
    instance_count,
    batch_size,
    instance_type.replace('.','-').replace('ml-', ''),
    epochs,
    date)

## Set up Tensorboard and metrics

In [None]:
LOG_DIR="/opt/ml/output/tensorboard"

output_path = os.path.join(
    s3_uri, "tensorboard-logs", job_name
)

tensorboard_output_config = TensorBoardOutputConfig(
    s3_output_path=os.path.join(output_path, 'tensorboard'),
    container_local_output_path=LOG_DIR
)

metric_definitions = [
    {"Name": "train:loss", "Regex": ".*loss: ([0-9\\.]+) - accuracy: [0-9\\.]+.*"},
    {"Name": "train:accuracy", "Regex": ".*loss: [0-9\\.]+ - accuracy: ([0-9\\.]+).*"},
    {
        "Name": "validation:accuracy",
        "Regex": ".*step - loss: [0-9\\.]+ - accuracy: [0-9\\.]+ - val_loss: [0-9\\.]+ - val_accuracy: ([0-9\\.]+).*",
    },
    {
        "Name": "validation:loss",
        "Regex": ".*step - loss: [0-9\\.]+ - accuracy: [0-9\\.]+ - val_loss: ([0-9\\.]+) - val_accuracy: [0-9\\.]+.*",
    },
    {
        "Name": "sec/sample",
        "Regex": ".* - \d+s (\d+)[mu]s/sample - loss: [0-9\\.]+ - accuracy: [0-9\\.]+ - val_loss: [0-9\\.]+ - val_accuracy: [0-9\\.]+",
    },
]

## Construct the TensorFlow Estimator and start the training job for non distributed training

In [None]:
estimator = TensorFlow(entry_point='train.py',
                      role=role,
                      instance_count=instance_count,
                      instance_type=instance_type,
                      sagemaker_session=sagemaker_session,
                      framework_version="2.12",
                      py_version="py310",
                      model_dir=s3_uri_model,
                      tensorboard_output_config=tensorboard_output_config,
                      hyperparameters={
                          'epochs': epochs
                      },
                      script_mode=False)
estimator.fit(inputs=input_channels, wait=False,
              job_name=job_name)

## Construct the TensorFlow Estimator and start the training job for parameter server strategy training

In [None]:
estimator = TensorFlow(entry_point='train_PS_chunking.py',
                      role=role,
                      instance_count=instance_count,
                      instance_type=instance_type,
                      sagemaker_session=sagemaker_session,
                      framework_version="2.12",
                      py_version="py310",
                      model_dir=s3_uri_model,
                      tensorboard_output_config=tensorboard_output_config,
                      hyperparameters={
                          'epochs': epochs,
                          'learning_rate': learning_rate,
                          'batch_size' : batch_size
                      },
                      metric_definitions=metric_definitions,
                      distribution={"parameter_server": {"enabled": True}},
                      script_mode=False)
estimator.fit(inputs = input_channels ,wait=False, # True makes notebook wait and logs output in real time
              job_name=job_name)

## Construct the TensorFlow Estimator and start the training job for multi worker mirrored strategy training

In [None]:
batch_size = batch_size * instance_count
estimator = TensorFlow(entry_point='train_MWM.py',
                      role=role,
                      instance_count=instance_count,
                      instance_type=instance_type,
                      sagemaker_session=sagemaker_session,
                      framework_version="2.12",
                      py_version="py310",
                      model_dir=s3_uri_model,
                      tensorboard_output_config=tensorboard_output_config,
                      hyperparameters={
                          'epochs': epochs,
                          'learning_rate': learning_rate,
                          'batch_size' : batch_size
                      },
                      metric_definitions=metric_definitions,
                      distribution= {"multi_worker_mirrored_strategy": {"enabled": True}},
                      script_mode=False)
estimator.fit(inputs = input_channels ,wait=False, # True makes notebook wait and logs output in real time
              job_name=job_name)