In [1]:
import torch
from torch import nn
from torchvision import transforms as T
from PIL import Image
import numpy as np
from pathlib import Path
from collections import deque
import random, datetime, os, copy
%reload_ext autoreload
%autoreload 

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [3]:
import gym
from gym.spaces import Box
from gym.wrappers import FrameStack

In [4]:
from nes_py.wrappers import JoypadSpace

In [5]:
import gym_super_mario_bros
import gym_super_mario_bros.actions as JoypadActions

from lib.env_wrappers import EnvWrapperFactory

imageShape = (50, 50)
actionShape = len(JoypadActions.SIMPLE_MOVEMENT)

In [6]:
%reload_ext autoreload
%autoreload 
env = gym_super_mario_bros.make('SuperMarioBros-v0')
env = JoypadSpace(env, JoypadActions.SIMPLE_MOVEMENT)
env = EnvWrapperFactory.convert(env, shape=imageShape)
state = env.reset()
print(state.shape)

shape before any transformations: (240, 256, 3)
shape after grayscaler: (240, 256)
shape after resizer: (50, 50)
shape after all transformations: (5, 50, 50)
(5, 50, 50)


In [7]:
from agents.ForgetfulAgent import ForgetfulAgent
from lib.MetricLogger import MetricLogger
%reload_ext autoreload
%autoreload 

In [8]:
env.observation_space.shape

(5, 50, 50)

In [9]:
%reload_ext autoreload
%autoreload 
agent = ForgetfulAgent(state_shape=env.observation_space.shape, action_shape=actionShape, device=device)
save_dir = Path("logs") / datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
save_dir.mkdir(parents=True)
logger = MetricLogger(save_dir)

In [10]:
# try training

maxStepsPerEpisode = 100_000_000
learnCount = 0
episodes = 500

for eps in range(episodes):
    state = env.reset()
    print(f"starting episode: {eps}")
    for i in range(maxStepsPerEpisode):

        action = agent.getAction(state)
        next_state, reward, done, info = env.step(action)

        # add to memory
        agent.cache(state, next_state, action, reward, done)

        q, loss = agent.learn()
        
        logger.log_step(reward, loss, q)
        
        state = next_state

        if q is not None:
            learnCount += 1
            if learnCount % 1000 == 0:
                print(f"q:{q}, loss={loss}")
        
        if info["flag_get"]:
            print(f"reached a flag")
            print(info)
            
        if done or info["flag_get"]:
            break
            
    # print(f"done: {done},\n info: {info}")
    logger.log_episode()
    if eps % 1 == 0:
        logger.record(episode=eps, epsilon=agent.exploration_rate, step=agent.current_step)
        print(info)
        

starting episode: 0
Episode 0 - Step 1817 - Epsilon 0.9990919123343348 - Mean Reward 1808.0 - Mean Length 1817.0 - Mean Loss 1.82 - Mean Q Value 0.27 - Time Delta 17.049 - Time 2022-03-30T16:27:57
{'coins': 2, 'flag_get': False, 'life': 255, 'score': 600, 'stage': 1, 'status': 'small', 'time': 221, 'world': 1, 'x_pos': 832, 'y_pos': 84}
starting episode: 1
Episode 1 - Step 2423 - Epsilon 0.9987892332674471 - Mean Reward 1494.5 - Mean Length 1211.5 - Mean Loss 1.612 - Mean Q Value 0.563 - Time Delta 6.13 - Time 2022-03-30T16:28:03
{'coins': 2, 'flag_get': False, 'life': 255, 'score': 700, 'stage': 1, 'status': 'small', 'time': 299, 'world': 1, 'x_pos': 898, 'y_pos': 79}
starting episode: 2
Episode 2 - Step 2569 - Epsilon 0.9987163242964078 - Mean Reward 1214.667 - Mean Length 856.333 - Mean Loss 1.495 - Mean Q Value 0.797 - Time Delta 1.682 - Time 2022-03-30T16:28:05
{'coins': 1, 'flag_get': False, 'life': 255, 'score': 200, 'stage': 1, 'status': 'small', 'time': 389, 'world': 1, 'x_pos

KeyboardInterrupt: 

<Figure size 432x288 with 0 Axes>

In [None]:

env = gym_super_mario_bros.make('SuperMarioBros-v0')
env = JoypadSpace(env, JoypadActions.SIMPLE_MOVEMENT)
env = EnvWrapperFactory.convert(env, shape=imageShape)


done = True
count = 0 
for step in range(100000):
    if done:
        count += 1
        if count > 2:
            break
        state = env.reset()
    state, reward, done, info = env.step(agent.getAction(state))
    env.render()
env.close()

In [None]:
def saveCheckpoint(name, epoch, model, optimizer):
    
    torch.save({
        "epoch": epoch,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "exploration_rate": agent.exploration_rate
        
    # }, f"{model.name}-checkpoint-{epoch}")
    }, f"{name}-checkpoint-{epoch}.pytorch")

In [None]:
# save_path = (
#             save_dir / f"mario_net_{int(agent.current_step // agent.onlinePeriod)}.chkpt"
#         )
# torch.save(
#     dict(model=agent.net.state_dict(), exploration_rate=agent.exploration_rate),
#     save_path,
# )
saveCheckpoint("ForgetfulAgent-CNN50x50", 500, agent.net, agent.optimizer)

In [None]:
agent.save(dir="", epoch=501)