TODO:
 * Test environment
 * Set up experiments (WandB?)
 * Train!

In [None]:
# !python -m pip install gymnasium==0.28.1
# !python -m pip install stable-baselines3[extra]==2.1.0

In [130]:
import time

import gymnasium as gym
import matplotlib.pyplot as plt
import numpy as np
from control_comms import ControlComms, StatusCode, DebugLevel

# Reinforcement model modules
import stable_baselines3 as sb3
from stable_baselines3.common import env_checker
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.logger import KVWriter, Logger

# Check versions
print(f"gym version: {gym.__version__}")
print(f"sb3 version: {sb3.__version__}")

gym version: 0.28.1
sb3 version: 2.1.0


In [118]:
# Communication settings
SERIAL_PORT = "COM6"    # Check your devices
BAUD_RATE = 500000      # Must match what's in the Arduino code!
CTRL_TIMEOUT = 1.0      # Seconds
DEBUG_LEVEL = DebugLevel.DEBUG_ERROR

# Reinforcement learning settings
K_T = 1                 # Reward constant to multiply theta (angle of encoder)
K_DT = 0.1              # Reward constant to multiply dtheto/dt (angular velocity of encoder)
K_P = 0.01              # Reward constant to multiply phi (angle of stepper)
K_DP = 0.001            # Reward constant to multiply dphi/dt (angular velocity of stepper)
REWARD_OOB = -10_000    # Reward (penalty) for having the stepper motor move out of bounds (OOB)
ENC_OFFSET = 180.0      # Pendulum in the "up" position should be 0 deg
STP_MOVE_MIN = -10.0
STP_MOVE_MAX = 10.0
STP_ANGLE_MIN = -180.0  # Episode ends if stepper goes beyond this angle
STP_ANGLE_MAX = 180.0   # Episode ends if stepper goes beyond this angle
ENV_TIMEOUT = 10.0
RESET_SETTLE_TIME = 1.0

In [3]:
# Communication constants
CMD_SET_HOME = 0        # Set current stepper position as home (0 deg)
CMD_MOVE_TO = 1         # Move stepper to a particular position (deg)
CMD_MOVE_BY = 2         # Move stepper by a given amount (deg)
CMD_SET_STEP_MODE = 3   # Set step mode
CMD_SET_BLOCK_MODE = 4  # Set blocking mode
CMD_NOP = 5             # Take no action, just receive observation
STEP_MODE_1 = 0         # 1 division per step
STEP_MODE_2 = 1         # 2 divisions per step
STEP_MODE_4 = 2         # 4 divisions per step
STEP_MODE_8 = 3         # 8 divisions per step
STEP_MODE_16 = 4        # 16 divisions per step
STATUS_OK = 0           # Stepper idle
STATUS_STP_MOVING = 1   # Stepper is currently moving

# Set to desired step mode
STEP_MODE = STEP_MODE_8

In [187]:
# Close connection to Arduino board (if open)
try:
    controller.close()
except:
    pass

In [185]:
# Connect to Arduino board
controller = ControlComms(timeout=CTRL_TIMEOUT, debug_level=DEBUG_LEVEL)
ret = controller.connect(SERIAL_PORT, BAUD_RATE)
if ret is not StatusCode.OK:
    print("ERROR: Could not connect to board")

In [186]:
# Test basic comms
controller.step(CMD_SET_STEP_MODE, [STEP_MODE_8])
controller.step(CMD_SET_HOME, [0])
controller.step(CMD_SET_BLOCK_MODE, [1])
controller.step(CMD_MOVE_BY, [90])

(0, 4071823, False, [1.2, 0.0])

## Build gym Environment

Subclass gymnasium.Env to create a custom environment. Learn more here:<br>
https://gymnasium.farama.org/tutorials/gymnasium_basics/environment_creation/

