# DQN with Pytorch_Lightning

Note: Use pytorch_lighning==1.6.0

In [1]:
import copy
import gym
import torch
import random

import numpy as np
import torch.nn.functional as F

from collections import deque, namedtuple
from IPython.display import HTML
from base64 import b64encode

from torch import Tensor, nn
from torch.utils.data import DataLoader
from torch.utils.data.dataset import IterableDataset
from torch.optim import AdamW

from pytorch_lightning import LightningModule, Trainer

from pytorch_lightning.callbacks import EarlyStopping

from gym.wrappers import RecordVideo, RecordEpisodeStatistics, TimeLimit
import matplotlib.pyplot as plt
%matplotlib inline

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
num_gpus = torch.cuda.device_count()
print(f"num gpus : {num_gpus}")

num gpus : 1


In [2]:
## Creating Deep Q Network:
class DQN(nn.Module):
    def __init__ (self , hidden_size , obs_size , action_size):
        super().__init__()
        self.net = nn.Sequential(
        nn.Linear(obs_size , hidden_size),
        nn.ReLU(),
        nn.Linear(hidden_size , hidden_size),
        nn.ReLU(),
        nn.Linear(hidden_size , action_size),
        )
    def forward(self , x):
        return self.net(x.float())

In [3]:
## Creating Policy: state -> action or action_probs
def epsilon_greedy(state , env , net , epsilon=0):
    if np.random.random() < epsilon :
        action = env.action_space.sample()
        
    else:
        state = torch.tensor([state]).to(device)
        q_values = net(state)
        _ , action = torch.max(q_values , dim=1) # returns (value , idx)
        action = int(action.item())
    
    return action

In [4]:
## Creating Replay Buffer:
class ReplayBuffer:
    
    def __init__(self , capacity):
        self.buffer = deque(maxlen=capacity) #    it's like a list but manages its contents automaticly
        
    def __len__(self):
        return len(self.buffer)
    
    def append(self, experience):
        self.buffer.append(experience)
    
    def sample(self , batch_size):
        return random.sample(self.buffer , batch_size)
    

In [5]:
class RLDataset(IterableDataset):
    
    def __init__ (self , buffer , sample_size = 200):
        self.buffer = buffer
        self.sample_size = sample_size
        
    def __iter__(self):
        for experience in self.buffer.sample(self.sample_size):
            yield experience # returns by request of pytorch
     

In [6]:
## Creating Environment
def create_environment(name):
    env = gym.make(name)
    env = TimeLimit(env , max_episode_steps = 400)  #terminates after 400 steps
    env = RecordVideo(env , video_folder = './videos' , episode_trigger=lambda x: x%50==0 )
    env = RecordEpisodeStatistics(env)
    return env

