In [12]:
import gym
from torchvision.transforms import Grayscale, Resize, Normalize
import gym_super_mario_bros
from gym_super_mario_bros.actions import SIMPLE_MOVEMENT
from nes_py.wrappers import JoypadSpace
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import plotly.graph_objects as go
import pandas as pd
import imageio
import os

# Shape of input images = (240, 256, 3) (height, width, channels) in SMB
PROCESSED_IMG_SIZE = (84, 84)
# How many frames to stack together per observation in SMB
NUM_STACKED_FRAMES = 4
# Alpha step size
LEARNING_RATE = 3e-3
# Discount factor
GAMMA = 0.99
# How many episodes to train for
NUM_EPISODES = 1000
# Capture a video of the agent playing every n episodes
CAPTURE_EVERY_N_EPISODES = 250

# Observe the last of the past 4 frames to speed up training
class SkipFrames(gym.Wrapper):
    def __init__(self, env):
        super().__init__(env)

    def step(self, action):
        total_reward = 0
        done = False
        for _ in range(NUM_STACKED_FRAMES):
            state, reward, done, info = self.env.step(action)
            total_reward += reward
            if done:
                break
        return state, total_reward, done, info


# Preprocess images to greyscale & resize to 84x84
class ImagePreprocessing(gym.ObservationWrapper):
    def __init__(self, env):
        super().__init__(env)
        obs_shape = self.observation_space.shape[:2]
        self.observation_space = gym.spaces.Box(
            low=0, high=255, shape=obs_shape, dtype=np.uint8
        )

    def observation(self, image):
        # image shape = (240, 256, 3)
        image = np.moveaxis(image, [2], [0])
        # image shape = (3, 240, 256)
        image = torch.tensor(image.copy(), dtype=torch.float)
        # convert 3 RGB channels to grayscale
        image = Grayscale()(image)
        # convert 240x256 to 84x84
        image = Resize(PROCESSED_IMG_SIZE)(image)
        # convert [0, 1] to [0, 255]
        image = Normalize(0, 255)(image)
        return image


# A CNN connected to a fully connected layer
# The input dimensions are 1x84x84, with a flat output dimension of NUM_ACTIONS
class SuperMarioCNN(nn.Module):
    def __init__(self, input_channels, output_dim):
        super(SuperMarioCNN, self).__init__()
        self.conv1 = nn.Conv2d(input_channels, 16, kernel_size=4, stride=2)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=4, stride=2)
        self.conv3 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
        self.fc1 = nn.Linear(64 * 64, 256)
        self.fc2 = nn.Linear(256, output_dim)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.relu(self.conv2(x))
        x = torch.relu(self.conv3(x))
        x = torch.flatten(x)
        x = torch.relu(self.fc1(x))
        x = torch.softmax(self.fc2(x), dim=-1)
        return x


