## Make all necessary imports

In [None]:
import sagemaker
from sagemaker.tensorflow import TensorFlow
import os
from sagemaker.debugger import TensorBoardOutputConfig
from datetime import datetime

## Setup AWS General ressources

In [None]:
role = 'aws-role'
sagemaker_session = sagemaker.Session()
region = sagemaker_session.boto_session.region_name

## Set up variables

In [None]:
date = datetime.now().strftime("%y%m%d-%H%M%S")
epochs = 5
instance_count = 1
entry_point = 'train.py'
device = 'gpu'

s3_uri_model = 's3://16062023-sagemaker-bucket-01/models/'
s3_uri_training_data = 's3://16062023-sagemaker-bucket-01/datasets/'
s3_uri = 's3://16062023-sagemaker-bucket-01'
input_channels = { 'train' : 's3://16062023-sagemaker-bucket-01/datasets/'} # put the S3 URI for the datasets to be loaded here


instance_type = "ml.c4.xlarge" # choose instance

job_name = '{}-TensorFlow-Mnist-data-loading-{}-{}-{}-{}e'.format(
    date,
    instance_count,
    instance_type.replace('.','-').replace('ml-', ''),
    device,
    epochs)

## Set up Tensorboard

In [None]:
LOG_DIR="/opt/ml/output/tensorboard"

output_path = os.path.join(
    s3_uri, job_name
)

tensorboard_output_config = TensorBoardOutputConfig(
    s3_output_path=os.path.join(output_path, 'tensorboard'),
    container_local_output_path=LOG_DIR
)

## Construct the TensorFlow Estimator

In [None]:
estimator = TensorFlow(entry_point='train.py',
                      role=role,
                      instance_count=instance_count,
                      instance_type=instance_type,
                      sagemaker_session=sagemaker_session,
                      framework_version="2.12",
                      py_version="py310",
                      model_dir=s3_uri_model,
                      tensorboard_output_config=tensorboard_output_config,
                      hyperparameters={
                          'epochs': epochs
                      },
                      distribution={"parameter_server": {"enabled": True}},  #choose distribution strategy if necessary
                      script_mode=False)


## Start the training Job

In [None]:
estimator.fit(inputs=input_channels, wait=False,  # True makes notebook wait and logs output in real time
              job_name=job_name)