## Self Balancing Robot in PyBullet
**Balance and control of a 2-wheeled robot simulated with PyBullet Physics library**
<br>V2: Everything Implemented in GYM 

Importing Required Libraries

In [1]:
import pybullet
import time
import pybullet_data
import numpy as np
import matplotlib.pyplot as plt
import gym
from gym import spaces


GYM Environment For Robot

In [2]:
class SelfBalancing(gym.Env):
    #metadata = {'render.modes': ['human']}

    def __init__(self):
        super(SelfBalancing, self).__init__()
        # Define action and observation space
        self.action_space = spaces.Box(low=0.0, high=+1.0,shape=(3,),dtype=np.float64)
        self.observation_space = spaces.Box(low=np.array([-np.pi/2,-1000]), high=np.array([+np.pi/2,+1000]))
        """
            Action Space: action[0] -> kp, action[1] -> ki, action[2] -> kd
            Observation Space: torso_pitch orientation, torso linear speed
        """
        self.state = np.array([0.0,0.0])
        self.steps = 0
        self.max_episode_steps = 10000
        # Instantiate PyBullet
        phisycsClient = pybullet.connect(pybullet.GUI)
        pybullet.setAdditionalSearchPath(pybullet_data.getDataPath())
        # Spawn Robot
        self.robotID = None
        self.reset()
        # Initialize Controller Parameters
        self.integral = 0
        self.derivative = 0
        self.prev_error = 0
        
    def step(self, action):
        motion = self.controller(action)
        self.take_action(motion)
        ## Calculating reward
        reward = self.calculate_reward()
        obs = self.observe()
        done = self.terminated()
        return obs, reward, done, {}
    
    def take_action(self,motion):
        # Takes a tuple as input
        # motion --> (left wheel speed, right wheel speed)
        pybullet.setJointMotorControl2(bodyUniqueId=self.robotID, 
                        jointIndex=0, 
                        controlMode=pybullet.VELOCITY_CONTROL,
                        targetVelocity = motion[0])
        pybullet.setJointMotorControl2(bodyUniqueId=self.robotID, 
                        jointIndex=1, 
                        controlMode=pybullet.VELOCITY_CONTROL,
                        targetVelocity = motion[1])
        pybullet.stepSimulation()
        time.sleep(1.0/400)
        self.steps += 1
    
    def observe(self):
        position, orientation = pybullet.getBasePositionAndOrientation(self.robotID)
        self.state[0] = np.array([pybullet.getEulerFromQuaternion(orientation)[0]])
        linear_vel, anagular_vel = pybullet.getBaseVelocity(self.robotID)
        self.state[1] = (linear_vel[0]**2 + linear_vel[1]**2 + linear_vel[2]**2) ** 0.5
        return self.state
    
    def calculate_reward(self):
        reward = - (self.observe()[0]**2 + self.observe()[1]**2)
        if self.terminated():
            reward += (self.steps - self.max_episode_steps) / 500
        return reward
    
    def controller(self,action):
        ## Simple PID
        error = self.observe()[0]
        self.integral += error
        self.derivative = error - self.prev_error
        self.prev_error = error
        
        motion = ((action[0]*1000) * error + (action[1]*0.1) * self.integral + (action[2]*100) * self.derivative)
        return (motion,motion)
    
    def reset(self):
        pybullet.resetSimulation()
        planeID = pybullet.loadURDF("plane.urdf")
        pybullet.setGravity(0,0,-9.81)
        self.robotID = pybullet.loadURDF("robot.urdf",
                                 [0.0,0.0,0.0],pybullet.getQuaternionFromEuler([0.0,0.0,0.0]),useFixedBase = 0)
        pybullet.setRealTimeSimulation(0) # change to (1) for real time simulation
        self.state = self.observe()
        return self.state
    
    def terminated(self):
        ## If the robot tilt angle reaches 75 degrees or
        ## the simulation reaches its maximum time steps
        if self.steps > self.max_episode_steps or abs(self.state[0]) > (np.pi / (75/80)):
            return True
        else:
            return False
    
    def __del__(self):
        pybullet.disconnect()

