In [10]:
import torch
import torch.nn as nn
import numpy as np
from torchrl.data.replay_buffers import TensorDictReplayBuffer
from torchrl.data import LazyMemmapStorage
from tensordict import TensorDict

In [11]:
class MarioNet(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        c, h, w = input_dim
        if h != 84:
            raise ValueError(f"Expecting input height: 84, got: {h}")
        if w != 84:
            raise ValueError(f"Expecting input width: 84, got: {w}")

        self.online = self._build_cnn_layers(c, output_dim)
        self.target = self._build_cnn_layers(c, output_dim)

    def forward(self, input, model):
        if model == "online":
            return self.online(input)
        elif model == "target":
            return self.target(input)

    def _build_cnn_layers(self, input_dim, output_dim):
        """Construct the convolutional layers"""
        self.conv1 = nn.Conv2d(
            in_channels=input_dim, out_channels=32, kernel_size=8, stride=4
        )
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2)
        self.conv3 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1)

        return nn.Sequential(
            self.conv1,
            nn.ReLU(),
            self.conv2,
            nn.ReLU(),
            self.conv3,
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(7 * 7 * 64, 512),
            nn.ReLU(),
            nn.Linear(512, output_dim),
        )


In [12]:
class MarioAgent:
    def __init__(self, state_dim, action_dim, save_dir):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(self.device)
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.save_dir = save_dir

        self.net = MarioNet(self.state_dim, self.action_dim).float()

        self.exploration_rate = 1
        self.exploration_rate_decay = 0.99999975
        self.exploration_rate_min = 0.1
        self.curr_step = 0

        self.save_every = 5e5

        self.memory = TensorDictReplayBuffer(
            storage=LazyMemmapStorage(max_size=50_000, device="cpu"),
        )
        self.batch_size = 32

    def act(self, state):
        """Given a state, choose an epsilon-greedy action"""

        # explore
        if np.random.rand() < self.exploration_rate:
            return np.random.randint(self.action_dim)
        # exploit
        else:
            state = (
                state[0].__array__() if isinstance(state, tuple) else state.__array__()
            )
            state = torch.tensor(state, device=self.device).unsqueeze(0)
            action_values = self.net(state, model="online")
            action_idx = torch.argmax(action_values, dim=1).item()

        self.exploration_rate *= self.exploration_rate_decay
        self.exploration_rate = max(self.exploration_rate, self.exploration_rate_min)
        self.curr_step += 1
        return action_idx

    def cache(self, state, next_state, action, reward, done):
        """Add the experience to memory"""

        def first_if_tuple(x):
            return x[0] if isinstance(x, tuple) else x
        
        state = first_if_tuple(state).__array__()
        next_state = first_if_tuple(next_state).__array__()

        self.memory.add(
            TensorDict(
                {
                    "state": torch.tensor(state,device=self.device),
                    "next_state": torch.tensor(next_state,device=self.device),
                    "action": torch.tensor([action],device=self.device),
                    "reward": torch.tensor([reward],device=self.device),
                    "done": torch.tensor([done],device=self.device),
                },
                batch_size=[],
            )
        )

    def recall(self):
        """Sample experiences from memory"""
        batch = self.memory.sample(self.batch_size).to(self.device)
        state, next_state, action, reward, done = (
            batch.get(key)
            for key in ("state", "next_state", "action", "reward", "done")
        )
        return state, next_state, action, reward, done

In [13]:
class Mario(MarioAgent):
    def __init__(self, state_dim, action_dim, save_dir):
        super().__init__(state_dim, action_dim, save_dir)
        self.gamma = 0.99
        self.optimizer = torch.optim.Adam(self.net.parameters(), lr=0.00025)
        self.loss_fn = torch.nn.SmoothL1Loss()

        self.burnin = 1e4
        self.learn_every = 3
        self.sync_every = 1e4

    def td_estimate(self, state, action):
        current_Q = self.net(state, model="online")[
            np.arange(0, self.batch_size), action
        ]
        return current_Q

    @torch.no_grad()
    def td_target(self, next_state, reward, done):
        next_state_Q = self.net(next_state, model="target")
        best_action = torch.argmax(next_state_Q, dim=1)
        next_Q = self.net(next_state, model="target")[
            np.arange(0, self.batch_size), best_action
        ]
        return (reward + (1 - done.float()) * self.gamma * next_Q).float()

    def update_Q_online(self, td_estimate, td_target):
        loss = self.loss_fn(td_estimate, td_target)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        return loss.item()

    def sync_Q_target(self):
        self.net.target.load_state_dict(self.net.online.state_dict())

    def save_model(self):
        save_path = (
            self.save_dir / f"mario_net_{int(self.curr_step // self.save_every)}.chkpt"
        )
        torch.save(
            dict(model=self.net.state_dict(), exploration_rate=self.exploration_rate),
            save_path,
        )
        print(f"MarioNet saved to {save_path} at step {self.curr_step}")

    def learn(self):
        if self.curr_step % self.learn_every != 0:
            return None, None

        if self.curr_step % self.sync_every == 0:
            self.sync_Q_target()

        if self.curr_step % self.save_every == 0:
            self.save_model()

        if self.curr_step < self.burnin:
            return None, None

        state, next_state, action, reward, done = self.recall()

        # get TD estimate
        td_est = self.td_estimate(state, action)
        # get TD target
        td_tgt = self.td_target(next_state, reward, done)
        # backpropagate loss
        loss = self.update_Q_online(td_est, td_tgt)

        return (td_est.mean().item(), loss)



# Preprocess & environment setup

In [14]:
import gym
from gym.spaces import Box
import numpy as np
import torch
import torchvision.transforms as T
from gym.wrappers import FrameStack
# NES Emulator for OpenAI Gym
from nes_py.wrappers import JoypadSpace
import gym_super_mario_bros

