# Setup

In [None]:
import gymnasium as gym
env = gym.make('MountainCar-v0')
import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
import wandb
from matplotlib import pyplot as plt
import random

In [None]:
from dqnv2 import DQNAgentV2

In [None]:
# hyperparameters
learning_rate = 1e-3
n_episodes = 3_000
start_epsilon = 0.9
final_epsilon = 0.05
epsilon_decay = 0.95
# reduce the exploration over time
batch_size = 128
discount_factor = 0.99
replay_size = 10_000
logging_interval = 1
hidden_size=128
dropout_rate=0.0
weight_decay=1e-4
target_network = True
target_network_update = int(1e4)
alpha = 1.5
seed=42
np.random.seed(seed)    
torch.manual_seed(seed)
random.seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)
amsgrad = True
reward_hidden_size=64
reward_factor=1
predictor_learning_rate=1e-3


In [None]:
agent = DQNAgentV2(
    learning_rate=learning_rate,
    state_size=2,
    action_size=3,
    discount_factor=discount_factor,
    final_epsilon=final_epsilon,
    hidden_size=hidden_size,
    epsilon_decay=epsilon_decay,
    initial_epsilon=start_epsilon,
    replay_size=replay_size,
    dropout_rate=dropout_rate,
    target_network=target_network,
    weight_decay=weight_decay,
    target_network_update=target_network_update,
    alpha=alpha,
    amsgrad=amsgrad,
    reward_hidden_size=reward_hidden_size,
    reward_factor=reward_factor,
    running_window=10,
    predictor_learning_rate=predictor_learning_rate
)

In [None]:
run = wandb.init(project='ANN', config={"learning_rate": learning_rate, "n_episodes": n_episodes, "start_epsilon": start_epsilon, "final_epsilon": final_epsilon, "epsilon_decay": epsilon_decay, "batch_size": batch_size, "discount_factor": discount_factor, "replay_size": replay_size, "hidden_size": hidden_size, "dropout_rate": dropout_rate, "weight_decay":weight_decay, "target_network":target_network, "alpha":alpha,"target_network_update":target_network_update, "reward_factor":reward_factor, "reward_hidden_size":reward_hidden_size, "amsgrad":amsgrad}, name='DQNv2')


In [None]:
env = gym.wrappers.RecordEpisodeStatistics(env, deque_size=n_episodes)
with tqdm(total=n_episodes, desc=f"Episode 0/{n_episodes}") as pbar:
    target_count = 0
    finished = 0
    empty = True
    cumulative_auxiliary_reward = 0
    cumulative_env_reward = 0
    for episode in tqdm(range(n_episodes)):
        obs, info = env.reset()
        done = False
        # play one episode
        t = 0
        episode_auxiliary_reward = 0
        episode_env_reward = 0
        episode_loss = 0
        
        while not done:
            action = agent.get_action(obs, env)
            next_obs, env_reward, terminated, truncated, info = env.step(action)

            # update if the environment is done and the current obs
            done = terminated or truncated

            loss, target_count, RND,intrinsic_loss = agent.update(obs, action, env_reward, next_obs, batch_size=batch_size, target_count=target_count, terminal=terminated)
                
            if loss is not None:
                episode_auxiliary_reward += RND
                episode_env_reward += env_reward
                episode_loss+=loss
            obs = next_obs
            t+=1

        pbar.set_description(f"Episode {episode + 1}/{n_episodes}")
        pbar.set_postfix(train_loss=episode_loss, epsilon=agent.epsilon, target_count=target_count, episode_steps=t, episode_auxiliary_reward=episode_auxiliary_reward, episode_env_reward=episode_env_reward, intrinsic_loss=intrinsic_loss, finished=finished)
        pbar.update(1)
        pbar.refresh() 
        if not empty:
            finished += terminated
            cumulative_auxiliary_reward += episode_auxiliary_reward
            cumulative_env_reward += episode_env_reward

            agent.decay_epsilon()
            wandb.log({"train_loss": episode_loss, "epsilon": agent.epsilon, "episode_steps": t, "finished": finished, "episode_env_reward":episode_env_reward, "episode_aux_reward":episode_auxiliary_reward, "cumulative_env_reward":cumulative_env_reward, "cumulative_aux_reward":cumulative_auxiliary_reward})

                
        if loss is not None:
            empty = False

In [None]:
wandb.finish()