In [188]:
class Pendulum(gym.Env):
    """
    Subclass gymnasium Env class
    
    This is the gym wrapper class that allows our agent to interact with our environment. We need
    to implement four main methods: step(), reset(), render(), and close(). We should also define
    the action_space and observation space as class members.
    
    Note: on Windows, time.sleep() is only accurate to around 10ms. As a result, setting fps_limit
    will give you a "best effort" limit.
    
    More information: https://gymnasium.farama.org/api/env/
    """
    
    def __init__(
        self,
        serial_port,
        baud_rate,
        ctrl_timeout=1.0,
        debug_level=DebugLevel.DEBUG_NONE,
        env_timeout=0.0, 
        stp_mode=STEP_MODE_8, 
        stp_blocking=False
    ):
        """
        Set up the environment, action, and observation shapes. Optional tiemout in seconds.
        """
        
        # Call superclass's constructor
        super().__init__()
        
        # Connect to Arduino board
        self.ctrl = ControlComms(timeout=ctrl_timeout, debug_level=debug_level)
        ret = self.ctrl.connect(serial_port, baud_rate)
        if ret is not StatusCode.OK:
            print("ERROR: Could not connect to board")
        
        # Define action space (scalar signifying how many degrees to move stepper by)
        self.action_space = gym.spaces.Box(
            low=STP_MOVE_MIN,
            high=STP_MOVE_MAX,
            shape=(1, 1),
            dtype=np.float32
        )
        
        # Define observation space 
        # [encoder angle, encoder angular velocity, stepper angle, stepper angular velocity]
        self.observation_space = gym.spaces.Box(
            low=np.array([-180, -np.inf, STP_ANGLE_MIN, -np.inf]),
            high=np.array([180, np.inf, STP_ANGLE_MAX, np.inf]),
            dtype=np.float32
        )
        
        # Record time from microcontroller and own elapsed time
        self.timestamp = 0
        self.timeout = env_timeout
        self.start_time = time.time()
        
        # Record previous encoder and stepper angles (to calculate velocities)
        self.angle_stp_prev = 0
        self.angle_enc_prev = 0
        
        # Set current stepper position as "home" and optionally set blocking
        self.ctrl.step(CMD_SET_STEP_MODE, [stp_mode])
        self.ctrl.step(CMD_SET_HOME, [0])
        if stp_blocking:
            self.ctrl.step(CMD_SET_BLOCK_MODE, [1])
        else:
            self.ctrl.step(CMD_SET_BLOCK_MODE, [0])
        
    def step(self, action):
        """
        What happens when you tell the stepper motor to do something then record the observation.
        """
        
        print("STEP")
        
        # Initialize return values
        obs = np.array([0.0, 0.0, 0.0, 0.0], dtype=np.float32)
        reward = 0.0
        info = {"error": False, "dtime": 0.0, "elapesed_time": 0.0}
        terminated = False
        truncated = False
        
        # Move the stepper motor and wait for a response
        resp = self.ctrl.step(CMD_MOVE_BY, [action])
        if resp:
            status, timestamp, terminated, angles = resp
            
            # Compute lapsed time from previous observation
            info["dtime"] = timestamp - self.timestamp
            self.timestamp = timestamp
            
            # Calculate velocities
            dtheta = (angles[0] - self.angle_enc_prev) / info["dtime"]
            dphi = (angles[1] - self.angle_stp_prev) / info["dtime"]
            
            # Construct observation
            obs[0] = angles[0] - ENC_OFFSET
            obs[1] = dtheta
            obs[2] = angles[1]
            obs[3] = dphi
                    
            # Calculate reward
            if (obs[2] >= STP_ANGLE_MIN) and (obs[2] <= STP_ANGLE_MAX):
                reward = -1 * (K_T * obs[0] ** 2 + 
                               K_DT * obs[1] ** 2 + 
                               K_P * obs[2] ** 2 +
                               K_DP * obs[3] ** 2)
            
            # Stepper motor is out of bounds--terminate episode
            else:
                reward = REWARD_OOB
                terminated = True
        
        # Something is wrong with communication
        else:
            print("ERROR: Could not communicate with Arduino")
            info["error"] = True
            terminated = True
        
        # Calculate elapsed time
        info["elapsed_time"] = time.time() - self.start_time
        
        # Check if we've exceeded the time limit
        if not terminated and self.timeout > 0.0 and info["elapsed_time"] >= self.timeout:
            truncated = True
        
        return obs, reward, terminated, truncated, info
    
    def reset(self, seed=None):
        """
        Return the pendulum to the starting position
        """
        
        print("RESETTING")
        
        # Initialize return values
        obs = np.array([0.0, 0.0, 0.0, 0.0], dtype=np.float32)
        info = {"error": False, "dtime": 0, "elapsed_time": 0.0}
        
        # Reset timer
        self.start_time = time.time()
        
        # Let the pendulum fall and return to the starting position
        time.sleep(RESET_SETTLE_TIME)
        resp = self.ctrl.step(CMD_MOVE_TO, [0.0])
        if resp:
            status, timestamp, terminated, angles = resp
            
            # Compute lapsed time from previous observation
            info["dtime"] = timestamp - self.timestamp
            self.timestamp = timestamp
            
            # Calculate velocities
            dtheta = (angles[0] - self.angle_enc_prev) / info["dtime"]
            dphi = (angles[1] - self.angle_stp_prev) / info["dtime"]
            
            # Construct observation
            obs[0] = angles[0] - ENC_OFFSET
            obs[1] = dtheta
            obs[2] = angles[1]
            obs[3] = dphi
            time.sleep(RESET_SETTLE_TIME)
            
        # Something is wrong with communication
        else:
            print("ERROR: Could not communicate with Arduino")
            info["error"] = True
            
        # Calculate elapsed time
        info["elapsed_time"] = time.time() - self.start_time
        
        return obs, info
    
    def close(self):
        """
        Close connection to Arduino
        """
        self.ctrl.close()

