In [1]:
import torch
from torch import nn
from torchvision import transforms as T
from PIL import Image
import numpy as np
from pathlib import Path
from collections import deque
import random, datetime, os, copy
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [2]:
%reload_ext autoreload
%autoreload 2
import gym
from gym.spaces import Box
from gym.wrappers import FrameStack
from nes_py.wrappers import JoypadSpace
import gym_super_mario_bros
import gym_super_mario_bros.actions as JoypadActions

from lib.env_wrappers import EnvWrapperFactory
from agents.ForgetfulAgent import ForgetfulAgent
from lib.MetricLogger import MetricLogger


In [3]:
imageShape = (84, 84)
actionShape = len(JoypadActions.SIMPLE_MOVEMENT)

In [13]:
%reload_ext autoreload
%autoreload 2
env = gym_super_mario_bros.make('SuperMarioBros-v0')
env = JoypadSpace(env, JoypadActions.SIMPLE_MOVEMENT)
env = EnvWrapperFactory.convert(env, shape=imageShape)
state = env.reset()
print(state.shape)

shape before any transformations: (240, 256, 3)
shape after grayscaler: (240, 256)
shape after resizer: (84, 84)
shape after all transformations: (5, 84, 84)
(5, 84, 84)


In [14]:
env.observation_space.shape

(5, 84, 84)

In [15]:
%reload_ext autoreload
%autoreload 2
agent = ForgetfulAgent(state_shape=env.observation_space.shape, action_shape=actionShape, device=device, net_name="CNN84x84")
save_dir = Path("logs") / datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
save_dir.mkdir(parents=True)
logger = MetricLogger(save_dir)
print(agent.name)

ForgetfulAgent-CNN84x84


In [17]:
# try training

maxStepsPerEpisode = 20000
learnCount = 0
episodes = 20

for eps in range(episodes):
    state = env.reset()
    print(f"starting episode: {eps}")
    for i in range(maxStepsPerEpisode):

        action = agent.getAction(state)
        next_state, reward, done, info = env.step(action)

        # add to memory
        agent.cache(state, next_state, action, reward, done)

        q, loss = agent.learn()
        
        logger.log_step(reward, loss, q)
        
        state = next_state

        if q is not None:
            learnCount += 1
            if learnCount % 1000 == 0:
                print(f"q:{q}, loss={loss}")
        
        if done or info["flag_get"]:
            break
            
    # print(f"done: {done},\n info: {info}")
    logger.log_episode()
    if eps % 10 == 0:
        logger.record(episode=eps, epsilon=agent.exploration_rate, step=agent.current_step)
        

starting episode: 0
Episode 0 - Step 1878 - Epsilon 0.9990614404880913 - Mean Reward 2212.5 - Mean Length 938.5 - Mean Loss 1.244 - Mean Q Value 0.219 - Time Delta 27.577 - Time 2022-03-30T15:33:19
starting episode: 1
starting episode: 2
starting episode: 3
starting episode: 4
starting episode: 5
starting episode: 6
q:2.087984800338745, loss=1.0603065490722656
starting episode: 7
starting episode: 8
starting episode: 9
starting episode: 10
Episode 10 - Step 10963 - Epsilon 0.9945334946459722 - Mean Reward 1793.917 - Mean Length 913.5 - Mean Loss 1.583 - Mean Q Value 1.715 - Time Delta 88.265 - Time 2022-03-30T15:34:47
starting episode: 11
q:2.4226925373077393, loss=0.9642242193222046
starting episode: 12
starting episode: 13
starting episode: 14
starting episode: 15
q:3.2718420028686523, loss=0.9225887060165405
starting episode: 16
starting episode: 17
starting episode: 18
starting episode: 19


<Figure size 432x288 with 0 Axes>

In [18]:

env = gym_super_mario_bros.make('SuperMarioBros-v0')
env = JoypadSpace(env, JoypadActions.SIMPLE_MOVEMENT)
env = EnvWrapperFactory.convert(env, shape=imageShape)


done = True
for step in range(1000):
    if done:
        state = env.reset()
    state, reward, done, info = env.step(agent.getAction(state))
    env.render()
env.close()

shape before any transformations: (240, 256, 3)
shape after grayscaler: (240, 256)
shape after resizer: (84, 84)
shape after all transformations: (5, 84, 84)


In [22]:
def saveCheckpoint(name, epoch, model, optimizer):
    
    torch.save({
        "epoch": epoch,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "exploration_rate": agent.exploration_rate
        
    # }, f"{model.name}-checkpoint-{epoch}")
    }, f"{name}-checkpoint-{epoch}.pytorch")

In [23]:
# save_path = (
#             save_dir / f"mario_net_{int(agent.current_step // agent.onlinePeriod)}.chkpt"
#         )
# torch.save(
#     dict(model=agent.net.state_dict(), exploration_rate=agent.exploration_rate),
#     save_path,
# )
saveCheckpoint(agent.name, 20, agent.net, agent.optimizer)