# **1. Setup Mario Game**

In [None]:
# Basic python imports
import os
from pathlib import Path

# Ignore warnings
import warnings
warnings.filterwarnings('ignore')

# Imports for game
import gym_super_mario_bros

# Import for joystick wrapper
from nes_py.wrappers import JoypadSpace

# Import for simplified controls
from gym_super_mario_bros.actions import SIMPLE_MOVEMENT

In [None]:
# Fix for JoyPadSpace reset
JoypadSpace.reset = lambda self, **kwargs: self.env.reset(**kwargs)

In [None]:
# View the simplified actions
print(SIMPLE_MOVEMENT)

In [None]:
# Setup the game
env = gym_super_mario_bros.make('SuperMarioBros-v0')

# Add the wrapper
env = JoypadSpace(env, SIMPLE_MOVEMENT)

In [None]:
# View the output shape of observation space
print(env.observation_space.shape)

In [None]:
# Print the action space
print(env.action_space)

In [None]:
# # Variable to track if the game is done
# done = True

# # Loop for 10_000 iterations
# for step in range(10_000):
#     # If the game is done
#     if done:
#         # Reset the environment
#         state = env.reset()

#     # Render the environment
#     env.render()

#     # Take a random action
#     state, reward, done, info = env.step(env.action_space.sample())

# # Close the environment
# env.close()

# **2. Preprocess Environment**

In [None]:
# Import for Stacker Wrapper and Grayscale Wrapper
from gym.wrappers import GrayScaleObservation

# Import Vectorization Wrapper
from stable_baselines3.common.vec_env import VecFrameStack, DummyVecEnv

# Import matplotlib for rendering
import matplotlib.pyplot as plt

In [None]:
# 1. Create the base environment
env = gym_super_mario_bros.make('SuperMarioBros-v0')

# 2. Simplify the controls
env = JoypadSpace(env, SIMPLE_MOVEMENT)

# 3. GrayScale Observation Wrapper
env = GrayScaleObservation(env, keep_dim=True)

# 4. Wrap the environment in a vectorized environment
env = DummyVecEnv([lambda: env])

# 5. Stack the frames
env = VecFrameStack(env, n_stack=4, channels_order="last")

In [None]:
# Reset the environment
state = env.reset()

In [None]:
# NOTE: Run this cell multiple times to see how the environment changes

# Step through the environment
state, reward, done, info = env.step([env.action_space.sample()])

# Plot the stacked frames
plt.figure(figsize=(10, 10))
for idx in range(state.shape[3]):
    plt.subplot(1, state.shape[3], idx + 1)
    plt.imshow(state[0, :, :, idx])
    plt.axis('off')
plt.show()

# **3. Train the Reinforcement Learning Model**

In [None]:
# Import PPO algorithm for training
from stable_baselines3 import PPO

# Import BaseCallback class for saving models
from stable_baselines3.common.callbacks import BaseCallback

In [None]:
# Import pytorch
import torch

# Check if GPU is available and set the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Print the device
print(device)

In [None]:
# Class for training and logging callback
class TrainAndLoggingCallback(BaseCallback):
    # Constructor
    def __init__(self, check_freq, save_path, verbose = 1):
        # Call the constructor of the base class
        super(TrainAndLoggingCallback, self).__init__(verbose)
        
        # Add the arguments to the class members
        self.check_freq = check_freq
        self.save_path = save_path
        
    # Method to initialize the callback
    def _init_callback(self):
        # If the save path exists
        if self.save_path is not None:
            # Create the folder
            os.makedirs(self.save_path, exist_ok=True)
            
    # Method to save the model
    def _on_step(self):
        # Check if the current step is a multiple of check_freq
        if self.n_calls % self.check_freq == 0:
            # Save the model
            self.model.save(os.path.join(self.save_path, f"model_{self.n_calls}.zip"))
            
        # Return True to continue training
        return True

In [None]:
# Initialize paths for directories
CHECKPOINT_PATH = Path("./train/")
LOG_DIR = Path("./logs/")

In [None]:
# Initialize the callback
callback = TrainAndLoggingCallback(check_freq=50_000, save_path=CHECKPOINT_PATH)

In [None]:
# Initialize the model
model = PPO("CnnPolicy", env, verbose=1, tensorboard_log=LOG_DIR, learning_rate=1e-6, n_steps=512, device=device)

In [None]:
# # Train the model
# model.learn(total_timesteps=10_00_000, callback=callback)

# **4. Test the Model**

In [None]:
# Load the trained model
model = PPO.load(os.path.join(CHECKPOINT_PATH, "model_1000000.zip"), env=env, device=device)

In [None]:
# Reset the environment
state = env.reset()

# Game loop
while True:
    # Render the environment
    env.render()

    # Get the action
    action, _ = model.predict(state)

    # Take the action
    state, reward, done, info = env.step(action)

    # If the game is done
    if done:
        # Reset the environment
        state = env.reset()