# A simple MLP with 2 hidden layers
class AtariNN(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dim=32):
        super(AtariNN, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = torch.from_numpy(x).float().unsqueeze(0)
        x = torch.relu(self.fc1(x))
        x = torch.softmax(self.fc2(x), dim=1)
        return x


# REINFORCE Monte Carlo
class REINFORCE:
    # env: the environment to train on
    # env_type: the type of environment (atari or mario)
    def __init__(self, env, env_type, learning_rate):
        self.env = env
        self.num_actions = env.action_space.n
        self.num_states = env.observation_space.shape[0]
        self.actions_list = np.arange(self.num_actions)
        if env_type == "atari":
            self.policy = AtariNN(self.num_states, self.num_actions)
        else:
            self.policy = SuperMarioCNN(1, self.num_actions)
        self.optimizer = optim.Adam(self.policy.parameters(), lr=learning_rate)

    # Sample an action from the policy, along with the lob probability of that action
    def choose_action(self, state):
        probabilities = self.policy(state)
        dist = torch.distributions.Categorical(probabilities)
        sample = dist.sample()
        return sample.item(), dist.log_prob(sample)

    def update_policy(self, rewards, log_probs):
        # Calculate discounted rewards
        discounted_rewards = []
        G = 0
        for r in reversed(rewards):
            G = r + GAMMA * G
            discounted_rewards.insert(0, G)

        # Calculate the loss
        policy_loss = []
        for log_prob, G in zip(log_probs, discounted_rewards):
            policy_loss.append(-log_prob * G)
        policy_loss = torch.stack(policy_loss).sum()

        # Update the policy with respect to the loss
        self.optimizer.zero_grad()
        policy_loss.backward()
        self.optimizer.step()

    def train(self, episodes, capture=False):
        total_rewards = []
        for episode in range(episodes):
            state = self.env.reset()
            rewards = []
            log_probs = []
            images = []
            done = False
            # Generate episode of experience
            while not done:
                action, log_prob = self.choose_action(state)
                state, reward, done, _ = self.env.step(action)
                # Capture gifs of the agent's progress
                if capture and episode % CAPTURE_EVERY_N_EPISODES == 0 or episode == episodes - 1:
                    img = env.render(mode="rgb_array")
                    images.append(img)
                log_probs.append(log_prob)
                rewards.append(reward)

            # Use the episode to update the policy
            self.update_policy(rewards, log_probs)

            if episode % 100 == 0:
                print(f"Episode {episode}: Total reward: {sum(rewards)}")
            # Save images if captured
            if images:
                cwd = os.getcwd()
                target_dir = os.path.join(cwd, "videos", f"episode_{episode}.gif")
                imageio.mimsave(
                    target_dir, [np.array(img) for i, img in enumerate(images)], fps=30
                )

            total_rewards.append(sum(rewards))
        return total_rewards


def create_mario_env():
    env = gym_super_mario_bros.make("SuperMarioBros-1-1-v0")
    env = JoypadSpace(env, SIMPLE_MOVEMENT)
    env = ImagePreprocessing(env)
    env = SkipFrames(env)
    return env


def create_cartpole_env():
    env = gym.make("CartPole-v1")
    return env


def plot_rewards(rewards):
    num_episodes = [i+1 for i in range(NUM_EPISODES)]
    fig = go.Figure()

    for results, lr in rewards:
        df = pd.DataFrame(results, columns=["rewards"])
        fig.add_trace(
            go.Scatter(
                x=num_episodes,
                y=df["rewards"],
                name=f"{lr}",
            )
        )

    fig.update_layout(
        title="REINFORCE - CartPole",
        xaxis_title="Episodes",
        yaxis_title="Reward",
        legend_title="Legend",
        font=dict(family="Courier New, monospace", size=18, color="RebeccaPurple"),
    )

    fig.show()


# env = create_mario_env()
# agent = REINFORCE(env, "mario")

env = create_cartpole_env()

lr_high = 0.003
lr_med = 0.002
lr_low = 0.001

rewards_lr_high = REINFORCE(env, "atari", learning_rate=lr_high).train(episodes=NUM_EPISODES)
rewards_lr_med = REINFORCE(env, "atari", learning_rate=lr_med).train(episodes=NUM_EPISODES)
rewards_lr_low = REINFORCE(env, "atari", learning_rate=lr_low).train(episodes=NUM_EPISODES)

results = [
    (rewards_lr_high, lr_high),
    (rewards_lr_med, lr_med),
    (rewards_lr_low, lr_low),
]

plot_rewards(results)


Episode 0: Total reward: 15.0
Episode 100: Total reward: 16.0
Episode 200: Total reward: 24.0
Episode 300: Total reward: 482.0
Episode 400: Total reward: 500.0
Episode 500: Total reward: 128.0
Episode 600: Total reward: 500.0
Episode 700: Total reward: 500.0
Episode 800: Total reward: 500.0
Episode 900: Total reward: 500.0
Episode 0: Total reward: 16.0
Episode 100: Total reward: 28.0
Episode 200: Total reward: 49.0
Episode 300: Total reward: 204.0
Episode 400: Total reward: 261.0
Episode 500: Total reward: 161.0
Episode 600: Total reward: 227.0
Episode 700: Total reward: 500.0
Episode 800: Total reward: 500.0
Episode 900: Total reward: 500.0
Episode 0: Total reward: 10.0
Episode 100: Total reward: 27.0
Episode 200: Total reward: 28.0
Episode 300: Total reward: 26.0
Episode 400: Total reward: 53.0
Episode 500: Total reward: 204.0
Episode 600: Total reward: 91.0
Episode 700: Total reward: 106.0
Episode 800: Total reward: 429.0
Episode 900: Total reward: 421.0
