In [13]:
#@title Import Brax and some helper modules

import functools
import time

from IPython.display import HTML, Image 
import gym
from gym import wrappers

try:
  import brax
except ImportError:
  from IPython.display import clear_output 
  !pip install git+https://github.com/google/brax.git@main
  clear_output()
  import brax

from brax import envs
from brax import jumpy as jp
from brax.envs import to_torch
from brax.io import html
from brax.io import image
import jax
from jax import numpy as jnp

import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.nn.utils as utils
import torchvision.transforms as T
from torch.autograd import Variable

import numpy as np
import math

pi = Variable(torch.FloatTensor([math.pi])).cuda()

In [14]:
class ReplayBuffer(object):
    """Buffer to store environment transitions."""
    def __init__(self, obs_size, action_size, capacity, device):
        self.capacity = capacity
        self.device = device
        
        self.obses = np.empty((capacity, obs_size), dtype=np.float32)
        self.next_obses = np.empty((capacity, obs_size), dtype=np.float32)
        self.actions = np.empty((capacity, action_size), dtype=np.float32)
        self.rewards = np.empty((capacity, 1), dtype=np.float32)
        self.not_dones = np.empty((capacity, 1), dtype=np.float32)

        self.idx = 0
        self.last_save = 0
        self.full = False

    def __len__(self):
        return self.capacity if self.full else self.idx

    def add(self, obs, action, reward, next_obs, done):
        np.copyto(self.obses[self.idx], obs)
        np.copyto(self.actions[self.idx], action)
        np.copyto(self.rewards[self.idx], reward)
        np.copyto(self.next_obses[self.idx], next_obs)
        np.copyto(self.not_dones[self.idx], not done)

        self.idx = (self.idx + 1) % self.capacity
        self.full = self.full or self.idx == 0

    def sample(self, batch_size):
        idxs = np.random.randint(0,
                                 self.capacity if self.full else self.idx,
                                 size=batch_size)

        obses = torch.as_tensor(self.obses[idxs], device=self.device).float()
        actions = torch.as_tensor(self.actions[idxs], device=self.device)
        rewards = torch.as_tensor(self.rewards[idxs], device=self.device)
        next_obses = torch.as_tensor(self.next_obses[idxs],
                                     device=self.device).float()
        not_dones = torch.as_tensor(self.not_dones[idxs], device=self.device)

        return obses, actions, rewards, next_obses, not_dones
        
class Qf(nn.Module):
    def __init__(self, hidden_size, num_inputs):
        super(Qf, self).__init__()
        self.linear1 = nn.Linear(num_inputs, hidden_size)
        self.linear2 = nn.Linear(hidden_size, 1)

    def forward(self, obs, action):
        x = torch.cat([obs, action], dim=-1)
        x = F.relu(self.linear1(x))
        baseline_out = self.linear2(x)
        return baseline_out

class Policy(nn.Module):
    def __init__(self, hidden_size, num_inputs, action_space):
        super(Policy, self).__init__()
        self.action_space = action_space
        num_outputs = action_space.shape[1]

        self.linear1 = nn.Linear(num_inputs, hidden_size)
        self.linear2 = nn.Linear(hidden_size, num_outputs)
        self.linear2_ = nn.Linear(hidden_size, num_outputs)

    def forward(self, inputs):
        x = inputs
        x = F.relu(self.linear1(x))
        mu = self.linear2(x)
        sigma_sq = self.linear2_(x)

        return mu, sigma_sq


