In [4]:
import  gym
import math
import random
import matplotlib
import matplotlib.pyplot as plt
from collections import namedtuple, deque
from itertools import count
import numpy as np  

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

In [None]:
env = gym.make("LunarLander-v2", render_mode = "human")
env.reset()
env.render()
NUM_ACTIONS = env.action_space.n

In [None]:
#named tuple to store experience 
Transition = namedtuple('Transition',
                        ('state', 'action', 'next_state', 'reward'))

#Experience replay buffer object 
class ReplayMemory(object):

    def __init__(self, capacity):
        #contains a deque of specified buffer length 
        self.memory = deque([], maxlen=capacity)

    def push(self, *args):
        #add new transition to buffer 
        self.memory.append(Transition(*args))

    #simple sample funtion 
    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

In [None]:
#deep q network definition  
class DQN(nn.Module):

    def __init__(self, n_observations, n_actions):
        super(DQN, self).__init__()
        #observations to 128 
        self.layer1 = nn.Linear(n_observations, 128)
        #linear transformation 
        self.layer2 = nn.Linear(128, 128)
        #linear downsizing to number of possible actions 
        self.layer3 = nn.Linear(128, n_actions)

    def forward(self, x):
        x = F.relu(self.layer1(x))
        x = F.relu(self.layer2(x))
        return self.layer3(x)

In [None]:
def policy_fn(num_actions, e, state, policy_network):
    
    action_probabilities = torch.ones(num_actions).to(state.device) * (e/num_actions)

    with torch.no_grad(): 
        highest_action_value = policy_network(state).max(1)[1].item()

    action_probabilities[highest_action_value] += (1 - e)
    
    return action_probabilities

The following function represents the optimization steps. If our replay buffer is at capacity, we: 
 -sample a batch of transitions from memory
 -Compute the expected value based on the policy network 
 -Compute target values summing the observed reward with expected next state value using the target network 
 -Compute loss between predicted values and target values and perform optimization on policy network  

In [None]:
def optimize_model(gamma, batch_size, memory, policy_network, target_network, loss_fn, optimizer):
    
    if len(memory) < batch_size:
        return
    
    device = policy_network.device
    transitions = memory.sample(batch_size)
    batch = Transition(*zip(*transitions))
    
    non_final_mask = torch.tensor([s is not None for s in batch.next_state],
                                  device=device, dtype=torch.bool)
    non_final_next_states = torch.cat([s for s in batch.next_state if s is not None])
    
    states = torch.cat(batch.state)
    actions = torch.cat(batch.action)
    rewards = torch.cat(batch.reward)
    
    predicted_values = policy_network(states).gather(1, actions) #gathers all max value actions for actions in batch
    next_state_values = torch.zeros(batch_size, device=device)
    
    with torch.no_grad():
        next_state_values[non_final_mask] = target_network(non_final_next_states).max(1)[0]
    
    target_values = (next_state_values * gamma) +rewards # compute bootstrapped target values
    
    loss = loss_fn(predicted_values, target_values.unsqueeze(1))
    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_value_(policy_network.parameters(), 100)
    optimizer.step()

In [None]:
def trainer(num_episodes, batch_size, target_network, policy_network, gamma, epsilon, device, replay_buffer, optimizer, tau): 
    
    for episode in range(num_episodes): 
        
        state, _ = env.reset()
        state = torch.tensor(state, dtype=torch.float32, device = device).unsqueeze(0)
        episode_end = True
        
        while not episode_end: 
            
            action_probs = policy_fn(NUM_ACTIONS, epsilon, state, policy_network)
            action = np.random.choice(np.arange(NUM_ACTIONS), p = action_probs[state])
            
            next_state, reward, terminated, truncated, _ = env.step(action)
            episode_end = terminated or truncated 
            
            if not terminated: 
                next_state = torch.tensor(next_state, dtype=torch.float32, \
                    device = device).unsqueeze(0)
            
            replay_buffer.push(state, action, next_state, reward)
            state = next_state
            
            optimize_model(gamma, batch_size, replay_buffer, policy_network, target_network, optimizer)
            
            #extract network parameters for both networks 
            target_network_state_dict = target_network.state_dict()
            policy_network_state_dict = policy_network.state_dict()
            
            for key in policy_network_state_dict.keys():
                # θ′ ← τ θ + (1 −τ )θ′
                target_network_state_dict[key] = policy_network_state_dict[key]*tau + target_network_state_dict[key]*(1-tau)
                target_network.load_state_dict(target_network_state_dict)
            
            