Test Environment with (Kp = 650, Ki = 0.005, Kd = 10.0)

In [10]:
env = SelfBalancing()

while not env.terminated():
    env.step((0.55,0.1,0.25))
del env

Advantagous Actor Critic (A2C)

In [6]:
import gym
import coax
import optax
import haiku as hk
import jax.numpy as jnp
import jax
from numpy import prod

# pick environment
name = 'a2c'
env = SelfBalancing()
env = coax.wrappers.TrainMonitor(env, name=name, tensorboard_dir=f"./data/tensorboard/{name}")

def func_v(S, is_training):
    # custom haiku function
    value = hk.Sequential([
                          hk.Linear(20),
                          hk.Linear(20),
                          hk.Linear(1,w_init=jnp.zeros),jnp.ravel])
    return value(S)  # output shape: (batch_size,)

def func_pi(S, is_training):
    shared = hk.Sequential((
        hk.Linear(20), jax.nn.relu,
        hk.Linear(20), jax.nn.relu,
    ))
    mu = hk.Sequential((
        shared,
        hk.Linear(10), jax.nn.relu,
        hk.Linear(3, w_init=jnp.zeros),
        hk.Reshape(env.action_space.shape),
    ))
    logvar = hk.Sequential((
        shared,
        hk.Linear(8), jax.nn.relu,
        hk.Linear(3, w_init=jnp.zeros),
        hk.Reshape(env.action_space.shape),
    ))
    return {'mu': mu(S), 'logvar': logvar(S)}

# function approximators
v = coax.V(func_v, env)
pi = coax.Policy(func_pi, env)


# specify how to update policy and value function
vanilla_pg = coax.policy_objectives.VanillaPG(pi, optimizer=optax.adam(0.001))
simple_td = coax.td_learning.SimpleTD(v, optimizer=optax.adam(0.002))


# specify how to trace the transitions
tracer = coax.reward_tracing.NStep(n=5, gamma=0.9)
buffer = coax.experience_replay.SimpleReplayBuffer(capacity=256)


