In [1]:
import numpy as np
import gym
import torch

from collections import deque, namedtuple
from IPython.display import HTML
from base64 import b64encode

from torch.utils.data import DataLoader
from torch.utils.data.dataset import IterableDataset
from torch.optim import AdamW

from pytorch_lightning import LightningDataModule, Trainer

from gym.wrappers import RecordVideo, RecordEpisodeStatistics

device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cpu'

In [2]:
from torch import nn

class dqn(nn.Module):
    def __init__(self, observation_dim, hidden_size, action_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(observation_dim, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, action_dim)
        )

    def forward(self, observation):
        return self.net(observation)


In [3]:
def epsilon_policy(
    state, env, net, eplison=0.0, device="cuda" if torch.cuda.is_available() else "cpu"
):
    if np.random.random() < eplison:
        action = env.action_space.sample()
        return action
    else:
        state = torch.tensor([state]).to(device)
        q_values = net(state)
        _, action = torch.max(q_values, dim=1)
        # action = int(action.item())
        return q_values
    

In [4]:
import random

class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)

    def __len__(self):
        return len(self.buffer)

    def append(self, experience):
        self.buffer.append(experience)

    def sample(self, batch_size):
        return random.sample(self.buffer, batch_size)

In [5]:
class RLDataset(torch.utils.data.dataset.IterableDataset):
    def __init__(self, buffer, sample_size=200):
        self.buffer = buffer
        self.sample_size = sample_size

    def __iter__(self):
        for experience in self.buffer.sample(self.sample_size):
            yield experience

In [6]:
def create_env(name):
    env = gym.make(name, render_mode="rgb_array")
    return env
    

In [7]:
# import matplotlib.pyplot as plt
# env = create_env("LunarLander-v2")
# env.reset()
# frame = env.render()
# plt.imshow(frame)

In [None]:
import torch
torch.nn.functional.smooth_l1_loss

In [8]:
import copy
class DeepQLearning(LightningDataModule):
    def __init__(self, 
                 env_name, 
                 policy=epsilon_policy, 
                 capacity=100_000, 
                 batch_size=256,
                 lr=1e-3,
                 hidden_size=128,
                 gamma=0.99,
                 loss_fn=torch.nn.functional.smooth_l1_loss,
                 optimizer=torch.optim.AdamW,
                 eps_start=1.0,
                 eps_end=0.15,
                 eps_last_episode=100,
                 sample_per_epoch=10_000,
                 sync_rate=10):
        super().__init__()
        self.env = create_env(env_name)
        observation_size = self.env.observation_space.shape[0]
        action_dim = self.env.action_space.n
        self.q_net = dqn(hidden_size=hidden_size, observation_dim=observation_size, action_dim=action_dim)
        self.target_q_net = copy.deepcopy(self.q_net)
        self.policy = policy
        self.buffer = ReplayBuffer(capacity=capacity)

        self.save_hyperparameters()

        while len(self.buffer) <= self.hparams.sample_per_epoch:
            print(f"{len(self.buffer)} samples in experience buffer. Filling...")
            self.play_episode(epsilon = self.hparams.eps_start)
            
    @torch.no_grad()
    def play_episode(self, policy=None, epsilon=0.):
        state = self.env.reset()
        done = False

        while not done:
            if policy:
                action = policy(state, self.env, self.q_net, epsilon=epsilon)
            else:
                action = self.env.sample_action()
            next_state, reward, done , _, _= self.env.step(action)
            exp = (state, action, reward, done, next_state)
            self.buffer.append(exp)
            state = next_state

    def forward(self, x):
        return self.q_net(x)
    
    def configure_optimizers(self):
        q_net_optimizer = self.hparams.optimizer(self.q_net.parameters(), lr=self.hparams.lr)
        return [q_net_optimizer]
    
    def train_dataloaders(self):
        dataset = RLDataset(self.buffer, self.hparams.sample_per_epoch)
        dataloader = DataLoader(dataset=dataset,
                                batch_size=self.hparams.batch_size)
        return dataloader
    
    def training_step(self, batch, batch_inx):
        states, actions, rewards, dones, next_states = batch
        actions = actions.unsqueeze(1)
        rewards = rewards.unsqueeze(1)
        dones = dones.unsqueeze(1)

        state_action_values = self.q_net(states).gather(1, actions)
        next_action_values, _ = self.target_q_net(next_states).max(dim=1, keepdim=True)
        next_action_values[dones] = 0.0

        expected_state_action_values = rewards + self.hparams.gamma * next_action_values

        loss = self.hparams.loss_fn(state_action_values, expected_state_action_values)
        self.log('episode/Q_Error', loss)
        return loss
    
    def training_epoch_end(self):
        epsilon = max(self.hparams.eps_end,
                      self.hparams.eps_start - self.current_epoch / self.hparams.eps_last_episodes)
        self.play_episode(policy=self.policy, epsilon=epsilon)
        self.log('episodes/Return', self.env.return_queuel[-1])

        if self.current_epoch % self.hparams.sync_rate == 0:
            self.target_q_net.load_state_dict(self.q_net.state_dict())


In [9]:
# !rm -r /content/lighting_logs/
# !rm -r /contect/videos/
# %load_ext tensorboard
# %tensorboard --logdir /content/lightning_logs/

In [None]:
from pytorch_lightning.callbacks import EarlyStopping

algo = DeepQLearning('LunarLander-v2')
trainer = Trainer(qgus=num_gpus,
                  max_epochs=10_000,
                  callbacks=[EarlyStopping(monitor='episode/Return', mode="max", patient=500)])
Trainer.fit(algo)