In [15]:
class SkipFrame(gym.Wrapper):
    def __init__(self, env, skip):
        """Return only every `skip`-th frame"""
        super().__init__(env)
        self._skip = skip

    def step(self, action):
        """Repeat action, and sum reward"""
        total_reward = 0.0
        for i in range(self._skip):
            # Accumulate reward and repeat the same action
            obs, reward, done, trunk, info = self.env.step(action)
            total_reward += reward
            if done:
                break
        return obs, total_reward, done, trunk, info


class GrayScaleObservation(gym.ObservationWrapper):
    def __init__(self, env):
        super().__init__(env)
        obs_shape = self.observation_space.shape[:2]
        self.observation_space = Box(low=0, high=255, shape=obs_shape, dtype=np.uint8)

    def permute_orientation(self, observation):
        # permute [H, W, C] array to [C, H, W] tensor
        observation = np.transpose(observation, (2, 0, 1))
        observation = torch.tensor(observation.copy(), dtype=torch.float)
        return observation

    def observation(self, observation):
        observation = self.permute_orientation(observation)
        transform = T.Grayscale()
        observation = transform(observation)
        return observation


class ResizeObservation(gym.ObservationWrapper):
    def __init__(self, env, shape):
        super().__init__(env)
        if isinstance(shape, int):
            self.shape = (shape, shape)
        else:
            self.shape = tuple(shape)

        obs_shape = self.shape + self.observation_space.shape[2:]
        self.observation_space = Box(low=0, high=255, shape=obs_shape, dtype=np.uint8)

    def observation(self, observation):
        transforms = T.Compose(
            [T.Resize(self.shape, antialias=True), T.Normalize(0, 255)]
        )
        observation = transforms(observation).squeeze(0)
        return observation



In [16]:
env = gym_super_mario_bros.make(
    "SuperMarioBros-1-1-v0", render_mode="rgb", apply_api_compatibility=True
)


# Limit the action-space to
#   0. walk right
#   1. jump right
env = JoypadSpace(env, [["right"], ["right", "A"]])


env.reset()
next_state, reward, done, trunc, info = env.step(action=0)
print(f"{next_state.shape},\n {reward},\n {done},\n {info}")
env = SkipFrame(env, skip=4)
env = GrayScaleObservation(env)
env = ResizeObservation(env, shape=(84, 84))
env = FrameStack(env, num_stack=4)

# done = True
# for step in range(5000):
#     if done:
#         state = env.reset()
#         env.step(env.action_space.sample())
#     env.render()

print('action space',env.action_space.n)
env.close()



  logger.warn(
  logger.warn(


(240, 256, 3),
 0.0,
 False,
 {'coins': 0, 'flag_get': False, 'life': 2, 'score': 0, 'stage': 1, 'status': 'small', 'time': 400, 'world': 1, 'x_pos': 40, 'y_pos': 79}
action space 2


  if not isinstance(terminated, (bool, np.bool8)):


In [17]:
import datetime
from pathlib import Path

In [18]:
save_dir = Path("checkpoints") / datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
save_dir.mkdir(parents=True)


env = gym_super_mario_bros.make(
    "SuperMarioBros-1-1-v0", render_mode="rgb_array", apply_api_compatibility=True
)

# limit actions and apply wrappers
env = JoypadSpace(env, [["right"], ["right", "A"]])
env = SkipFrame(env, skip=4)
env = GrayScaleObservation(env)
env = ResizeObservation(env, shape=(84, 84))
env = FrameStack(env, num_stack=4)


mario = Mario(state_dim=(4, 84, 84), action_dim=env.action_space.n, save_dir=save_dir)

episodes = 40

for e in range(episodes):
    state = env.reset()
    
    while True:
        action = mario.act(state)
        next_state, reward, done, trunc, info = env.step(action)
        mario.cache(state, next_state, action, reward, done)
        q, loss = mario.learn()
        state = next_state

        if done or info["flag_get"]:
            print(
                f"Episode: {e}, "
                f"Step: {mario.curr_step}, "
                f"Exploration Rate: {mario.exploration_rate:.5f}, "
                f"Q: {q}, "
                f"Loss: {loss}"
            )
            break
        

cuda
MarioNet saved to checkpoints/2025-10-12T00-21-50/mario_net_0.chkpt at step 0
MarioNet saved to checkpoints/2025-10-12T00-21-50/mario_net_0.chkpt at step 0
MarioNet saved to checkpoints/2025-10-12T00-21-50/mario_net_0.chkpt at step 0
MarioNet saved to checkpoints/2025-10-12T00-21-50/mario_net_0.chkpt at step 0
MarioNet saved to checkpoints/2025-10-12T00-21-50/mario_net_0.chkpt at step 0
MarioNet saved to checkpoints/2025-10-12T00-21-50/mario_net_0.chkpt at step 0
MarioNet saved to checkpoints/2025-10-12T00-21-50/mario_net_0.chkpt at step 0
MarioNet saved to checkpoints/2025-10-12T00-21-50/mario_net_0.chkpt at step 0
MarioNet saved to checkpoints/2025-10-12T00-21-50/mario_net_0.chkpt at step 0
MarioNet saved to checkpoints/2025-10-12T00-21-50/mario_net_0.chkpt at step 0
MarioNet saved to checkpoints/2025-10-12T00-21-50/mario_net_0.chkpt at step 0
MarioNet saved to checkpoints/2025-10-12T00-21-50/mario_net_0.chkpt at step 0
MarioNet saved to checkpoints/2025-10-12T00-21-50/mario_net

### Making it play the game !!!!