In [19]:
from pyvirtualdisplay import Display
Display(visible=False, size=(1400, 900)).start()

<pyvirtualdisplay.display.Display at 0x75a015023350>

In [20]:
import copy
import gymnasium as gym
import torch
import random

import numpy as np
import torch.nn.functional as F

from collections import deque, namedtuple
from IPython.display import HTML
from base64 import b64encode

from torch import Tensor, nn
from torch.utils.data import DataLoader
from torch.utils.data.dataset import IterableDataset
from torch.optim import AdamW

from pytorch_lightning import LightningModule, Trainer

from pytorch_lightning.callbacks import EarlyStopping

from gymnasium.wrappers import RecordVideo, RecordEpisodeStatistics, TimeLimit


device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
num_gpus = torch.cuda.device_count()



In [21]:
def display_video(episode=0):
  video_file = open(f'./videos/rl-video-episode-{episode}.mp4', "r+b").read()
  video_url = f"data:video/mp4;base64,{b64encode(video_file).decode()}"
  return HTML(f"<video width=600 controls><source src='{video_url}'></video>")

In [22]:
import torch
import torch.nn as nn

class DQN(nn.Module):
    def __init__(self, hidden_size, obs_size, n_actions):
        super(DQN, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(obs_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, n_actions)    
        )


    def forward(self, x):
        return self.net(x.float())



In [23]:
sample_dqn = DQN(128, 4, 2) 
sample_obs = torch.rand(2, 4) # 2 observations, each with 4 features
chosen_action = sample_dqn(sample_obs) # 2 actions, one for each observation
print (sample_dqn)
print (sample_obs.shape)
print (chosen_action.shape)
print (chosen_action)
_, action = torch.max(chosen_action, dim=1)
torch.tensor([int(x.item()) for x in action])

DQN(
  (net): Sequential(
    (0): Linear(in_features=4, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=128, bias=True)
    (3): ReLU()
    (4): Linear(in_features=128, out_features=2, bias=True)
  )
)
torch.Size([2, 4])
torch.Size([2, 2])
tensor([[ 0.0580,  0.0497],
        [-0.0133,  0.0088]], grad_fn=<AddmmBackward0>)


tensor([0, 1])

In [24]:
# Define the ε-greedy policy
def epsilon_greedy(state, env, net, epsilon=0.0):
    if random.random() < epsilon:
        return env.action_space.sample()
    else:
        state = torch.tensor([state]).to(device=device)
        q_values = net(state)
        _, action = torch.max(q_values, dim=1)
        action = torch.tensor([int(x.item()) for x in action])
    return action
    

In [25]:
class ReplaysBuffer:
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)

    def __len__(self):
        return len(self.buffer)
    
    def append(self, experience):
        self.buffer.append(experience)

    def sample(self, batch_size):
        return random.sample(self.buffer, batch_size)
    

In [26]:
from typing import Iterator


class RLDatasest(IterableDataset):

    def __init__(self, buffer, sample_size=200):
        self.buffer = buffer
        self.sample_size = sample_size

    def __iter__(self):
        for experience in self.buffer.sample(self.sample_size):
            yield experience


In [27]:
def create_environment(name):
    env = gym.make(name, render_mode='rgb_array')
    env = TimeLimit(env, max_episode_steps=400)
    env = RecordVideo(env, video_folder='./videos', episode_trigger=lambda x: x % 50 == 0)
    env = RecordEpisodeStatistics(env)
    return env

