In [69]:
import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from typing import Sequence
from collections import namedtuple, deque
import itertools
import random                       
import warnings
import time
import os
warnings.filterwarnings("ignore")

In [70]:
GAMMA = 0.99
BATCH_SIZE = 128
BUFFER_SIZE = 10000                                                             
MIN_REPLAY_SIZE = 5000
EPS_START = 1.0
EPS_END = 0.05
EPS_DECAY = 0.995
TARGET_UPDATE_FREQ = 5

In [71]:
env = gym.make("LunarLander-v2")
obs, info = env.reset()
episode_reward = 0.0

In [72]:
Transition = namedtuple('Transition', ('states', 'actions', 'rewards', 'dones', 'next_states'))

class Replay_memory():

    def __init__(self, env, fullsize, minsize, batchsize):
        self.env = env
        self.memory = deque(maxlen=fullsize)
        self.rewards = deque(maxlen=50)
        self.batchsize = batchsize
        self.minsize = minsize

    def append(self, transition):
        self.memory.append(transition)
    
    def sample_batch(self):
        batch = random.sample(self.memory, self.batchsize)
        batch = Transition(*zip(*batch))
        
        states = torch.from_numpy(np.array(batch.states, dtype=np.float32))
        actions = torch.from_numpy(np.array(batch.actions, dtype=np.int64)).unsqueeze(1)
        rewards = torch.from_numpy(np.array(batch.rewards, dtype=np.float32)).unsqueeze(1)
        dones = torch.from_numpy(np.array(batch.dones, dtype=np.bool8)).unsqueeze(1)
        next_states = torch.from_numpy(np.array(batch.next_states, dtype=np.float32))
        
        return states, actions, rewards, dones, next_states

    def initialize(self):
        obs, info = env.reset()
        for _ in range(self.minsize):
            action = self.env.action_space.sample()
            new_obs, reward, done, info = env.step(action)[:4]
            transition = Transition(obs, action, reward, done, new_obs)
            self.append(transition)
            obs = new_obs
            if done:
                self.env.reset()
        return self

In [73]:
replay_memory = Replay_memory(env, BUFFER_SIZE, MIN_REPLAY_SIZE, BATCH_SIZE).initialize()

In [74]:
class DQN(nn.Module):
    def __init__(self, ninputs, noutputs):
        super(DQN, self).__init__()
        self.a1 = nn.Linear(ninputs, 128)
        self.a2 = nn.Linear(128, noutputs)
    
    def forward(self, X):
        o = self.a1(X)
        o = F.relu(o)
        o = self.a2(o)
        return o
    
    def __call__(self, X):
        return self.forward(X)

In [75]:
dqn_policy = DQN(env.observation_space.shape[0], env.action_space.n)
dqn_target = DQN(env.observation_space.shape[0], env.action_space.n)
dqn_target.load_state_dict(dqn_policy.state_dict())
dqn_target.eval()

DQN(
  (a1): Linear(in_features=8, out_features=128, bias=True)
  (a2): Linear(in_features=128, out_features=4, bias=True)
)

In [76]:
loss_fn = nn.SmoothL1Loss()
optimizer = torch.optim.Adam(dqn_policy.parameters(), lr=0.0001)

In [77]:
def epsilon_greedy_policy(epsilon, obs):
    rnd_sample = random.random()
    if rnd_sample <= epsilon:
        action = env.action_space.sample()
    else:
        with torch.no_grad():
            action = int(torch.argmax(dqn_policy(torch.Tensor(obs))))
    return action

In [78]:
obs, info = env.reset()
eps_threshold =EPS_START
episode = 1
scores=[]

for step in itertools.count():
    action = epsilon_greedy_policy(eps_threshold, obs)
    new_obs, reward, terminated, truncated, _ = env.step(action)
    replay_memory.append(Transition(obs, action, reward, terminated, new_obs))
    episode_reward += reward
    obs = new_obs
    
    if terminated or truncated:
        
        scores.append(episode_reward)
        episode += 1
        eps_threshold = np.max((eps_threshold*EPS_DECAY, EPS_END))
        replay_memory.rewards.append(episode_reward)
        obs, info = env.reset()
        avg_res = np.mean(replay_memory.rewards)

        if episode % 50 == 0: 
            avg_res = np.mean(replay_memory.rewards)
            print(f'Episode: {episode} Avg Results: {avg_res}')

        if avg_res >= 195:
            print(f'Solved at episode: {episode} Avg Results: {avg_res}')
            break
        
        if step % TARGET_UPDATE_FREQ == 0:
            dqn_target.load_state_dict(dqn_policy.state_dict())
            

        episode_reward = 0

    b_states, b_actions, b_rewards, b_dones, b_next_states = replay_memory.sample_batch()

    qvalues = dqn_policy(b_states).gather(1, b_actions)
    
    with torch.no_grad():
        target_qvalues = dqn_target(b_next_states)
        max_target_qvalues = torch.max(target_qvalues, axis=1).values.unsqueeze(1)
        expected_qvalues = b_rewards + GAMMA * (1 - b_dones.type(torch.int64)) * max_target_qvalues

    loss = loss_fn(qvalues, expected_qvalues)
    optimizer.zero_grad()
    loss.backward()
    for param in dqn_policy.parameters():
        param.grad.data.clamp_(-1, 1)
    optimizer.step()

Episode: 50 Avg Results: -181.10225480887394
Episode: 100 Avg Results: -98.13171124136662
Episode: 150 Avg Results: -56.4291410471585
Episode: 200 Avg Results: 0.7975081113884019
Episode: 250 Avg Results: 50.135902490327624
Episode: 300 Avg Results: 40.607193097348976
Episode: 350 Avg Results: 100.7029911293888
Episode: 400 Avg Results: 93.39247421299002
Episode: 450 Avg Results: 38.8051583060367
Episode: 500 Avg Results: 81.80921048452635
Episode: 550 Avg Results: 40.95476317372967
Episode: 600 Avg Results: -15.77165513890726
Episode: 650 Avg Results: -15.868668570174883
Episode: 700 Avg Results: 16.898228188856375
Episode: 750 Avg Results: -1.8669504556541543
Episode: 800 Avg Results: -15.461618840023334
Episode: 850 Avg Results: 7.235242425656471
Episode: 900 Avg Results: 68.92470105386754
Episode: 950 Avg Results: 21.036322939229567
Episode: 1000 Avg Results: 95.0337940761292
Episode: 1050 Avg Results: 93.44207339524863
Episode: 1100 Avg Results: 112.54930317344629
Episode: 1150 Av

In [1]:
fig = plt.figure(figsize=(9, 9))
ax = fig.add_subplot(111)
plt.plot(np.arange(len(scores)), scores,  marker='.')
plt.title('Vanilla DQN (Standard DQN)')
plt.ylabel('Rewards')
plt.xlabel('Episodes')
plt.show()

print(avg_res)

NameError: name 'plt' is not defined