In [1]:
import gym
import os
import torch
import wandb
import torch.optim as optim
from itertools import count
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
import matplotlib
from matplotlib import pyplot as plt

In [3]:
!wandb login 9c6dd4b0a9335b3cfcfbb62569cf65cd4e537266

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[32mSuccessfully logged in to Weights & Biases![0m


In [4]:
class Critic(nn.Module):
  def __init__(self, d_in):
    super(Critic, self).__init__()
    self.lin1 = nn.Linear(d_in, 32)
    self.lin2 = nn.Linear(32, 1)

  def forward(self, x):
    x = x.double()
    x = F.relu(self.lin1(x))
    return self.lin2(x)

class Actor(nn.Module):
  def __init__(self, d_in, d_out):
    super(Actor, self).__init__()
    self.lin1 = nn.Linear(d_in, 32)
    self.lin2 = nn.Linear(32, d_out)
    self.soft = nn.Softmax(dim=0)

  def forward(self, x):
    x = x.double()
    x = F.relu(self.lin1(x))
    return self.soft(self.lin2(x))

In [5]:
### HYPERPARAMETERS ###

wandb.init(project="a2c") # weights & biases tracking
env = gym.make('CartPole-v0')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("using " + str(device))

n_actions = env.action_space.n
n_state = env.observation_space.shape[0]-2

# actions:{0,1} apply force {+1 move right, -1 move left}
# Actor represents policy, given a state provide the probability of taking each action
actor = Actor(d_in=n_state, d_out=n_actions).double().to(device)

# Critic represents value function, given a state return the estimated value
critic = Critic(d_in=n_state).double().to(device)

# reward discount factor
gamma = 0.99
n_episodes = 500
lr_actor = .001
lr_critic = .005

optimizer_actor = optim.Adam(actor.parameters(),lr=lr_actor)
optimizer_critic = optim.Adam(critic.parameters(),lr=lr_critic)
criterion_actor = nn.CrossEntropyLoss()
criterion_critic = nn.MSELoss()

wandb.watch(actor)
wandb.watch(critic)

using cuda


[<wandb.wandb_torch.TorchGraph at 0x7ff1a5ae9908>]

In [6]:
for index, episode in enumerate(range(n_episodes)):
  state = env.reset()
  # only last two state observations are needed: angle and angular velocity
  state = torch.DoubleTensor(state)[-2:].to(device)

  if (index+1) % 200 == 0:
    lr_actor *= .1
    optimizer_actor = optim.Adam(actor.parameters(),lr=lr_actor)

  I = 1. # anneal policy update

  # log progress
  acc_loss_actor = 0.0
  acc_loss_critic = 0.0

  action = actor(state)

  for t in count():    
    # choose action stochastically
    step_action = np.random.choice(n_actions,p=action.detach().cpu().numpy())
    state_prime, reward, done, info = env.step(step_action)
    done_mask = 0. if done else 1.

    done_mask = torch.as_tensor(done_mask).requires_grad_().to(device)
    reward = torch.as_tensor(reward).requires_grad_().to(device)
    state_prime = torch.tensor(state_prime[-2:], requires_grad=True).to(device)

    # state-value V is expectation of return
    value = critic(state.detach())
    value_prime = critic(state_prime.detach())

    ### UPDATE ACTOR ###
    optimizer_actor.zero_grad()
    
    # update actor in direction of higher value (or quality)
    quality = reward + gamma*value_prime*done_mask # definition of Q(s,a)

    g_prime = quality - value # update expected return G of the action taken 
    advantage = g_prime
    
    target = torch.tensor([step_action])
    log_prob = torch.log(action)[target]
    
    # update actor's policy in direction of increasing advantage
    actor_loss = -(I * log_prob * advantage.detach()).mean()
    actor_loss.backward(retain_graph=True)
    optimizer_actor.step()

    ### UPDATE CRITIC ###

    optimizer_critic.zero_grad()
    
    # weigh the predicted value of our state against the added/lossed value when we took action a
    critic_loss = criterion_critic(value, quality)
    critic_loss.backward(retain_graph=True)
    optimizer_critic.step()

    acc_loss_actor += actor_loss.detach().cpu()
    acc_loss_critic += critic_loss.detach().cpu()
    action_prime = actor(state_prime)
    action = action_prime
    state = state_prime
    I *= gamma # anneal policy update

    if done:
      wandb.log({"Episode Duration": t+1, "Actor Loss": acc_loss_actor/(t+1),
                "Critic Loss": acc_loss_critic/(t+1)})
      break

print('Complete')
env.close()

# Save model to wandb
torch.save(actor.state_dict(), os.path.join(wandb.run.dir, 'actor.pt'))
torch.save(critic.state_dict(), os.path.join(wandb.run.dir, 'critic.pt'))

Complete
