# Import

In [None]:
import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'

from nes_py.wrappers import JoypadSpace
import gym
import gym_super_mario_bros
from gym.wrappers import FrameStack, GrayScaleObservation, ResizeObservation, TransformObservation, Monitor
import numpy as np
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Import own Functions
from src.helper_functions.create_Plot import plot_results
from src.helper_functions.create_Agent import MarioAgentEpsilonGreedy

# Set Hyperparameters

In [None]:
action_space = [
    ['NOOP'],
    ['A'],
    ['B'], 
    ['right'],
    ['left'],
    ['right', 'A'],
    ['right', 'B'],
    ['right', 'A', 'B']
]
buffer_size = 25000
batch_size = 64
learning_rate = 0.00009

stacking_number = 10
# skipping_number = 4    # Not implemented

online_update_every = 3
exp_before_target_sync = 5000

epsilon_start = 1.0
epsilon_min = 0.01
epsilon_decay = 0.001
gamma = 0.99
num_episodes = 1000

plot_every = 25
save_every = 50

# Initialize Environment and other variables

In [None]:
root_folder = os.path.join("..")
vid_folder = os.path.join("res", "training_v2", "all_videos")
exp_before_training = batch_size + 5

env = gym_super_mario_bros.make('SuperMarioBros-1-1-v0')
env = JoypadSpace(env, action_space)
env = Monitor(env, vid_folder, video_callable=lambda episode_id: True, force=True)
# env = SkipFrame(env, skip=skipping_number) # Not implemented
env = GrayScaleObservation(env, keep_dim=False)
env = ResizeObservation(env, shape=84)
env = TransformObservation(env, lambda obs: np.squeeze(obs, axis=-1))
env = TransformObservation(env, f=lambda x: x / 255.)
env = FrameStack(env, num_stack=stacking_number)

state = env.reset()
state_shape = state.shape
model_folder = os.path.join("models")
if not os.path.exists(model_folder):
    os.makedirs(model_folder)
checkpoint_folder = os.path.join(model_folder, "training_v2", "checkpoints")
if not os.path.exists(checkpoint_folder):
    os.makedirs(checkpoint_folder)
starting_point = None #os.path.join(checkpoint_folder, "model_ep850.pth") 
plot_folder = os.path.join("res", "training_v2", "plots")
if not os.path.exists(plot_folder):
    os.makedirs(plot_folder)

mario = MarioAgentEpsilonGreedy(num_actions=len(action_space), state_shape=state_shape, checkpoint_folder=checkpoint_folder, model_folder=model_folder, wantcuda=True, starting_point=starting_point, learning_rate=learning_rate, epsilon_start=epsilon_start, epsilon_min=epsilon_min, epsilon_decay=epsilon_decay, batch_size=32, gamma=gamma, buffer_size=buffer_size, exp_before_training=exp_before_training, online_update_every=online_update_every, exp_before_target_sync=exp_before_target_sync, save_every=save_every)

reward_list = []
steps_list = []
q_list = []
loss_list = []
epsilon_list = []

# Start the training

In [None]:
for episode in range(1, num_episodes+1):
    state = env.reset()
    total_reward = 0
    steps = 0
    mean_episode_q = []
    mean_episode_loss = []
    while True:
        #env.render() # Visualize
        action = mario.selectAction(state)
        next_state, reward, resetnow, info = env.step(action)
        mario.saveExp(state, action, next_state, reward, resetnow)
        q, loss = mario.learn_get_TDest_loss()
        state = next_state
        total_reward = total_reward + reward
        steps = steps + 1
        mean_episode_q.append(q)
        mean_episode_loss.append(loss)
        if resetnow or info['flag_get']:
            break
    print(f"Episode {episode} abgeschlossen mit {steps} Schritten, Gesamtbelohnung: {total_reward}, Epsilon: {mario.epsilon}\n\n")
    
    reward_list.append(total_reward)
    steps_list.append(steps)
    q_list.append(np.mean(mean_episode_q))
    loss_list.append(np.mean(mean_episode_loss))
    epsilon_list.append(mario.epsilon)

    if episode % plot_every == 0:
        plot_results(reward_list, steps_list, q_list, loss_list, epsilon_list, os.path.join(plot_folder, f"plot_{episode}.png"))
    
    if episode % save_every == 0:
        torch.save(dict(model=mario.model.state_dict(), optimizer=mario.optimizer.state_dict(), epsilon=mario.epsilon), os.path.join(checkpoint_folder, f"model_ep{episode}.pth"))
    
    mario.decayEpsilon(strat="lin")

torch.save(dict(model=mario.model.state_dict(), optimizer=mario.optimizer.state_dict(), epsilon=mario.epsilon), os.path.join(model_folder, f"final_model.pth"))

env.close()