In [1]:
%%bash

python3 -m pip install 'gym-super-mario-bros==7.4.0'

Defaulting to user installation because normal site-packages is not writeable


In [2]:
# Generic Non-Torch Imports
import random, time, datetime, copy, functools
from pathlib import Path
import numpy as np
from PIL import Image
from matplotlib import pyplot as plt
import seaborn
from tqdm.notebook import tqdm

# Torch Imports
import torch
from torch import nn
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms as T
from tensordict import TensorDict

# An annoying problem with this library is that it appears to not work with older versions of PyTorch
# This surprisingly includes 2.0.1+cu117, which it is SUPPOSED to support.
# Note to self: This is what venv is made for, future TODO: create venv.
from torchrl.data import TensorDictReplayBuffer, LazyMemmapStorage

# OpenAI Gymnasium Toolkit + NesPy (which we will run Mario on)
import gym
from gym.spaces import Box
from gym.wrappers import FrameStack
from nes_py.wrappers import JoypadSpace
import gym_super_mario_bros
from gym_super_mario_bros.actions import RIGHT_ONLY, SIMPLE_MOVEMENT, COMPLEX_MOVEMENT

# Reproducibility Measures
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
torch.backends.cudnn.deterministic = True

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

cuda


In [3]:
env = gym_super_mario_bros.make("SuperMarioBros-1-1-v0", render_mode = "rgb", apply_api_compatibility = True)

# Define keystrokes so life is easier.
move_right = "right"
jump_key = "A"

env = JoypadSpace(env, [[move_right], [move_right, jump_key]])

# Test the environment by inputting an action.
env.reset()
next_state, reward, done, trunc, info = env.step(action = 0) # Ask Mario to move right.
print(f"{next_state.shape},\n {reward},\n {done},\n {info}")