## Test gym Environment

Test the gym wrapper before training

In [193]:
# Create our environment
try:
    env.close()
except NameError:
    pass
env = Pendulum(
        SERIAL_PORT,
        BAUD_RATE,
        ctrl_timeout=CTRL_TIMEOUT,
        debug_level=DEBUG_LEVEL,
        env_timeout=ENV_TIMEOUT, 
        stp_mode=STEP_MODE, 
        stp_blocking=True
)

In [194]:
# Try running the environment for a few steps
obs, info = env.reset()
obs_str = ", ".join([f"{val:.2f}" for val in obs])
if info["error"]:
    print("Stopping")
else:
    print(f"Reset\t| Obs: [{obs_str}]\t| Reward: 0.0\t| Done: {False}\t| Info: {info}")
    for _ in range(10):
        obs, reward, terminated, truncated, info = env.step(-25)
        obs_str = ", ".join([f"{val:.2f}" for val in obs])
        print(f"Step\t| Obs: [{obs_str}]\t| Reward: {reward:.2f}\t| Done: {terminated or truncated}\t| Info: {info}")
        if info["error"]:
            print("Stopping")
            break
        if terminated or truncated:
            print("Episode done")
            break

RESETTING
Reset	| Obs: [-179.10, 0.00, 0.00, 0.00]	| Reward: 0.0	| Done: False	| Info: {'error': False, 'dtime': 4123662, 'elapsed_time': 2.0221705436706543}
STEP
Step	| Obs: [-179.10, 0.00, 0.00, 0.00]	| Reward: -32076.81	| Done: False	| Info: {'error': False, 'dtime': 1117, 'elapesed_time': 0.0, 'elapsed_time': 2.121436357498169}
STEP
Step	| Obs: [-159.30, 0.21, -24.97, -0.25]	| Reward: -25382.73	| Done: False	| Info: {'error': False, 'dtime': 99, 'elapesed_time': 0.0, 'elapsed_time': 2.2180588245391846}
STEP
Step	| Obs: [-146.40, 0.34, -49.95, -0.51]	| Reward: -21457.92	| Done: False	| Info: {'error': False, 'dtime': 98, 'elapesed_time': 0.0, 'elapsed_time': 2.3158371448516846}
STEP
Step	| Obs: [-142.80, 0.38, -74.93, -0.76]	| Reward: -20448.00	| Done: False	| Info: {'error': False, 'dtime': 99, 'elapesed_time': 0.0, 'elapsed_time': 2.4137394428253174}
STEP
Step	| Obs: [-148.80, 0.32, -99.90, -1.01]	| Reward: -22241.25	| Done: False	| Info: {'error': False, 'dtime': 99, 'elapesed_ti

In [191]:
# Test timeout
obs, info = env.reset()
action = 2
if not info["error"]:
    for _ in range(1000):
        action = -2 if action == 2 else 2
        obs, reward, terminated, truncated, info = env.step(action)
        if terminated or truncated:
            print("Episode done")
            break

RESETTING
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP
STEP


In [195]:
# Final environment check to make sure it works with Stable-Baselines3
env_checker.check_env(env)

RESETTING
RESETTING
STEP
Error parsing message. Message received: JSON Error: InvalidInput

ERROR: Could not communicate with Arduino
RESETTING
Error parsing message. Message received: JSON Error: InvalidInput

ERROR: Could not communicate with Arduino
STEP
Error parsing message. Message received: JSON Error: InvalidInput

ERROR: Could not communicate with Arduino
RESETTING
Error parsing message. Message received: JSON Error: InvalidInput

ERROR: Could not communicate with Arduino
STEP
Error parsing message. Message received: JSON Error: InvalidInput

ERROR: Could not communicate with Arduino
RESETTING
Error parsing message. Message received: JSON Error: InvalidInput

ERROR: Could not communicate with Arduino
STEP
Error parsing message. Message received: JSON Error: keys 'command' or 'action' not found

ERROR: Could not communicate with Arduino
RESETTING
Error parsing message. Message received: Error receiving actions

ERROR: Could not communicate with Arduino
STEP
Error parsing messag