<a href="https://colab.research.google.com/github/threewisemonkeys-as/genrl/blob/master/examples/DQN_demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Examples with DQN from GenRL

## Setup

In [4]:
!git clone https://github.com/SforAiDl/genrl
!pip install -e genrl

fatal: destination path 'genrl' already exists and is not an empty directory.
Obtaining file:///content/genrl
Installing collected packages: genrl
  Found existing installation: genrl 0.0.1
    Can't uninstall 'genrl'. No files were found to uninstall.
  Running setup.py develop for genrl
Successfully installed genrl


In [None]:
import torch

from genrl.agents import DQN
from genrl.agents.deep.dqn.utils import ddqn_q_target, prioritized_q_loss
from genrl.environments import VectorEnv
from genrl.trainers import OffPolicyTrainer, OnPolicyTrainer

## Training Vanilla DQN on CartPole 

In [None]:
env = VectorEnv("CartPole-v0")
agent = DQN("mlp", env)
trainer = OffPolicyTrainer(agent, env, max_timesteps=20000)
trainer.train()
trainer.evaluate()

timestep         Episode          value_loss       epsilon          Episode Reward   
44               0.0              0                0.9785           0                
240              10.0             0                0.8695           20.1             
450              20.0             0                0.7117           23.0             
734              30.0             0                0.559            28.2             
930              40.0             0                0.4411           19.8             
1120             50.0             0.3117           0.3654           19.4             
1226             60.0             0.5066           0.3162           10.6             
1466             70.0             0.9103           0.268            14.4             
3290             80.0             2.2416           0.115            156.9            
5290             90.0             9.0437           0.0259           200.0            
7288             100.0            21.788           0.0

## Extending DQN to Double DQN

In [None]:
class DoubleDQN(DQN):
    def __init__(self, *args, **kwargs):
        super(DoubleDQN, self).__init__(*args, **kwargs)
        self._create_model()

    def get_target_q_values(self, next_states, rewards, dones):
        next_q_value_dist = self.model(next_states)
        next_best_actions = torch.argmax(next_q_value_dist, dim=-1).unsqueeze(-1)
        rewards, dones = rewards.unsqueeze(-1), dones.unsqueeze(-1)
        next_q_target_value_dist = self.target_model(next_states)
        max_next_q_target_values = next_q_target_value_dist.gather(2, next_best_actions)
        target_q_values = rewards + agent.gamma * torch.mul(
            max_next_q_target_values, (1 - dones)
        )
        return target_q_values

In [None]:
env = VectorEnv("CartPole-v0")
agent = DoubleDQN("mlp", env)
trainer = OffPolicyTrainer(agent, env, max_timesteps=20000)
trainer.train()
trainer.evaluate()

timestep         Episode          value_loss       epsilon          Episode Reward   
26               0.0              0                0.9872           0                
238              10.0             0                0.8783           19.6             
404              20.0             0                0.7283           19.1             
644              30.0             0                0.597            20.1             
842              40.0             0                0.4812           23.1             
1054             50.0             0.3035           0.394            16.5             
1158             60.0             0.4897           0.3374           16.4             
1288             70.0             0.7634           0.3013           12.8             
2686             80.0             2.1095           0.1569           101.9            
4610             90.0             9.0553           0.0399           197.3            
6372             100.0            19.4667          0.0

## Extending DQN to Duelling DQN

In [None]:
class DuelingDQN(DQN):
    def __init__(self, *args, buffer_type="push", **kwargs):
        super(DuelingDQN, self).__init__(*args, buffer_type=buffer_type, **kwargs)
        self.dqn_type = "dueling"  # can be "noisy" for NoisyDQN
        self._create_model()

    def get_target_q_values(self, *args):
        return ddqn_q_target(self, *args)
    
    # Prioritized Loss function needs to be imported only if buffer_type is set as prioritized
    def get_q_loss(self, *args):
        return prioritized_q_loss(self, *args)

In [None]:
env = VectorEnv("CartPole-v0")
agent = DuelingDQN("mlp", env, buffer_type="prioritized")
trainer = OffPolicyTrainer(agent, env, max_timesteps=20000)
trainer.train()
trainer.evaluate()

timestep         Episode          value_loss       epsilon          Episode Reward   
24               0.0              0                0.9882           0                
182              10.0             0                0.9031           16.4             
392              20.0             0                0.7536           20.8             
568              30.0             0                0.6228           17.7             
764              40.0             0                0.5189           18.8             
1026             50.0             0.6067           0.4153           26.6             
1172             60.0             0.4115           0.3398           14.5             
1282             70.0             0.3327           0.3001           11.8             
1524             80.0             0.3231           0.2538           18.3             
2684             90.0             0.2799           0.1375           84.5             
3830             100.0            0.5504           0.0