In [11]:
import numpy as np
import time
import random
import gym

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import matplotlib.pyplot as plt

from IPython.display import clear_output
%matplotlib inline

In [12]:
env = gym.make('InvertedPendulum-v2')

print('Observation Shape:', env.observation_space.shape, '\nAction Shape:', env.action_space)

Observation Shape: (4,) 
Action Shape: Box(1,)


In [13]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [14]:
## Hyperparameters

BATCH_SIZE = 128
LEARNING_RATE = 0.001
DISCOUNT = 0.99
EPS = 1
EPS_DECAY = 0.9999
END_EPS = 0.1

N_EPISODE = 2000

obs_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]

In [15]:
class Actor(nn.Module):
    def __init__(self, observations, actions):
        super(Actor, self).__init__()
        self.actor = nn.Sequential(
            nn.Linear(observations, 32),
            nn.ReLU(),
            nn.Linear(32, 16),
            nn.ReLU()
        )
        self.l1 = nn.Linear(16, actions)
        self.l2 = nn.Linear(16, actions)
        
    def forward(self, x):
        x = self.actor(x)
        mean = self.l1(x)
        variance = F.softplus(self.l2(x))
        
        return mean, variance

In [16]:
actor = Actor(obs_dim, action_dim).to(device)

In [10]:
def actors_action(state):
    state = torch.FloatTensor(state).to(device)
    
    mean, variance = actor(state)
    
    m = torch.distributions.Normal(mean, torch.sqrt(variance))
    action = m.sample()
    log_prob = m.log_prob(action)
    
    return action.detach().cpu().numpy(), log_prob

In [17]:
class Critic(nn.Module):
    def __init__(self, observations, actions):
        super(Critic, self).__init__()
        self.critic = nn.Sequential(
            nn.Linear(observations, 64),
            nn.ReLU(),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, actions)
        )
    
    def forward(self, x):
        return self.critic(x)

In [18]:
critic = Critic(obs_dim, action_dim).to(device)

In [19]:
def critics_action(state):
    global EPS, END_EPS, EPS_DECAY
    EPS = max(EPS*EPS_DECAY, END_EPS)
    
    if random.random()<EPS:
        action = env.action_space.sample()
    else:
        state = torch.FloatTensor(state).to(device)
        action = torch.argmax(critic(state).item())
    
    return action