In [1]:
#@title Import Brax and some helper modules

import functools
import time

from IPython.display import HTML, Image 
import gym

try:
  import brax
except ImportError:
  from IPython.display import clear_output 
  !pip install git+https://github.com/google/brax.git@main
  clear_output()
  import brax

from brax import envs
from brax import jumpy as jp
from brax.envs import to_torch
from brax.io import html
from brax.io import image
import jax
from jax import numpy as jnp
import torch
v = torch.ones(1, device='cuda')  # init torch cuda before jax

import sys
import math

import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.nn.utils as utils
import torchvision.transforms as T
from torch.autograd import Variable

import argparse, math, os
import numpy as np
import gym
from gym import wrappers
import brax
from brax import envs
from brax.envs import to_torch
import functools

import torch
from torch.autograd import Variable
import torch.nn.utils as utils


pi = Variable(torch.FloatTensor([math.pi])).cuda()

In [2]:
import gym


class NormalizedActions(gym.ActionWrapper):

    def action(self, action):
        action = (action + 1) / 2  # [-1, 1] => [0, 1]
        action *= (self.action_space.high - self.action_space.low)
        action += self.action_space.low
        return action

    def reverse_action(self, action):
        action -= self.action_space.low
        action /= (self.action_space.high - self.action_space.low)
        action = action * 2 - 1
        return actions


In [3]:
def normal(x, mu, sigma_sq):
    a = (-1*(Variable(x)-mu).pow(2)/(2*sigma_sq)).exp()
    b = 1/(2*sigma_sq*pi.expand_as(sigma_sq)).sqrt()
    return a*b


class Policy(nn.Module):
    def __init__(self, hidden_size, num_inputs, action_space):
        super(Policy, self).__init__()
        self.action_space = action_space
        num_outputs = action_space.shape[0]

        self.linear1 = nn.Linear(num_inputs, hidden_size)
        self.linear2 = nn.Linear(hidden_size, num_outputs)
        self.linear2_ = nn.Linear(hidden_size, num_outputs)

    def forward(self, inputs):
        x = inputs
        x = F.relu(self.linear1(x))
        mu = self.linear2(x)
        sigma_sq = self.linear2_(x)

        return mu, sigma_sq


class REINFORCE:
    def __init__(self, hidden_size, num_inputs, action_space):
        self.action_space = action_space
        self.model = Policy(hidden_size, num_inputs, action_space)
        self.model = self.model.cuda()
        self.optimizer = optim.Adam(self.model.parameters(), lr=1e-3)
        self.model.train()

    def select_action(self, state):
        mu, sigma_sq = self.model(Variable(state).cuda())
        sigma_sq = F.softplus(sigma_sq)

        eps = torch.randn(mu.size())
        # calculate the probability
        action = (mu + sigma_sq.sqrt()*Variable(eps).cuda()).data
        prob = normal(action, mu, sigma_sq)
        entropy = -0.5*((sigma_sq+2*pi.expand_as(sigma_sq)).log()+1)

        log_prob = prob.log()
        return action, log_prob, entropy

    def update_parameters(self, rewards, log_probs, entropies, gamma):
        R = torch.zeros(1, 1)
        loss = 0
        for i in reversed(range(len(rewards))):
            R = gamma * R + rewards[i]
            loss = loss - (log_probs[i]*(Variable(R).expand_as(log_probs[i])).cuda()).sum() - (0.0001*entropies[i].cuda()).sum()
        loss = loss / len(rewards)
		
        self.optimizer.zero_grad()
        loss.backward()
        utils.clip_grad_norm(self.model.parameters(), 40)
        self.optimizer.step()


In [8]:

seed = 0
gamma = 0.99
exploration_end = 100
num_steps = 1000
num_episodes = 2000
hidden_size = 128


entry_point = functools.partial(envs.create_gym_env, env_name='reacher')
if 'brax-reacher-v0' not in gym.envs.registry.env_specs:
    gym.register('brax-reacher-v0', entry_point=entry_point)
env = gym.make('brax-reacher-v0')
env = to_torch.JaxToTorchWrapper(env, device='cpu')

# if type(env.action_space) != gym.spaces.discrete.Discrete:
#     from reinforce_continuous import REINFORCE
#     # env = NormalizedActions(gym.make(env_name))
# else:
#     from reinforce_discrete import REINFORCE

env.seed(seed)
torch.manual_seed(seed)
np.random.seed(seed)

agent = REINFORCE(hidden_size, env.observation_space.shape[0], env.action_space)

for i_episode in range(num_episodes):
    state = env.reset()[None]
    entropies = []
    log_probs = []
    rewards = []
    for t in range(num_steps):
        action, log_prob, entropy = agent.select_action(state)
        action = action.cpu()
        
        next_state, reward, done, _ = env.step(action.numpy()[0])
        done = done.cpu().numpy().item()
        entropies.append(entropy)
        log_probs.append(log_prob)
        rewards.append(reward.cpu().numpy().item())
        state = next_state[None]

        if done:
            break

    agent.update_parameters(rewards, log_probs, entropies, gamma)

    print("Episode: {}, reward: {}".format(i_episode, np.sum(rewards)))
	
env.close()




Episode: 0, reward: -1571.221885085106
Episode: 1, reward: -1640.5578821897507
Episode: 2, reward: -1647.704906783998
Episode: 3, reward: -1545.376258423552
Episode: 4, reward: -1647.3130498751998
Episode: 5, reward: -1640.9833238646388
Episode: 6, reward: -1698.7089726254344
Episode: 7, reward: -1603.1306994650513
Episode: 8, reward: -1524.0778157226741
Episode: 9, reward: -1604.8992355391383
Episode: 10, reward: -1569.7512127421796
Episode: 11, reward: -1633.798154644668
Episode: 12, reward: -1558.9657598733902
Episode: 13, reward: -1547.2181352041662
Episode: 14, reward: -1600.4012333378196
Episode: 15, reward: -1559.8717135414481
Episode: 16, reward: -1549.1581638380885
Episode: 17, reward: -1637.0808537714183
Episode: 18, reward: -1556.6965787038207
Episode: 19, reward: -1560.3428805153817
Episode: 20, reward: -1511.4160367697477
Episode: 21, reward: -1591.6446281820536
Episode: 22, reward: -1639.6424173228443
Episode: 23, reward: -1576.5129375085235
Episode: 24, reward: -1658.999

KeyboardInterrupt: ignored