In [1]:
from IPython.display import Markdown, display
def printmd(string):
    display(Markdown(string))

In [2]:
import itertools
import random
from collections import deque

import gym
import torch
import numpy as np
from torch import nn

In [3]:
GAMMA = 0.99              # discount rate for computing temporal difference target
BATCH_SIZE = 32           # How many transitions to sample from replay buffer when computing gradients
BUFFER_SIZE = 50000       # Max number of transitions to store before overwriting old transitions
MIN_REPLAY_SIZE = 1000    # How many transitions we want in replay buffer before starting to compute gradients, training
EPSILON_START = 1.0       # Start epsilon
EPSILON_END = 0.02        # End epsilon
EPSILON_DECAY = 10000     # Decay period under which epsilon will linearly anneal from start to end 
TARGET_UPDATE_FREQ = 1000 # Number of steps where we set target parameters equal to online parameters

In [4]:
env = gym.make('CartPole-v0') # Cartpole allows quick iteration 

In [5]:
class Network(nn.Module):
    def __init__(self, env):
        super().__init__()
        
        # How many neurons are in input layer of neural network
        # 4 in the case of cartpole (prod is not necessary here, since 1 dim)
        in_features = int(np.prod(env.observation_space.shape))  
        
        # 2 layer sequential linear network with 64 hidden units, separated by tanh nonlinearity
        # Outputs in network are equal to number of actions agents can take 
        self.net = nn.Sequential(
            nn.Linear(in_features, 64),
            nn.Tanh(),
            nn.Linear(64, env.action_space.n) # Q-learning can only be used with action spaces that have finite number of actions
        )
        
    def forward(self, x):
        # Given an input observation (our state) predict q values for that given state
        return self.net(x)
        
    def act(self, obs):
        # Select an action intelligently. Predict the q value of your current state, then pick the action 
        # i.e. the index that maximizes the value of the current state
        obs_t = torch.as_tensor(obs, dtype=torch.float32)
        q_values = self.forward(obs_t.unsqueeze(0)) # create fake batch dimension        
        max_q_index = torch.argmax(q_values, dim=1)[0]
        action = max_q_index.detach().item()
        
        return action
        
        


In [6]:
replay_buffer = deque(maxlen=BUFFER_SIZE)  

rew_buffer = deque([0, 0], maxlen=100)  # reward buffer, stores rewards earned by agent during episode (track improvment of agent)

episode_reward = 0.0

In [7]:
online_net = Network(env)
target_net = Network(env)

target_net.load_state_dict(online_net.state_dict())  # Set target parameters to onine net parameters

optimizer = torch.optim.Adam(online_net.parameters(), lr=5e-4)

In [8]:
# Initialize replay buffer
obs = env.reset()
for _ in range(MIN_REPLAY_SIZE):
    action = env.action_space.sample()
    
    new_obs, rew, done, info = env.step(action)
    transition = (obs, action, rew, done, new_obs)
    replay_buffer.append(transition)
    
    if done:
        obs = env.reset()

transition_labels = ['**Observation**: ', '**Action**: ', '**Reward**: ', '**Done** (is episode over): ', '**Info**: ']
_ = [printmd(x + str(y)) for (x, y) in zip(transition_labels, transition)]

**Observation**: [0.01056365 0.02887569 0.00452288 0.02949128]

**Action**: 0

**Reward**: 1.0

**Done** (is episode over): False

**Info**: [0.03143533 0.21146269 0.03802817 0.01330967]

In [9]:
# Main training loop
obs = env.reset()

for step in itertools.count():
    
    # First we must select an action to take in the environment. We are using epsilon greedy
    epsilon = np.interp(step, [0, EPSILON_DECAY], [EPSILON_START, EPSILON_END])
    
    rnd_sample = random.random()
    
    if rnd_sample <= epsilon:   # Explore
        action = env.action_space.sample() 
    else:                       # Exploit - Intelligently select action using network
        action = online_net.act(obs)
    
    new_obs, rew, done, info = env.step(action)
    transition = (obs, action, rew, done, new_obs) # Transition tuple
    replay_buffer.append(transition)
    
    episode_reward += rew
    
    if done:
        obs = env.reset()
        
        rew_buffer.append(episode_reward)
        episode_reward = 0
    
    # Start Gradient Step
    
    # Sample batch size of transitions. Break them out of individual tuples into lists. Turn into tensors
    transitions = random.sample(replay_buffer, BATCH_SIZE)
    obses = np.asarray([t[0] for t in transitions])
    actions = np.asarray([t[1] for t in transitions])
    rews = np.asarray([t[2] for t in transitions])
    dones = np.asarray([t[3] for t in transitions])
    new_obses = np.asarray([t[4] for t in transitions])

    obses_t = torch.as_tensor(obses, dtype=torch.float32)
    actions_t = torch.as_tensor(actions, dtype=torch.int64).unsqueeze(-1) # unsqueeze to put batch dimension on tensors
    rews_t = torch.as_tensor(rews, dtype=torch.float32).unsqueeze(-1)     # we use -1 here to get dimension added at end
    dones_t = torch.as_tensor(dones, dtype=torch.float32).unsqueeze(-1) 
    new_obses_t = torch.as_tensor(new_obses, dtype=torch.float32)
    
    # Compute Targets for loss function 
    
    # For each observation in our tensor, get a prediction of the q value. So each observation will be mapped to a 
    # 2 dimensional vector, the value of being in that state (new_obs) and taking action 0 and taking action 1
    # For each observation we get a set of q values
    # Note: The entire point of Q is to learn the value/quality of being in a given state and taking a certain action
    # So if this has been learned correctly, we can easily know what action to take
    # Note: this is the q value at the next observation!
    target_q_values = target_net(new_obses_t) 
    max_target_q_values = target_q_values.max(dim=1, keepdim=True)[0] # Must collapse down to the highest q value per observation
    
    targets = rews_t + GAMMA * (1 - dones_t) * max_target_q_values # Piece wise function in one step
    
    # Compute Loss and Gradients
    # The online net calculates the q value at the current observation
    q_values = online_net(obses_t)
    # Get the q value for the actual action that we took in the transition
    # This yields the effectively predicted q value of the action we took at the original time of the transition
    action_q_values = torch.gather(input=q_values, dim=1, index=actions_t)

    loss = nn.functional.smooth_l1_loss(action_q_values, targets)
    
    # Gradient Descent
    optimizer.zero_grad()
    loss.backward()  # Compute gradient
    optimizer.step() # Apply gradients
    
    # Update target network (at target update frequency steps)
    if step % TARGET_UPDATE_FREQ:
        target_net.load_state_dict(online_net.state_dict())
    
    # Logging
    if step % 1000 == 0:
        print('Step: ', step)
        print('Avg reward: ', np.mean(rew_buffer))
    
    
    
    
    
    
    
    
    
    # Questions to answer:
    # 1) When does the target net get updated?
    # 2) How does target actually help us? How does it allow us to bootstrap forward?
    

