In [None]:
import gym
import torch
import numpy as np
import gym_super_mario_bros
import random, datetime, os, copy

from torch import nn
from torchvision import transforms as T
from PIL import Image
from pathlib import Path
from collections import deque
from gym.spaces import Box
from gym.wrappers import FrameStack, Monitor
from gym_wrappers import SkipFrame, GrayScaleObs, ResizeObs
from nes_py.wrappers import JoypadSpace
from actor import Mario
from model import DDQN
from logger import MetricLogger

In [None]:
env = gym_super_mario_bros.make("SuperMarioBros-1-1-v3")
env = JoypadSpace(env, [["right"], ["right", "A"]])

env = SkipFrame(env, skip=4)
env = GrayScaleObs(env)
env = ResizeObs(env, shape=84)
env = FrameStack(env, num_stack=4)

In [None]:
env = Monitor(env, "./gym-results", force=True)

In [None]:
!cp checkpoints/2022-02-22T08-25-58/mario_net_18.chkpt model.chkpt

In [None]:
chekpoint = torch.load("checkpoints/2022-02-22T08-25-58/mario_net_18.chkpt")
state_dict = chekpoint['model']
expl_rate = chekpoint['exploration_rate']

model = DDQN((4, 84, 84), env.action_space.n)
model.load_state_dict(state_dict)
model.exploration_rate = expl_rate
model = model.to(device='cuda')

for e in range(10):
    state = env.reset()

    # Play the game!
    while True:
        state = torch.tensor(state.__array__()).unsqueeze(0).cuda()
        action = torch.argmax(model(state)).item()

        next_state, reward, done, info = env.step(action)
        state = next_state
        env.render()

        # Check if end of game
        if done or info["flag_get"]:
            break
            
env.close()

In [None]:
model = DDQN((4, 84, 84), env.action_space.n)

In [None]:
save_dir = Path("checkpoints") / datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
save_dir.mkdir(parents=True)

mario = Mario(model, env.action_space.n, save_dir)

logger = MetricLogger(save_dir)

episodes = 80_000
for e in range(episodes):
    state = env.reset()

    # Play the game!
    while True:

        # Run agent on the state
        action = mario.act(state)

        # Agent performs action
        next_state, reward, done, info = env.step(action)

        # Remember
        mario.cache(state, next_state, action, reward, done)

        # Learn
        q, loss = mario.learn()

        # Logging
        logger.log_step(reward, loss, q)

        # Update state
        state = next_state

        # Check if end of game
        if done or info["flag_get"]:
            break

    logger.log_episode()

    if e % 20 == 0:
        logger.record(episode=e+1, epsilon=mario.exploration_rate, step=mario.current_step)

In [None]:
import time

time.sleep(180*60)

In [None]:
!python3 train.py