In [8]:
%pip install nes_py
%pip install gym_super_mario_bros
%pip install opencv-python

Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Collecting opencv-python
  Downloading opencv_python-4.5.5.62-cp36-abi3-win_amd64.whl (35.4 MB)
     --------------------------------------- 35.4/35.4 MB 28.5 MB/s eta 0:00:00
Installing collected packages: opencv-python
Successfully installed opencv-python-4.5.5.62
Note: you may need to restart the kernel to use updated packages.


In [1]:
RIGHT_AND_JUMP = [
    ['right'], # 0
    ['right', 'A'] # 1
]

# actions for the simple run right environment
RIGHT_ONLY = [
    ['NOOP'],
    ['right'],
    ['right', 'A'],
    ['right', 'B'],
    ['right', 'A', 'B'],
]

# actions for very simple movement
SIMPLE_MOVEMENT = [
    ['right'],
    ['right', 'A'],
    ['right', 'B'],
    ['right', 'A', 'B'],
    ['A'],
    ['left']
]

In [2]:
import numpy as np
import gym
from gym.wrappers import *
from nes_py.wrappers import JoypadSpace
from gym.spaces import Box
from torchvision import transforms
import torch
import gym_super_mario_bros

In [75]:
class Counter(dict):

    def __init__(self, size=1):
        super().__init__()
        self.size = size

    def __getitem__(self, idx):
        idx = str(idx)
        self.setdefault(idx, np.zeros(self.size))
        return dict.__getitem__(self, idx)



class SkipFrame(gym.Wrapper):
    def __init__(self, env, skip):
        super().__init__(env)
        self._skip = skip

    def step(self, action):
        total_reward = 0.0
        done = False
        for _ in range(self._skip):
            obs, reward, done, info = self.env.step(action)
            total_reward += reward
            if done:
                break
        return obs, total_reward, done, info


class GrayScaleObservation(gym.ObservationWrapper):
    def __init__(self, env):
        super().__init__(env)
        self.observation_space = Box(
            low=0, high=255, shape=self.observation_space.shape[:2], dtype=np.uint8)

    def observation(self, observation):
        transform = transforms.Grayscale()
        result = transform(torch.tensor(np.transpose(
            observation, (2, 0, 1)).copy(), dtype=torch.float))
        return result


class ResizeObservation(gym.ObservationWrapper):
    def __init__(self, env, shape):
        super().__init__(env)
        self.shape = (shape, shape)
        obs_shape = self.shape + self.observation_space.shape[2:]
        self.observation_space = Box(low=0, high=255, shape=obs_shape, dtype=np.uint8)

    def observation(self, observation):
        transformations = transforms.Compose(
            [transforms.Resize(self.shape), transforms.Normalize(0, 255)])
        return transformations(observation).squeeze(0)
    
# removes the top part of the image
class CropImage(gym.ObservationWrapper):
      def __init__(self, env):
          super().__init__(env)
      def observation(self, observation):
          return observation[20:, :]
    
def setup_environment(actions=RIGHT_AND_JUMP, skip=4, second=False):
    if second:
        env = gym.make("SuperMarioBros2-v0")
    else:
        env = gym_super_mario_bros.make('SuperMarioBros-v0')
    env = JoypadSpace(env, actions)
    env = FrameStack(
            CropImage(
                ResizeObservation(
                    GrayScaleObservation(
                        SkipFrame(env, skip)),
                                       shape=120),),
                                            num_stack=4)
    env.seed(42)
    env.action_space.seed(42)
    
    return env

In [79]:
import copy
import os
import random
from collections import deque
from os.path import exists

import matplotlib.pyplot as plt
import numpy as np
import torch

from gym.wrappers import *
from torch import nn
from torch.distributions import *


torch.manual_seed(42)
torch.random.manual_seed(42)
np.random.seed(42)


class DDQNSolver(nn.Module):
    def __init__(self, output_dim):
        super().__init__()
        self.online = nn.Sequential(
            nn.Conv2d(in_channels=4, out_channels=32, kernel_size=8, stride=4),
            nn.ReLU(),
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(6336, 512),
            nn.ReLU(),
            nn.Linear(512, output_dim),
        )

        self.target = copy.deepcopy(self.online)
        for p in self.target.parameters():
            p.requires_grad = False

    def forward(self, input, model):
        if input.ndim == 5:
            input = input.squeeze()
        if input.ndim == 3:
            input = input.unsqueeze(0)
            
        input = input.float()
        return self.online(input) if model == "online" else self.target(input)


