# Train your own model with your data
This notebook shows you how to train a model for your own dataset

In [None]:
import os
import json
import boto3
import sagemaker
import random
from sklearn.model_selection import train_test_split

from src.sm_utils import parse_s3_url

In [None]:
with open("stack_outputs.json") as f:
    sagemaker_config = json.load(f)
s3_bucket = sagemaker_config["S3Bucket"]

# Data preprocessing
Get the `output.manifest` file from the labelling job

In [None]:
label_job_id = "{}-labelling-job".format(sagemaker_config["SolutionPrefix"])

In [None]:
s3_client = boto3.client('s3')
# output manifest of the labelling job
output_manifest_location = "custom_data/output/{}/manifests/output/output.manifest".format(
    label_job_id)

item_object = s3_client.get_object(Bucket=s3_bucket, 
                                   Key=output_manifest_location)
body = item_object['Body'].read().decode('utf-8')

Randomly split the data into train and test

In [None]:
train_size, test_size = 0.8, 0.2

In [None]:
def process_manifest(entries, train_or_test):
    with open("{}.manifest".format(train_or_test), "w") as f:
        for entry in entries:
            f.write(entry+"\n")
            if len(entry) == 0:
                continue
            json_entry = json.loads(entry)
            source_image_bucket, source_image_key = parse_s3_url(json_entry["source-ref"])
            copy_location = "custom_data/training/{}/{}".format(
                train_or_test, os.path.basename(source_image_key))
            s3_client.copy(CopySource={"Key": source_image_key, 
                                       "Bucket": source_image_bucket},
                           Bucket=s3_bucket,
                           Key=copy_location)

In [None]:
assert train_size+test_size == 1.0, "train + test must be equal to 1.0"

train_entries, test_entries = train_test_split(
    body.split("\n"), train_size=train_size, test_size=test_size)

process_manifest(train_entries, "train")
process_manifest(train_entries, "test")

In [None]:
!aws s3 cp train.manifest s3://$s3_bucket/custom_data/training/train/train.manifest
!aws s3 cp test.manifest s3://$s3_bucket/custom_data/training/test/test.manifest

!rm train.manifest
!rm test.manifest

# Word segmentation
In this section we will train an object detection model to locate all the words in the image passage.
You can look into the details of the model and algorithm in `src/word_and_line_segmentation.py`.

In [None]:
word_segmentation_parameters = {    
    "min_c": 0.01,
    "overlap_thres": 0.10,
    "topk": 150,
    "epoch": 401,
    "image_size": 500,
    "expand_bb_scale": 0.00,
    "batch_size": 1,
    "gpu_count": 1,
    
    "train_path": "train",
    "train_annotation_filename": "train.manifest",
    
    "test_path": "test",
    "test_annotation_filename": "test.manifest"
}

In [None]:
metric_definitions = [
    {'Name': 'Epoch', 'Regex': 'Epoch: ([-+]?[0-9]*[.]?[0-9]+([eE][-+]?[0-9]+)?)'},
    {'Name': 'train_loss', 'Regex': 'train_loss: ([-+]?[0-9]*[.]?[0-9]+([eE][-+]?[0-9]+)?)'},
    {'Name': 'test_loss',  'Regex': 'test_loss: ([-+]?[0-9]*[.]?[0-9]+([eE][-+]?[0-9]+)?)'},
    {'Name': 'test accuracy',  'Regex': 'test accuracy: ([-+]?[0-9]*[.]?[0-9]+([eE][-+]?[0-9]+)?)'},
    {'Name': 'mae', 'Regex': 'mae: ([-+]?[0-9]*[.]?[0-9]+([eE][-+]?[0-9]+)?)'}
]

In [None]:
from sagemaker.mxnet import MXNet

session = sagemaker.session.Session()
role = sagemaker_config["SageMakerIamRole"]

estimator = MXNet(entry_point='word_and_line_segmentation.py',
                  source_dir='src',
                  role=role,
                  train_instance_type=sagemaker_config["SageMakerTrainingInstanceType"],
                  train_instance_count=1,
                  output_path="s3://"+s3_bucket+"/word_segmentation_training/",
                  framework_version='1.6.0',
                  py_version='py3',
                  metric_definitions=metric_definitions,
                  base_job_name=sagemaker_config["SolutionPrefix"]+"-word-seg",
                  hyperparameters=word_segmentation_parameters,
                  sagemaker_session=session
                 )

estimator.fit({"train": "s3://{}/custom_data/training/".format(s3_bucket)})

# Handwriting recognition
In this section we will train an handwriting recognition model to transcribe all the words in a line.
You can look into the details of the model and algorithm in `src/handwriting_line_recognition.py`.

In [None]:
metric_definitions = [
    {'Name': 'Epoch', 'Regex': 'Epoch: ([-+]?[0-9]*[.]?[0-9]+([eE][-+]?[0-9]+)?)'},
    {'Name': 'train_loss', 'Regex': 'train_loss: ([-+]?[0-9]*[.]?[0-9]+([eE][-+]?[0-9]+)?)'},
    {'Name': 'test_loss',  'Regex': 'test_loss: ([-+]?[0-9]*[.]?[0-9]+([eE][-+]?[0-9]+)?)'}
]

In [None]:
handwriting_recognition_parameters = {    
    "learning_rate": 0.00005,
    "random_x_translation": 0.10,
    "random_y_translation": 0.10,
    "random_x_scaling": 0.01,
    "random_y_scaling": 0.1,
    
    "batch_size": 1,
    
    "rnn_layers": 1,
    "rnn_hidden_states": 128,
    "line_or_word": "word",
    
    "train_path": "train",
    "train_annotation_filename": "train.manifest",
    
    "test_path": "test",
    "test_annotation_filename": "test.manifest"
}

In [None]:
from sagemaker.mxnet import MXNet

session = sagemaker.session.Session()
role = sagemaker_config["SageMakerIamRole"]

estimator = MXNet(entry_point='handwriting_line_recognition.py',
                  source_dir='src',
                  role=role,
                  train_instance_type=sagemaker_config["SageMakerTrainingInstanceType"],
                  train_instance_count=1,
                  output_path="s3://"+s3_bucket+"/handwriting_line_recognition/",
                  framework_version='1.6.0',
                  py_version='py3',
                  metric_definitions=metric_definitions,
                  base_job_name=sagemaker_config["SolutionPrefix"]+"-line-reg",
                  hyperparameters=handwriting_recognition_parameters,
                  sagemaker_session=session,
                 )

estimator.fit({"train": "s3://{}/custom_data/training/".format(s3_bucket)})

# Navigation
- Click [here](./2_label_own_data.ipynb) to create a new labelling job 
- Click [here](./5_endpoint_updates.ipynb) to make a sagemaker endpoint with your new model