for ep in range(300):
    s = env.reset()

    for t in range(env.max_episode_steps):
        a, logp = pi(s, return_logp=True)
        s_next, r, done, info = env.step(a)

        # add transition to buffer
        # N.B. vanilla-pg doesn't use logp but we include it to make it easy to
        # swap in another policy updater that does require it, e.g. ppo-clip
        tracer.add(s, a, r, done, logp)
        while tracer:
            buffer.add(tracer.pop())

        # update
        if len(buffer) == buffer.capacity:
            for _ in range(4 * buffer.capacity // 32):  # ~4 passes
                transition_batch = buffer.sample(batch_size=32)
                metrics_v, td_error = simple_td.update(transition_batch, return_td_error=True)
                metrics_pi = vanilla_pg.update(transition_batch, td_error)
                env.record_metrics(metrics_v)
                env.record_metrics(metrics_pi)

            buffer.clear()

        if done:
            break

        s = s_next
del env

[MainThread|TrainMonitor|INFO] ep: 1,	T: 5,001,	G: -19.8,	avg_G: 0,	t: 5000,	dt: 15.123ms,	SimpleTD/loss: 7.01e-06,	VanillaPG/loss: -0.000372
[MainThread|TrainMonitor|INFO] ep: 2,	T: 5,003,	G: 0.00164,	avg_G: 0.00164,	t: 1,	dt: 304.329ms
[MainThread|TrainMonitor|INFO] ep: 3,	T: 5,005,	G: 0.00364,	avg_G: 0.00264,	t: 1,	dt: 291.905ms
[MainThread|TrainMonitor|INFO] ep: 4,	T: 5,007,	G: 0.00564,	avg_G: 0.00364,	t: 1,	dt: 311.455ms
[MainThread|TrainMonitor|INFO] ep: 5,	T: 5,009,	G: 0.00764,	avg_G: 0.00464,	t: 1,	dt: 264.317ms
[MainThread|TrainMonitor|INFO] ep: 6,	T: 5,011,	G: 0.00964,	avg_G: 0.00564,	t: 1,	dt: 279.904ms
[MainThread|TrainMonitor|INFO] ep: 7,	T: 5,013,	G: 0.0116,	avg_G: 0.00664,	t: 1,	dt: 282.554ms
[MainThread|TrainMonitor|INFO] ep: 8,	T: 5,015,	G: 0.0136,	avg_G: 0.00764,	t: 1,	dt: 298.038ms
[MainThread|TrainMonitor|INFO] ep: 9,	T: 5,017,	G: 0.0156,	avg_G: 0.00864,	t: 1,	dt: 258.788ms
[MainThread|TrainMonitor|INFO] ep: 10,	T: 5,019,	G: 0.0176,	avg_G: 0.00964,	t: 1,	dt: 282.301

[MainThread|TrainMonitor|INFO] ep: 88,	T: 5,175,	G: 0.174,	avg_G: 0.156,	t: 1,	dt: 251.482ms
[MainThread|TrainMonitor|INFO] ep: 89,	T: 5,177,	G: 0.176,	avg_G: 0.158,	t: 1,	dt: 263.359ms
[MainThread|TrainMonitor|INFO] ep: 90,	T: 5,179,	G: 0.178,	avg_G: 0.16,	t: 1,	dt: 275.988ms
[MainThread|TrainMonitor|INFO] ep: 91,	T: 5,181,	G: 0.18,	avg_G: 0.162,	t: 1,	dt: 261.011ms
[MainThread|TrainMonitor|INFO] ep: 92,	T: 5,183,	G: 0.182,	avg_G: 0.164,	t: 1,	dt: 240.779ms
[MainThread|TrainMonitor|INFO] ep: 93,	T: 5,185,	G: 0.184,	avg_G: 0.166,	t: 1,	dt: 284.033ms
[MainThread|TrainMonitor|INFO] ep: 94,	T: 5,187,	G: 0.186,	avg_G: 0.168,	t: 1,	dt: 262.692ms
[MainThread|TrainMonitor|INFO] ep: 95,	T: 5,189,	G: 0.188,	avg_G: 0.17,	t: 1,	dt: 281.057ms
[MainThread|TrainMonitor|INFO] ep: 96,	T: 5,191,	G: 0.19,	avg_G: 0.172,	t: 1,	dt: 296.639ms
[MainThread|TrainMonitor|INFO] ep: 97,	T: 5,193,	G: 0.192,	avg_G: 0.174,	t: 1,	dt: 263.795ms
[MainThread|TrainMonitor|INFO] ep: 98,	T: 5,195,	G: 0.194,	avg_G: 0.176,	t

[MainThread|TrainMonitor|INFO] ep: 176,	T: 5,351,	G: 0.35,	avg_G: 0.332,	t: 1,	dt: 298.339ms
[MainThread|TrainMonitor|INFO] ep: 177,	T: 5,353,	G: 0.352,	avg_G: 0.334,	t: 1,	dt: 263.311ms
[MainThread|TrainMonitor|INFO] ep: 178,	T: 5,355,	G: 0.354,	avg_G: 0.336,	t: 1,	dt: 278.568ms
[MainThread|TrainMonitor|INFO] ep: 179,	T: 5,357,	G: 0.356,	avg_G: 0.338,	t: 1,	dt: 282.775ms
[MainThread|TrainMonitor|INFO] ep: 180,	T: 5,359,	G: 0.358,	avg_G: 0.34,	t: 1,	dt: 245.596ms
[MainThread|TrainMonitor|INFO] ep: 181,	T: 5,361,	G: 0.36,	avg_G: 0.342,	t: 1,	dt: 327.948ms
[MainThread|TrainMonitor|INFO] ep: 182,	T: 5,363,	G: 0.362,	avg_G: 0.344,	t: 1,	dt: 299.075ms
[MainThread|TrainMonitor|INFO] ep: 183,	T: 5,365,	G: 0.364,	avg_G: 0.346,	t: 1,	dt: 245.449ms
[MainThread|TrainMonitor|INFO] ep: 184,	T: 5,367,	G: 0.366,	avg_G: 0.348,	t: 1,	dt: 281.106ms
[MainThread|TrainMonitor|INFO] ep: 185,	T: 5,369,	G: 0.368,	avg_G: 0.35,	t: 1,	dt: 312.091ms
[MainThread|TrainMonitor|INFO] ep: 186,	T: 5,371,	G: 0.37,	avg_G

[MainThread|TrainMonitor|INFO] ep: 264,	T: 5,527,	G: 0.526,	avg_G: 0.508,	t: 1,	dt: 296.073ms
[MainThread|TrainMonitor|INFO] ep: 265,	T: 5,529,	G: 0.528,	avg_G: 0.51,	t: 1,	dt: 262.611ms
[MainThread|TrainMonitor|INFO] ep: 266,	T: 5,531,	G: 0.53,	avg_G: 0.512,	t: 1,	dt: 264.045ms
[MainThread|TrainMonitor|INFO] ep: 267,	T: 5,533,	G: 0.532,	avg_G: 0.514,	t: 1,	dt: 314.976ms
[MainThread|TrainMonitor|INFO] ep: 268,	T: 5,535,	G: 0.534,	avg_G: 0.516,	t: 1,	dt: 263.445ms
[MainThread|TrainMonitor|INFO] ep: 269,	T: 5,537,	G: 0.536,	avg_G: 0.518,	t: 1,	dt: 345.797ms
[MainThread|TrainMonitor|INFO] ep: 270,	T: 5,539,	G: 0.538,	avg_G: 0.52,	t: 1,	dt: 312.961ms
[MainThread|TrainMonitor|INFO] ep: 271,	T: 5,541,	G: 0.54,	avg_G: 0.522,	t: 1,	dt: 296.211ms
[MainThread|TrainMonitor|INFO] ep: 272,	T: 5,543,	G: 0.542,	avg_G: 0.524,	t: 1,	dt: 264.826ms
[MainThread|TrainMonitor|INFO] ep: 273,	T: 5,545,	G: 0.544,	avg_G: 0.526,	t: 1,	dt: 330.508ms
[MainThread|TrainMonitor|INFO] ep: 274,	T: 5,547,	G: 0.546,	avg_

Proximal Policy Optimization (PPO)

In [8]:
import gym
import coax
import optax
import haiku as hk
import jax.numpy as jnp
import jax
from numpy import prod

# pick environment
name = 'PPO'
env = SelfBalancing()
env = coax.wrappers.TrainMonitor(env)


def func_v(S, is_training):
    # custom haiku function
    value = hk.Sequential([
                          hk.Linear(20),
                          hk.Linear(20),
                          hk.Linear(1,w_init=jnp.zeros),jnp.ravel])
    return value(S)  # output shape: (batch_size,)

def func_pi(S, is_training):
    shared = hk.Sequential((
        hk.Linear(20), jax.nn.relu,
        hk.Linear(20), jax.nn.relu,
    ))
    mu = hk.Sequential((
        shared,
        hk.Linear(3, w_init=jnp.zeros),
        hk.Reshape(env.action_space.shape), jax.nn.sigmoid,
    ))
    logvar = hk.Sequential((
        shared,
        hk.Linear(3, w_init=jnp.zeros),
         hk.Reshape(env.action_space.shape), jax.nn.sigmoid
    ))
    return {'mu': mu(S), 'logvar': logvar(S)}

# function approximators
v = coax.V(func_v, env)
pi = coax.Policy(func_pi, env)


# slow-moving avg of pi
pi_behavior = pi.copy()


# specify how to update policy and value function
ppo_clip = coax.policy_objectives.PPOClip(pi, optimizer=optax.adam(0.001))
simple_td = coax.td_learning.SimpleTD(v, optimizer=optax.adam(0.001))


# specify how to trace the transitions
tracer = coax.reward_tracing.NStep(n=5, gamma=0.9)
buffer = coax.experience_replay.SimpleReplayBuffer(capacity=256)


for ep in range(300):
    s = env.reset()

    for t in range(env.max_episode_steps):
        a, logp = pi_behavior(s, return_logp=True)
        s_next, r, done, info = env.step(a)

        # add transition to buffer
        tracer.add(s, a, r, done, logp)
        while tracer:
            buffer.add(tracer.pop())

        # update
        if len(buffer) == buffer.capacity:
            for _ in range(4 * buffer.capacity // 32):  # ~4 passes
                transition_batch = buffer.sample(batch_size=32)
                metrics_v, td_error = simple_td.update(transition_batch, return_td_error=True)
                metrics_pi = ppo_clip.update(transition_batch, td_error)
                #env.record_metrics(metrics_v)
                #env.record_metrics(metrics_pi)

            buffer.clear()
            pi_behavior.soft_update(pi, tau=0.1)

        if done:
            break

        s = s_next
del env

[MainThread|TrainMonitor|INFO] ep: 1,	T: 5,001,	G: -8.61,	avg_G: 0,	t: 5000,	dt: 14.323ms
[MainThread|TrainMonitor|INFO] ep: 2,	T: 5,003,	G: 0.00164,	avg_G: 0.00164,	t: 1,	dt: 328.269ms
[MainThread|TrainMonitor|INFO] ep: 3,	T: 5,005,	G: 0.00364,	avg_G: 0.00264,	t: 1,	dt: 260.278ms
[MainThread|TrainMonitor|INFO] ep: 4,	T: 5,007,	G: 0.00564,	avg_G: 0.00364,	t: 1,	dt: 317.798ms
[MainThread|TrainMonitor|INFO] ep: 5,	T: 5,009,	G: 0.00764,	avg_G: 0.00464,	t: 1,	dt: 230.662ms
[MainThread|TrainMonitor|INFO] ep: 6,	T: 5,011,	G: 0.00964,	avg_G: 0.00564,	t: 1,	dt: 315.943ms
[MainThread|TrainMonitor|INFO] ep: 7,	T: 5,013,	G: 0.0116,	avg_G: 0.00664,	t: 1,	dt: 298.444ms
[MainThread|TrainMonitor|INFO] ep: 8,	T: 5,015,	G: 0.0136,	avg_G: 0.00764,	t: 1,	dt: 317.068ms
[MainThread|TrainMonitor|INFO] ep: 9,	T: 5,017,	G: 0.0156,	avg_G: 0.00864,	t: 1,	dt: 346.997ms
[MainThread|TrainMonitor|INFO] ep: 10,	T: 5,019,	G: 0.0176,	avg_G: 0.00964,	t: 1,	dt: 302.496ms
[MainThread|TrainMonitor|INFO] ep: 11,	T: 5,021,	

[MainThread|TrainMonitor|INFO] ep: 89,	T: 5,177,	G: 0.176,	avg_G: 0.158,	t: 1,	dt: 332.600ms
[MainThread|TrainMonitor|INFO] ep: 90,	T: 5,179,	G: 0.178,	avg_G: 0.16,	t: 1,	dt: 232.800ms
[MainThread|TrainMonitor|INFO] ep: 91,	T: 5,181,	G: 0.18,	avg_G: 0.162,	t: 1,	dt: 248.132ms
[MainThread|TrainMonitor|INFO] ep: 92,	T: 5,183,	G: 0.182,	avg_G: 0.164,	t: 1,	dt: 315.267ms
[MainThread|TrainMonitor|INFO] ep: 93,	T: 5,185,	G: 0.184,	avg_G: 0.166,	t: 1,	dt: 316.243ms
[MainThread|TrainMonitor|INFO] ep: 94,	T: 5,187,	G: 0.186,	avg_G: 0.168,	t: 1,	dt: 268.896ms
[MainThread|TrainMonitor|INFO] ep: 95,	T: 5,189,	G: 0.188,	avg_G: 0.17,	t: 1,	dt: 283.262ms
[MainThread|TrainMonitor|INFO] ep: 96,	T: 5,191,	G: 0.19,	avg_G: 0.172,	t: 1,	dt: 263.373ms
[MainThread|TrainMonitor|INFO] ep: 97,	T: 5,193,	G: 0.192,	avg_G: 0.174,	t: 1,	dt: 251.260ms
[MainThread|TrainMonitor|INFO] ep: 98,	T: 5,195,	G: 0.194,	avg_G: 0.176,	t: 1,	dt: 265.839ms
[MainThread|TrainMonitor|INFO] ep: 99,	T: 5,197,	G: 0.196,	avg_G: 0.178,	t

[MainThread|TrainMonitor|INFO] ep: 177,	T: 5,353,	G: 0.352,	avg_G: 0.334,	t: 1,	dt: 315.207ms
[MainThread|TrainMonitor|INFO] ep: 178,	T: 5,355,	G: 0.354,	avg_G: 0.336,	t: 1,	dt: 329.170ms
[MainThread|TrainMonitor|INFO] ep: 179,	T: 5,357,	G: 0.356,	avg_G: 0.338,	t: 1,	dt: 316.164ms
[MainThread|TrainMonitor|INFO] ep: 180,	T: 5,359,	G: 0.358,	avg_G: 0.34,	t: 1,	dt: 283.590ms
[MainThread|TrainMonitor|INFO] ep: 181,	T: 5,361,	G: 0.36,	avg_G: 0.342,	t: 1,	dt: 248.082ms
[MainThread|TrainMonitor|INFO] ep: 182,	T: 5,363,	G: 0.362,	avg_G: 0.344,	t: 1,	dt: 252.782ms
[MainThread|TrainMonitor|INFO] ep: 183,	T: 5,365,	G: 0.364,	avg_G: 0.346,	t: 1,	dt: 312.040ms
[MainThread|TrainMonitor|INFO] ep: 184,	T: 5,367,	G: 0.366,	avg_G: 0.348,	t: 1,	dt: 315.179ms
[MainThread|TrainMonitor|INFO] ep: 185,	T: 5,369,	G: 0.368,	avg_G: 0.35,	t: 1,	dt: 266.157ms
[MainThread|TrainMonitor|INFO] ep: 186,	T: 5,371,	G: 0.37,	avg_G: 0.352,	t: 1,	dt: 299.691ms
[MainThread|TrainMonitor|INFO] ep: 187,	T: 5,373,	G: 0.372,	avg_

[MainThread|TrainMonitor|INFO] ep: 265,	T: 5,529,	G: 0.528,	avg_G: 0.51,	t: 1,	dt: 264.779ms
[MainThread|TrainMonitor|INFO] ep: 266,	T: 5,531,	G: 0.53,	avg_G: 0.512,	t: 1,	dt: 298.720ms
[MainThread|TrainMonitor|INFO] ep: 267,	T: 5,533,	G: 0.532,	avg_G: 0.514,	t: 1,	dt: 332.013ms
[MainThread|TrainMonitor|INFO] ep: 268,	T: 5,535,	G: 0.534,	avg_G: 0.516,	t: 1,	dt: 283.657ms
[MainThread|TrainMonitor|INFO] ep: 269,	T: 5,537,	G: 0.536,	avg_G: 0.518,	t: 1,	dt: 282.322ms
[MainThread|TrainMonitor|INFO] ep: 270,	T: 5,539,	G: 0.538,	avg_G: 0.52,	t: 1,	dt: 249.371ms
[MainThread|TrainMonitor|INFO] ep: 271,	T: 5,541,	G: 0.54,	avg_G: 0.522,	t: 1,	dt: 299.049ms
[MainThread|TrainMonitor|INFO] ep: 272,	T: 5,543,	G: 0.542,	avg_G: 0.524,	t: 1,	dt: 316.277ms
[MainThread|TrainMonitor|INFO] ep: 273,	T: 5,545,	G: 0.544,	avg_G: 0.526,	t: 1,	dt: 265.237ms
[MainThread|TrainMonitor|INFO] ep: 274,	T: 5,547,	G: 0.546,	avg_G: 0.528,	t: 1,	dt: 252.425ms
[MainThread|TrainMonitor|INFO] ep: 275,	T: 5,549,	G: 0.548,	avg_

Checking Learned Policy

In [11]:
env = SelfBalancing()
while not env.terminated():
    s = env.observe()
    action = pi(s, return_logp=False)
    env.step(action)
del env
print (action)

[0.7042213 0.7425323 0.4257823]