class DDQNAgent:
    def __init__(self, action_dim, save_directory):
        self.action_dim = action_dim
        self.save_directory = save_directory
        self.net = DDQNSolver(self.action_dim).cuda()
        self.exploration_rate = 1.0
        self.exploration_rate_decay = 0.999
        self.exploration_rate_min = 0.01
        self.current_step = 0
        self.maxlen_memory = 70000
        self.memory = deque(maxlen=self.maxlen_memory)
        self.batch_size = 64
        self.gamma = 0.95
        self.sync_period = 1e4
        self.optimizer = torch.optim.Adam(self.net.parameters(), lr=0.00025, eps=1e-4)
        self.loss = torch.nn.SmoothL1Loss()
        self.episode_rewards = []
        self.moving_average_episode_rewards = []
        self.current_episode_reward = 0.0

    def log_episode(self):
        self.episode_rewards.append(self.current_episode_reward)
        self.current_episode_reward = 0.0

    def log_period(self, episode, epsilon, step, checkpoint_period):
        self.moving_average_episode_rewards.append(np.round(
            np.mean(self.episode_rewards[-checkpoint_period:]), 3))
        print(f"Episode {episode} - Step {step} - Epsilon {epsilon} "
              f"- Mean Reward {self.moving_average_episode_rewards[-1]}")
        plt.plot(self.moving_average_episode_rewards)
        filename = os.path.join(self.save_directory, "episode_rewards_plot.png")
        if exists(filename):
            plt.savefig(filename, format="png")
        with open(filename, "w"):
            plt.savefig(filename, format="png")
        plt.clf()

    def load_checkpoint(self, path):
        checkpoint = torch.load(path)
        self.net.load_state_dict(checkpoint['model'])
        self.exploration_rate = checkpoint['exploration_rate']

    def save_checkpoint(self):
        filename = os.path.join(self.save_directory, 'checkpoint.pth')
        torch.save(dict(model=self.net.state_dict(), exploration_rate=self.exploration_rate), f=filename)
        print('Checkpoint saved to \'{}\''.format(filename))

    def remember(self, state, next_state, action, reward, done):
        self.memory.append((torch.tensor(state.__array__()), torch.tensor(next_state.__array__()),
                            torch.tensor([action]), torch.tensor([reward]), torch.tensor([done])))

    def experience_replay(self, step_reward):
        self.current_episode_reward += step_reward
        if (self.current_step % self.sync_period) == 0:
            self.net.target.load_state_dict(self.net.online.state_dict())

        if len(self.memory) < self.batch_size:
            return

        state, next_state, action, reward, done = self.recall()
        q_estimate = self.net(state.cuda(), model="online")[np.arange(0, self.batch_size), action.cuda()]
        with torch.no_grad():
            best_action = torch.argmax(self.net(next_state.cuda(), model="online"), dim=1)
            next_q = self.net(next_state.cuda(), model="target")[np.arange(0, self.batch_size), best_action]
            q_target = (reward.cuda() + (1 - done.cuda().float()) * self.gamma * next_q).float()
        loss = self.loss(q_estimate, q_target)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

    def recall(self):
        state, next_state, action, reward, done = map(torch.stack,
                                                      zip(*random.sample(self.memory, self.batch_size)))

        return state, next_state, action.squeeze(), reward.squeeze(), done.squeeze()

    def act(self, state):
        if np.random.rand() < self.exploration_rate:
            action = np.random.randint(self.action_dim)
        else:
            action_values = self.net(torch.tensor(state.__array__()).cuda().unsqueeze(0), model="online")
            action = torch.argmax(action_values, dim=1).item()
        self.exploration_rate *= self.exploration_rate_decay
        self.exploration_rate = max(self.exploration_rate_min, self.exploration_rate)
        self.current_step += 1
        return action

In [80]:
def train():

    env = setup_environment(actions=SIMPLE_MOVEMENT, skip=2)
    episode = 0
    checkpoint_period = 50
    save_directory = ''
    load_checkpoint = None

    agent = DDQNAgent(action_dim=env.action_space.n, save_directory=save_directory)
    if load_checkpoint is not None and exists(save_directory + "/" + load_checkpoint):
        agent.load_checkpoint(save_directory + "/" + load_checkpoint)

    num_episodes = 40000
    for e in range(num_episodes):
        state = env.reset()
        done = False
        reward_per_episode = 0
        while not done:  # what happens during every episode

            action = agent.act(state)

           # if episode >= 20000:
               # env.render()

            next_state, reward, done, info = env.step(action)
            
            if info["x_pos"] > 190:
                agent.remember(state, next_state, action, reward, done)
            agent.experience_replay(reward)

            state = next_state
            reward_per_episode += reward

            if done:

                agent.log_episode()
                episode += 1

                if episode % checkpoint_period == 0:
                    agent.save_checkpoint()
                    agent.log_period(episode, agent.exploration_rate, agent.current_step, checkpoint_period)

In [81]:
train()

KeyboardInterrupt: 