TODO:
 * Stepper motor skipping steps--more current (better power supply)?
 * Normalize action and observation spaces (see: https://ai.stackexchange.com/questions/21477/why-do-we-also-need-to-normalize-the-actions-values-on-continuous-action-spaces)

In [1]:
# !python -m pip install gymnasium==0.28.1
# !python -m pip install stable-baselines3[extra]==2.1.0
# !python -m pip install ax-platform==0.3.4
# !python -m pip install wandb
# !python -m pip install onnx==1.14.1
# !python -m pip install onnxruntime==1.16.0

In [2]:
# Check versions
import importlib.metadata

print(f"torch version: {importlib.metadata.version('torch')}")
print(f"gymnasium version: {importlib.metadata.version('gymnasium')}")
print(f"sb3 version: {importlib.metadata.version('stable-baselines3')}")
print(f"cv2 version: {importlib.metadata.version('opencv-python')}")
print(f"ax version: {importlib.metadata.version('ax-platform')}")

torch version: 1.13.1
gymnasium version: 0.28.1
sb3 version: 2.1.0
cv2 version: 4.8.0.76
ax version: 0.3.4


In [16]:
# Python Standard Library
import time
import datetime
import os
import random
import logging
import math
import csv
from typing import Any, Dict, Tuple, Union

# Encoder and stepper controls (local)
from control_comms import ControlComms, StatusCode, DebugLevel

# Third-party packages
import gymnasium as gym
import matplotlib.pyplot as plt
import numpy as np
import wandb
import torch as th
import onnx
import onnxruntime as ort

# Reinforcement model modules
import stable_baselines3 as sb3
from stable_baselines3.common import env_checker
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.logger import KVWriter, Logger

# Meta Ax
from ax.service.ax_client import AxClient
from ax.service.managed_loop import optimize
from ax.utils.notebook.plotting import render
from ax.utils.tutorials.cnn_utils import train, evaluate

## Settings

In [17]:
# Communication settings
SERIAL_PORT = "COM4"    # Check your devices
BAUD_RATE = 1_000_000   # Must match what's in the Arduino code!
CTRL_TIMEOUT = 2.0      # Seconds
DEBUG_LEVEL = DebugLevel.DEBUG_ERROR

# Reinforcement learning settings
K_T = 1                 # Reward constant to multiply theta (angle of encoder)
K_DT = 0.01             # Reward constant to multiply dtheto/dt (angular velocity of encoder)
K_P = 0.001             # Reward constant to multiply phi (angle of stepper)
K_DP = 0.00001          # Reward constant to multiply dphi/dt (angular velocity of stepper)
REWARD_OOB = -500       # Reward (penalty) for having the stepper motor move out of bounds (OOB)
REWARD_LAND = 100       # Reward for having the pendulum softly reach the top of the swing up
REWARD_CRASH = -200     # Reward (penalty) for having the pendulum swing too fast into the crash zone
ENC_ANGLE_NORM = 180    # Divide by this to normalize +/-180 deg angle to +/-1
ENC_GOAL_ANGLE = 5      # Reward agent and end episode if pendulum is within the goal (+/-5 deg)
ENC_GOAL_VELOCITY = 540 # ...and velocity is <= this amount (deg/sec)
ENC_CRASH_ANGLE = 45    # Penalize agent and end episode if pendulum is within the crash zone (+/- 45 deg)
ENC_CRASH_VELOCITY = 540 # ...and veolicy is > this amount (deg/sec)
STP_ACTIONS_MAP = {
    0: -10,
    1: 0,
    2: 10,
}
STP_ANGLE_MIN = -180    # Episode ends if stepper goes beyond this angle
STP_ANGLE_MAX = 180     # Episode ends if stepper goes beyond this angle
STP_ANGLE_NORM = 180    # Divide by this to normalize +/-180 deg angle to +/-1

# Environment settings
ENV_TIMEOUT = 30.0
RESET_SETTLE_TIME = 8.0 # Seconds to wait after reset to start moving again

# Angle constants
ENC_OFFSET = 180.0      # Pendulum in the "up" position should be 0 deg
ANG_REV = 360           # Degrees in a single revolution

# Conversion settings
REP_SAMPLE_SET_PATH = "rep-sample.npy"  # Where to save representative sample set
OUT_ONNX_PATH = "pendulum-policy.onnx"  # Where to save the final policy network

In [18]:
# Communication constants
CMD_SET_HOME = 0        # Set current stepper position as home (0 deg)
CMD_MOVE_TO = 1         # Move stepper to a particular position (deg)
CMD_MOVE_BY = 2         # Move stepper by a given amount (deg)
CMD_SET_STEP_MODE = 3   # Set step mode
CMD_SET_BLOCK_MODE = 4  # Set blocking mode
CMD_NOP = 5             # Take no action, just receive observation
CMD_MOVE_HOME = 6       # Slowly move pendulum back to starting position
STEP_MODE_1 = 0         # 1 division per step
STEP_MODE_2 = 1         # 2 divisions per step
STEP_MODE_4 = 2         # 4 divisions per step
STEP_MODE_8 = 3         # 8 divisions per step
STEP_MODE_16 = 4        # 16 divisions per step
STATUS_OK = 0           # Stepper idle
STATUS_STP_MOVING = 1   # Stepper is currently moving

# Set to desired step mode
STEP_MODE = STEP_MODE_2

## Setup

In [19]:
# Close connection to Arduino board (if open)
try:
    controller.close()
except:
    pass

In [20]:
# Connect to Arduino board
controller = ControlComms(timeout=CTRL_TIMEOUT, debug_level=DEBUG_LEVEL)
ret = controller.connect(SERIAL_PORT, BAUD_RATE)
if ret is not StatusCode.OK:
    print("ERROR: Could not connect to board")

In [21]:
# Test basic comms
controller.step(CMD_SET_STEP_MODE, [STEP_MODE_8])
controller.step(CMD_SET_HOME, [0])
controller.step(CMD_SET_BLOCK_MODE, [1])
for i in range(10):
    controller.step(CMD_MOVE_BY, [10])

Error parsing message. Message received: Observation: 0.94, -0.00, 0.08, 0.02, dtime: 2.350

Error parsing message. Message received: Observation: 0.90, -0.50, 0.13, 0.69, dtime: 0.080

Error parsing message. Message received: Observation: 0.90, -0.19, 0.13, 0.00, dtime: 0.009

Error parsing message. Message received: Observation: 0.90, -0.09, 0.19, 0.73, dtime: 0.075

Error parsing message. Message received: Observation: 0.90, 0.21, 0.19, 0.00, dtime: 0.008

Error parsing message. Message received: Observation: 0.92, 0.27, 0.24, 0.73, dtime: 0.075

Error parsing message. Message received: Observation: 0.92, 0.74, 0.24, 0.00, dtime: 0.009

Error parsing message. Message received: Observation: 0.96, 0.53, 0.30, 0.73, dtime: 0.075

Error parsing message. Message received: Observation: -0.99, 0.58, 0.36, 0.69, dtime: 0.080

Error parsing message. Message received: Observation: -0.88, 1.59, 0.31, -0.75, dtime: 0.067

Error parsing message. Message received: Observation: -0.80, 1.26, 0.25, 

In [22]:
# Numpy test
action = np.array([[-25]])
action_list = action.flatten().tolist()
controller.step(CMD_MOVE_BY, action_list)

Error parsing message. Message received: Observation: -0.76, -0.19, 0.21, 0.00, dtime: 0.009



In [23]:
# Test hard limit (360 deg)
for i in range(20):
    resp = controller.step(CMD_MOVE_BY, [50])
    print(resp)
    time.sleep(0.1)

Error parsing message. Message received: Observation: -0.76, -0.21, 0.21, 0.00, dtime: 0.008

None
Error parsing message. Message received: Observation: -0.77, -0.63, 0.21, 0.00, dtime: 0.008

None
Error parsing message. Message received: Observation: -0.77, -0.56, 0.21, 0.00, dtime: 0.009

None
Error parsing message. Message received: Observation: -0.78, -0.83, 0.21, 0.00, dtime: 0.008

None
Error parsing message. Message received: Observation: -0.79, -0.83, 0.21, 0.00, dtime: 0.008

None
Error parsing message. Message received: Observation: -0.79, -0.56, 0.21, 0.00, dtime: 0.009

None
Error parsing message. Message received: Observation: -0.80, -1.04, 0.21, 0.00, dtime: 0.008

None
Error parsing message. Message received: Observation: -0.81, -1.04, 0.21, 0.00, dtime: 0.008

None
Error parsing message. Message received: Observation: -0.82, -1.30, 0.21, 0.00, dtime: 0.009

None
Error parsing message. Message received: Observation: -0.83, -1.25, 0.21, 0.00, dtime: 0.008

None
Error pars

In [24]:
# Stress/torque test
resp = controller.step(CMD_MOVE_HOME, [0])
controller.step(CMD_SET_BLOCK_MODE, [1])
print(resp)
time.sleep(2.0)
action = 180
for i in range(10):
    action = -180 if action == 180 else 180
    resp = controller.step(CMD_MOVE_BY, [action])
    print(resp)
    time.sleep(0.01)

Error parsing message. Message received: Observation: -0.98, -2.08, 0.21, 0.00, dtime: 0.008

Error parsing message. Message received: Observation: -0.99, -1.87, 0.21, 0.00, dtime: 0.008

None
Error parsing message. Message received: Observation: 1.00, -1.48, 0.21, 0.00, dtime: 0.009

None
Error parsing message. Message received: Observation: 0.83, -2.24, 0.26, 0.73, dtime: 0.075

None
Error parsing message. Message received: Observation: 0.81, -2.08, 0.26, 0.00, dtime: 0.008

None
Error parsing message. Message received: Observation: 0.70, -1.53, 0.31, 0.73, dtime: 0.075

None
Error parsing message. Message received: Observation: 0.69, -0.74, 0.31, 0.00, dtime: 0.009

None
Error parsing message. Message received: Observation: 0.68, -0.83, 0.31, 0.00, dtime: 0.008

None
Error parsing message. Message received: Observation: 0.68, -0.42, 0.31, 0.00, dtime: 0.008

None
Error parsing message. Message received: Observation: 0.68, -0.19, 0.31, 0.00, dtime: 0.009

None
Error parsing message. 

In [25]:
# Comms stress test
# resp = controller.step(CMD_MOVE_TO, [0])
# print(resp)
# time.sleep(2.0)
# for i in range(100000):
#     resp = controller.step(CMD_MOVE_BY, [0])
#     time.sleep(0.001)

In [26]:
# Move home
resp = controller.step(CMD_MOVE_HOME, [0])
print(resp)

Error parsing message. Message received: Observation: 0.68, 0.00, 0.31, 0.00, dtime: 0.009

None


In [27]:
# Close comms
controller.close()

In [None]:
# Basic step test
action = np.array([[-10]])

controller = ControlComms(timeout=CTRL_TIMEOUT, debug_level=DEBUG_LEVEL)
ret = controller.connect(SERIAL_PORT, BAUD_RATE)
if ret is not StatusCode.OK:
    print("ERROR: Could not connect to board")

# Set to blocking mode
resp = controller.step(CMD_SET_BLOCK_MODE, [1])
print(resp)
    
# Reset
time.sleep(2.0)
resp = controller.step(CMD_MOVE_HOME, [0.0])
print(resp)
time.sleep(2.0)

# Loop
for i in range(10):
    action_list = action.flatten().tolist()
    resp = controller.step(CMD_MOVE_BY, action_list)
    print(resp)
    time.sleep(0.1)

# Move home and close
resp = controller.step(CMD_MOVE_TO, [0])
print(resp)
controller.close()

In [14]:
# Log in to Weights & Biases
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mshawnhymel[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [15]:
# Make wandb be quiet
os.environ["WANDB_SILENT"] = "true"
logger = logging.getLogger("wandb")
logger.setLevel(logging.ERROR)

## Helper functions

In [16]:
def set_random_seeds(seed: int) -> None:
    """
    Seed the different random generators.
    https://stable-baselines3.readthedocs.io/en/master/_modules/stable_baselines3/common/utils.html
    """
    
    # Set seed for Python random and NumPy
    random.seed(seed)
    np.random.seed(seed)

In [17]:
def calc_angular_velocity(ang, ang_prev, dt):
    """
    Estimate engular velocity based on current and previous readings. Note that we assume that the
    object in question cannot go the long way around (e.g. more than 180 deg).
    """
    da = ang - ang_prev
    if da > (ANG_REV / 2):
        da -= ANG_REV
    elif da < -(ANG_REV / 2):
        da += ANG_REV
    
    return da / dt

## Build gym Environment

Subclass gymnasium.Env to create a custom environment. Learn more here:<br>
https://gymnasium.farama.org/tutorials/gymnasium_basics/environment_creation/

In [28]:
class Pendulum(gym.Env):
    """
    Subclass gymnasium Env class
    
    This is the gym wrapper class that allows our agent to interact with our environment. We need
    to implement four main methods: step(), reset(), render(), and close(). We should also define
    the action_space and observation space as class members.
    
    Note: on Windows, time.sleep() is only accurate to around 10ms. As a result, setting fps_limit
    will give you a "best effort" limit.
    
    More information: https://gymnasium.farama.org/api/env/
    """
    
    def __init__(
        self,
        serial_port,
        baud_rate,
        ctrl_timeout=1.0,
        debug_level=DebugLevel.DEBUG_NONE,
        env_timeout=0.0, 
        stp_mode=STEP_MODE_8, 
        stp_blocking=False
    ):
        """
        Set up the environment, action, and observation shapes. Optional tiemout in seconds.
        """
        
        # Call superclass's constructor
        super().__init__()
        
        # Connect to Arduino board
        self.ctrl = ControlComms(timeout=ctrl_timeout, debug_level=debug_level)
        try:
            self.ctrl.close()
        except:
            pass
        ret = self.ctrl.connect(serial_port, baud_rate)
        if ret is not StatusCode.OK:
            print("ERROR: Could not connect to board")
        
        # Define action space (scalar signifying how many degrees to move stepper by)
        self.action_space = gym.spaces.Discrete(len(STP_ACTIONS_MAP))
        
        # Define observation space 
        # [encoder angle, encoder angular velocity, stepper angle, stepper angular velocity]
        self.observation_space = gym.spaces.Box(
            low=np.array([-180, -np.inf, STP_ANGLE_MIN, -np.inf]),
            high=np.array([180, np.inf, STP_ANGLE_MAX, np.inf]),
            dtype=np.float32
        )
        
        # Record time from microcontroller and own elapsed time
        self.timestamp = 0
        self.timeout = env_timeout
        self.start_time = time.time()
        
        # Record previous encoder and stepper angles (to calculate velocities)
        self.angle_stp_prev = 0
        self.angle_enc_prev = 0
        
        # Set current stepper position as "home" and optionally set blocking
        self.ctrl.step(CMD_SET_STEP_MODE, [stp_mode])
        self.ctrl.step(CMD_SET_HOME, [0])
        if stp_blocking:
            self.ctrl.step(CMD_SET_BLOCK_MODE, [1])
        else:
            self.ctrl.step(CMD_SET_BLOCK_MODE, [0])
    
    def __del__(self):
        """
        Destructor: make sure to close the serial port
        """
        self.close()
    
    def step(self, action: np.ndarray):
        """
        What happens when you tell the stepper motor to do something then record the observation.
        """
        
        # Initialize return values
        obs = np.array([0.0, 0.0, 0.0, 0.0], dtype=np.float32)
        reward = 0.0
        info = {"error": False, "dtime": 0.0, "elapsed_time": 0.0}
        terminated = False
        truncated = False
        
        # Move the stepper motor and wait for a response
        resp = self.ctrl.step(CMD_MOVE_BY, [STP_ACTIONS_MAP[action]])
        
        # Figure out reward and episode termination based on observation
        if resp:
            
            # Extract information from controller response
            status, timestamp, terminated, angles = resp
            
            # Compute lapsed time (in seconds) from previous observation (milliseconds)
            info["dtime"] = (timestamp - self.timestamp) / 1000.0
            self.timestamp = timestamp
            
            # Offset encoder angle so that 0 deg is up
            angles[0] -= ENC_OFFSET
            
            # Calculate velocities
            dtheta = calc_angular_velocity(angles[0], self.angle_enc_prev, info['dtime'])
            dphi = calc_angular_velocity(angles[1], self.angle_stp_prev, info['dtime'])
            self.angle_enc_prev = angles[0]
            self.angle_stp_prev = angles[1]
            
            # Construct observation (normalized)
            obs[0] = angles[0] / ENC_ANGLE_NORM
            obs[1] = dtheta / ENC_ANGLE_NORM
            obs[2] = angles[1] / STP_ANGLE_NORM
            obs[3] = dphi / STP_ANGLE_NORM

            # Make sure stepper is in bounds
            if (angles[1] >= STP_ANGLE_MIN) and (angles[1] <= STP_ANGLE_MAX):
                
                # Calculate reward
                reward += -1 * (K_T * obs[0] ** 2 + 
                               K_DT * obs[1] ** 2 + 
                               K_P * obs[2] ** 2 +
                               K_DP * obs[3] ** 2)
                
                
                # If the pendulum is moving too fast in the crash zone, penalize and end
                if (abs(angles[0]) <= ENC_CRASH_ANGLE) and (abs(dtheta) > ENC_CRASH_VELOCITY):
                    reward += REWARD_CRASH
                    terminated = True

                # If the pendulum is moving slow enough in the goal zone, reward and end
                elif (abs(angles[0]) <= ENC_GOAL_ANGLE) and (abs(dtheta) <= ENC_GOAL_VELOCITY):
                    reward += REWARD_LAND
                    terminated = True
            
            # Stepper motor is out of bounds--terminate episode
            else:
                reward = REWARD_OOB
                terminated = True
        
        # Something is wrong with communication
        else:
            print("ERROR: Could not communicate with Arduino")
            info["error"] = True
            terminated = True
        
        # Calculate elapsed time
        info["elapsed_time"] = time.time() - self.start_time
        
        # Check if we've exceeded the time limit
        if not terminated and self.timeout > 0.0 and info["elapsed_time"] >= self.timeout:
            truncated = True
        
        return obs, reward, terminated, truncated, info
    
    def reset(self, seed=None):
        """
        Return the pendulum to the starting position
        """
        
        # Initialize return values
        obs = np.array([0.0, 0.0, 0.0, 0.0], dtype=np.float32)
        info = {"error": False, "dtime": 0, "elapsed_time": 0.0}
        
        # Reset timer
        self.start_time = time.time()
        
        # Let the pendulum fall and return to the starting position
        time.sleep(2.0)
        resp = self.ctrl.step(CMD_MOVE_HOME, [0.0])
        if resp:
            
            # Extract information from controller response
            status, timestamp, terminated, angles = resp
            
            # Compute lapsed time (in seconds) from previous observation (milliseconds)
            info["dtime"] = (timestamp - self.timestamp) / 1000.0
            self.timestamp = timestamp
            
            # Offset encoder angle so that 0 deg is up
            angles[0] -= ENC_OFFSET
            
            # Calculate velocities
            dtheta = calc_angular_velocity(angles[0], self.angle_enc_prev, info['dtime'])
            dphi = calc_angular_velocity(angles[1], self.angle_stp_prev, info['dtime'])
            self.angle_enc_prev = angles[0]
            self.angle_stp_prev = angles[1]
            
            # Construct observation (normalized)
            obs[0] = angles[0] / ENC_ANGLE_NORM
            obs[1] = dtheta / ENC_ANGLE_NORM
            obs[2] = angles[1] / STP_ANGLE_NORM
            obs[3] = dphi / STP_ANGLE_NORM
            
            # Let pendulum settle for a bit
            time.sleep(RESET_SETTLE_TIME)
            
        # Something is wrong with communication
        else:
            print("ERROR: Could not communicate with Arduino")
            info["error"] = True
            
        # Calculate elapsed time
        info["elapsed_time"] = time.time() - self.start_time
        
        return obs, info
    
    def close(self):
        """
        Close connection to Arduino
        """
        self.ctrl.close()

## Test gym Environment

Test the gym wrapper before training

In [29]:
# Create our environment
try:
    env.close()
except:
    pass
env = Pendulum(
        SERIAL_PORT,
        BAUD_RATE,
        ctrl_timeout=CTRL_TIMEOUT,
        debug_level=DEBUG_LEVEL,
        env_timeout=ENV_TIMEOUT, 
        stp_mode=STEP_MODE, 
        stp_blocking=True
)

Error parsing message. Message received: 
Error parsing message. Message received: 
Error parsing message. Message received: 


In [30]:
# Check action space (should be what's on the STP_ACTIONS_MAP)
for i in range(10):
    print(env.action_space.sample())

0
0
2
0
0
2
1
2
1
2


In [31]:
# Test encoder
obs, info = env.reset()
obs_str = ", ".join([f"{val:.2f}" for val in obs])
if info["error"]:
    print("Stopping")
else:
    print(f"{'Step': ^8} | {'Observation': ^36} | {'Reward': ^8} | {'Done': ^8} | Info")
    print(f"{'Reset': ^8} | {obs_str: <36} | {0.0: <8} | {str(False): ^8} | {info}")
    for i in range(10):
        obs, reward, terminated, truncated, info = env.step(1)
        obs_str = ", ".join([f"{val:.2f}" for val in obs])
        print(f"{i: ^8} | {obs_str: <36} | {reward: <8.2f} | {str(terminated or truncated): ^8} | {info}")
        if info["error"]:
            print("Stopping")
            break
        if terminated or truncated:
            print("Episode done")
            break
        time.sleep(1.0)

Error parsing message. Message received: 

ERROR: Could not communicate with Arduino
Stopping


In [32]:
# Test encoder (no pause)
obs, info = env.reset()
obs_str = ", ".join([f"{val:.2f}" for val in obs])
if info["error"]:
    print("Stopping")
else:
    print(f"{'Step': ^8} | {'Observation': ^36} | {'Reward': ^8} | {'Done': ^8} | Info")
    print(f"{'Reset': ^8} | {obs_str: <36} | {0.0: <8} | {str(False): ^8} | {info}")
    for i in range(20):
        obs, reward, terminated, truncated, info = env.step(1)
        obs_str = ", ".join([f"{val:.2f}" for val in obs])
        print(f"{i: ^8} | {obs_str: <36} | {reward: <8.2f} | {str(terminated or truncated): ^8} | {info}")
        if info["error"]:
            print("Stopping")
            break
        if terminated or truncated:
            print("Episode done")
            break

Error parsing message. Message received: --- Start episode ---

ERROR: Could not communicate with Arduino
Stopping


In [23]:
# Run some random steps
actions = []
rewards = []
obs, info = env.reset()
if info["error"]:
    print("Stopping")
else:
    print(f"{'Step': ^8} | {'Action': ^8} | {'Observation': ^36} | {'Reward': ^8} | {'Done': ^8} | Info")
    print(f"{'Reset': ^8} | {0: ^8} | {obs_str: <36} | {0.0: <8} | {str(False): ^8} | {info}")
    for i in range(20):
        action = env.action_space.sample()
        actions.append(action)
        obs, reward, terminated, truncated, info = env.step(action)
        obs_str = ", ".join([f"{val:.2f}" for val in obs])
        rewards.append(reward)
        print(f"{i: ^8} | {action: ^8} | {obs_str: <36} | {reward: <8.2f} | {str(terminated or truncated): ^8} | {info}")
        if info["error"]:
            print("Stopping")
            break
        if terminated or truncated:
            print("Episode done")
            break

# Print stats
action_mean = np.mean(np.array(actions))
action_std = np.std(np.array(actions))
print(f"Action mean: {action_mean}")
print(f"Action std dev: {action_std}")
print(f"Total reward: {sum(rewards)}")

  Step   |  Action  |             Observation              |  Reward  |   Done   | Info
 Reset   |    0     | 1.00, 0.00, 0.00, 0.00               | 0.0      |  False   | {'error': False, 'dtime': 2.025, 'elapsed_time': 10.032556772232056}
   0     |    0     | -0.96, 0.00, -0.05, -0.01            | -0.93    |  False   | {'error': False, 'dtime': 8.086, 'elapsed_time': 10.098191261291504}
   1     |    1     | -0.96, 0.28, -0.05, 0.00             | -0.93    |  False   | {'error': False, 'dtime': 0.012, 'elapsed_time': 10.11621618270874}
   2     |    1     | -0.96, -0.16, -0.05, 0.00            | -0.93    |  False   | {'error': False, 'dtime': 0.021, 'elapsed_time': 10.131706714630127}
   3     |    1     | -0.97, -0.11, -0.05, 0.00            | -0.93    |  False   | {'error': False, 'dtime': 0.015, 'elapsed_time': 10.150167226791382}
   4     |    0     | -0.94, 0.31, -0.10, -0.72            | -0.89    |  False   | {'error': False, 'dtime': 0.069, 'elapsed_time': 10.216284990310669}
 

In [24]:
# Test timeout (stepper will reset and vibrate for a while)
obs, info = env.reset()
print(f"obs: {obs}, info: {info}")
action = 0
if not info["error"]:
    for i in range(1000):
        action = 0 if action == 2 else 2
        obs, reward, terminated, truncated, info = env.step(action)
        # print(f"{i: ^8} | {obs_str: <36} | {reward: <8.2f} | {str(terminated or truncated): ^8} | {info}")
        if terminated or truncated:
            print("Episode done")
            break

obs: [-0.8833333   0.13091879  0.          0.00940734], info: {'error': False, 'dtime': 2.126, 'elapsed_time': 10.106022596359253}
Episode done


In [25]:
# Final environment check to make sure it works with Stable-Baselines3 (no errors means it worked)
env_checker.check_env(env)

In [26]:
# Close the environment
env.close()

## Define Test

In [27]:
# Function that tests the model in the given environment
def test_agent(env, model, max_steps=0):

    # Reset environment
    obs, info = env.reset()
    ep_len = 0
    ep_rew = 0
    avg_step_time = 0.0
    actions = []

    # Run episode until complete
    while True:

        # Provide observation to policy to predict the next action
        timestamp = time.time()
        action, _ = model.predict(obs)
        action = int(action)
        actions.append(action)

        # Perform action, update total reward
        obs, reward, terminated, truncated, info = env.step(action)
        avg_step_time += time.time() - timestamp
        ep_rew += reward

        # Increase step counter
        ep_len += 1
        if (max_steps > 0) and (ep_len >= max_steps):
            break

        # Check to see if episode has ended
        if terminated or truncated:
            break
        
    # Calculate average step time
    avg_step_time /= ep_len
    
    # Calculate action stats
    action_mean = np.mean(np.array(actions))
    action_std = np.std(np.array(actions))
    
    return ep_len, ep_rew, avg_step_time, action_mean, action_std

## Testing and logging callbacks

Construct custom callbacks for Stable-Baselines3 to test our agent and log metrics to Weights & Biases.

In [33]:
# Evaluate agent on a number of tests
def evaluate_agent(env, model, steps_per_test, num_tests):
    
    # Initialize metrics
    avg_ep_len = 0.0
    avg_ep_rew = 0.0
    avg_step_time = 0.0
    avg_action_mean = 0.0
    avg_action_std = 0.0
    
    # Test the agent a number of times
    for ep in range(num_tests):
        ep_len, ep_rew, step_time, action_mean, action_std = test_agent(env, model, max_steps=steps_per_test)
        avg_ep_len += ep_len
        avg_ep_rew += ep_rew
        avg_step_time += step_time
        avg_action_mean += action_mean
        avg_action_std += action_std
        
    # Compute metrics
    avg_ep_len /= num_tests
    avg_ep_rew /= num_tests
    avg_step_time /= num_tests
    avg_action_mean /= num_tests
    avg_action_std /= num_tests
    
    return avg_ep_len, avg_ep_rew, avg_step_time, avg_action_mean, avg_action_std

In [34]:
class EvalAndSaveCallback(BaseCallback):
    """
    Evaluate and save the model every ``check_freq`` steps
    
    More info: https://stable-baselines3.readthedocs.io/en/master/guide/callbacks.html
    """
    
    # Constructor
    def __init__(
        self, 
        check_freq, 
        save_dir,
        model_name="model",
        replay_buffer_name=None,
        steps_per_test=0, 
        num_tests=10,
        step_offset=0,
        verbose=1,
    ):
        super(EvalAndSaveCallback, self).__init__(verbose)
        self.check_freq = check_freq
        self.save_dir = save_dir
        self.model_name = model_name
        self.replay_buffer_name = replay_buffer_name
        self.num_tests = num_tests
        self.steps_per_test = steps_per_test
        self.step_offset = step_offset
        self.verbose = verbose
        
    # Create directory for saving the models
    def _init_callback(self):
        if self.save_dir is not None:
            os.makedirs(self.save_dir, exist_ok=True)
            
    # Save and evaluate model at a set interval
    def _on_step(self):
        if self.n_calls % self.check_freq == 0:
            
            # Set actual number of steps (including offset)
            actual_steps = self.step_offset + self.n_calls
            
            # Save model
            model_path = os.path.join(self.save_dir, f"{self.model_name}_{str(actual_steps)}")
            self.model.save(model_path)
            
            # Save replay buffer
            if self.replay_buffer_name != None:
                replay_buffer_path = os.path.join(self.save_dir, f"{self.replay_buffer_name}")
                self.model.save_replay_buffer(replay_buffer_path)
            
            # Evaluate the agent
            avg_ep_len, avg_ep_rew, avg_step_time, avg_action_mean, avg_action_std = evaluate_agent(
                env, 
                self.model, 
                self.steps_per_test, 
                self.num_tests
            )
            if self.verbose:
                print(f"{str(actual_steps)} steps | average test length: {avg_ep_len}, average test reward: {avg_ep_rew}")
                
            # Log metrics to WandB
            log_dict = {
                'avg_ep_len': avg_ep_len,
                'avg_ep_rew': avg_ep_rew,
                'avg_step_time': avg_step_time,
                'avg_action_mean': avg_action_mean,
                'avg_action_std': avg_action_std,
            }
            wandb.log(log_dict, commit=True, step=actual_steps)
            
        return True

In [35]:
class WandBWriter(KVWriter):
    """
    Log metrics to Weights & Biases when called by .learn()
    
    More info: https://stable-baselines3.readthedocs.io/en/master/_modules/stable_baselines3/common/logger.html#KVWriter
    """
    
    # Initialize run
    def __init__(self, run, verbose=1):
        super().__init__()
        self.run = run
        self.verbose = verbose

    # Write metrics to W&B project
    def write(self, 
              key_values: Dict[str, Any], 
              key_excluded: Dict[str, Union[str, Tuple[str, ...]]], 
              step: int = 0) -> None:
        log_dict = {}
        
        # Go through each key/value pairs
        for (key, value), (_, excluded) in zip(
            sorted(key_values.items()), sorted(key_excluded.items())):
            
            if self.verbose >= 2:
                print(f"step={step} | {key} : {value} ({type(value)})")
            
            # Skip excluded items
            if excluded is not None and "wandb" in excluded:
                continue
                
            # Log integers and floats
            if isinstance(value, np.ScalarType):
                if not isinstance(value, str):
                    wandb.log(data={key: value}, step=step)
                    log_dict[key] = value
                
        # Print to console
        if self.verbose >= 1:
            print(f"Log for steps={step}")
            print(f"--------------")
            for (key, value) in sorted(log_dict.items()):
                print(f"  {key}: {value}")
            print()
                
    # Close the W&B run
    def close(self) -> None:
        self.run.finish()

## Define train and test function for a single trial

A single "trial" is fully training and then testing the agent using one set of hyperparameters.

In [36]:
def do_trial(settings, hparams):
    """
    Training loop used to evaluate a set of hyperparameters
    """
    
    # Set random seed
    set_random_seeds(settings['seed'])
    
    # Create new W&B run
    config = {}
    dt = datetime.datetime.now(datetime.timezone.utc)
    dt = dt.replace(microsecond=0, tzinfo=None)
    run = wandb.init(
        project=settings['wandb_project'], 
        name=str(dt), 
        config=config,
        settings=wandb.Settings(silent=(not settings['verbose_wandb']))
    )

    # Print run info
    if settings['verbose_trial'] > 0:
        print(f"WandB run ID: {run.id}")
        print(f"WandB run name: {run.name}")
    
    # Log hyperparameters to W&B
    wandb.config.update(hparams)
    
    # Set custom logger with our custom writer
    wandb_writer = WandBWriter(run, verbose=settings['verbose_log'])
    loggers = Logger(
        folder=None,
        output_formats=[wandb_writer]
    )
    
    # Calculate derived hyperparameters
    n_steps = 2 ** hparams['steps_per_update_pow2']
    minibatch_size = (hparams['n_envs'] * n_steps) // (2 ** hparams['batch_size_div_pow2'])
    layer_1 = 2 ** hparams['layer_1_pow2']
    layer_2 = 2 ** hparams['layer_2_pow2']

    # Create new agent
    # PPO docs: https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html
    # Policy networks: https://stable-baselines.readthedocs.io/en/master/modules/policies.html
    model = sb3.PPO(
        'MlpPolicy',
        env,
        learning_rate=hparams['learning_rate'], # Learning rate of neural network (default: 0.0003)
        n_steps=n_steps,                        # Number of steps per update (default: 2048)
        batch_size=minibatch_size,              # Minibatch size for NN update (default: 64)
        gamma=hparams['gamma'],                 # Discount factor (default: 0.99)
        gae_lambda=hparams['gae_lambda'],       # Trade-off of bias vs. variance for GAE (default: 0.95)
        clip_range=hparams['clip_range'],       # Clipping parameter (default: 0.2)
        ent_coef=hparams['entropy_coef'],       # Entropy, how much to explore (default: 0.0)
        vf_coef=hparams['vf_coef'],             # Value function coefficient for the loss calculation (default: 0.5)
        max_grad_norm=hparams['max_grad_norm'], # Max value for gradient clipping (default: 0.5)
        use_sde=hparams['use_sde'],             # Use generalized State Dependent Exploration (default: False)
        sde_sample_freq=hparams['sde_freq'],    # Number of steps before sampling new noise matrix (default -1)
        policy_kwargs={
            'activation_fn': th.nn.ReLU,        # (default: th.nn.Tanh)
            'net_arch': [layer_1, layer_2]      # (default: [64, 64])
        }, 
        verbose=settings['verbose_train']       # Print training metrics (default: 0)
    )
    steps_to_complete = settings['total_steps']
        
    # Set up checkpoint callback
    checkpoint_callback = EvalAndSaveCallback(
        check_freq=settings['checkpoint_freq'], 
        save_dir=settings['save_dir'],
        model_name=settings['model_name'],
        replay_buffer_name=settings['replay_buffer_name'],
        steps_per_test=settings['steps_per_test'],
        num_tests=settings['tests_per_check'],
        step_offset=(settings['total_steps'] - steps_to_complete),
        verbose=settings['verbose_test'],
    )
    
    # Choo choo train
    model.learn(total_timesteps=steps_to_complete, 
                callback=[checkpoint_callback])
    
    # Get dataframe of run metrics
    history = wandb.Api().run(f"{run.project}/{run.id}").history()

    # Get index of evaluation with maximum reward
    max_idx = np.argmax(history.loc[:, 'avg_ep_rew'].values)

    # Find number of steps required to produce that maximum reward
    max_rew_steps = history['_step'][max_idx]
    if settings['verbose_trial'] > 0:
        print(f"Steps with max reward: {max_rew_steps}")
    
    # Load model with maximum reward from previous run
    model_path = os.path.join(settings['save_dir'], f"{settings['model_name']}_{str(max_rew_steps)}.zip")
    model = sb3.PPO.load(model_path, env)
    
    # Evaluate the agent
    avg_ep_len, avg_ep_rew, avg_step_time, avg_action_mean, avg_action_std = evaluate_agent(
        env, 
        model, 
        settings['steps_per_test'],
        settings['tests_per_check'],
    )
    
    # Log final evaluation metrics to WandB run
    wandb.run.summary['Average test episode length'] = avg_ep_len
    wandb.run.summary['Average test episode reward'] = avg_ep_rew
    wandb.run.summary['Average test step time'] = avg_step_time
    
    # Print final run metrics
    if settings['verbose_trial'] > 0:
        print("---")
        print(f"Best model: {settings['model_name']}_{str(max_rew_steps)}.zip")
        print(f"Average episode length: {avg_ep_len}")
        print(f"Average episode reward: {avg_ep_rew}")
        print(f"Average step time: {avg_step_time}")
        print(f"Average action mean: {avg_action_mean}")
        print(f"Average action std dev: {avg_action_std}")
                      
    # Close W&B run
    run.finish()
    
    return avg_ep_rew

## Perform trials

In [37]:
# Project settings that do not change
settings = {
    'wandb_project': "pendulum-esp32-hpo-5",
    'model_name': "ppo-pendulum",
    'ax_experiment_name': "ppo-pendulum-esp32-5",
    'ax_objective_name': "avg_ep_rew",
    'replay_buffer_name': None,
    'save_dir': "checkpoints",
    'checkpoint_freq': 5_000,
    'steps_per_test': 500,
    'tests_per_check': 10,
    'total_steps': 50_000,
    'num_trials': 100,
    'seed': 42,
    'verbose_ax': False,
    'verbose_wandb': False,
    'verbose_train': 0,
    'verbose_log': 0,
    'verbose_test': 0,
    'verbose_trial': 1,
}

In [38]:
# Define hyperparameters we want to optimize
# Ref: https://github.com/facebook/Ax/blob/6443cee30cbf8cec290200a7420a3db08e4b5445/ax/service/ax_client.py#L236
# Example: https://github.com/facebook/Ax/blob/main/tutorials/tune_cnn_service.ipynb
# Hyperparameters: https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html#stable_baselines3.ppo.PPO
hparams = [
    {
        'name': "n_envs",
        'type': "fixed",
        'value_type': "int",
        'value': 1,
    },
    {
        'name': "learning_rate",
        'type': "range",
        'value_type': "float",
        'bounds': [1e-5, 1e-3],
        'log_scale': True,
    },
    {
        'name': "steps_per_update_pow2",
        'type': "range",
        'value_type': "int",
        'bounds': [8, 11], # Inclusive, 2**n between [256, 2048]
        'log_scale': False,
        'is_ordered': False,
    },
    {
        'name': "batch_size_div_pow2",
        'type': "range",
        'value_type': "int",
        'bounds': [0, 3], # Inclusive, 2**n between [512, 4096]
        'log_scale': False,
        'is_ordered': False,
    },
    {
        'name': "gae_lambda",
        'type': "range",
        'value_type': "float",
        'bounds': [0.9, 0.99],
        'log_scale': False,
    },
    {
        'name': "clip_range",
        'type': "range",
        'value_type': "float",
        'bounds': [0.1, 0.4],
        'log_scale': False,
    },
    {
        'name': "gamma",
        'type': "range",
        'value_type': "float",
        'bounds': [0.92, 0.99],
        'log_scale': False,
    },
    {
        'name': "entropy_coef",
        'value_type': "float",
        'type': "range",
        'bounds': [0.0, 0.01],
        'log_scale': False,
    },
    {
        'name': "vf_coef",
        'type': "range",
        'value_type': "float",
        'bounds': [0.2, 0.7],
        'log_scale': False,
    },
    {
        'name': "max_grad_norm",
        'type': "range",
        'value_type': "float",
        'bounds': [0.5, 5.0],
        'log_scale': False,
    },
    {
        'name': "use_sde",
        'type': "fixed",
        'value_type': "bool",
        'value': False,
    },
    {
        'name': "sde_freq",
        'type': "fixed",
        'value_type': "int",
        'value': -1,
    },
    {
        'name': "layer_1_pow2",
        'type': "fixed",
        'value_type': "int",
        'value': 8, # 2**n (is 256)
        'log_scale': False,
        'is_ordered': False,
    },
    {
        'name': "layer_2_pow2",
        'type': "fixed",
        'value_type': "int",
        'value': 8, # 2**n (is 256)
        'log_scale': False,
        'is_ordered': False,
    },
]

# Set parameter constraints
# Example: https://github.com/facebook/Ax/issues/621
parameter_constraints = []

In [39]:
# Create our environment
try:
    env.close()
except:
    pass
env = Pendulum(
        SERIAL_PORT,
        BAUD_RATE,
        ctrl_timeout=CTRL_TIMEOUT,
        debug_level=DEBUG_LEVEL,
        env_timeout=ENV_TIMEOUT, 
        stp_mode=STEP_MODE, 
        stp_blocking=True
)

Error parsing message. Message received: Observation: 0.90, -0.36, 0.21, 0.00, dtime: 0.709

Error parsing message. Message received: Observation: 0.91, 1.46, 0.21, 0.00, dtime: 0.008

Error parsing message. Message received: Observation: 0.99, 1.04, 0.27, 0.73, dtime: 0.075



In [35]:
# Cosntruct path to Ax experiment snapshot file
ax_snapshot_path = os.path.join(settings['save_dir'], f"{settings['ax_experiment_name']}.json")

In [36]:
# DANGER! Uncomment to delete the experiment file to start over

# os.remove(ax_snapshot_path)

In [37]:
# Load experiment from snapshot if it exists, otherwise create a new one
# Ref: https://ax.dev/versions/0.2.10/api/service.html#ax.service.ax_client.AxClient.create_experiment
if os.path.exists(ax_snapshot_path):
    print(f"Loading experiment from snapshot: {ax_snapshot_path}")
    ax_client = AxClient.load_from_json_file(ax_snapshot_path)
else:
    print(f"Creating new experiment. Snapshot to be saved at {ax_snapshot_path}.")
    ax_client = AxClient(
        random_seed=settings['seed'],
        verbose_logging=settings['verbose_ax'],
    )
    ax_client.create_experiment(
        name=settings['ax_experiment_name'],
        parameters=hparams,
        objective_name=settings['ax_objective_name'],
        minimize=False,
        parameter_constraints=parameter_constraints,
    )

[INFO 10-07 23:45:49] ax.service.ax_client: Starting optimization with verbose logging. To disable logging, set the `verbose_logging` argument to `False`. Note that float values in the logs are rounded to 6 decimal points.


Loading experiment from snapshot: checkpoints\ppo-pendulum-esp32-5.json


In [38]:
# DANGER! Use this cell to mark trials as failed (e.g. if component breaks and WandB shows bad data for a given trial)
# Check .json file with a site like https://jsonformatter.org/json-pretty-print

# trial_index = 5
# trial = ax_client.experiment.trials[trial_index]
# trial.mark_failed(unsafe=True)
# print(trial)
# ax_client.save_to_json_file(ax_snapshot_path)

In [39]:
# Choo choo! Perform trials to optimize hyperparameters
while True:
    
    # Get next hyperparameters and end experiment if we've reached max trials
    next_hparams, trial_index = ax_client.get_next_trial()
    if trial_index >= settings['num_trials']:
        break
        
    # Show that we're starting a new trial
    if settings['verbose_trial'] > 0:
        print(f"--- Trial {trial_index} ---")
        
    # Perform trial
    avg_ep_rew = do_trial(settings, next_hparams)
    ax_client.complete_trial(
        trial_index=trial_index,
        raw_data=avg_ep_rew,
    )
    
    # Save experiment snapshot
    ax_client.save_to_json_file(ax_snapshot_path)

[INFO 10-07 23:46:13] ax.service.ax_client: Generated new trial 1 with parameters {'learning_rate': 6.5e-05, 'steps_per_update_pow2': 10, 'batch_size_div_pow2': 1, 'gae_lambda': 0.964348, 'clip_range': 0.223857, 'gamma': 0.960777, 'entropy_coef': 0.007052, 'vf_coef': 0.215646, 'max_grad_norm': 3.956313, 'n_envs': 1, 'use_sde': False, 'sde_freq': -1, 'layer_1_pow2': 8, 'layer_2_pow2': 8}.


--- Trial 1 ---


WandB run ID: ybnog1su
WandB run name: 2023-10-08 05:46:13
Steps with max reward: 30000
---
Best model: ppo-pendulum_30000.zip
Average episode length: 310.7
Average episode reward: -230.7636959360869
Average step time: 0.06225318977041803
Average action mean: 0.9614089186781973
Average action std dev: 0.8115665986086205


[INFO 10-08 02:05:17] ax.service.ax_client: Completed trial 1 with data: {'avg_ep_rew': (-230.763696, None)}.
[INFO 10-08 02:05:17] ax.service.ax_client: Saved JSON-serialized state of optimization to `checkpoints\ppo-pendulum-esp32-5.json`.
[INFO 10-08 02:05:17] ax.service.ax_client: Generated new trial 2 with parameters {'learning_rate': 2.4e-05, 'steps_per_update_pow2': 9, 'batch_size_div_pow2': 2, 'gae_lambda': 0.906158, 'clip_range': 0.168439, 'gamma': 0.976861, 'entropy_coef': 0.001267, 'vf_coef': 0.43887, 'max_grad_norm': 3.731091, 'n_envs': 1, 'use_sde': False, 'sde_freq': -1, 'layer_1_pow2': 8, 'layer_2_pow2': 8}.


--- Trial 2 ---


WandB run ID: 8e5q1hzw
WandB run name: 2023-10-08 08:05:17
Steps with max reward: 50000
---
Best model: ppo-pendulum_50000.zip
Average episode length: 323.6
Average episode reward: -303.80948916444106
Average step time: 0.058976005315618585
Average action mean: 0.9826148590820922
Average action std dev: 0.7805947377105328


[INFO 10-08 04:25:27] ax.service.ax_client: Completed trial 2 with data: {'avg_ep_rew': (-303.809489, None)}.
[INFO 10-08 04:25:27] ax.service.ax_client: Saved JSON-serialized state of optimization to `checkpoints\ppo-pendulum-esp32-5.json`.
[INFO 10-08 04:25:27] ax.service.ax_client: Generated new trial 3 with parameters {'learning_rate': 0.000276, 'steps_per_update_pow2': 11, 'batch_size_div_pow2': 0, 'gae_lambda': 0.972363, 'clip_range': 0.349125, 'gamma': 0.943164, 'entropy_coef': 0.008766, 'vf_coef': 0.601399, 'max_grad_norm': 1.834309, 'n_envs': 1, 'use_sde': False, 'sde_freq': -1, 'layer_1_pow2': 8, 'layer_2_pow2': 8}.


--- Trial 3 ---


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01693333333338766, max=1.0)…

WandB run ID: x4ls8r5t
WandB run name: 2023-10-08 10:25:27
Steps with max reward: 20000
---
Best model: ppo-pendulum_20000.zip
Average episode length: 222.5
Average episode reward: -474.5755365269048
Average step time: 0.06511511215441525
Average action mean: 0.9968893671504565
Average action std dev: 0.8251102325083902


[INFO 10-08 06:48:14] ax.service.ax_client: Completed trial 3 with data: {'avg_ep_rew': (-474.575537, None)}.
[INFO 10-08 06:48:14] ax.service.ax_client: Saved JSON-serialized state of optimization to `checkpoints\ppo-pendulum-esp32-5.json`.
[INFO 10-08 06:48:14] ax.service.ax_client: Generated new trial 4 with parameters {'learning_rate': 0.00015, 'steps_per_update_pow2': 9, 'batch_size_div_pow2': 1, 'gae_lambda': 0.988324, 'clip_range': 0.391125, 'gamma': 0.93386, 'entropy_coef': 0.000134, 'vf_coef': 0.371305, 'max_grad_norm': 3.286363, 'n_envs': 1, 'use_sde': False, 'sde_freq': -1, 'layer_1_pow2': 8, 'layer_2_pow2': 8}.


--- Trial 4 ---


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016916666666656966, max=1.0…

WandB run ID: td11jkfg
WandB run name: 2023-10-08 12:48:14
Steps with max reward: 35000
---
Best model: ppo-pendulum_35000.zip
Average episode length: 136.4
Average episode reward: -141.79422704900588
Average step time: 0.05034694238028429
Average action mean: 1.0086364124560443
Average action std dev: 0.7198438113578044


[INFO 10-08 09:07:45] ax.service.ax_client: Completed trial 4 with data: {'avg_ep_rew': (-141.794227, None)}.
[INFO 10-08 09:07:45] ax.service.ax_client: Saved JSON-serialized state of optimization to `checkpoints\ppo-pendulum-esp32-5.json`.
[INFO 10-08 09:07:45] ax.service.ax_client: Generated new trial 5 with parameters {'learning_rate': 1e-05, 'steps_per_update_pow2': 11, 'batch_size_div_pow2': 3, 'gae_lambda': 0.913861, 'clip_range': 0.126504, 'gamma': 0.966461, 'entropy_coef': 0.007634, 'vf_coef': 0.638448, 'max_grad_norm': 2.235512, 'n_envs': 1, 'use_sde': False, 'sde_freq': -1, 'layer_1_pow2': 8, 'layer_2_pow2': 8}.


--- Trial 5 ---


WandB run ID: 40oyst7s
WandB run name: 2023-10-08 15:07:45


KeyboardInterrupt: 

## Analyze Top Performing Trials

In [None]:
# Get runs in WandB project
runs = wandb.Api().runs(settings['wandb_project'])

In [None]:
# Plot best average episode reward from each run over time
avg_rews = []
for i, run in enumerate(runs):
    avg_rew = run.summary['Average test episode reward']
    if isinstance(avg_rew, float):
        avg_rews.append(avg_rew)
avg_rews.reverse()
plt.plot(avg_rews)

In [None]:
# CSV file path
csv_file_path = os.path.join(".", settings['wandb_project'] + ".csv")

# List summary names
summary_names = [
    "Average test episode reward",
    "Average test episode length",
    "Average test step time",
]

# Get hyperparameter names
hparam_names = [hparam['name'] for hparam in hparams]

print()

# Create CSV with HPO trial results
with open(csv_file_path, 'w', newline='') as file:
    writer = csv.writer(file)
    writer.writerow(["name"] + summary_names + hparam_names)
    for run in runs:
        row = [run.name]
        for name in summary_names:
            row.append(run.summary[name])
        for name in hparam_names:
            row.append(run.config[name])
        writer.writerow(row)

## Train Model on Best Hyperparameters

In [None]:
# TODO

## Covnert Model to ONNX

In [40]:
# Settings
max_rew_steps = 40000   # Number of steps with best average reward (for model loading file)
max_steps = 500
num_tests = 10

In [41]:
class OnnxablePolicy(th.nn.Module):
  """
  Convert SB3 model to ONNX model using PyTorch.
  From: https://stable-baselines3.readthedocs.io/en/master/guide/export.html
  """
  def __init__(self, extractor, action_net, value_net):
    super().__init__()
    self.extractor = extractor
    self.action_net = action_net
    self.value_net = value_net

  def forward(self, observation):
    # NOTE: You may have to process (normalize) observation in the correct
    #       way before using this. See `common.preprocessing.preprocess_obs`
    action_hidden, value_hidden = self.extractor(observation)
    return self.action_net(action_hidden), self.value_net(value_hidden)

In [42]:
# Load model 
model_path = os.path.join(settings['save_dir'], f"{settings['model_name']}_{str(max_rew_steps)}.zip")
model = sb3.PPO.load(model_path, env)

In [11]:
# Create our environment
try:
    env.close()
except:
    pass
env = Pendulum(
        SERIAL_PORT,
        BAUD_RATE,
        ctrl_timeout=CTRL_TIMEOUT,
        debug_level=DEBUG_LEVEL,
        env_timeout=ENV_TIMEOUT, 
        stp_mode=STEP_MODE, 
        stp_blocking=True
)

# Store all observations for creating a representative sample set
obss = []

# Initialize metrics
avg_ep_len = 0.0
avg_ep_rew = 0.0
avg_step_time = 0.0
avg_action_mean = 0.0
avg_action_std = 0.0

# Test the agent a number of times
for ep in range(num_tests):
    
    # Reset environment
    obs, info = env.reset()
    ep_len = 0
    ep_rew = 0
    avg_step_time = 0.0
    actions = []

    # Run episode until complete
    while True:

        # Provide observation to policy to predict the next action
        timestamp = time.time()
        action, _ = model.predict(obs)
        action = int(action)
        actions.append(action)

        # Perform action, update total reward
        obs, reward, terminated, truncated, info = env.step(action)
        avg_step_time += time.time() - timestamp
        ep_rew += reward
        obss.append(obs)

        # Increase step counter
        ep_len += 1
        if (max_steps > 0) and (ep_len >= max_steps):
            break

        # Check to see if episode has ended
        if terminated or truncated:
            break
        
    # Calculate average step time
    avg_step_time /= ep_len
    
    # Calculate action stats
    action_mean = np.mean(np.array(actions))
    action_std = np.std(np.array(actions))
    
    # Acculumate metrics
    avg_ep_len += ep_len
    avg_ep_rew += ep_rew
    avg_step_time += avg_step_time
    avg_action_mean += action_mean
    avg_action_std += action_std

# Compute metrics
avg_ep_len /= num_tests
avg_ep_rew /= num_tests
avg_step_time /= num_tests
avg_action_mean /= num_tests
avg_action_std /= num_tests

# Print metrics
print(f"Avg ep len: {avg_ep_len}")
print(f"Avg ep rew: {avg_ep_rew}")
print(f"Avg step time: {avg_step_time}")
print(f"Avg action mean: {avg_action_mean}")
print(f"Avg action std: {avg_action_std}")

NameError: name 'Pendulum' is not defined

In [53]:
# Close environment and save observations as our representative sample set
env.close()
np.save(REP_SAMPLE_SET_PATH, np.array(obss))

In [43]:
# Convert model to an intermediate form
onnxable_model = OnnxablePolicy(
    model.policy.mlp_extractor, model.policy.action_net, model.policy.value_net
)

# Set input shape and save to ONNX file
observation_size = model.observation_space.shape
dummy_input = th.randn(1, *observation_size).to('cpu')
th.onnx.export(
    onnxable_model.to('cpu'),
    dummy_input,
    OUT_ONNX_PATH,
    opset_version=9,
    input_names=["input"],
)

# Load ONNX file
onnx_model = onnx.load(OUT_ONNX_PATH)
onnx.checker.check_model(OUT_ONNX_PATH)

# Create a new architecture without the value network
value_net_out = None
new_nodes = []
for node in onnx_model.graph.node:

  # Find the value network output name
  if "value_net" in node.name:
    try:
      int(node.output[0])
      value_net_out = node.output[0]
    except ValueError:
      pass

  # Construct new graph with non-value network nodes
  else:
    new_nodes.append(node)

# Remove the output associated with the value network
graph_output = []
for output in onnx_model.graph.output:
  if output.name != value_net_out:
    graph_output.append(output)

# Make sure we found the output node
assert value_net_out is not None

# Construct a new graph
new_graph = onnx.helper.make_graph(
    new_nodes,
    onnx_model.graph.name,
    onnx_model.graph.input,
    graph_output,
    onnx_model.graph.initializer
)

# Set the new graph as the model's graph
onnx_model.graph.CopyFrom(new_graph)

# Save the modified model
# onnx.save(onnx_model, OUT_ONNX_PATH)

## Test the ONNX Model in the Environment

You will need to turn this into C++ code to run on the target device

In [102]:
def calc_obs(angles, angles_prev, dtime):
    
    # Offset encoder angle so that 0 deg is up
    angles[0] -= ENC_OFFSET

    # Calculate velocities
    dtheta = calc_angular_velocity(angles[0], angles_prev[0], dtime)
    dphi = calc_angular_velocity(angles[1], angles_prev[1], dtime)

    # Construct observation (normalized)
    obs[0] = angles[0] / ENC_ANGLE_NORM
    obs[1] = dtheta / ENC_ANGLE_NORM
    obs[2] = angles[1] / STP_ANGLE_NORM
    obs[3] = dphi / STP_ANGLE_NORM
    
    return obs

In [115]:
# Close connection to Arduino board (if open)
try:
    controller.close()
except:
    pass

# Connect to Arduino board
controller = ControlComms(timeout=CTRL_TIMEOUT, debug_level=DEBUG_LEVEL)
ret = controller.connect(SERIAL_PORT, BAUD_RATE)
if ret is not StatusCode.OK:
    print("ERROR: Could not connect to board")

# Load ONNX model
ort_sess = ort.InferenceSession(OUT_ONNX_PATH, providers=['CPUExecutionProvider'])

# Move controller home
angles_prev = [0, 0]
timestamp_prev = 0
resp = controller.step(CMD_MOVE_HOME, [0.0])
if resp:
    
    # Extract information from controller response
    status, timestamp, terminated, angles = resp
    
    # Compute lapsed time (in seconds) from previous observation (milliseconds)
    dtime = (timestamp - timestamp_prev) / 1000.0
    timestamp_prev = timestamp
    
    # Calculate the normalized observation
    obs = calc_obs(angles, angles_prev, dtime)
    angles_prev[0] = angles[0]
    angles_prev[1] = angles[1]
    
    # Let pendulum settle for a bit
    time.sleep(RESET_SETTLE_TIME)
    
# Something is wrong with communication
else:
    print("ERROR: Could not communicate with Arduino")

# Run episode until complete
ep_running = True
dtimes = []
start_time = time.time()
while ep_running:

    # Provide observation to policy to predict the next action
    action_logits = ort_sess.run(None, {"input": np.expand_dims(obs, axis=0)})
    action = np.argmax(action_logits)
    
    # Convert action index to label (degrees to move stepper)
    action = STP_ACTIONS_MAP[np.argmax(action_logits)]
    
    # Move the stepper motor and wait for a response
    resp = controller.step(CMD_MOVE_BY, [action])
    if resp:
        
        # Extract information from controller response
        status, timestamp, terminated, angles = resp
        
        # Compute lapsed time (in seconds) from previous observation (milliseconds)
        dtime = (timestamp - timestamp_prev) / 1000.0
        timestamp_prev = timestamp
        dtimes.append(dtime)
        
        # Calculate the normalized observation
        obs = calc_obs(angles, angles_prev, dtime)
        angles_prev[0] = angles[0]
        angles_prev[1] = angles[1]
        
        # Make sure stepper is in bounds
        if (angles[1] >= STP_ANGLE_MIN) and (angles[1] <= STP_ANGLE_MAX):
            
            # Fail. We "crashed" by swinging too hard
            if (abs(angles[0]) <= ENC_CRASH_ANGLE) and (abs(obs[1] * ENC_ANGLE_NORM) > ENC_CRASH_VELOCITY):
                print("Crashed")
                ep_running = False
                
            # Success! We landed softly at the top
            elif (abs(angles[0]) <= ENC_GOAL_ANGLE) and (abs(obs[1] * ENC_ANGLE_NORM) <= ENC_GOAL_VELOCITY):
                print("Success!")
                ep_running = False
                
        # Stepper moved out of bounds
        else:
            print("Stepper out of bounds")
            ep_running = False
        
    # Comms to Arduino lost
    else:
        print("ERROR: Could not communicate with Arduino")
        ep_running = False
        
    # Check timeout
    if time.time() - start_time > ENV_TIMEOUT:
        print("Episode timed out")
        ep_running = False
        
# Print average times
print(f"Average dtime: {sum(dtimes) / len(dtimes)}")

# Close comms
controller.close()

Success!
Average dtime: 0.08837241379310345