In [28]:
env = create_environment("LunarLander-v3")
for episode in range(10):
    done = False
    env.reset()
    while not done:
        action = env.action_space.sample()
        next_state, reward, done, truncated, info = env.step(action)
        

  logger.warn(


In [29]:
from pytorch_lightning.utilities.types import TRAIN_DATALOADERS


class DeepQLearning(LightningModule):
    def __init__(self, env_name, policy = epsilon_greedy, capacity = 100000, batch_size=256, lr=1e-3, hidden_size=128, 
                 gamma=0.99, loss_fn=F.smooth_l1_loss, optim=AdamW, epsilon_start=1.0, epsilon_end=0.15, epsilon_last_episode=100,
                 samples_per_epoch = 10000, sync_rate=10):
        super(DeepQLearning, self).__init__()
        self.env = create_environment(env_name)
        self.q_net = DQN(hidden_size, self.env.observation_space.shape[0], self.env.action_space.n)

        self.target_q_net = copy.deepcopy(self.q_net)

        self.policy = policy
        self.buffer = ReplaysBuffer(capacity)

        self.save_hyperparameters()

        while len(self.buffer) < self.hparams.samples_per_epoch:
            print(f"Populating buffer: {len(self.buffer)}/{self.hparams.samples_per_epoch * 100}%")
            self.play_episode(epsilon=self.hparams.epsilon_start)

    @torch.no_grad        
    def play_episode(self, policy=None, epsilon=0.0):
        done = False
        state = self.env.reset()
        while not done:
            if policy:
                action = policy(state, self.env, self.q_net, epsilon)
            else:
                action = self.env.action_space.sample() # raondom action to increase exploration
            next_state, reward, done, truncated, info = self.env.step(action)
            experience = (state, action, reward, done, next_state)
            self.buffer.append(experience)
            state = next_state

    def forward(self, x):
        return self.q_net(x)
    
    def configure_optimizers(self):
        q_net_optimizer = self.hparams.optim(self.q_net.parameters(), lr=self.hparams.lr)
        return [q_net_optimizer]
    

    def train_dataloader(self):
        dataset = RLDatasest(self.buffer, self.hparams.samples_per_epoch)
        dataloader = DataLoader(dataset, batch_size=self.hparams.batch_size)
        return dataloader
    
    def training_step(self, batch, batch_idx):
        states, actions, rewards, dones, next_states = batch
        actions = actions.unsqueeze(1)
        rewards = rewards.unsqueeze(1)
        dones = dones.unsqueeze(1)

        state_action_values = self.q_net(states).gather(1, actions)

        next_action_values, _ = self.target_q_net(next_states).max(dim=1, keepdim=True)[0]
        next_action_values[dones] = 0.0

        expected_state_action_values = rewards + self.hparams.gamma * next_action_values 

        loss = self.hparams.loss_fn(state_action_values, expected_state_action_values)

        self.log("episode/Q_error", loss)
        
        return loss
    
    
    def training_epoch_end(self, training_step_outputs):
        epsilon = max(self.hparams.eps_end, 
                      self.hparams.end_start  - self.current_epoch / self.hparams.epsilon_last_episode)
        
        self.play_episode(policy=self.policy, epsilon=epsilon)
        self.log('episode/Return', self.env.return_queue[-1])

        if self.current_epoch % self.hparams.sync_rate == 0:
            self.target_q_net.load_state_dict(self.q_net.state_dict())
    


In [30]:
! rm -r videos/
! rm -r lightning_logs/

rm: cannot remove 'lightning_logs/': No such file or directory


In [31]:
algo = DeepQLearning("LunarLander-v3")

trainer = Trainer(
    max_epochs = 10000,
    callbacks=[EarlyStopping(monitor="episode/Return", mode="max", patience=500)]
)

Populating buffer: 0/1000000%
Populating buffer: 89/1000000%
Populating buffer: 250/1000000%
Populating buffer: 326/1000000%
Populating buffer: 401/1000000%
Populating buffer: 523/1000000%
Populating buffer: 619/1000000%
Populating buffer: 677/1000000%
Populating buffer: 750/1000000%
Populating buffer: 851/1000000%
Populating buffer: 972/1000000%
Populating buffer: 1053/1000000%
Populating buffer: 1156/1000000%
Populating buffer: 1226/1000000%
Populating buffer: 1299/1000000%
Populating buffer: 1392/1000000%
Populating buffer: 1461/1000000%
Populating buffer: 1562/1000000%
Populating buffer: 1690/1000000%
Populating buffer: 1782/1000000%
Populating buffer: 1910/1000000%
Populating buffer: 1986/1000000%
Populating buffer: 2092/1000000%
Populating buffer: 2160/1000000%
Populating buffer: 2268/1000000%
Populating buffer: 2352/1000000%
Populating buffer: 2452/1000000%
Populating buffer: 2556/1000000%
Populating buffer: 2676/1000000%
Populating buffer: 2773/1000000%
Populating buffer: 2891/

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


Populating buffer: 9684/1000000%
Populating buffer: 9795/1000000%
Populating buffer: 9875/1000000%
Populating buffer: 9960/1000000%
