In [1]:
import torch
from torch import nn
from torchvision import transforms as T
from PIL import Image
import numpy as np
from pathlib import Path
from collections import deque
import random, datetime, os, copy
%reload_ext autoreload
%autoreload 

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [3]:
import gym
from gym.spaces import Box
from gym.wrappers import FrameStack

In [4]:
from nes_py.wrappers import JoypadSpace

In [5]:
import gym_super_mario_bros
import gym_super_mario_bros.actions as JoypadActions

from lib.env_wrappers import EnvWrapperFactory

stateShape = (50, 50)
actionShape = len(JoypadActions.SIMPLE_MOVEMENT)

In [6]:
%reload_ext autoreload
%autoreload 
env = gym_super_mario_bros.make('SuperMarioBros-v0')
env = JoypadSpace(env, JoypadActions.SIMPLE_MOVEMENT)
env = EnvWrapperFactory.convert(env, shape=stateShape)
state = env.reset()
print(state.shape)

shape before any transformations: (240, 256, 3)
shape after grayscaler: (240, 256)
shape after resizer: (50, 50)
shape after all transformations: (5, 50, 50)
(5, 50, 50)


In [7]:
from agents.ForgetfulAgent import ForgetfulAgent
from lib.MetricLogger import MetricLogger
%reload_ext autoreload
%autoreload 

In [8]:
env.observation_space.shape

(5, 50, 50)

In [9]:
%reload_ext autoreload
%autoreload 
agent = ForgetfulAgent(state_shape=env.observation_space.shape, action_shape=actionShape, device=device)
save_dir = Path("logs") / datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
save_dir.mkdir(parents=True)
logger = MetricLogger(save_dir)

In [10]:
# try training

maxStepsPerEpisode = 100000
learnCount = 0
episodes = 50

for eps in range(episodes):
    state = env.reset()
    print(f"starting episode: {eps}")
    for i in range(maxStepsPerEpisode):

        action = agent.getAction(state)
        next_state, reward, done, info = env.step(action)

        # add to memory
        agent.cache(state, next_state, action, reward, done)

        q, loss = agent.learn()
        
        logger.log_step(reward, loss, q)
        
        state = next_state

        if q is not None:
            learnCount += 1
            if learnCount % 1000 == 0:
                print(f"q:{q}, loss={loss}")
        
        if done or info["flag_get"]:
            break
            
    print(f"done: {done},\n info: {info}")
    logger.log_episode()
    if eps % 10 == 0:
        logger.record(episode=eps, epsilon=agent.exploration_rate, step=agent.current_step)
        

starting episode: 0
q:0.03959231823682785, loss=0.1873483657836914
done: True,
 info: {'coins': 0, 'flag_get': False, 'life': 255, 'score': 100, 'stage': 1, 'status': 'small', 'time': 392, 'world': 1, 'x_pos': 298, 'y_pos': 79}
Episode 0 - Step 8365 - Epsilon 0.9197527286676134 - Mean Reward 0.0 - Mean Length 0.0 - Mean Loss 0.0 - Mean Q Value 0.0 - Time Delta 36.249 - Time 2022-03-28T13:54:49
starting episode: 1
q:0.02794359251856804, loss=0.0770876407623291
q:0.026085305958986282, loss=0.11122407019138336
done: True,
 info: {'coins': 1, 'flag_get': False, 'life': 255, 'score': 300, 'stage': 1, 'status': 'small', 'time': 393, 'world': 1, 'x_pos': 302, 'y_pos': 79}
starting episode: 2
q:0.039665818214416504, loss=0.062245652079582214
done: True,
 info: {'coins': 0, 'flag_get': False, 'life': 255, 'score': 200, 'stage': 1, 'status': 'small', 'time': 393, 'world': 1, 'x_pos': 307, 'y_pos': 79}
starting episode: 3
q:0.0967075526714325, loss=0.1661219298839569
q:0.054130829870700836, loss=

KeyboardInterrupt: 

<Figure size 432x288 with 0 Axes>

In [11]:

env = gym_super_mario_bros.make('SuperMarioBros-v0')
env = JoypadSpace(env, JoypadActions.SIMPLE_MOVEMENT)
env = EnvWrapperFactory.convert(env, shape=stateShape)


done = True
for step in range(1000):
    if done:
        state = env.reset()
    state, reward, done, info = env.step(agent.getAction(state))
    env.render()
env.close()

shape before any transformations: (240, 256, 3)
shape after grayscaler: (240, 256)
shape after resizer: (50, 50)
shape after all transformations: (5, 50, 50)




In [None]:
def saveCheckpoint(epoch, model, optimizer):
    
    torch.save({
        "epoch": epoch,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "exploration_rate": agent.exploration_rate
        
    # }, f"{model.name}-checkpoint-{epoch}")
    }, f"CNN50x50-checkpoint-{epoch}")

In [None]:
# save_path = (
#             save_dir / f"mario_net_{int(agent.current_step // agent.onlinePeriod)}.chkpt"
#         )
# torch.save(
#     dict(model=agent.net.state_dict(), exploration_rate=agent.exploration_rate),
#     save_path,
# )
saveCheckpoint(1000, agent.net, agent.optimizer)