(240, 256, 3),
 0.0,
 False,
 {'coins': 0, 'flag_get': False, 'life': 2, 'score': 0, 'stage': 1, 'status': 'small', 'time': 400, 'world': 1, 'x_pos': 40, 'y_pos': 79}


  logger.warn(
  logger.warn(
  if not isinstance(terminated, (bool, np.bool8)):


In [4]:
# According to the tutorial, color information in this context is not necessary.
# So `next_state` is probably visual information from the screen.
# Of course, `info` is a (rather reductive) summary about the game's current state. 
# The following is our preprocessing regime:
# 1) Reduce the state info down to grayscale.
# 2) Downsample the frames into a square.
# 3) In the tutorial, SkipFrame allows us to skip intermediate frames that may not carry enough useful
# information to matter. Probably improves processing time and allows for greater
# room for error in latency etc.
# 4) Again in the tutorial, FrameStack groups frames together.
# According to the tutorial, these preprocessing steps should wrap around the environment.
# So we will copy the tutorial.

class SkipFrame(gym.Wrapper):
    def __init__(self, env, skip):
        """Return only every `skip`-th frame"""
        super().__init__(env)
        self._skip = skip

    def step(self, action):
        """Repeat action, and sum reward"""
        total_reward = 0.0
        for i in range(self._skip):
            # Accumulate reward and repeat the same action
            obs, reward, done, trunk, info = self.env.step(action)
            total_reward += reward
            if done:
                break
        return obs, total_reward, done, trunk, info

class GrayScaleObservation(gym.ObservationWrapper):
    def __init__(self, env):
        super().__init__(env)
        obs_shape = self.observation_space.shape[:2]
        self.observation_space = Box(low = 0, high = 255, shape = obs_shape, dtype = np.uint8)

    def permute_orientation(self, observation):
        # permute [H, W, C] array to [C, H, W] tensor
        observation = np.transpose(observation, (2, 0, 1))
        observation = torch.tensor(observation.copy(), dtype = torch.float)
        return observation

    def observation(self, observation):
        observation = self.permute_orientation(observation)
        transform = T.Grayscale()
        observation = transform(observation)
        return observation


class ResizeObservation(gym.ObservationWrapper):
    def __init__(self, env, shape):
        super().__init__(env)
        if isinstance(shape, int):
            self.shape = (shape, shape)
        else:
            self.shape = tuple(shape)

        obs_shape = self.shape + self.observation_space.shape[2:]
        self.observation_space = Box(low = 0, high = 255, shape = obs_shape, dtype = np.uint8)

    def observation(self, observation):
        transforms = T.Compose(
            [T.Resize(self.shape), T.Normalize(127, 128)] # Restrict the data to be within [-1, 1] and not [0, 1]
        )
        observation = transforms(observation).squeeze(0)
        return observation


# Apply Wrappers to environment
def function_pipeline(*fs):
    return functools.reduce(lambda f,g: (lambda x: g(f(x))), fs)


In [5]:
# Now that our Environment is set up, we should probably get Mario up and running.
# Mario is an Agent. Therefore, he should be able to:
# 1) LEARN a good policy.
# 2) REMEMBER prior feedback (in the form of state-action-reward-new_state tuples).
# 3) ACT according to his learned policy based on his current Environment. That is, Mario should either:
#   a) EXPLORE a slightly modified policy.
#   b) EXPLOIT the current best policy.

# Because it seems more logical to me, we will first get to implementing Mario's Policy.
# That's his brain, and is crucial to his capabilities.

# The tutorial suggests that the policy be learned through a DDQN structure (Double Deep Q-Networks).
# The original paper for DDQN's only innovation appears to be the addition of a target DQN.

class MarioPolicy(nn.Module):

    def weight_init(self, layer):
        # Kaiming initialization for faster training.
        if hasattr(layer, "weight") and layer.weight.dim() > 1:
            nn.init.kaiming_uniform_(layer.weight.data, a = 1e-2, nonlinearity = 'leaky_relu')
    
    def __init__(self, input_dim, output_dim):
        super().__init__()
        
        c, h, w = input_dim
        # We deviate slightly from the tutorial to introduce some improvements:
        # 1) We initialize the weights using Kaiming initialization.
        # 2) We apply dropout between the last 2 convolution layers (to discriminate the most important features)
        # 3) We apply dropout before the final FFN (to regularize the model)
        # 4) (Entirely personal preference) We introduce LeakyReLU as a precautionary measure against dying neurons.
        self.conv_net = nn.Sequential(
                    nn.Conv2d(in_channels = c, out_channels = 32, kernel_size = 8, stride = 4),
                    nn.LeakyReLU(),
                    nn.Conv2d(in_channels = 32, out_channels = 64, kernel_size = 4, stride = 2),
                    nn.Dropout(0.2),
                    nn.LeakyReLU(),
                    nn.Conv2d(in_channels = 64, out_channels = 64, kernel_size = 3, stride = 1),
                    nn.LeakyReLU(),
                    nn.Flatten()
        )
        
        self.conv_output_shape = self.conv_net(torch.zeros(1, c, h, w, requires_grad = False)).shape

        self.online = nn.Sequential(
                    self.conv_net,
                    nn.Linear(self.conv_output_shape[-1], 512),
                    # nn.Dropout(0.2),
                    nn.LeakyReLU(),
                    nn.Linear(512, output_dim)
                )
        
        print(self.conv_output_shape)

        self.online.apply(self.weight_init)

        self.target = copy.deepcopy(self.online)

        # Q_target parameters are frozen.
        for p in self.target.parameters():
            p.requires_grad = False

    def forward(self, input, model):
        if model == "online":
            return self.online(input)
        elif model == "target":
            return self.target(input)

In [6]:
# The rest of Mario's body.

class Mario:

    def first_if_tuple(self, x):
        return x[0] if isinstance(x, tuple) else x

    def __init__(self, state_dim, action_dim, save_dir, lr = 2.5e-4,
                 batch_size = 128, train_steps_per_loop = 4, explore_rate = 1, explore_gamma = 0.99999975, 
                 min_explore_rate = 0.1, save_interval = 2e5, gamma = 0.9, burnin = 1e5, learn_every = 3, 
                 sync_every = 1e5, weight_decay = 0):
        
        # Mario's hyperparameters
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.save_dir = save_dir
        self.batch_size = batch_size * train_steps_per_loop
        self.train_steps_per_loop = train_steps_per_loop
        
        self.lr = lr
        self.explore_rate = explore_rate
        self.explore_gamma = explore_gamma
        self.min_explore_rate = min_explore_rate
        self.save_interval = save_interval
        self.gamma = gamma
        self.burnin = burnin
        self.learn_interval = learn_every
        self.sync_interval = sync_every

        self.device = device

        self.gradient_clip = 10
    
        # This is the Policy that Mario will follow.
        # To put it mathematically, it is a function which maps states to action logits.
        self.policy_net = MarioPolicy(self.state_dim, self.action_dim).to(device = self.device)
        self.optimizer = torch.optim.AdamW(self.policy_net.parameters(), lr = lr, weight_decay = weight_decay)
        self.scheduler = torch.optim.lr_scheduler.ExponentialLR(self.optimizer, gamma = explore_gamma)
        self.loss_fn = torch.nn.HuberLoss()

        # This is Mario's memory.
        self.memory = TensorDictReplayBuffer(storage = LazyMemmapStorage(5e5, device = self.device), 
                                             pin_memory = True, batch_size = self.batch_size, prefetch = 32)

        self.step_ct = 0

    def act(self, state, eval = False):
        
        # Determine whether or not we should explore.
        if np.random.rand() < self.explore_rate and not eval:
            action_idx = np.random.randint(self.action_dim)
        else:
            # This implies LazyFrames are not in fact Tensors. Concerning.
            state = self.first_if_tuple(state).__array__()
            state = torch.tensor(state, device = self.device).unsqueeze(0)
            
            logits = self.policy_net(state, model = "online")
            # TODO: Consider sampling according to his policy.
            # Although, that is probably what his EXPLORE stage should be doing...
            action_idx = torch.argmax(logits, dim = -1).item()
        
        # Decay explore rate since Mario should have learnt a bit more by now.
        # gamma = (self.explore_gamma ** (self.explore_rate > self.min_explore_rate))
        # self.explore_rate = self.explore_rate * gamma

        self.explore_rate *= self.explore_gamma
        self.explore_rate = max(self.min_explore_rate, self.explore_rate)

        # Increment step
        self.step_ct += 1
        return action_idx
    
    def cache(self, state, next_state, action, reward, done):
        # Store feedback to Mario's memory.

        state = self.first_if_tuple(state).__array__()
        next_state = self.first_if_tuple(next_state).__array__()

        state = torch.tensor(state, device = self.device)
        next_state = torch.tensor(next_state, device = self.device)
        action = torch.tensor([action], device = self.device)
        reward = torch.tensor([reward], device = self.device)
        done = torch.tensor([done], device = self.device)

        self.memory.add(
            TensorDict({
                "state": state,
                "next_state": next_state,
                "action": action,
                "reward": reward,
                "done": done,                
            }, batch_size = [])
        )

    def recall(self):
        # Recalls feedback from Mario's memory and batches it.
        # Kind of like a batch collator.

        batch = self.memory.sample()
        state, next_state, action, reward, done = (batch.get(key).contiguous() for key in ("state", "next_state", "action", "reward", "done"))
        return state, next_state, action, reward, done
    
    # Q Learning is in fact a special kind of TD learning.
    def td_estimate(self, state, action):
        current_Q = torch.gather(self.policy_net(state, model = "online"), dim = -1, index = action)
        return current_Q

    @torch.no_grad()    
    def td_target(self, reward, next_state, done):
        next_state_Q = self.policy_net(next_state, model = "online")
        best_action = torch.argmax(next_state_Q, dim = -1).unsqueeze(-1)

        next_Q = torch.gather(self.policy_net(next_state, model = "target"), dim = -1, index = best_action)
        return (reward + (1 - done.float()) * self.gamma * next_Q).float()
    
    def update_Q_online(self, td_estimate, td_target):
        loss = self.loss_fn(td_estimate, td_target)
        loss.backward()

        # Clip gradients for more stable training
        grad_norm = torch.nn.utils.clip_grad_norm_(self.policy_net.online.parameters(), self.gradient_clip)
        self.optimizer.step()
        self.scheduler.step()
        return loss, grad_norm

    def sync_Q_target(self):
        self.policy_net.target.load_state_dict(self.policy_net.online.state_dict())
    
    def save(self):
        save_path = Path(self.save_dir, f"mario_net_{int(self.step_ct // self.save_interval)}.ckpt")
        torch.save(
            {
                "model": self.policy_net.state_dict(),
                "explore_rate": self.explore_rate
            }, 
            save_path
        )
        print(f"Mario's brain saved to {save_path} at step {self.step_ct}")

    def load(self, load_path: Path):

        if not load_path.exists():
            raise ValueError(f"{load_path} does not exist")
        
        ckp = torch.load(load_path, map_location = self.device)
        explore_rate = ckp.get('explore_rate')
        state_dict = ckp.get('model')

        print(f"Loading model at {load_path} with exploration rate {explore_rate}")
        self.policy_net.load_state_dict(state_dict)
        self.explore_rate = explore_rate
        
    def step(self, eval = False):
        if not eval:

            if self.step_ct % self.sync_interval == 0:
                self.sync_Q_target()

            if self.step_ct % self.save_interval == 0:
                self.save()

            if self.step_ct < self.burnin:
                return 0, 0, 0

            if self.step_ct % self.learn_interval != 0:
                return 0, 0, 0

        # Sample from memory
        state, next_state, action, reward, done = self.recall()

        # Get TD Estimate
        td_est = self.td_estimate(state, action)

        # Get TD Target
        td_tgt = self.td_target(reward, next_state, done)

        self.optimizer.zero_grad()
        # Backpropagate loss through Q_online
        loss, grad_norm = self.update_Q_online(td_est, td_tgt)

        return (loss.cpu().detach(), td_est.mean().item(), grad_norm.cpu().detach())


In [7]:
# Logger class provided by the tutorial.
class MetricLogger:
    def __init__(self, save_dir):
        self.save_log = Path(save_dir, "log")
        with open(self.save_log, "w") as f:
            f.write(
                f"{'Episode':>8}{'Step':>8}{'Epsilon':>10}{'MeanReward':>15}"
                f"{'MeanLength':>15}{'MeanLoss':>15}{'MeanQValue':>15}"
                f"{'TimeDelta':>15}{'Time':>20}\n"
            )
        self.ep_rewards_plot = Path(save_dir, "reward_plot.jpg")
        self.ep_lengths_plot = Path(save_dir, "length_plot.jpg")
        self.ep_avg_losses_plot = Path(save_dir, "loss_plot.jpg")
        self.ep_avg_qs_plot = Path(save_dir, "q_plot.jpg")

        # History metrics
        self.ep_rewards = []
        self.ep_lengths = []
        self.ep_avg_losses = []
        self.ep_avg_qs = []
        self.ep_avg_grad_norm = []

        # Moving averages, added for every call to record()
        self.moving_avg_ep_rewards = []
        self.moving_avg_ep_lengths = []
        self.moving_avg_ep_avg_losses = []
        self.moving_avg_ep_avg_qs = []
        self.moving_avg_ep_grad_norm = []

        # Current episode metric
        self.init_episode()

        # Timing
        self.record_time = time.time()

        # Tensorboard writer
        self.save_dir = str(save_dir)
        self.writer = SummaryWriter(log_dir = f"{save_dir}/tb")

    def log_step(self, reward, loss, q, grad_norm):
        self.curr_ep_reward += reward
        self.curr_ep_length += 1
        if loss:
            self.curr_ep_loss += loss
            self.curr_ep_q += q
            self.curr_ep_loss_length += 1
            self.curr_ep_grad_norm += grad_norm

    def log_episode(self):
        "Mark end of episode"
        self.ep_rewards.append(self.curr_ep_reward)
        self.ep_lengths.append(self.curr_ep_length)
        if self.curr_ep_loss_length == 0:
            ep_avg_loss = 0
            ep_avg_q = 0
            ep_avg_grad_norm = 0
        else:
            ep_avg_loss = np.round(self.curr_ep_loss / self.curr_ep_loss_length, 5)
            ep_avg_q = np.round(self.curr_ep_q / self.curr_ep_loss_length, 5)
            ep_avg_grad_norm = np.round(self.curr_ep_grad_norm / self.curr_ep_loss_length, 5)
        self.ep_avg_losses.append(ep_avg_loss)
        self.ep_avg_qs.append(ep_avg_q)
        self.ep_avg_grad_norm.append(ep_avg_grad_norm)

        self.init_episode()

    def init_episode(self):
        self.curr_ep_reward = 0.0
        self.curr_ep_length = 0
        self.curr_ep_loss = 0.0
        self.curr_ep_q = 0.0
        self.curr_ep_loss_length = 0
        self.curr_ep_grad_norm = 0.0

    def record(self, episode, epsilon, step, lr):
        mean_ep_reward = np.round(np.mean(self.ep_rewards[-100:]), 3)
        mean_ep_length = np.round(np.mean(self.ep_lengths[-100:]), 3)
        mean_ep_loss = np.round(np.mean(self.ep_avg_losses[-100:]), 3)
        mean_ep_q = np.round(np.mean(self.ep_avg_qs[-100:]), 3)
        mean_ep_grad_norm = np.round(np.mean(self.ep_avg_grad_norm[-100:]), 3)
        self.moving_avg_ep_rewards.append(mean_ep_reward)
        self.moving_avg_ep_lengths.append(mean_ep_length)
        self.moving_avg_ep_avg_losses.append(mean_ep_loss)
        self.moving_avg_ep_avg_qs.append(mean_ep_q)
        self.moving_avg_ep_grad_norm.append(mean_ep_grad_norm)

        last_record_time = self.record_time
        self.record_time = time.time()
        time_since_last_record = np.round(self.record_time - last_record_time, 3)

        self.writer.add_scalars('moving_avg_reward', {self.save_dir: mean_ep_reward}, global_step = episode)
        self.writer.add_scalars('moving_avg_length', {self.save_dir: mean_ep_length}, global_step = episode)
        self.writer.add_scalars('moving_avg_loss', {self.save_dir: mean_ep_loss}, global_step = episode)
        self.writer.add_scalars('moving_avg_q', {self.save_dir: mean_ep_q}, global_step = episode)
        self.writer.add_scalars('moving_avg_grad_norm', {self.save_dir: mean_ep_grad_norm}, global_step = episode)

        print(
            f"Episode {episode} - "
            f"Step {step} - "
            f"Epsilon {epsilon} - "
            f"Mean Reward {mean_ep_reward} - \n"
            f"Mean Length {mean_ep_length} - "
            f"Mean Loss {mean_ep_loss} - "
            f"Mean Q Value {mean_ep_q} - "
            f"Mean Grad Norm {mean_ep_grad_norm}\n"
            f"Learn Rate {lr} - "
            f"Time Delta {time_since_last_record} - "
            f"Time {datetime.datetime.now().strftime('%Y-%m-%dT%H:%M:%S')}\n\n"
        )

        with open(self.save_log, "a") as f:
            f.write(
                f"{episode:8d}{step:8d}{epsilon:10.3f}"
                f"{mean_ep_reward:15.3f}{mean_ep_length:15.3f}{mean_ep_loss:15.3f}{mean_ep_q:15.3f}"
                f"{time_since_last_record:15.3f}"
                f"{datetime.datetime.now().strftime('%Y-%m-%dT%H:%M:%S'):>20}\n"
            )

        for metric in ["ep_lengths", "ep_avg_losses", "ep_avg_qs", "ep_rewards","ep_grad_norm"]:
            
            plt.clf()
            plt.plot(getattr(self, f"moving_avg_{metric}"), label = f"moving_avg_{metric}")
            plt.legend()

            self.writer.add_figure(metric, plt.gcf())

In [8]:
# Finally we can train Mario!

save_dir = Path("checkpoints", "mario_specimen")
if not save_dir.exists():
    save_dir.mkdir(parents = True)

del env
env = gym_super_mario_bros.make("SuperMarioBros-1-1-v0", render_mode = "rgb", apply_api_compatibility = True)

env = JoypadSpace(env, RIGHT_ONLY)

env = function_pipeline(
                lambda x: SkipFrame(x, skip = 4),
                lambda x: GrayScaleObservation(x),
                lambda x: ResizeObservation(x, shape = 84),
                lambda x: FrameStack(x, num_stack = 4)
            )(env)

# env = SkipFrame(env, skip = 4)
# env = GrayScaleObservation(env)
# env = ResizeObservation(env, shape = 84)
# env = FrameStack(env, num_stack = 4)

mario = Mario(state_dim = (4, 84, 84), action_dim = env.action_space.n, save_dir = save_dir,
               train_steps_per_loop = 4, explore_gamma = 0.99999975, gamma = 0.9, batch_size = 256,
               lr = 3e-4, burnin = 1e5, sync_every = 1e5, save_interval = 2e5)
logger = MetricLogger(save_dir)

episodes = 10000
print(mario.policy_net.online)
for e in tqdm(range(episodes)):

    state = env.reset()

    # Play the game!
    while True:
        # env.render()
        # Run agent on the state
        for i in range(mario.train_steps_per_loop):
            action = mario.act(state)
            # Agent performs action
            next_state, reward, done, trunc, info = env.step(action)
            # Remember
            mario.cache(state, next_state, action, reward, done)
            
            if done:
                break
            else:
                state = next_state
        # Learn.
        loss, q, grad_norm = mario.step()

        # Logging
        logger.log_step(reward, loss, q, grad_norm)
        # Update state
        state = next_state

        # Check if end of game
        if done or info["flag_get"]:
            break

    logger.log_episode()

    if e % 20 == 0:
        logger.record(episode = e, epsilon = mario.explore_rate, step = mario.step_ct, lr = mario.scheduler.get_last_lr()[0])

env.close()

torch.Size([1, 3136])
Sequential(
  (0): Sequential(
    (0): Conv2d(4, 32, kernel_size=(8, 8), stride=(4, 4))
    (1): LeakyReLU(negative_slope=0.01)
    (2): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2))
    (3): Dropout(p=0.2, inplace=False)
    (4): LeakyReLU(negative_slope=0.01)
    (5): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
    (6): LeakyReLU(negative_slope=0.01)
    (7): Flatten(start_dim=1, end_dim=-1)
  )
  (1): Linear(in_features=3136, out_features=512, bias=True)
  (2): LeakyReLU(negative_slope=0.01)
  (3): Linear(in_features=512, out_features=5, bias=True)
)


  0%|          | 0/10000 [00:00<?, ?it/s]



