# Model Training

In this notebook we will train the model on the data we prepared in [Module 1: Preprocessing](../01_preprocessing/data_preprocessing.ipynb) using the AWS-managed Tensorflow container and a script describing the model used for classification.

## Import modules and initialize parameters for this notebook

In [None]:
import sagemaker
from sagemaker import get_execution_role

role = get_execution_role()
sagemaker_session = sagemaker.Session()

account = sagemaker_session.account_id()
region = sagemaker_session.boto_region_name
default_bucket = sagemaker_session.default_bucket() # or use your own custom bucket name
prefix = 'postprocessing'

In [None]:
processed_data_prefix = f'preprocess/outputs' #prefix generated by data processing module
processed_data_s3_uri = f's3://{default_bucket}/{processed_data_prefix}'

In [None]:
from sagemaker.inputs import TrainingInput
from sagemaker.workflow.steps import TrainingStep
from sagemaker.tensorflow import TensorFlow

TF_FRAMEWORK_VERSION = '2.4.1'

hyperparameters = {'initial_epochs':     5,
                   'batch_size':         8,
                   'fine_tuning_epochs': 20, 
                   'dropout':            0.4,
                   'data_dir':           '/opt/ml/input/data'}

metric_definitions = [{'Name': 'loss',      'Regex': 'loss: ([0-9\\.]+)'},
                  {'Name': 'acc',       'Regex': 'accuracy: ([0-9\\.]+)'},
                  {'Name': 'val_loss',  'Regex': 'val_loss: ([0-9\\.]+)'},
                  {'Name': 'val_acc',   'Regex': 'val_accuracy: ([0-9\\.]+)'}]


distribution = {'parameter_server': {'enabled': False}}
DISTRIBUTION_MODE = 'FullyReplicated'
    
train_in = TrainingInput(s3_data=processed_data_s3_uri +'/train', distribution=DISTRIBUTION_MODE)
val_in   = TrainingInput(s3_data=processed_data_s3_uri +'/valid', distribution=DISTRIBUTION_MODE)
test_in  = TrainingInput(s3_data=processed_data_s3_uri +'/test', distribution=DISTRIBUTION_MODE)

inputs = {'train':train_in, 'test': test_in, 'validation': val_in}

training_instance_type = 'ml.c5.4xlarge'

training_instance_count = 1

In [None]:
model_path = f"s3://{default_bucket}/{prefix}"

estimator = TensorFlow(entry_point='train-mobilenet.py',
               source_dir='code',
               output_path=model_path,
               instance_type=training_instance_type,
               instance_count=training_instance_count,
               distribution=distribution,
               hyperparameters=hyperparameters,
               metric_definitions=metric_definitions,
               role=role,
               framework_version=TF_FRAMEWORK_VERSION, 
               py_version='py37',
               base_job_name=prefix,
               script_mode=True)

In [None]:
estimator.fit(inputs)

In [None]:
training_job_name = estimator.latest_training_job.name

print(f"Model artifact file is uploaded here: {model_path}/{training_job_name}/output ========")

