In [1]:
import os
import time
import gym
import pybullet_envs
import argparse

In [2]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical

In [3]:
from IPython.display import clear_output
import matplotlib.pyplot as plt

In [67]:
from common.multiprocessing_env import SubprocVecEnv

num_envs = 16
env_name = "MinitaurBulletEnv-v0"

def make_env():
    def _thunk():
        env = gym.make(env_name)
        return env

    return _thunk
    
'''
envs = [make_envs(env_name) for _ in range(num_envs)]
# Environment for train
envs = SubprocVecEnv(envs)
# Environment for test
env = gym.make(env_name)
'''

'\nenvs = [make_envs(env_name) for _ in range(num_envs)]\n# Environment for train\nenvs = SubprocVecEnv(envs)\n# Environment for test\nenv = gym.make(env_name)\n'

In [68]:
device = torch.device("cuda" if torch.cuda.is_available else "cpu")
print(f"Run using {device}")

Run using cuda


In [81]:
class Actor(nn.Module):
    def __init__(self, input_shape, n_actions):
        super(Actor, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(input_shape, 256),            
            nn.ReLU(),
            nn.Linear(256, n_actions),
            nn.Softmax(dim=-1)
        )

    def forward(self, x):
        probs = self.net(x)
        dist = Categorical(probs)
        return dist

In [82]:
class Critic(nn.Module):
    def __init__(self, input_shape):
        super(Critic, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(input_shape, 256),            
            nn.ReLU(),
            nn.Linear(256, 1)
        )
    
    def forward(self, x):
        value = self.net(x)
        return value

In [70]:
def plot(frame_idx, rewards, actor_losses, critic_losses, total_losses):
    def subplot(loc, title, values):
        plt.subplot(loc)
        plt.title(title)
        plt.plot(values)

    subplot_params = {
        (141, f"frame {frame_idx}, rewards : {rewards[-1]}", rewards),
        (142, "actor_loss", actor_losses),
        (143, "critic_loss", critic_losses),
        (144, "total_loss", total_losses)
    }

    clear_output(True)
    plt.figure(figsize=(30, 5))
    for loc, title, values in subplot_params:
        subplot(loc, title, values)
    plt.show()

In [83]:
def test_env(vis=False):
    state = env.reset()
    if vis: env.render()
    done = False
    total_reward = 0
    while not done:
        state = torch.tensor(state, dtype=torch.float).unsqueeze(0).to(device)
        dist = actor(state)        
        next_state, reward, done, _ = env.step(dist.sample().cpu().numpy()[0])
        state = next_state
        if vis:env.render()
        total_reward += reward
    return total_reward

In [84]:
def compute_returns(next_values, rewards, masks, gamma=0.99):
    R = next_values
    returns = []
    for step in reversed(range(len(rewards))):
        R = rewards[step] + gamma * R * masks[step]
        returns.insert(0, R)
    return returns

In [86]:
input_shape = envs.observation_space.shape[0]
n_actions = envs.action_space.shape[0]

num_steps = 5
actor = Actor(input_shape, n_actions).to(device)
critic = Critic(input_shape).to(device)

actor_optimizer = optim.Adam(actor.parameters(), lr=1e-4)
critic_optimizer = optim.Adam(critic.parameters(), lr=1e-3)

In [87]:
frame_idx = 0
test_rewards = []

In [88]:
envs = [make_env() for _ in range(num_envs)]
# Environment for train
envs = SubprocVecEnv(envs)
# Environment for test
env = gym.make(env_name)

urdf_root=/home/hskim/.local/lib/python3.8/site-packages/pybullet_data
urdf_root=/home/hskim/.local/lib/python3.8/site-packages/pybullet_data
urdf_root=/home/hskim/.local/lib/python3.8/site-packages/pybullet_data
urdf_root=/home/hskim/.local/lib/python3.8/site-packages/pybullet_data
urdf_root=/home/hskim/.local/lib/python3.8/site-packages/pybullet_data
urdf_root=/home/hskim/.local/lib/python3.8/site-packages/pybullet_data
urdf_root=/home/hskim/.local/lib/python3.8/site-packages/pybullet_data
urdf_root=/home/hskim/.local/lib/python3.8/site-packages/pybullet_data
urdf_root=/home/hskim/.local/lib/python3.8/site-packages/pybullet_data
urdf_root=/home/hskim/.local/lib/python3.8/site-packages/pybullet_data
urdf_root=/home/hskim/.local/lib/python3.8/site-packages/pybullet_data
urdf_root=/home/hskim/.local/lib/python3.8/site-packages/pybullet_data
urdf_root=/home/hskim/.local/lib/python3.8/site-packages/pybullet_data
urdf_root=/home/hskim/.local/lib/python3.8/site-packages/pybullet_data
urdf_r

In [80]:
states = envs.reset()
while True:
    log_prob_list = []
    value_list = []
    reward_list = []
    mask_list = []
    entropy = 0

    frame_idx += 1
    for _ in range(num_steps):
        states_t = torch.tensor(states, dtype=torch.float).to(device)
        dists = actor(states_t)
        values = critic(states_t)

        actions = dists.sample()
        next_states, rewards, done, _ = envs.step(actions.cpu().numpy())

        log_prob = dists.log_prob(actions)
        entropy += dists.entropy().mean()

        log_prob_list.append(log_prob)
        value_list.append(values)

        rewards_t = torch.tensor(rewards, dtype=torch.float).unsqueeze(-1).to(device)
        reward_list.append(rewards_t)

        mask_t = torch.tensor(1-done, dtype= torch.float).unsqueeze(-1).to(device)
        mask_list.append(mask_t)
        states = next_states    

    next_state_t = torch.tensor(next_states, dtype=torch.float).to(device)
    next_value = critic(next_state_t)
    returns = compute_returns(next_value, reward_list, mask_list)

    log_prob_t = torch.cat(log_prob_list)
    returns_t = torch.cat(returns)
    values = torch.cat(value_list)    

    advantage = returns_t - values
    assert returns_t.shape == values.shape, f"Different shape. returns_t : {returns_t.shape}, values : {values_t.shape}"

    actor_loss = -(log_prob_t * advantage.detach()).mean - 0.001 * entropy
    critic_loss = nn.MSELoss()(returns_t, values)

    total_loss = actor_loss + critic_loss
    
    a2c_optimizer.zero_grad()
    total_loss.backward()
    a2c_optimizer.step()

    if frame_idx % 1000 == 0:
        test_rewards.append(np.mean([test_env(env) for _ in range(10)]))
        plot(frame_idx, test_rewards, actor_loss, critic_loss, total_loss)


[2 6 6 6 1 0 6 1 3 6 3 5 7 5 1 6]


EOFError: 