In [4]:
import  gym
import math
import random
import matplotlib
import matplotlib.pyplot as plt
from collections import namedtuple, deque
from itertools import count
import numpy as np  

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

In [None]:
env = gym.make("LunarLander-v2", render_mode = "human")
env.reset()
env.render()

In [None]:
#named tuple to store experience 
Transition = namedtuple('Transition',
                        ('state', 'action', 'next_state', 'reward'))

#Experience replay buffer object 
class ReplayMemory(object):

    def __init__(self, capacity):
        #contains a deque of specified buffer length 
        self.memory = deque([], maxlen=capacity)

    def push(self, *args):
        #add new transition to buffer 
        self.memory.append(Transition(*args))

    #simple sample funtion 
    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

In [None]:
#deep q network definition  
class DQN(nn.Module):

    def __init__(self, n_observations, n_actions):
        super(DQN, self).__init__()
        #observations to 128 
        self.layer1 = nn.Linear(n_observations, 128)
        #linear transformation 
        self.layer2 = nn.Linear(128, 128)
        #linear downsizing to number of possible actions 
        self.layer3 = nn.Linear(128, n_actions)

    def forward(self, x):
        x = F.relu(self.layer1(x))
        x = F.relu(self.layer2(x))
        return self.layer3(x)

In [None]:
def policy_fn(num_actions, e, state, policy_network):
    
    action_probabilities = torch.ones(num_actions).to(state.device) * (e/num_actions)

    with torch.no_grad(): 
        highest_action_value = policy_network(state).max(1)[1].item()

    action_probabilities[highest_action_value] += (1 - e)
    
    return action_probabilities

In [None]:
def optimize_model(batch_size, memory, policy_network): 
    if len(memory) < batch_size:
        return
    transitions = memory.sample(batch_size)