class ActorCritic(nn.Module):
    def __init__(self, hidden_size, num_inputs, action_space, discount=0.99):
        super(ActorCritic, self).__init__()
        self.action_space = action_space
        self.discount = discount
        
        self.model = Policy(hidden_size, num_inputs, action_space)
        self.model = self.model.cuda()
        
        self.qf1 = Qf(hidden_size, num_inputs + action_space.shape[1])
        self.qf1 = self.qf1.cuda()
        
        self.target_qf1 = Qf(hidden_size, num_inputs + action_space.shape[1])
        self.target_qf1 = self.target_qf1.cuda()
        
        self.model.train()
        self.qf1.train()

    @torch.jit.export
    def dist_sample_no_postprocess(self, loc, scale):
        return torch.normal(loc, scale)

    @torch.jit.export
    def dist_entropy(self, loc, scale):
        log_normalized = 0.5 * math.log(2 * math.pi) + torch.log(scale)
        entropy = 0.5 + log_normalized
        entropy = entropy * torch.ones_like(loc)
        dist = torch.normal(loc, scale)
        log_det_jacobian = 2 * (math.log(2) - dist - F.softplus(-2 * dist))
        entropy = entropy + log_det_jacobian
        return entropy.sum(dim=-1)

    @torch.jit.export
    def dist_log_prob(self, loc, scale, dist):
        log_unnormalized = -0.5 * ((dist - loc) / scale).square()
        log_normalized = 0.5 * math.log(2 * math.pi) + torch.log(scale)
        log_det_jacobian = 2 * (math.log(2) - dist - F.softplus(-2 * dist))
        log_prob = log_unnormalized - log_normalized - log_det_jacobian
        return log_prob.sum(dim=-1)

    @torch.jit.export
    def select_action(self, state):
        mu, sigma_sq = self.model(state.cuda())
        sigma_sq = F.softplus(sigma_sq)
        sigma = sigma_sq.sqrt()
        action = torch.tanh(self.dist_sample_no_postprocess(mu, sigma))
        entropy = self.dist_entropy(mu,  sigma)
        log_prob = self.dist_log_prob(mu,  sigma, action)
        return action, log_prob, entropy
        
    # TODO: Check maximum entropy
    @torch.jit.export
    def compute_losses(self, obs_t, actions_t, rewards_t, next_obs_t, not_dones_t):
        # Policy loss
        mu, sigma_sq = self.model(obs_t)
        sigma_sq = F.softplus(sigma_sq)
        sigma = sigma_sq.sqrt()
        new_obs_actions = torch.tanh(self.dist_sample_no_postprocess(mu, sigma))
        entropy = self.dist_entropy(mu,  sigma)
        log_pi = self.dist_log_prob(mu,  sigma, new_obs_actions)
        
        # TODO: Change this over to REINFORCE loss maybe instead of reparameterization??
        q_new_actions = self.qf1(obs_t, new_obs_actions)
        policy_loss = -q_new_actions.mean()

        # Compute Bellman loss
        q1_pred = self.qf1(obs_t, actions_t)
        mu, sigma_sq = self.model(next_obs_t.cuda())
        sigma_sq = F.softplus(sigma_sq)
        sigma = sigma_sq.sqrt()
        new_next_actions = torch.tanh(self.dist_sample_no_postprocess(mu, sigma))
        new_log_pi = self.dist_log_prob(mu,  sigma, new_obs_actions)
        
        target_q_values = self.target_qf1(next_obs_t, new_next_actions)
        q_target = rewards_t + not_dones_t * self.discount * target_q_values
        
        # L2 error on bellman
        qf_loss = torch.linalg.norm(q1_pred - q_target.detach(), dim=-1).mean()

        return policy_loss, qf_loss

In [17]:
def flip_target(net, target_net, tau):
    for param, target_param in zip(net.parameters(), target_net.parameters()):
        target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)

In [18]:
seed = 0
gamma = torch.Tensor([0.99]).cuda()
exploration_end = 100
num_steps = 1000
num_episodes = 2000
hidden_size = 128
num_envs = 100
num_update_steps = 100
capacity = 10000
device = 'cuda'
batch_size = 32
target_flip_freq = 10
target_flip_tau = 5e-3

entry_point = functools.partial(envs.create_gym_env, env_name='reacher')
if 'brax-reacher-v0' not in gym.envs.registry.env_specs:
    gym.register('brax-reacher-v0', entry_point=entry_point)
env = gym.make('brax-reacher-v0', batch_size=num_envs, episode_length=num_steps)
env = to_torch.JaxToTorchWrapper(env, device='cuda')

env.seed(seed)
torch.manual_seed(seed)
np.random.seed(seed)

replay_buffer = ReplayBuffer(env.observation_space.shape[1], 
                             env.action_space.shape[1], 
                             capacity, 
                             device)

agent = ActorCritic(hidden_size, env.observation_space.shape[1], env.action_space)
agent = torch.jit.script(agent)
optimizer = optim.Adam(list(agent.model.parameters()) + list(agent.qf1.parameters()))
    
# Copy parameters initially
flip_target(agent.qf1, agent.target_qf1, 1.0)

