In [1]:
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
from gym.wrappers import FrameStack
from torchvision import transforms
import gym_super_mario_bros
import numpy as np
import torch
import torch.nn as nn
from nes_py.wrappers import JoypadSpace
from torch.distributions import Categorical
from gym.spaces import Box
import matplotlib.pyplot as plt
import gym

In [2]:
class SkipFrame(gym.Wrapper):
    def __init__(self, env, skip):
        super().__init__(env)
        self._skip = skip

    def step(self, action):
        total_reward = 0.0
        done = False
        for i in range(self._skip):
            obs, reward, done, info = self.env.step(action)
            total_reward += reward
            if done:
                break
        return obs, total_reward, done, info


class GrayScaleObservation(gym.ObservationWrapper):
    def __init__(self, env):
        super().__init__(env)
        self.observation_space = Box(low=0, high=255, shape=self.observation_space.shape[:2], dtype=np.uint8)

    def observation(self, observation):
        transform = transforms.Grayscale()
        return transform(torch.tensor(np.transpose(observation, (2, 0, 1)).copy(), dtype=torch.float))


class ResizeObservation(gym.ObservationWrapper):
    def __init__(self, env, shape):
        super().__init__(env)
        self.shape = (shape, shape)
        obs_shape = self.shape + self.observation_space.shape[2:]
        self.observation_space = Box(low=0, high=255, shape=obs_shape, dtype=np.uint8)

    def observation(self, observation):
        transformations = transforms.Compose([transforms.Resize(self.shape), transforms.Normalize(0, 255)])
        return transformations(observation).squeeze(0)

In [3]:
env = gym_super_mario_bros.make('SuperMarioBros-1-1-v0')
env = JoypadSpace(env, [["right"], ["right", "A"]])
env = FrameStack(ResizeObservation(GrayScaleObservation(SkipFrame(env, skip=4)), shape=84), num_stack=4)
env = gym.wrappers.RecordVideo(env, 'D:/COMP_Topics_i_AI/ProjectB/pg/video', episode_trigger = lambda x: True)
env.seed(42)
env.action_space.seed(42)
torch.manual_seed(42)
torch.random.manual_seed(42)
np.random.seed(42)

  logger.warn(


In [4]:
class MarioSolver:
    def __init__(self, learning_rate):
        self.model = nn.Sequential(
            nn.Conv2d(in_channels=4, out_channels=32, kernel_size=8, stride=4),
            nn.ReLU(),
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(3136, 512),
            nn.ReLU(),
            nn.Linear(512, env.action_space.n),
            nn.Softmax(dim=-1)
        ).cuda()
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=learning_rate, eps=1e-4)
        self.reset()

    def forward(self, x):
        return self.model(x)

    def reset(self):
        self.episode_actions = torch.tensor([], requires_grad=True).cuda()
        self.episode_rewards = []

    def save_checkpoint(self, directory, episode):
        if not os.path.exists(directory):
            os.makedirs(directory)
        filename = os.path.join(directory, 'checkpoint_{}.pth'.format(episode))
        torch.save(self.model.state_dict(), f=filename)
        print('Checkpoint saved to \'{}\''.format(filename))

    def load_checkpoint(self, directory, filename):
        self.model.load_state_dict(torch.load(os.path.join(directory, filename)))
        print('Resuming training from checkpoint \'{}\'.'.format(filename))
        return int(filename[11:-4])

    def backward(self):
        future_reward = 0
        rewards = []
        for r in self.episode_rewards[::-1]:
            future_reward = r + gamma * future_reward
            rewards.append(future_reward)
        rewards = torch.tensor(rewards[::-1], dtype=torch.float32).cuda()
        rewards = (rewards - rewards.mean()) / (rewards.std() + np.finfo(np.float32).eps)
        loss = torch.sum(torch.mul(self.episode_actions, rewards).mul(-1))
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        self.reset()

In [5]:
batch_size = 10
gamma = 0.95
load_filename = 'checkpoint_2000.pth'
save_directory = 'D:/COMP_Topics_i_AI/ProjectB/pg'
batch_rewards = []
episode = 0

In [6]:
model = MarioSolver(learning_rate=0.00025)
if load_filename is not None:
    episode = model.load_checkpoint(save_directory, load_filename)
    print("loaded")
all_episode_rewards = []
all_mean_rewards = []
while True:
    observation = env.reset()
    done = False
    while not done:
        env.render()
        observation = torch.tensor(observation.__array__()).cuda().unsqueeze(0)
        distribution = Categorical(model.forward(observation))
        action = distribution.sample()
        observation, reward, done, _ = env.step(action.item())
        model.episode_actions = torch.cat([model.episode_actions, distribution.log_prob(action).reshape(1)])
        model.episode_rewards.append(reward)
        if done:
            all_episode_rewards.append(np.sum(model.episode_rewards))
            batch_rewards.append(np.sum(model.episode_rewards))
            model.backward()
            episode += 1
            if episode % batch_size == 0:
                print('Batch: {}, average reward: {}'.format(episode // batch_size, np.array(batch_rewards).mean()))
                batch_rewards = []
                all_mean_rewards.append(np.mean(all_episode_rewards[-batch_size:]))
                if episode % 500 == 0:
                    plt.plot(all_mean_rewards)
                    plt.savefig("{}/mean_reward_{}.png".format(save_directory, episode))
                    plt.clf()
            if episode % 1000 == 0 and save_directory is not None:
                model.save_checkpoint(save_directory, episode)

Resuming training from checkpoint 'checkpoint_2000.pth'.
loaded


  return (self.ram[0x86] - self.ram[0x071c]) % 256


Batch: 201, average reward: 626.9
Batch: 202, average reward: 709.2
Batch: 203, average reward: 709.7
Batch: 204, average reward: 737.2
Batch: 205, average reward: 701.0
Batch: 206, average reward: 784.0
Batch: 207, average reward: 633.5
Batch: 208, average reward: 667.9
Batch: 209, average reward: 663.3
Batch: 210, average reward: 594.5
Batch: 211, average reward: 698.2
Batch: 212, average reward: 741.5
Batch: 213, average reward: 674.8
Batch: 214, average reward: 482.9
Batch: 215, average reward: 765.4
Batch: 216, average reward: 488.7
Batch: 217, average reward: 785.0
Batch: 218, average reward: 748.4
Batch: 219, average reward: 704.1
Batch: 220, average reward: 669.9
Batch: 221, average reward: 864.0
Batch: 222, average reward: 580.9
Batch: 223, average reward: 848.4
Batch: 224, average reward: 997.9
Batch: 225, average reward: 715.2
Batch: 226, average reward: 641.5
Batch: 227, average reward: 764.8
Batch: 228, average reward: 688.9
Batch: 229, average reward: 762.9
Batch: 230, av

KeyboardInterrupt: 

<Figure size 640x480 with 0 Axes>