In [1]:
#@title Import Brax and some helper modules

import functools
import time

from IPython.display import HTML, Image 
import gym
from gym import wrappers

try:
  import brax
except ImportError:
  from IPython.display import clear_output 
  !pip install git+https://github.com/google/brax.git@main
  clear_output()
  import brax

from brax import envs
from brax import jumpy as jp
from brax.envs import to_torch
from brax.io import html
from brax.io import image
import jax
from jax import numpy as jnp

import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.nn.utils as utils
import torchvision.transforms as T
from torch.autograd import Variable

import numpy as np
import math

pi = Variable(torch.FloatTensor([math.pi])).cuda()

In [2]:
class Baseline(nn.Module):
    def __init__(self, hidden_size, num_inputs):
        super(Baseline, self).__init__()
        self.linear1 = nn.Linear(num_inputs, hidden_size)
        self.linear2 = nn.Linear(hidden_size, 1)

    def forward(self, inputs):
        x = inputs
        x = F.relu(self.linear1(x))
        baseline_out = self.linear2(x)
        return baseline_out

class Policy(nn.Module):
    def __init__(self, hidden_size, num_inputs, action_space):
        super(Policy, self).__init__()
        self.action_space = action_space
        num_outputs = action_space.shape[1]

        self.linear1 = nn.Linear(num_inputs, hidden_size)
        self.linear2 = nn.Linear(hidden_size, num_outputs)
        self.linear2_ = nn.Linear(hidden_size, num_outputs)

    def forward(self, inputs):
        x = inputs
        x = F.relu(self.linear1(x))
        mu = self.linear2(x)
        sigma_sq = self.linear2_(x)

        return mu, sigma_sq


class REINFORCE:
    def __init__(self, hidden_size, num_inputs, action_space):
        self.action_space = action_space
        
        self.model = Policy(hidden_size, num_inputs, action_space)
        self.model = self.model.cuda()
        
        self.baseline = Baseline(hidden_size, num_inputs)
        self.baseline = self.baseline.cuda()
        
        self.optimizer = optim.Adam(list(self.model.parameters()) + 
                                    list(self.baseline.parameters()))
        self.model.train()
        self.baseline.train()

    def select_action(self, state):
        mu, sigma_sq = self.model(Variable(state).cuda())
        sigma_sq = F.softplus(sigma_sq)
        eps = torch.randn(mu.size())
        action = (mu + sigma_sq.sqrt()*Variable(eps).cuda()).data
        prob = torch.distributions.normal.Normal(mu, sigma_sq.sqrt())
        entropy = prob.entropy().sum(axis=-1)
        log_prob = prob.log_prob(action).sum(axis=-1)
        return action, log_prob, entropy

    def update_parameters(self, rewards, log_probs, entropies, gamma, states):
        # Bookkeeping
        R = torch.zeros(rewards.shape[0],rewards.shape[1]).cuda()
        running_r = torch.zeros(rewards.shape[0],).cuda()
        loss = 0
        R_EPS = 1e-9
        baseline_loss = 0
        
        # Compute discounted cumulative sum TODO: Check this
        for i in reversed(range(rewards.shape[1])):
            running_r = gamma * running_r + rewards[:, i]
            baseline_rpred = self.baseline(states[:, i])[:, 0]
            R[:, i] = running_r - baseline_rpred # Subtract the baseline
            baseline_loss += torch.sum((baseline_rpred - running_r)**2)
            
        # Normalize advantages
        R_mean = torch.mean(R)
        R_std = torch.std(R)
        R = (R - R_mean) / (R_std + R_EPS)
        
        # Compute loss
        loss = -(log_probs*R).sum() - 0.0001*entropies.sum()
        loss = loss.sum(axis=-1) / len(rewards)
        baseline_loss = baseline_loss.sum() / len(rewards)
        loss += baseline_loss
        
        # Update
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

In [None]:
seed = 0
gamma = 0.99
exploration_end = 100
num_steps = 1000
num_episodes = 2000
hidden_size = 128
num_envs = 100

entry_point = functools.partial(envs.create_gym_env, env_name='reacher')
if 'brax-reacher-v0' not in gym.envs.registry.env_specs:
    gym.register('brax-reacher-v0', entry_point=entry_point)
env = gym.make('brax-reacher-v0', batch_size=num_envs, episode_length=num_steps)
env = to_torch.JaxToTorchWrapper(env, device='cuda')

env.seed(seed)
torch.manual_seed(seed)
np.random.seed(seed)


agent = REINFORCE(hidden_size, env.observation_space.shape[1], env.action_space)

for i_episode in range(num_episodes):
    state = env.reset()
    entropies = []
    log_probs = []
    rewards = []
    states = []
    actions = []
    for t in range(num_steps):
        action, log_prob, entropy = agent.select_action(state)
        action = action.cpu()
    
        next_state, reward, done, _ = env.step(action.numpy())
        entropies.append(entropy)
        log_probs.append(log_prob)
        rewards.append(reward)
        states.append(state)
        actions.append(action)
        state = next_state

    # Bookkeeping
    states = torch.cat([s[None] for s in states])
    states = torch.transpose(states, 1, 0)
    
    actions = torch.cat([a[None] for a in actions])
    actions = torch.transpose(actions, 1, 0)
    
    rewards = torch.cat([r[None] for r in rewards])
    rewards = torch.transpose(rewards, 1, 0)
    
    entropies = torch.cat([e[None] for e in entropies])
    entropies = torch.transpose(entropies, 1, 0)
    
    log_probs = torch.cat([lp[None] for lp in log_probs])
    log_probs = torch.transpose(log_probs, 1, 0)
    
    agent.update_parameters(rewards, log_probs, entropies, gamma, states)
    
    rewards_np = rewards.cpu().numpy().sum(axis=-1).mean()
    print("Episode: {}, reward: {}".format(i_episode, rewards_np))
    