for i_episode in range(num_episodes):
    print("Episode ", i_episode)
    state = env.reset()
    entropies = []
    log_probs = []
    rewards = []
    states = []
    actions = []
    for t in range(num_steps):
        action, log_prob, entropy = agent.select_action(state)
        action = action.cpu()
    
        next_state, reward, done, _ = env.step(action)

        # TODO: Check this for speed. 
        for j in range(next_state.shape[0]):
            replay_buffer.add(state.cpu().detach().numpy()[j], 
                              action.cpu().detach().numpy()[j], 
                              reward.cpu().detach().numpy()[j], 
                              next_state.cpu().detach().numpy()[j], 
                              done.cpu().detach().numpy()[j])
        
        entropies.append(entropy)
        log_probs.append(log_prob)
        rewards.append(reward)
        states.append(state)
        actions.append(action)
        state = next_state

    # Bookkeeping and logging
    rewards = torch.cat([r[None] for r in rewards])
    rewards = torch.transpose(rewards, 1, 0)
    rewards_np = rewards.cpu().numpy().sum(axis=-1).mean()
    print("Episode: {}, reward: {}".format(i_episode, rewards_np))
    
    # Perform actor-critic update
    for update_num in range(num_update_steps):
        obs_t, actions_t, rewards_t, next_obs_t, not_dones_t = replay_buffer.sample(batch_size)
        policy_loss, qf_loss = agent.compute_losses(obs_t, actions_t, rewards_t, next_obs_t, not_dones_t)
        loss = policy_loss + qf_loss
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    # Flip target network originally
    if i_episode % target_flip_freq == 0:
        flip_target(agent.qf1, agent.target_qf1, target_flip_tau)
        
env.close()


Episode  0
Update step number  0
Update step number  1
Update step number  2
Update step number  3
Update step number  4
Update step number  5
Update step number  6
Update step number  7
Update step number  8
Update step number  9
Update step number  10
Update step number  11
Update step number  12
Update step number  13
Update step number  14
Update step number  15
Update step number  16
Update step number  17
Update step number  18
Update step number  19
Update step number  20
Update step number  21
Update step number  22
Update step number  23
Update step number  24
Update step number  25
Update step number  26
Update step number  27
Update step number  28
Update step number  29
Update step number  30
Update step number  31
Update step number  32
Update step number  33
Update step number  34
Update step number  35
Update step number  36
Update step number  37
Update step number  38
Update step number  39
Update step number  40
Update step number  41
Update step number  42
Update ste

Update step number  73
Update step number  74
Update step number  75
Update step number  76
Update step number  77
Update step number  78
Update step number  79
Update step number  80
Update step number  81
Update step number  82
Update step number  83
Update step number  84
Update step number  85
Update step number  86
Update step number  87
Update step number  88
Update step number  89
Update step number  90
Update step number  91
Update step number  92
Update step number  93
Update step number  94
Update step number  95
Update step number  96
Update step number  97
Update step number  98
Update step number  99
Episode: 3, reward: -831.6994018554688
Episode  4
Update step number  0
Update step number  1
Update step number  2
Update step number  3
Update step number  4
Update step number  5
Update step number  6
Update step number  7
Update step number  8
Update step number  9
Update step number  10
Update step number  11
Update step number  12
Update step number  13
Update step numbe

Update step number  76
Update step number  77
Update step number  78
Update step number  79
Update step number  80
Update step number  81
Update step number  82
Update step number  83
Update step number  84
Update step number  85
Update step number  86
Update step number  87
Update step number  88
Update step number  89
Update step number  90
Update step number  91
Update step number  92
Update step number  93
Update step number  94
Update step number  95
Update step number  96
Update step number  97
Update step number  98
Update step number  99
Episode: 7, reward: -838.4404907226562
Episode  8


KeyboardInterrupt: 

In [None]:
    def sample_t(self, batch_size): 
        obses, actions, rewards, next_obses, not_dones = self.sample(batch_size)
        import IPython
        IPython.embed()
        obs_t = torch.Tensor(obses).cuda()
        actions_t = torch.Tensor(actions).cuda()
        rewards_t = torch.Tensor(rewards).cuda()
        next_obs_t = torch.Tensor(next_obses).cuda()
        not_dones_t = torch.Tensor(not_dones).cuda()
        return obs_t, actions_t, rewards_t, next_obs_t, not_dones_t