In [1]:
#@title Import Brax and some helper modules

import functools
import time

from IPython.display import HTML, Image 
import gym
from gym import wrappers

try:
  import brax
except ImportError:
  from IPython.display import clear_output 
  !pip install git+https://github.com/google/brax.git@main
  clear_output()
  import brax

from brax import envs
from brax import jumpy as jp
from brax.envs import to_torch
from brax.io import html
from brax.io import image
import jax
from jax import numpy as jnp

import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.nn.utils as utils
import torchvision.transforms as T
from torch.autograd import Variable

import numpy as np
import math

pi = Variable(torch.FloatTensor([math.pi])).cuda()

In [2]:
class Baseline(nn.Module):
    def __init__(self, hidden_size, num_inputs):
        super(Baseline, self).__init__()
        self.linear1 = nn.Linear(num_inputs, hidden_size)
        self.linear2 = nn.Linear(hidden_size, 1)

    def forward(self, inputs):
        x = inputs
        x = F.relu(self.linear1(x))
        baseline_out = self.linear2(x)
        return baseline_out

class Policy(nn.Module):
    def __init__(self, hidden_size, num_inputs, action_space):
        super(Policy, self).__init__()
        self.action_space = action_space
        num_outputs = action_space.shape[1]

        self.linear1 = nn.Linear(num_inputs, hidden_size)
        self.linear2 = nn.Linear(hidden_size, num_outputs)
        self.linear2_ = nn.Linear(hidden_size, num_outputs)

    def forward(self, inputs):
        x = inputs
        x = F.relu(self.linear1(x))
        mu = self.linear2(x)
        sigma_sq = self.linear2_(x)

        return mu, sigma_sq


class REINFORCE(nn.Module):
    def __init__(self, hidden_size, num_inputs, action_space):
        super(REINFORCE, self).__init__()
        self.action_space = action_space
        
        self.model = Policy(hidden_size, num_inputs, action_space)
        self.model = self.model.cuda()
        
        self.baseline = Baseline(hidden_size, num_inputs)
        self.baseline = self.baseline.cuda()

        self.model.train()
        self.baseline.train()

    @torch.jit.export
    def dist_sample_no_postprocess(self, loc, scale):
        return torch.normal(loc, scale)

    @torch.jit.export
    def dist_entropy(self, loc, scale):
        log_normalized = 0.5 * math.log(2 * math.pi) + torch.log(scale)
        entropy = 0.5 + log_normalized
        entropy = entropy * torch.ones_like(loc)
        dist = torch.normal(loc, scale)
        log_det_jacobian = 2 * (math.log(2) - dist - F.softplus(-2 * dist))
        entropy = entropy + log_det_jacobian
        return entropy.sum(dim=-1)

    @torch.jit.export
    def dist_log_prob(self, loc, scale, dist):
        log_unnormalized = -0.5 * ((dist - loc) / scale).square()
        log_normalized = 0.5 * math.log(2 * math.pi) + torch.log(scale)
        log_det_jacobian = 2 * (math.log(2) - dist - F.softplus(-2 * dist))
        log_prob = log_unnormalized - log_normalized - log_det_jacobian
        return log_prob.sum(dim=-1)

    @torch.jit.export
    def select_action(self, state):
        mu, sigma_sq = self.model(state.cuda())
        sigma_sq = F.softplus(sigma_sq)
        sigma = sigma_sq.sqrt()
        action = torch.tanh(self.dist_sample_no_postprocess(mu, sigma))
#         eps = torch.randn(mu.size()).cuda()
#         action = (mu + sigma_sq.sqrt()*eps).data
#         prob = torch.distributions.normal.Normal(mu, sigma_sq.sqrt())
#         entropy = prob.entropy().sum(axis=-1)
#         log_prob = prob.log_prob(action).sum(axis=-1)
        entropy = self.dist_entropy(mu,  sigma)
        log_prob = self.dist_log_prob(mu,  sigma, action)
        return action, log_prob, entropy

    @torch.jit.export
    def update_parameters(self, rewards, log_probs, entropies, gamma, states):
        # Bookkeeping
        R_EPS = 1e-9
        R = torch.zeros(rewards.shape[0],rewards.shape[1]).cuda()
        running_r = torch.zeros(rewards.shape[0],).cuda()
        baseline_losses = torch.zeros(rewards.shape[1],).cuda()
        
        # Compute discounted cumulative sum TODO: Check this
        for j in range(rewards.shape[1]):
            i = rewards.shape[1] - 1 - j
            running_r = gamma * running_r + rewards[:, i]
            baseline_rpred = self.baseline(states[:, i])[:, 0]
            R[:, i] = running_r - baseline_rpred # Subtract the baseline
            baseline_loss = torch.sum((baseline_rpred - running_r)**2)
            baseline_losses[i] = baseline_loss
            
        # Normalize advantages
        R_mean = torch.mean(R)
        R_std = torch.std(R)
        R = (R - R_mean) / (R_std + R_EPS)
        
        # Compute loss
        loss = -(log_probs*R).sum() - 0.0001*entropies.sum()
        loss = loss / len(rewards)
        baseline_loss = baseline_losses.sum() / len(rewards)
        loss += baseline_loss
        return loss

In [3]:
seed = 0
gamma = torch.Tensor([0.99]).cuda()
exploration_end = 100
num_steps = 1000
num_episodes = 2000
hidden_size = 128
num_envs = 100

entry_point = functools.partial(envs.create_gym_env, env_name='reacher')
if 'brax-reacher-v0' not in gym.envs.registry.env_specs:
    gym.register('brax-reacher-v0', entry_point=entry_point)
env = gym.make('brax-reacher-v0', batch_size=num_envs, episode_length=num_steps)
env = to_torch.JaxToTorchWrapper(env, device='cuda')

env.seed(seed)
torch.manual_seed(seed)
np.random.seed(seed)


agent = REINFORCE(hidden_size, env.observation_space.shape[1], env.action_space)
agent = torch.jit.script(agent)
optimizer = optim.Adam(list(agent.model.parameters()) + list(agent.baseline.parameters()))
    
for i_episode in range(num_episodes):
    state = env.reset()
    entropies = []
    log_probs = []
    rewards = []
    states = []
    actions = []
    for t in range(num_steps):
        action, log_prob, entropy = agent.select_action(state)
        action = action.cpu()
    
        next_state, reward, done, _ = env.step(action)
        entropies.append(entropy)
        log_probs.append(log_prob)
        rewards.append(reward)
        states.append(state)
        actions.append(action)
        state = next_state

    # Bookkeeping
    states = torch.cat([s[None] for s in states])
    states = torch.transpose(states, 1, 0)
    
    actions = torch.cat([a[None] for a in actions])
    actions = torch.transpose(actions, 1, 0)
    
    rewards = torch.cat([r[None] for r in rewards])
    rewards = torch.transpose(rewards, 1, 0)
    
    entropies = torch.cat([e[None] for e in entropies])
    entropies = torch.transpose(entropies, 1, 0)
    
    log_probs = torch.cat([lp[None] for lp in log_probs])
    log_probs = torch.transpose(log_probs, 1, 0)

    loss = agent.update_parameters(rewards, log_probs, entropies, gamma, states)
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    rewards_np = rewards.cpu().numpy().sum(axis=-1).mean()
    print("Episode: {}, reward: {}".format(i_episode, rewards_np))
    
env.close()




Episode: 0, reward: -838.8779907226562
Episode: 1, reward: -831.6727905273438


KeyboardInterrupt: 