In [18]:
import boto3
import sagemaker
import pandas as pd

sess = sagemaker.Session()
bucket = sess.default_bucket()
role = sagemaker.get_execution_role()
region = boto3.Session().region_name

sm = boto3.Session().client(service_name="sagemaker", region_name=region)

In [19]:
%store -r processed_train_data_s3_uri
%store -r processed_validation_data_s3_uri
%store -r processed_test_data_s3_uri
%store -r max_seq_length
%store -r experiment_name
%store -r trial_name

In [20]:
max_seq_length

183

### Initial Hyperparameters:

In [21]:
epochs = 3
learning_rate = 0.00001
epsilon = 0.00000001
train_batch_size = 128
validation_batch_size = 128
test_batch_size = 128
train_steps_per_epoch = 100
validation_steps = 100
test_steps = 100
train_instance_count = 1
train_instance_type = "ml.c5.9xlarge"
train_volume_size = 1024
use_xla = True
use_amp = True
freeze_bert_layer = False
enable_sagemaker_debugger = False
enable_checkpointing = False
enable_tensorboard = True
input_mode = "File"
run_validation = True
run_test = True
run_sample_predictions = True

### Checkpoint Location

In [22]:
import uuid

checkpoint_s3_prefix = "checkpoints/{}".format(str(uuid.uuid4()))
checkpoint_s3_uri = "s3://{}/{}/".format(bucket, checkpoint_s3_prefix)

print(checkpoint_s3_uri)

s3://sagemaker-us-east-1-211125778552/checkpoints/8fc50957-6520-426b-8616-ee8b42c5f3f1/


### Metrics

In [23]:
metrics_definitions = [
    {"Name": "train:loss", "Regex": "loss: ([0-9\\.]+)"},
    {"Name": "train:accuracy", "Regex": "accuracy: ([0-9\\.]+)"},
    {"Name": "validation:loss", "Regex": "val_loss: ([0-9\\.]+)"},
    {"Name": "validation:accuracy", "Regex": "val_accuracy: ([0-9\\.]+)"},
]

In [24]:
from sagemaker.tensorflow import TensorFlow

estimator = TensorFlow(
    entry_point="tf_bert_reviews_rec.py",
    source_dir="src",
    role=role,
    instance_count=train_instance_count,
    instance_type=train_instance_type,
    volume_size=train_volume_size,
    checkpoint_s3_uri=checkpoint_s3_uri,
    py_version="py37",
    framework_version="2.3.1",
    hyperparameters={
        "epochs": epochs,
        "learning_rate": learning_rate,
        "epsilon": epsilon,
        "train_batch_size": train_batch_size,
        "validation_batch_size": validation_batch_size,
        "test_batch_size": test_batch_size,
        "train_steps_per_epoch": train_steps_per_epoch,
        "validation_steps": validation_steps,
        "test_steps": test_steps,
        "use_xla": use_xla,
        "use_amp": use_amp,
        "max_seq_length": max_seq_length,
        "freeze_bert_layer": freeze_bert_layer,
        "enable_sagemaker_debugger": enable_sagemaker_debugger,
        "enable_checkpointing": enable_checkpointing,
        "enable_tensorboard": enable_tensorboard,
        "run_validation": run_validation,
        "run_test": run_test,
        "run_sample_predictions": run_sample_predictions,
    },
    input_mode=input_mode,
    metric_definitions=metrics_definitions
)

### Training the BERT model

In [25]:
experiment_config = {"ExperimentName": experiment_name, "TrialName": trial_name, "TrialComponentDisplayName": "train"}

In [26]:
from sagemaker.inputs import TrainingInput

s3_input_train_data = TrainingInput(s3_data=processed_train_data_s3_uri, distribution="ShardedByS3Key")
s3_input_validation_data = TrainingInput(s3_data=processed_validation_data_s3_uri, distribution="ShardedByS3Key")
s3_input_test_data = TrainingInput(s3_data=processed_test_data_s3_uri, distribution="ShardedByS3Key")

print(s3_input_train_data.config)
print(s3_input_validation_data.config)
print(s3_input_test_data.config)

{'DataSource': {'S3DataSource': {'S3DataType': 'S3Prefix', 'S3Uri': 's3://sagemaker-us-east-1-211125778552/sagemaker-scikit-learn-2024-04-20-05-17-50-785/output/bert-train', 'S3DataDistributionType': 'ShardedByS3Key'}}}
{'DataSource': {'S3DataSource': {'S3DataType': 'S3Prefix', 'S3Uri': 's3://sagemaker-us-east-1-211125778552/sagemaker-scikit-learn-2024-04-20-05-17-50-785/output/bert-validation', 'S3DataDistributionType': 'ShardedByS3Key'}}}
{'DataSource': {'S3DataSource': {'S3DataType': 'S3Prefix', 'S3Uri': 's3://sagemaker-us-east-1-211125778552/sagemaker-scikit-learn-2024-04-20-05-17-50-785/output/bert-test', 'S3DataDistributionType': 'ShardedByS3Key'}}}


In [27]:
estimator.fit(
    inputs={"train": s3_input_train_data, "validation": s3_input_validation_data, "test": s3_input_test_data},
    experiment_config=experiment_config,
    wait=False,
)

In [28]:
training_job_name = estimator.latest_training_job.name
print("Training Job Name:  {}".format(training_job_name))

Training Job Name:  tensorflow-training-2024-04-20-05-32-55-117


In [29]:
%%time

estimator.latest_training_job.wait(logs=True)

2024-04-20 05:32:57 Starting - Starting the training job...
2024-04-20 05:33:19 Starting - Preparing the instances for trainingProfilerReport-1713591175: InProgress
...
2024-04-20 05:33:44 Downloading - Downloading input data...
2024-04-20 05:34:25 Training - Training image download completed. Training in progress....2024-04-20 05:34:43.521241: W tensorflow/core/profiler/internal/smprofiler_timeline.cc:460] Initializing the SageMaker Profiler.
2024-04-20 05:34:43.521388: W tensorflow/core/profiler/internal/smprofiler_timeline.cc:105] SageMaker Profiler is not enabled. The timeline writer thread will not be started, future recorded events will be dropped.
2024-04-20 05:34:43.549803: W tensorflow/core/profiler/internal/smprofiler_timeline.cc:460] Initializing the SageMaker Profiler.
2024-04-20 05:34:44,724 sagemaker-training-toolkit INFO     Imported framework sagemaker_tensorflow_container.training
2024-04-20 05:34:44,732 sagemaker-training-toolkit INFO     No GPUs detected (normal if n

In [30]:
estimator.training_job_analytics.dataframe()

Unnamed: 0,timestamp,metric_name,value
0,0.0,train:loss,0.708600
1,60.0,train:loss,0.695357
2,120.0,train:loss,0.681417
3,180.0,train:loss,0.665957
4,240.0,train:loss,0.645329
...,...,...,...
101,1080.0,validation:loss,0.420300
102,2160.0,validation:loss,0.402800
103,0.0,validation:accuracy,0.896600
104,1080.0,validation:accuracy,0.890700


AttributeError: 'TensorFlow' object has no attribute 'predict'