In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

/kaggle/input/checkpoint7/model_1000_iter.pt
/kaggle/input/checkpoint9/model_15000_iter.pt
/kaggle/input/checkpoint8/model_3000_iter.pt


In [2]:
!pip install git+https://github.com/pytorch-labs/tensordict
!pip install git+https://github.com/pytorch/rl.git
!pip install gym-super-mario-bros

Collecting git+https://github.com/pytorch-labs/tensordict
  Cloning https://github.com/pytorch-labs/tensordict to /tmp/pip-req-build-jwn3tx83


  Running command git clone --filter=blob:none --quiet https://github.com/pytorch-labs/tensordict /tmp/pip-req-build-jwn3tx83


  Resolved https://github.com/pytorch-labs/tensordict to commit 85b02047dd816dd248b5e589100d9903d26696f9


  Installing build dependencies ... [?25l-

 \

 |

 /

 -

 \

 |

 /

 -

 \

 |

 /

 -

 \

 |

 /

 -

 \

 |

 /

 -

 \

 |

 /

 -

 \

 |

 /

 -

 \

 |

 /

 -

 \

 |

 /

 -

 \

 |

 /

 -

 \

 |

 /

 -

 \

 |

 /

 -

 \

 |

 /

 -

 \

 |

 /

 -

 \

 |

 /

 -

 \

 |

 /

 -

 \

 |

 /

 -

 \

 |

 /

 -

 \

 |

 /

 -

 \

 |

 /

 -

 \

 |

 /

 -

 \

 |

 /

 -

 \

 |

 /

 -

 \

 |

 /

 -

 \

 |

 /

 -

 \

 |

 /

 -

 \

 |

 /

 -

 \

 |

 /

 -

 \

 |

 /

 -

 \

 |

 /

 -

 \

 |

 /

 -

 \

 |

 /

 -

 \

 |

 /

 -

 \

 |

 /

 -

 \

 |

 /

 -

 \

 |

 /

 -

 \

 |

 /

 -

 \

 |

 /

 -

 \

 |

 /

 -

 \

 |

 /

 -

 \

 |

 /

 -

 \

 |

 /

 -

 \

 |

 /

 -

 \

 |

 /

In [None]:
import torch
import gym_super_mario_bros
from gym_super_mario_bros.actions import RIGHT_ONLY
import datetime
from pathlib import Path
from nes_py.wrappers import JoypadSpace
import os

import numpy as np
from tensordict import TensorDict
from torchrl.data import TensorDictReplayBuffer, LazyMemmapStorage
from torch import nn

from gym import Wrapper
from gym.wrappers import GrayScaleObservation, ResizeObservation, FrameStack


class SkipFrame(Wrapper):
    def __init__(self, env, skip):
        super().__init__(env)
        self.skip = skip
    
    def step(self, action):
        total_reward = 0.0
        done = False
        for _ in range(self.skip):
            next_state, reward, done, trunc, info = self.env.step(action)
            total_reward += reward
            if done:
                break
        return next_state, total_reward, done, trunc, info
    

def apply_wrappers(env):
    env = SkipFrame(env, skip=4) # Num of frames to apply one action to
    env = ResizeObservation(env, shape=84) # Resize frame from 240x256 to 84x84
    env = GrayScaleObservation(env)
    env = FrameStack(env, num_stack=4) # May need to change lz4_compress to False if issues arise
    return env

class MarioNN(nn.Module):
    def __init__(self, input_shape, n_actions, freeze=False):
        super().__init__()
        # Conolutional layers
        self.conv_layers = nn.Sequential(
            nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1),
            nn.ReLU(),
        )

        conv_out_size = self._get_conv_out(input_shape)

        # Linear layers
        self.network = nn.Sequential(
            self.conv_layers,
            nn.Flatten(),
            nn.Linear(conv_out_size, 512),
            nn.ReLU(),
            nn.Linear(512, n_actions)
        )

        if freeze:
            self._freeze()
        
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.to(self.device)

    def forward(self, x):
        return self.network(x)

    def _get_conv_out(self, shape):
        o = self.conv_layers(torch.zeros(1, *shape))
        # np.prod returns the product of array elements over a given axis
        return int(np.prod(o.size()))
    
    def _freeze(self):        
        for p in self.network.parameters():
            p.requires_grad = False
    