Episode 0 - Step 1222 - Epsilon 0.9996945466221541 - Mean Reward 131.0 - 
Mean Length 306.0 - Mean Loss 0.0 - Mean Q Value 0.0 - Mean Grad Norm 0.0
Learn Rate 0.0003 - Time Delta 11.535 - Time 2023-10-18T10:56:42


Episode 20 - Step 8734 - Epsilon 0.9978188818293865 - Mean Reward 185.333 - 
Mean Length 104.381 - Mean Loss 0.0 - Mean Q Value 0.0 - Mean Grad Norm 0.0
Learn Rate 0.0003 - Time Delta 65.102 - Time 2023-10-18T10:57:47


Episode 40 - Step 13260 - Epsilon 0.9966904881325136 - Mean Reward 160.732 - 
Mean Length 81.268 - Mean Loss 0.0 - Mean Q Value 0.0 - Mean Grad Norm 0.0
Learn Rate 0.0003 - Time Delta 31.711 - Time 2023-10-18T10:58:19


Episode 60 - Step 27528 - Epsilon 0.9931416258759 - Mean Reward 151.803 - 
Mean Length 113.23 - Mean Loss 0.0 - Mean Q Value 0.0 - Mean Grad Norm 0.0
Learn Rate 0.0003 - Time Delta 106.859 - Time 2023-10-18T11:00:06


