In [1]:
#@title Import Brax and some helper modules

import functools
import time

from IPython.display import HTML, Image 
import gym
from gym import wrappers

try:
  import brax
except ImportError:
  from IPython.display import clear_output 
  !pip install git+https://github.com/google/brax.git@main
  clear_output()
  import brax

from brax import envs
from brax import jumpy as jp
from brax.envs import to_torch
from brax.io import html
from brax.io import image
import jax
from jax import numpy as jnp

import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.nn.utils as utils
import torchvision.transforms as T
from torch.autograd import Variable

import numpy as np
import math

pi = Variable(torch.FloatTensor([math.pi])).cuda()

In [2]:

class Policy(nn.Module):
    def __init__(self, hidden_size, num_inputs, action_space):
        super(Policy, self).__init__()
        self.action_space = action_space
        num_outputs = action_space.shape[1]

        self.linear1 = nn.Linear(num_inputs, hidden_size)
        self.linear2 = nn.Linear(hidden_size, num_outputs)
        self.linear2_ = nn.Linear(hidden_size, num_outputs)

    def forward(self, inputs):
        x = inputs
        x = F.relu(self.linear1(x))
        mu = self.linear2(x)
        sigma_sq = self.linear2_(x)

        return mu, sigma_sq


class REINFORCE:
    def __init__(self, hidden_size, num_inputs, action_space):
        self.action_space = action_space
        self.model = Policy(hidden_size, num_inputs, action_space)
        self.model = self.model.cuda()
        self.optimizer = optim.Adam(self.model.parameters(), lr=1e-3)
        self.model.train()

    def select_action(self, state):
        mu, sigma_sq = self.model(Variable(state).cuda())
        sigma_sq = F.softplus(sigma_sq)
        eps = torch.randn(mu.size())
        action = (mu + sigma_sq.sqrt()*Variable(eps).cuda()).data
        prob = torch.distributions.normal.Normal(mu, sigma_sq.sqrt())
        entropy = prob.entropy().sum(axis=-1)
        log_prob = prob.log_prob(action).sum(axis=-1)
        return action, log_prob, entropy

    def update_parameters(self, rewards, log_probs, entropies, gamma):
        # Bookkeeping
        R = torch.zeros(rewards.shape[0],rewards.shape[1]).cuda()
        running_r = torch.zeros(rewards.shape[0],).cuda()
        loss = 0
        R_EPS = 1e-9
        
        # Compute discounted cumulative sum TODO: Check this
        for i in reversed(range(rewards.shape[1])):
            running_r = gamma * running_r + rewards[:, i]
            R[:, i] = running_r
            
        # Normalize advantages
        R_mean = torch.mean(R)
        R_std = torch.std(R)
        R = (R - R_mean) / (R_std + R_EPS)
        
        # Compute loss
        loss = -(log_probs*R).sum() - 0.0001*entropies.sum()
        loss = loss.sum(axis=-1) / len(rewards)

        # Update
        self.optimizer.zero_grad()
        loss.backward()
#         utils.clip_grad_norm(self.model.parameters(), 40)
        self.optimizer.step()

In [3]:
seed = 0
gamma = 0.99
exploration_end = 100
num_steps = 1000
num_episodes = 2000
hidden_size = 128
num_envs = 100

entry_point = functools.partial(envs.create_gym_env, env_name='reacher')
if 'brax-reacher-v0' not in gym.envs.registry.env_specs:
    gym.register('brax-reacher-v0', entry_point=entry_point)
env = gym.make('brax-reacher-v0', batch_size=num_envs, episode_length=num_steps)
env = to_torch.JaxToTorchWrapper(env, device='cuda')

env.seed(seed)
torch.manual_seed(seed)
np.random.seed(seed)


agent = REINFORCE(hidden_size, env.observation_space.shape[1], env.action_space)

for i_episode in range(num_episodes):
    state = env.reset()
    entropies = []
    log_probs = []
    rewards = []
    states = []
    actions = []
    for t in range(num_steps):
        action, log_prob, entropy = agent.select_action(state)
        action = action.cpu()
    
        next_state, reward, done, _ = env.step(action.numpy())
        entropies.append(entropy)
        log_probs.append(log_prob)
        rewards.append(reward)
        states.append(state)
        actions.append(action)
        state = next_state

    # Bookkeeping
    states = torch.cat([s[None] for s in states])
    states = torch.transpose(states, 1, 0)
    
    actions = torch.cat([a[None] for a in actions])
    actions = torch.transpose(actions, 1, 0)
    
    rewards = torch.cat([r[None] for r in rewards])
    rewards = torch.transpose(rewards, 1, 0)
    
    entropies = torch.cat([e[None] for e in entropies])
    entropies = torch.transpose(entropies, 1, 0)
    
    log_probs = torch.cat([lp[None] for lp in log_probs])
    log_probs = torch.transpose(log_probs, 1, 0)
    
    agent.update_parameters(rewards, log_probs, entropies, gamma)
    
    rewards_np = rewards.cpu().numpy().sum(axis=-1).mean()
    print("Episode: {}, reward: {}".format(i_episode, rewards_np))
    
env.close()




Episode: 0, reward: -1604.8577880859375
Episode: 1, reward: -1564.8717041015625
Episode: 2, reward: -1555.08447265625
Episode: 3, reward: -1525.9417724609375
Episode: 4, reward: -1509.2880859375
Episode: 5, reward: -1475.717529296875
Episode: 6, reward: -1461.998291015625
Episode: 7, reward: -1444.377197265625
Episode: 8, reward: -1427.8492431640625
Episode: 9, reward: -1410.3984375
Episode: 10, reward: -1381.02197265625
Episode: 11, reward: -1356.1070556640625
Episode: 12, reward: -1337.4871826171875
Episode: 13, reward: -1318.907470703125
Episode: 14, reward: -1295.4886474609375
Episode: 15, reward: -1269.55322265625
Episode: 16, reward: -1242.878173828125
Episode: 17, reward: -1211.090576171875
Episode: 18, reward: -1183.08447265625
Episode: 19, reward: -1148.06494140625
Episode: 20, reward: -1122.9906005859375
Episode: 21, reward: -1082.00732421875
Episode: 22, reward: -1054.793212890625
Episode: 23, reward: -1025.46630859375
Episode: 24, reward: -1004.3538818359375
Episode: 25, re

KeyboardInterrupt: 