In [1]:
!pip install wandb



In [2]:
import gym
import os
import torch
import wandb
import torch.optim as optim
from itertools import count
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
import matplotlib
from matplotlib import pyplot as plt

In [3]:
!wandb login 9c6dd4b0a9335b3cfcfbb62569cf65cd4e537266

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[32mSuccessfully logged in to Weights & Biases![0m


In [4]:
class Critic(nn.Module):
  def __init__(self, d_in, hidden_dim):
    super(Critic, self).__init__()
    self.lin1 = nn.Linear(d_in, hidden_dim)
    self.lin2 = nn.Linear(hidden_dim, hidden_dim)
    self.lin3 = nn.Linear(hidden_dim, 1)

  def forward(self, x):
    x = x.double()
    x = F.relu(self.lin1(x))
    x = F.relu(self.lin2(x))
    return self.lin3(x)

class Actor(nn.Module):
  def __init__(self, d_in, hidden_dim, d_out):
    super(Actor, self).__init__()
    self.lin1 = nn.Linear(d_in, hidden_dim)
    self.lin2 = nn.Linear(hidden_dim, hidden_dim)
    self.lin3 = nn.Linear(hidden_dim, d_out)
    self.soft = nn.Softmax(dim=0)

  def forward(self, x):
    x = x.double()
    x = F.relu(self.lin1(x))
    x = F.relu(self.lin2(x))
    return self.soft(self.lin3(x))

In [5]:
def log(durations,durations_means,losses_act,losses_crit,mean_action_confidences,solved_episode):
  max_hundred = durations_means.max()
  mean_hundred = durations_means[-1]

  # losses
  losses_act_t = torch.tensor(losses_act, dtype=torch.float)
  loss_act_means = losses_act_t.unfold(0, min(len(durations),100), 1).mean(1).view(-1)
  # 100-episode mean actor loss
  mean_act_loss = loss_act_means[-1]

  losses_crit_t = torch.tensor(losses_crit, dtype=torch.float)
  loss_crit_means = losses_crit_t.unfold(0, min(len(durations),100), 1).mean(1).view(-1)
  # 100-episode mean actor loss
  mean_crit_loss = loss_crit_means[-1]

  mean_action_confidences_t = torch.tensor(mean_action_confidences, dtype=torch.float)
  mean_confidences = mean_action_confidences_t.unfold(0, min(len(durations),100), 1).mean(1).view(-1)
  # 100-episode mean action confidence
  mean_hundred_confidence = mean_confidences[-1]

  wandb.log({"Max Duration": max_hundred, 
            "Solved Timestep": solved_episode,
            "Mean Duration": mean_hundred, 
          #  "Episode Duration": t+1, 
            "Actor Loss": mean_act_loss,
          "Critic Loss": mean_crit_loss, 
          #  "Mean Action-Confidence": mean_action_confidences[-1], 
            "Action-Confidence": mean_hundred_confidence})

In [6]:
### HYPERPARAMETERS ###

env = gym.make('CartPole-v0')
print('Observation max: ' + str(env.observation_space.high))
print('Observation min: ' + str(env.observation_space.low))

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("using " + str(device))

n_actions = env.action_space.n
n_state = env.observation_space.shape[0]-2
n_episodes = 500

# hyperparameter sweep
sweep_config = {
    'method': 'grid', #grid, random
    'metric': {
      'name': 'Solved Timestep',
      'goal': 'minimize'   
    },
    # 'controller': {
    #   'type': 'local'
    # },
    'parameters': {
        # 'lr_actor': {
        #     'values': [0.01,0.001,0.0001]
        # },
        # 'lr_critic': {
        #     'values': [0.05,0.005,0.0005]
        # },
        'n_steps': {
            'values': [3,5,10,15]
        },
        'hidden_dim': {
            'values': [16,32,64]
        },
        # 'gamma': {
        #     'values': [.99,0.999]
        # },
        # 'actor_lr_drop': {
        #     'values': [100,200,300,400,500]
        # },
        # 'critic_lr_drop': {
        #     'values': [100,200,300,400,500]
        # }
    }
}

sweep_id = wandb.sweep(sweep_config, project="a2c")

Observation max: [4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38]
Observation min: [-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38]
using cpu
Create sweep with ID: ee1p4wrv
Sweep URL: https://app.wandb.ai/nj-nj23essdf-sl-/a2c/sweeps/ee1p4wrv


In [7]:
def train():

  config_defaults = dict(
    lr_actor = .001,
    lr_critic = .005,
    n_steps = 2,
    hidden_dim = 64,
    gamma = 0.999,
    actor_lr_drop = 400,
    critic_lr_drop = 400
  )

  wandb.init(project="a2c",config=config_defaults)
  config = wandb.config

  lr_actor = config.lr_actor
  lr_critic = config.lr_critic
  n = config.n_steps
  hidden_dim = config.hidden_dim
  gamma = config.gamma
  # drop by factor of 10 every [actor_lr_drop] episodes  
  actor_lr_drop = config.actor_lr_drop
  critic_lr_drop = config.critic_lr_drop

  # actions:{0,1} apply force {+1 move right, -1 move left}
  # Actor represents policy, given a state provide the probability of taking each action
  actor = Actor(d_in=n_state, hidden_dim=config.hidden_dim, d_out=n_actions).double().to(device)

  # Critic represents value function, given a state return the estimated value
  critic = Critic(d_in=n_state,hidden_dim=config.hidden_dim).double().to(device)

  optimizer_actor = optim.Adam(actor.parameters(),lr=lr_actor)
  optimizer_critic = optim.Adam(critic.parameters(),lr=lr_critic)
  criterion_critic = nn.L1Loss()

  wandb.watch(actor)
  wandb.watch(critic)

  durations = []
  losses_act = []
  losses_crit = []
  mean_action_confidences = []

  max_hundred = 0.
  mean_hundred = 0.

  for index, episode in enumerate(range(n_episodes)):
    episode_confidences = []
    state = env.reset()
    # only last two state observations are needed: angle and angular velocity
    state = torch.DoubleTensor(state)[-2:].to(device)

    # lr scheduler for actor, drop by factor of 10 every actor_lr_drop episodes  
    if (index+1) % actor_lr_drop == 0:
      lr_actor *= .1
      optimizer_actor = optim.Adam(actor.parameters(),lr=lr_actor)

    if (index+1) % critic_lr_drop == 0:
      lr_critic *= .1
      optimizer_critic = optim.Adam(critic.parameters(),lr=lr_actor)

    I = 1. # anneal policy update

    # log progress
    acc_loss_actor = 0.
    acc_loss_critic = 0.
    action_confidence = 0.
    solved_episode = 500
    n_actor_losses = []
    n_critic_losses = []

    action = actor(state)

    for t in count():    
      # choose action stochastically
      step_action = np.random.choice(n_actions,p=action.detach().cpu().numpy())
      episode_confidences.append(action.detach().cpu().max())
      state_prime, reward, done, info = env.step(step_action)
      done_mask = 0. if done else 1.

      done_mask = torch.as_tensor(done_mask).requires_grad_().to(device)
      reward = torch.as_tensor(reward).requires_grad_().to(device)
      state_prime = torch.tensor(state_prime[-2:], requires_grad=True).to(device)

      value = critic(state)
      value_prime = critic(state_prime)
      
      # update actor in direction of higher value (or quality)
      quality = reward + gamma*value_prime*done_mask # definition of Q(s,a)
      g_prime = quality - value # update expected return G of the action taken 
      advantage = g_prime

      target = torch.tensor([step_action])
      log_prob = torch.log(action)[target]

      # update actor's policy in direction of increasing advantage
      actor_loss = -I * log_prob * advantage.detach()
      n_actor_losses.append(actor_loss)

      # weigh the predicted value of our state 
      # against the n-step predicted future value
      # after taking action 'a'
      critic_loss = criterion_critic(value, quality)
      n_critic_losses.append(critic_loss)
      
      # only update every n-steps
      if (t+1) % n == 0:

        ### UPDATE ACTOR ###
        optimizer_actor.zero_grad()
        torch.autograd.backward(n_actor_losses,retain_graph=True)
        optimizer_actor.step()

        ### UPDATE CRITIC ###
        optimizer_critic.zero_grad()
        torch.autograd.backward(n_critic_losses,retain_graph=True)
        optimizer_critic.step()

        n_actor_losses = []
        n_critic_losses = []
        acc_loss_actor += actor_loss.detach().cpu()
        acc_loss_critic += critic_loss.detach().cpu()

      action_prime = actor(state_prime)
      action = action_prime
      state = state_prime
      I *= gamma # anneal policy update
      
      if done:
        durations.append(t+1)
        losses_act.append(acc_loss_actor/(t+1))
        losses_crit.append(acc_loss_critic/(t+1))
        mean_action_confidences.append(np.array(episode_confidences).mean())

        """cartpole is solved when the agent can balance the pole
        for an average of at least 195 timesteps over 100 consecutive  
        episodes"""
        durations_t = torch.tensor(durations, dtype=torch.float)
        # log 100-episode means
        durations_means = durations_t.unfold(0, min(len(durations),100), 1).mean(1).view(-1)
        # max 100-episode mean duration
        max_hundred = durations_means.max()

        # solved?
        if max_hundred < 195.:
          log(durations,durations_means,losses_act,losses_crit,
              mean_action_confidences,solved_episode)
        else:
          # early stop if model has triumphed
          solved_episode = index
          log(durations,durations_means,losses_act,losses_crit,
              mean_action_confidences,solved_episode)
          env.close()
          # Save models
          torch.save(actor.state_dict(), os.path.join(wandb.run.dir, 'actor.pt'))
          torch.save(critic.state_dict(), os.path.join(wandb.run.dir, 'critic.pt'))
          return
        break

  env.close()
  # Save models
  torch.save(actor.state_dict(), os.path.join(wandb.run.dir, 'actor.pt'))
  torch.save(critic.state_dict(), os.path.join(wandb.run.dir, 'critic.pt'))
  return

In [8]:
# !wandb controller $sweep_id

In [9]:
wandb.agent(sweep_id, train)

wandb: Agent Starting Run: gljrszd6 with config:
	hidden_dim: 16
	n_steps: 1
wandb: Agent Started Run: gljrszd6


wandb: Agent Finished Run: gljrszd6 

wandb: Agent Starting Run: g1yjvqpf with config:
	hidden_dim: 16
	n_steps: 3
wandb: Agent Started Run: g1yjvqpf


wandb: Agent Finished Run: g1yjvqpf 

wandb: Agent Starting Run: 0hazt5dt with config:
	hidden_dim: 16
	n_steps: 5
wandb: Agent Started Run: 0hazt5dt


wandb: Agent Finished Run: 0hazt5dt 

wandb: Agent Starting Run: 4ld4nef9 with config:
	hidden_dim: 16
	n_steps: 10
wandb: Agent Started Run: 4ld4nef9


wandb: Agent Finished Run: 4ld4nef9 

wandb: Agent Starting Run: la9mobyx with config:
	hidden_dim: 16
	n_steps: 15
wandb: Agent Started Run: la9mobyx


wandb: Agent Finished Run: la9mobyx 

wandb: Agent Starting Run: ziq80kis with config:
	hidden_dim: 32
	n_steps: 1
wandb: Agent Started Run: ziq80kis


wandb: Agent Finished Run: ziq80kis 

wandb: Agent Starting Run: 6c8l7p4c with config:
	hidden_dim: 32
	n_steps: 3
wandb: Agent Started Run: 6c8l7p4c


wandb: Agent Finished Run: 6c8l7p4c 

wandb: Agent Starting Run: ruqv42jb with config:
	hidden_dim: 32
	n_steps: 5
wandb: Agent Started Run: ruqv42jb


wandb: Agent Finished Run: ruqv42jb 

wandb: Agent Starting Run: crj1dgsw with config:
	hidden_dim: 32
	n_steps: 10
wandb: Agent Started Run: crj1dgsw


wandb: Agent Finished Run: crj1dgsw 

wandb: Agent Starting Run: 8g68kyd9 with config:
	hidden_dim: 32
	n_steps: 15
wandb: Agent Started Run: 8g68kyd9


wandb: Agent Finished Run: 8g68kyd9 

wandb: Agent Starting Run: 5xfmg9vq with config:
	hidden_dim: 64
	n_steps: 1
wandb: Agent Started Run: 5xfmg9vq


wandb: Agent Finished Run: 5xfmg9vq 

wandb: Agent Starting Run: mwyp5rvm with config:
	hidden_dim: 64
	n_steps: 3
wandb: Agent Started Run: mwyp5rvm


wandb: Agent Finished Run: mwyp5rvm 

wandb: Agent Starting Run: zcfgx6dt with config:
	hidden_dim: 64
	n_steps: 5
wandb: Agent Started Run: zcfgx6dt


wandb: Agent Finished Run: zcfgx6dt 

wandb: Agent Starting Run: bkk17j6q with config:
	hidden_dim: 64
	n_steps: 10
wandb: Agent Started Run: bkk17j6q


wandb: Agent Finished Run: bkk17j6q 

wandb: Agent Starting Run: 836g1m3h with config:
	hidden_dim: 64
	n_steps: 15
wandb: Agent Started Run: 836g1m3h


wandb: Agent Finished Run: 836g1m3h 



In [10]:
# train()