Episode 80 - Step 33325 - Epsilon 0.9917033576514628 - Mean Reward 153.877 - 
Mean Length 103.247 - Mean Loss 0.0 - Mean Q Valu

KeyboardInterrupt: 

In [16]:
# Evaluate Mario
save_dir = Path("checkpoints", "test")
if not save_dir.exists():
    save_dir.mkdir(parents = True)
load_dir = Path("checkpoints", "mario_specimen", "mario_net_2.ckpt")
del env
# I actually want to see the agent at work, so :P
env = gym_super_mario_bros.make("SuperMarioBros-1-1-v0", render_mode = "human", apply_api_compatibility = True)

env = JoypadSpace(env, RIGHT_ONLY)

env = function_pipeline(
                lambda x: SkipFrame(x, skip = 4),
                lambda x: GrayScaleObservation(x),
                lambda x: ResizeObservation(x, shape = 84),
                lambda x: FrameStack(x, num_stack = 4)
            )(env)


mario = Mario(state_dim = (4, 84, 84), action_dim = env.action_space.n, save_dir = save_dir)
mario.load(load_dir)
logger = MetricLogger(save_dir)

episodes = 10
for e in tqdm(range(episodes)):

    state = env.reset()

    # Play the game!
    while True:
        env.render()
        # Run agent on the state
        action = mario.act(state)
        # Agent performs action
        next_state, reward, done, trunc, info = env.step(action)
        # Remember
        mario.cache(state, next_state, action, reward, done)
        # Learn
        loss, q, grad_norm = mario.step()
        # Logging
        logger.log_step(reward, loss.cpu(), q, grad_norm)
        # Update state
        state = next_state

        # Check if end of game
        if done or info["flag_get"]:
            break

    logger.log_episode()

    if e % 20 == 0:
        logger.record(episode = e, epsilon = mario.explore_rate, step = mario.step_ct, grad)

env.close()

torch.Size([1, 6272])
Loading model at checkpoints/mario_specimen/mario_net_2.ckpt with exploration rate 0.9048374067128394


  0%|          | 0/10 [00:00<?, ?it/s]

  logger.warn(


Episode 0 - Step 138 - Epsilon 0.9048061903568869 - Mean Reward 640.0 - Mean Length 138.0 - Mean Loss 5.331999778747559 - Mean Q Value 35.823 - Time Delta 6.498 - Time 2023-10-16T13:50:08