Step:  0
Avg reward:  0.0
Step:  1000
Avg reward:  21.085106382978722
Step:  2000
Avg reward:  22.348314606741575
Step:  3000
Avg reward:  22.78
Step:  4000
Avg reward:  19.34
Step:  5000
Avg reward:  15.42
Step:  6000
Avg reward:  13.03
Step:  7000
Avg reward:  11.79
Step:  8000
Avg reward:  11.24
Step:  9000
Avg reward:  10.65
Step:  10000
Avg reward:  9.7
Step:  11000
Avg reward:  9.41
Step:  12000
Avg reward:  9.47
Step:  13000
Avg reward:  9.47
Step:  14000
Avg reward:  9.44
Step:  15000
Avg reward:  9.47
Step:  16000
Avg reward:  9.46
Step:  17000
Avg reward:  9.42


KeyboardInterrupt: 

In [None]:
obs

In [80]:
online_net.forward(torch.as_tensor(obs, dtype=torch.float32))

tensor([-0.1949, -0.0645], grad_fn=<AddBackward0>)

In [84]:
obs_t = torch.as_tensor(obs, dtype=torch.float32)

In [85]:
online_net.forward(obs_t)

tensor([-0.0539, -0.1741], grad_fn=<AddBackward0>)

In [88]:
obs_t

tensor([ 0.0364, -0.0066,  0.0246, -0.0306])

In [87]:
obs_t.unsqueeze(0)

tensor([[ 0.0364, -0.0066,  0.0246, -0.0306]])

In [90]:
online_net.act(obs)

> [0;32m<ipython-input-81-81b701f5a0fa>[0m(25)[0;36mact[0;34m()[0m
[0;32m     23 [0;31m        [0mbreakpoint[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     24 [0;31m[0;34m[0m[0m
[0m[0;32m---> 25 [0;31m        [0mmax_q_index[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0margmax[0m[0;34m([0m[0mq_values[0m[0;34m,[0m [0mdim[0m[0;34m=[0m[0;36m1[0m[0;34m)[0m[0;34m[[0m[0;36m0[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     26 [0;31m        [0maction[0m [0;34m=[0m [0mmax_q_index[0m[0;34m.[0m[0mdetach[0m[0;34m([0m[0;34m)[0m[0;34m.[0m[0mitem[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     27 [0;31m[0;34m[0m[0m
[0m


ipdb>  q_values


tensor([[-0.0539, -0.1741]], grad_fn=<AddmmBackward0>)


ipdb>  n


> [0;32m<ipython-input-81-81b701f5a0fa>[0m(26)[0;36mact[0;34m()[0m
[0;32m     24 [0;31m[0;34m[0m[0m
[0m[0;32m     25 [0;31m        [0mmax_q_index[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0margmax[0m[0;34m([0m[0mq_values[0m[0;34m,[0m [0mdim[0m[0;34m=[0m[0;36m1[0m[0;34m)[0m[0;34m[[0m[0;36m0[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 26 [0;31m        [0maction[0m [0;34m=[0m [0mmax_q_index[0m[0;34m.[0m[0mdetach[0m[0;34m([0m[0;34m)[0m[0;34m.[0m[0mitem[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     27 [0;31m[0;34m[0m[0m
[0m[0;32m     28 [0;31m        [0;32mreturn[0m [0maction[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  max_q_index


tensor(0)


ipdb>  n


> [0;32m<ipython-input-81-81b701f5a0fa>[0m(28)[0;36mact[0;34m()[0m
[0;32m     26 [0;31m        [0maction[0m [0;34m=[0m [0mmax_q_index[0m[0;34m.[0m[0mdetach[0m[0;34m([0m[0;34m)[0m[0;34m.[0m[0mitem[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     27 [0;31m[0;34m[0m[0m
[0m[0;32m---> 28 [0;31m        [0;32mreturn[0m [0maction[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     29 [0;31m[0;34m[0m[0m
[0m[0;32m     30 [0;31m[0;34m[0m[0m
[0m


ipdb>  action


0


ipdb>  exit


BdbQuit: 