In [1]:
import torch
from torch import nn
from torchvision import transforms as T
from PIL import Image
import numpy as np
from pathlib import Path
from collections import deque
import random, datetime, os, copy
%reload_ext autoreload
%autoreload 

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [3]:
import gym
from gym.spaces import Box
from gym.wrappers import FrameStack

In [4]:
from nes_py.wrappers import JoypadSpace

In [5]:
import gym_super_mario_bros
import gym_super_mario_bros.actions as JoypadActions

from lib.env_wrappers import EnvWrapperFactory

stateShape = (50, 50)
actionShape = len(JoypadActions.SIMPLE_MOVEMENT)

In [6]:
%reload_ext autoreload
%autoreload 
env = gym_super_mario_bros.make('SuperMarioBros-v0')
env = JoypadSpace(env, JoypadActions.SIMPLE_MOVEMENT)
env = EnvWrapperFactory.convert(env, shape=stateShape)
state = env.reset()
print(state.shape)

shape before any transformations: (240, 256, 3)
shape after grayscaler: (240, 256)
shape after resizer: (50, 50)
shape after all transformations: (5, 50, 50)
(5, 50, 50)


In [7]:
from agents.ForgetfulAgent import ForgetfulAgent
from lib.MetricLogger import MetricLogger
%reload_ext autoreload
%autoreload 

In [8]:
env.observation_space.shape

(5, 50, 50)

In [9]:
%reload_ext autoreload
%autoreload 
agent = ForgetfulAgent(state_shape=env.observation_space.shape, action_shape=actionShape, device=device)
save_dir = Path("logs") / datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
save_dir.mkdir(parents=True)
logger = MetricLogger(save_dir)

In [11]:
# try training

maxStepsPerEpisode = 100000
learnCount = 0
episodes = 1000

for eps in range(episodes):
    state = env.reset()
    print(f"starting episode: {eps}")
    for i in range(maxStepsPerEpisode):

        action = agent.getAction(state)
        next_state, reward, done, info = env.step(action)

        # add to memory
        agent.cache(state, next_state, action, reward, done)

        q, loss = agent.learn()
        
        logger.log_step(reward, loss, q)
        
        state = next_state

        if q is not None:
            learnCount += 1
            if learnCount % 1000 == 0:
                print(f"q:{q}, loss={loss}")
        
        if done or info["flag_get"]:
            break
            
    print(f"done: {done},\n info: {info}")
    logger.log_episode()
    if eps % 10 == 0:
        logger.record(episode=eps, epsilon=agent.exploration_rate, step=agent.current_step)
        

starting episode: 0
q:0.5180431604385376, loss=0.11596138030290604
q:1.4825770854949951, loss=0.06576018035411835
q:0.8232667446136475, loss=0.09756126999855042
q:0.2698393762111664, loss=0.0666079893708229
done: True,
 info: {'coins': 0, 'flag_get': False, 'life': 255, 'score': 300, 'stage': 1, 'status': 'small', 'time': 0, 'world': 1, 'x_pos': 594, 'y_pos': 79}
Episode 0 - Step 52385 - Epsilon 0.9741475417787012 - Mean Reward 759.0 - Mean Length 26192.0 - Mean Loss 0.087 - Mean Q Value 0.332 - Time Delta 165.256 - Time 2022-03-28T19:20:47
starting episode: 1
q:1.246760606765747, loss=0.10371768474578857
q:0.8174723982810974, loss=0.07346370816230774
q:0.39501258730888367, loss=0.06678405404090881
q:0.2810438275337219, loss=0.13081733882427216
done: True,
 info: {'coins': 1, 'flag_get': False, 'life': 255, 'score': 200, 'stage': 1, 'status': 'small', 'time': 0, 'world': 1, 'x_pos': 594, 'y_pos': 79}
starting episode: 2
q:0.3032544255256653, loss=0.10625596344470978
q:0.527372658252716

<Figure size 432x288 with 0 Axes>

In [16]:

env = gym_super_mario_bros.make('SuperMarioBros-v0')
env = JoypadSpace(env, JoypadActions.SIMPLE_MOVEMENT)
env = EnvWrapperFactory.convert(env, shape=imageShape)


done = True
for step in range(1000):
    if done:
        state = env.reset()
    state, reward, done, info = env.step(agent.getAction(state))
    env.render()
env.close()

shape before any transformations: (240, 256, 3)
shape after grayscaler: (240, 256)
shape after resizer: (50, 50)
shape after all transformations: (5, 50, 50)


In [19]:
def saveCheckpoint(name, epoch, model, optimizer):
    
    torch.save({
        "epoch": epoch,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "exploration_rate": agent.exploration_rate
        
    # }, f"{model.name}-checkpoint-{epoch}")
    }, f"{name}-checkpoint-{epoch}")

In [18]:
# save_path = (
#             save_dir / f"mario_net_{int(agent.current_step // agent.onlinePeriod)}.chkpt"
#         )
# torch.save(
#     dict(model=agent.net.state_dict(), exploration_rate=agent.exploration_rate),
#     save_path,
# )
saveCheckpoint("ForgetfulAgent-CNN50x50", 1000, agent.net, agent.optimizer)