env.close()




Episode: 0, reward: -1601.69775390625
Episode: 1, reward: -1571.712646484375
Episode: 2, reward: -1554.6365966796875
Episode: 3, reward: -1528.5328369140625
Episode: 4, reward: -1508.90185546875
Episode: 5, reward: -1483.4117431640625
Episode: 6, reward: -1461.5804443359375
Episode: 7, reward: -1445.958251953125
Episode: 8, reward: -1429.760986328125
Episode: 9, reward: -1413.9884033203125
Episode: 10, reward: -1381.12451171875
Episode: 11, reward: -1360.1533203125
Episode: 12, reward: -1345.58984375
Episode: 13, reward: -1325.7763671875
Episode: 14, reward: -1310.641357421875
Episode: 15, reward: -1285.9781494140625
Episode: 16, reward: -1269.312255859375
Episode: 17, reward: -1239.0157470703125
Episode: 18, reward: -1221.7840576171875
Episode: 19, reward: -1189.4298095703125
Episode: 20, reward: -1162.86669921875
Episode: 21, reward: -1124.5494384765625
Episode: 22, reward: -1088.5606689453125
Episode: 23, reward: -1065.1495361328125
Episode: 24, reward: -1029.8067626953125
Episode: 

Episode: 202, reward: -204.9767608642578
Episode: 203, reward: -202.8994903564453
Episode: 204, reward: -199.24969482421875
Episode: 205, reward: -205.3081817626953
Episode: 206, reward: -196.9013214111328
Episode: 207, reward: -195.54225158691406
Episode: 208, reward: -200.2885284423828
Episode: 209, reward: -196.8485565185547
Episode: 210, reward: -194.0188446044922
Episode: 211, reward: -195.2609405517578
Episode: 212, reward: -195.0740966796875
Episode: 213, reward: -192.92689514160156
Episode: 214, reward: -197.0116424560547
Episode: 215, reward: -196.59991455078125
Episode: 216, reward: -192.46315002441406
Episode: 217, reward: -197.60105895996094
Episode: 218, reward: -194.84701538085938
Episode: 219, reward: -197.74774169921875
Episode: 220, reward: -196.2445526123047
Episode: 221, reward: -195.12171936035156
Episode: 222, reward: -195.60433959960938
Episode: 223, reward: -197.26470947265625
Episode: 224, reward: -198.80279541015625
Episode: 225, reward: -198.74261474609375
Epi

Episode: 400, reward: -186.84945678710938
Episode: 401, reward: -190.64137268066406
Episode: 402, reward: -194.64718627929688
Episode: 403, reward: -190.08299255371094
Episode: 404, reward: -194.6665802001953
Episode: 405, reward: -190.2834930419922
Episode: 406, reward: -187.94468688964844
Episode: 407, reward: -190.43695068359375
Episode: 408, reward: -185.1152801513672
Episode: 409, reward: -192.00404357910156
Episode: 410, reward: -188.7035675048828
Episode: 411, reward: -191.00369262695312
Episode: 412, reward: -192.33331298828125
Episode: 413, reward: -186.21046447753906
Episode: 414, reward: -186.23458862304688
Episode: 415, reward: -189.66064453125
Episode: 416, reward: -187.39649963378906
Episode: 417, reward: -187.15469360351562
Episode: 418, reward: -181.24671936035156
Episode: 419, reward: -189.37396240234375
Episode: 420, reward: -189.54910278320312
Episode: 421, reward: -187.93484497070312
Episode: 422, reward: -183.2698211669922
Episode: 423, reward: -182.8888702392578
E

Episode: 597, reward: -128.8185577392578
Episode: 598, reward: -136.3463134765625
Episode: 599, reward: -127.32836151123047
Episode: 600, reward: -130.63876342773438
Episode: 601, reward: -133.39712524414062
Episode: 602, reward: -142.71719360351562
Episode: 603, reward: -144.53253173828125
Episode: 604, reward: -160.83335876464844
Episode: 605, reward: -158.45761108398438
Episode: 606, reward: -165.97149658203125
Episode: 607, reward: -168.69749450683594
Episode: 608, reward: -171.75
Episode: 609, reward: -170.1461944580078
Episode: 610, reward: -156.61489868164062
Episode: 611, reward: -164.91268920898438
Episode: 612, reward: -165.5191650390625
Episode: 613, reward: -148.97605895996094
Episode: 614, reward: -150.3452911376953
Episode: 615, reward: -139.1302947998047
Episode: 616, reward: -134.79066467285156
Episode: 617, reward: -135.92398071289062
Episode: 618, reward: -128.9646453857422
Episode: 619, reward: -131.9531707763672
Episode: 620, reward: -117.20128631591797
Episode: 621

Episode: 796, reward: -253.69952392578125
Episode: 797, reward: -241.0338134765625
Episode: 798, reward: -228.24656677246094
