# Training Pipeline Example

## 1. Imports

In [1]:
%reload_ext dotenv
%dotenv

import warnings
warnings.filterwarnings(action='ignore', module='.*paramiko.*')

import os

from src.training_pipeline import train_pipeline
import io

from src.types.hyperparameters import HyperParameters
from src.types.model_metadata import ModelMetadata

## 2. Defining model parameters and reward function

In [2]:
# Default values set from official documentation
model_name = 'rl-deepracer-sagemaker'
hyperparameters = HyperParameters()
model_metadata = ModelMetadata()

In [3]:
hyperparameters

HyperParameters(batch_size=64, beta_entropy=0.01, discount_factor=0.999, e_greedy_value=0.05, epsilon_steps=10000, exploration_type=<ExplorationType.CATEGORICAL: 'categorical'>, loss_type=<LossType.HUBER: 'huber'>, lr=0.0003, num_episodes_between_training=40, num_epochs=3, stack_size=1, term_cond_avg_score=100000, term_cond_max_episodes=100000)

In [4]:
model_metadata

ModelMetadata(action_space_type=<ActionSpaceType.CONTINUOUS: 'continuous'>, action_space=ContinuousActionSpace(steering_angle=SteeringAngle(high=30.0, low=-30.0), speed=Speed(high=4.0, low=1.0)), version=5, training_algorithm=<TrainingAlgorithm.PPO: 'ppo'>, neural_network=<NeuralNetwork.DEEP_CONVOLUTIONAL_NETWORK_SHALLOW: 'DEEP_CONVOLUTIONAL_NETWORK_SHALLOW'>, sensor=[<Sensor.FRONT_FACING_CAMERA: 'FRONT_FACING_CAMERA'>])

In [5]:
def reward_function(params):
  reward = 1.0
  return float(reward)

## 3. Pipelines

### Training

In [6]:
train_pipeline(
  model_name=model_name,
  hyperparameters=hyperparameters,
  model_metadata=model_metadata,
  reward_function=reward_function,
  overwrite=True
)

Data uploaded successfully to custom files
The reward function copied successfully to models folder at rl-deepracer-sagemaker/reward_function.py
Verified: Training params file exists at rl-deepracer-sagemaker/training_params.yaml
Upload successfully the RoboMaker training configurations
Exposing the envs from config.env and system.env
Starting model training


In [7]:
from src.helpers.widget_logs import start_log_viewer

start_log_viewer(
    service_filter="rl_coach",
    wait_time=30, 
    refresh_interval=2.0, 
    tail=100
)

Textarea(value='', description='Logs (rl_coach):', layout=Layout(height='500px', width='100%'))

HBox(children=(Button(button_style='success', description='Start Refresh', style=ButtonStyle()), Button(button…