In [7]:
env = create_environment('LunarLander-v2')
env.reset()

  logger.warn(


array([ 0.00352249,  1.4044876 ,  0.35678267, -0.28589943, -0.00407496,
       -0.08081645,  0.        ,  0.        ], dtype=float32)

In [8]:
print(env.observation_space)
print(env.action_space)

Box([-inf -inf -inf -inf -inf -inf -inf -inf], [inf inf inf inf inf inf inf inf], (8,), float32)
Discrete(4)


In [9]:
env = create_environment('LunarLander-v2')
# for e in range(10):
#     done = False
#     env.reset()
#     while not done:
#         action = env.action_space.sample()
#         _ , _ , done , _ = env.step(action)
# env.close()

In [10]:
class DeepQLearning(LightningModule):
    
    # intialize
    def __init__(self , env_name , policy=epsilon_greedy , capacity=100_000 , batch_size=1024 , lr = 0.001 ,
                 hidden_size=128 , gamma=0.99, loss_fn = F.smooth_l1_loss , optim = AdamW ,
                eps_start = 1.0 , eps_end = 0.15 , eps_last_episode=100 , samples_per_epoch=10_000 ,
                sync_rate=10):
        super().__init__()
        self.env = create_environment(env_name)
        obs_size=self.env.observation_space.shape[0]
        action_size = self.env.action_space.n
        self.q_net=DQN(hidden_size , obs_size , action_size)
        self.target_q_net=copy.deepcopy(self.q_net)
        self.policy = policy
        self.buffer = ReplayBuffer(capacity=capacity)
        
        self.save_hyperparameters()
        
        while len(self.buffer)  < self.hparams.samples_per_epoch:
            self.play_episode(epsilon=self.hparams.eps_start)
    
    @torch.no_grad()
    def play_episode(self ,policy=None ,  epsilon =0):
        state = self.env.reset()
        done = False
        while not done :
            
            if policy:
                action = policy(state , self.env , self.q_net , epsilon = epsilon)
            else:
                action = self.env.action_space.sample()
            next_state , reward , done , _ = self.env.step(action)
            exp = (state , action , reward , done , next_state)
            self.buffer.append(exp)
            state = next_state
            
        
        
    # forward
    def forward(self , x):
        return self.q_net(x)
    
    
    # configure optimizers
    def configure_optimizers(self):
        q_net_optimizer = self.hparams.optim(self.q_net.parameters() , lr = self.hparams.lr)
        return [q_net_optimizer]
    
    
    # create dataloader
    def train_dataloader(self):
        dataset = RLDataset(self.buffer , self.hparams.samples_per_epoch)
        dataloader = DataLoader(dataset=dataset ,batch_size=self.hparams.batch_size )
        return dataloader
    
    
    # training step
    def training_step(self , batch , batch_idx):
        states , actions , rewards , dones , next_states = batch
        actions = actions.unsqueeze(1)
        rewards = rewards.unsqueeze(1)
        dones = dones.unsqueeze(1)
        state_action_values = self.q_net(states).gather(1,actions)
        next_action_values , _ = self.target_q_net(next_states).max(dim=1 , keepdim=True)
        expected_state_action_values = rewards + self.hparams.gamma * next_action_values * (torch.logical_not(dones))
        loss = self.hparams.loss_fn(state_action_values , expected_state_action_values )
        self.log('episode/Q-error' , loss)
        return loss
        
        
    
    # training epoch end
    def training_epoch_end(self, training_step_outputs):
        epsilon = max(self.hparams.eps_end , self.hparams.eps_start - self.current_epoch/self.hparams.eps_last_episode)
        self.play_episode(policy=self.policy , epsilon=epsilon)
        self.log('episode/Return' , self.env.return_queue[-1])
        
        if self.current_epoch % self.hparams.sync_rate == 0:
            self.target_q_net.load_state_dict(self.q_net.state_dict())
        

In [15]:
#!rm -r lightning_logs/
#!rm -r videos/
%load_ext tensorboard
%tensorboard --logdir lightning_logs/

rm: cannot remove 'lightning_logs/': No such file or directory


The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


rm: cannot remove 'videos/': No such file or directory


Reusing TensorBoard on port 6006 (pid 13104), started 0:00:48 ago. (Use '!kill 13104' to kill it.)

In [16]:
algo = DeepQLearning('LunarLander-v2')
trainer = Trainer ( gpus = num_gpus , max_epochs = 10_000 , callbacks=[EarlyStopping(monitor = 'episode/Return' , mode='max',patience = 500)] )
trainer.fit(algo)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: C:\Users\Ali\Documents\RLwithPhil\code\lightning_logs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name         | Type | Params
--------------------------------------
0 | q_net        | DQN  | 18.2 K
1 | target_q_net | DQN  | 18.2 K
--------------------------------------
36.4 K    Trainable params
0         Non-trainable params
36.4 K    Total params
0.145     Total estimated model params size (MB)
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Exception ignored in: <function Viewer.__del__ at 0x0000022E489133A0>
Traceback (most recent call last):
  File "C:\ProgramData\Anaconda3\envs\vrep\lib\site-packages\gym\envs\classic_control\rendering.py", line 185, in __del__
    self.close()
  File "C:\ProgramData\Anaconda3\envs\vrep\lib\site-packages\gym\envs\classic_control\rendering.py", line 101, in close
    self.window.close()
  File "C:\ProgramData\Anaconda3\envs\vrep\lib\site-packages\pyglet\window\win32\__init__.py", line 332, in close
    super(Win32Window, self).close()
  File "C:\ProgramData\Anaconda3\envs\vrep\lib\site-packages\pyglet\window\__init__.py", line 858, in close
    app.windows.remove(self)
  File "C:\ProgramData\Anaconda3\envs\vrep\lib\_weakrefset.py", line 114, in remove
    self.data.remove(ref(item))
KeyError: <weakref at 0x0000022E5AC489F0; to 'Win32Window' at 0x0000022E427B29D0>
  state = torch.tensor([state]).to(device)