class Mario:
    def __init__(self, 
                 input_dims, 
                 num_actions, 
                 learning_rate=0.00025, 
                 gamma=0.9, 
                 exploration_rate=1.0, 
                 exploration_decay=0.99999975, 
                 exploration_min=0.1, 
                 replay_buffer_capacity=10_000, 
                 batch_size=32, 
                 sync_network_rate=10000):
        
        self.num_actions = num_actions
        self.train_step_counter = 0

        # Hyperparameters
        self.learning_rate = learning_rate
        self.gamma = gamma
        self.exploration_rate = exploration_rate
        self.exploration_decay = exploration_decay
        self.exploration_min = exploration_min
        self.batch_size = batch_size
        self.sync_network_rate = sync_network_rate
        self.burnin = 32
        
        self.use_cuda = torch.cuda.is_available()

        # Networks for Qonline and Qtarget
        self.online_network = MarioNN(input_dims, num_actions)
        self.target_network = MarioNN(input_dims, num_actions, freeze=True)

        # Optimizer and loss
        self.optimizer = torch.optim.Adam(self.online_network.parameters(), lr=self.learning_rate)
        self.loss = torch.nn.MSELoss()
        # self.loss = torch.nn.SmoothL1Loss()

        # Replay buffer
        storage = LazyMemmapStorage(replay_buffer_capacity)
        self.replay_buffer = TensorDictReplayBuffer(storage=storage)

    def do_action(self, observation):
        if np.random.random() < self.exploration_rate:
            return np.random.randint(self.num_actions)

        observation = torch.tensor(np.array(observation), dtype=torch.float32) \
                        .unsqueeze(0) \
                        .to(self.online_network.device)
        # Grabbing the index of the action that's associated with the highest Q-value
        return self.online_network(observation).argmax().item()
    
    def decay_exploration_rate(self):
        self.exploration_rate = max(self.exploration_rate * self.exploration_decay, self.exploration_min)

    def store_memory(self, state, action, reward, next_state, done):
        self.replay_buffer.add(TensorDict({
                                            "state": torch.tensor(np.array(state), dtype=torch.float32), 
                                            "action": torch.tensor(action),
                                            "reward": torch.tensor(reward), 
                                            "next_state": torch.tensor(np.array(next_state), dtype=torch.float32), 
                                            "done": torch.tensor(done)
                                          }, batch_size=[]))
        
    def sync_networks(self):
        if self.train_step_counter % self.sync_network_rate == 0 and self.train_step_counter > 0:
            self.target_network.load_state_dict(self.online_network.state_dict())

    def save_model(self, path):
        # torch.save(self.online_network.state_dict(), path)
        torch.save(
            dict(model = self.online_network.state_dict(), exploration_rate = self.exploration_rate),
            path
        )
        print(f"Checkpoint {path} saved successfully")

    def load_model(self, path):
        if not path:
            return
        if not path.exists():
            raise ValueError(f"{path} does not exist")
        
        ckp = torch.load(path, map_location = ('cuda' if self.use_cuda else 'cpu'))
        self.exploration_rate = ckp.get('exploration_rate')
        state_dict = ckp.get('model')
        
        print(f"Loading model {path} with exploration rate {self.exploration_rate}")
        self.online_network.load_state_dict(state_dict)
        self.target_network.load_state_dict(state_dict)

    def train(self):
        if len(self.replay_buffer) < self.batch_size:
            return
        
        if self.train_step_counter < self.burnin:
            self.train_step_counter += 1
            return
        
        self.sync_networks()

        samples = self.replay_buffer.sample(self.batch_size).to(self.online_network.device)

        keys = ("state", "action", "reward", "next_state", "done")

        states, actions, rewards, next_states, dones = [samples[key] for key in keys]

        predicted_q_values = self.online_network(states)[np.arange(0, self.batch_size), actions]  # Shape is (batch_size, n_actions)
        # predicted_q_values = predicted_q_values[np.arange(self.batch_size), actions]

        # Max returns two tensors, the first one is the maximum value, the second one is the index of the maximum value
        next_q_values = self.online_network(next_states)
        best_action = torch.argmax(next_q_values, axis = 1)
        target_q_values = self.target_network(next_states)[np.arange(0, self.batch_size), best_action]
        # The rewards of any future states don't matter if the current state is a terminal state
        # If done is true, then 1 - done is 0, so the part after the plus sign (representing the future rewards) is 0
        target_q_values = rewards + self.gamma * target_q_values * (1 - dones.float())

        loss = self.loss(predicted_q_values, target_q_values)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        self.train_step_counter += 1
        self.decay_exploration_rate()

        

ENV = 'SuperMarioBros-1-1-v0'
TRAIN = True
DISPLAY = False
CKPT_SAVE_INTERVAL = 200
EPISODES = 50000

def main():
    save_dir = Path("checkpoints") / datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
    save_dir.mkdir(parents = True)

    if torch.cuda.is_available():
        print("Using CUDA device:", torch.cuda.get_device_name(0))
    else:
        print("CUDA is not available")

    
    env = gym_super_mario_bros.make(ENV, render_mode='human' if DISPLAY else 'rgb', apply_api_compatibility=True)
    env = JoypadSpace(env, RIGHT_ONLY)
    
    env = apply_wrappers(env)
    
    mario = Mario(input_dims=env.observation_space.shape, num_actions=env.action_space.n)
    
    ckpt_name = Path("/kaggle/input/checkpoint9/model_15000_iter.pt")
    mario.load_model(ckpt_name)

    if not TRAIN:
        mario.exploration_rate = 0.2
        mario.exploration_min = 0.0
        mario.exploration_decay = 0.0
    
    env.reset()
    next_state, reward, done, trunc, info = env.step(action=0)
    
    for e in range(15000, EPISODES):
        total_reward = 0
        max_x = 0
        done = False
        
        state, _ = env.reset()
        
        while not done:
            if DISPLAY:
                env.render()
            
            action = mario.do_action(next_state)
            next_state, reward, done, trunc, info = env.step(action)
            total_reward += reward
            
            max_x = max(info["x_pos"], max_x)
            
            if TRAIN:
                mario.store_memory(state, action, reward, next_state, done)
                mario.train()
                
            state = next_state
            
            if done:
                break
        
        if TRAIN and (e + 1) % CKPT_SAVE_INTERVAL == 0:
            save_path = os.path.join(save_dir, "model_" + str(e + 1) + "_iter.pt")
            mario.save_model(save_path)
            # mario.save_model(os.path.join(save_dir, "model_" + str(e + 1) + "_iter.pt"))
        
        
        print(f"Episode {e + 1} Total reward: {total_reward}, Max x: {max_x}")
        
    env.close()



In [None]:
if __name__ == "__main__":
    main()