In [None]:
import torch
from data_logger import DataLogger
from dqn_agent import Agent
from utils import *
from tqdm import tqdm
print(torch.cuda.is_available())

In [None]:
#Loading parameters
ep, lp, hp = get_params() #params["environment"], params["logging"], params["hyperparameters"]
hp_algo = hp['dqn/ddqn']

In [None]:
env = get_env(game=ep["game"], level=ep["level"], action_space=ep["action_space"])
env = apply_wrappers(env, skip=ep["skip"], gray_scale=ep["gray_scale"], shape=ep["frame_shape"], num_stack=ep["num_stack"], buffer=ep["buffer"])
next_state = env.reset()
print(next_state.shape)

In [None]:
#from utils import *
next_state, reward, done, info = env.step(get_action_sample(env))
plot_sequence_observations(next_state)

In [None]:
agent = Agent(env, batch_size=hp_algo["batch_size"], buffer_size=hp_algo["buffer_size"], learning_rate=hp_algo["learning_rate"], observation_size=(ep["num_stack"],)+tuple(ep["frame_shape"]), action_size=get_action_space_size(env), 
              discount_factor=hp_algo["discount_factor"], model=hp_algo["model"], epsilon_min=hp_algo["exploration_min"], epsilon_decay=hp_algo["exploration_decay"], update_online_every=hp_algo["train_freq"],
              update_target_from_online_every=hp_algo["target_update_interval"], start_learning_after=hp_algo["learning_starts"])
logger = DataLogger(env, hp, agent, model=hp_algo["model"])

for episode in tqdm(range(hp_algo["num_episodes"])):
    state = env.reset()

    while True:
        action_idx = agent.act(state)

        next_state, reward, done, info = env.step(get_action(action_idx, env))

        agent.append_experience_to_memory(state, next_state, action_idx, reward, done)
        
        loss = agent.learn()

        state = next_state

        env.render()

        if done or info["flag_get"]:
            break

        logger.time_step(reward, loss, n_time_steps_save_model=lp["n_time_steps_save_model"])

    logger.episode_step(info)

logger.close()

In [None]:
env.close()

In [None]:
#Testing trained model
env = get_env(game=ep["game"], level=ep["level"], action_space=ep["action_space"])
env = apply_wrappers(env, skip=ep["skip"], gray_scale=ep["gray_scale"], shape=ep["frame_shape"], num_stack=ep["num_stack"], buffer=ep["buffer"])

agent = Agent(env, batch_size=hp_algo["batch_size"], buffer_size=hp_algo["buffer_size"], learning_rate=hp_algo["learning_rate"], observation_size=(ep["num_stack"],)+tuple(ep["frame_shape"]), action_size=get_action_space_size(env), 
              discount_factor=hp_algo["discount_factor"], model=hp_algo["model"], epsilon_min=hp_algo["exploration_min"], epsilon_decay=hp_algo["exploration_decay"], update_online_every=hp_algo["train_freq"],
              update_target_from_online_every=hp_algo["target_update_interval"], start_learning_after=hp_algo["learning_starts"])
agent.load(r"logs/SuperMarioBros-1-1-v0/DDQN/20240123234603/checkpoints/best_model.pth")

In [None]:
try:
    while True:
        time_stamp=0
        state = env.reset()
        while time_stamp<1e4:
            action_idx = agent.act(state)

            next_state, reward, done, info = env.step(get_action(action_idx, env))

            state = next_state

            env.render()
            
            time_stamp += 1

            if done or info["flag_get"]:
                break
except:
    env.close()