In [1]:
import torch
from torch import nn
from torchvision import transforms as T
from PIL import Image
import numpy as np
from pathlib import Path
from collections import deque
import random, datetime, os, copy
%reload_ext autoreload
%autoreload 

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [3]:
import gym
from gym.spaces import Box
from gym.wrappers import FrameStack

In [4]:
from nes_py.wrappers import JoypadSpace

In [5]:
import gym_super_mario_bros
import gym_super_mario_bros.actions as JoypadActions

from lib.env_wrappers import EnvWrapperFactory

stateShape = (50, 50)
actionShape = len(JoypadActions.SIMPLE_MOVEMENT)

In [6]:
%reload_ext autoreload
%autoreload 
env = gym_super_mario_bros.make('SuperMarioBros-v0')
env = JoypadSpace(env, JoypadActions.SIMPLE_MOVEMENT)
env = EnvWrapperFactory.convert(env, shape=stateShape)
state = env.reset()
print(state.shape)

shape before any transformations: (240, 256, 3)
shape after grayscaler: (240, 256)
shape after resizer: (50, 50)
shape after all transformations: (5, 50, 50)
(5, 50, 50)


In [7]:
from agents.ForgetfulAgent import ForgetfulAgent
from lib.MetricLogger import MetricLogger
%reload_ext autoreload
%autoreload 

In [8]:
env.observation_space.shape

(5, 50, 50)

In [10]:
%reload_ext autoreload
%autoreload 
agent = ForgetfulAgent(state_shape=env.observation_space.shape, action_shape=actionShape, device=device)
save_dir = Path("logs") / datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
save_dir.mkdir(parents=True)
logger = MetricLogger(save_dir)

In [12]:
# try training

maxStepsPerEpisode = 10000
learnCount = 0
episodes = 100

for eps in range(episodes):
    state = env.reset()
    print(f"starting episode: {eps}")
    for i in range(maxStepsPerEpisode):

        action = agent.getAction(state)
        next_state, reward, done, info = env.step(env.action_space.sample())

        # add to memory
        agent.cache(state, next_state, action, reward, done)

        q, loss = agent.learn()


        if q is not None:
            learnCount += 1
            if learnCount % 1000 == 0:
                print(f"q:{q}, loss={loss}")
        
        if done or info["flag_get"]:
            break
    logger.log_episode()
    if eps % 10 == 0:
        logger.record(episode=eps, epsilon=agent.exploration_rate, step=agent.current_step)
        

starting episode: 0
q:0.033822257071733475, loss=0.5003948211669922
q:0.9758144617080688, loss=0.7996938228607178
q:0.5458636283874512, loss=0.49854713678359985
q:0.04848060756921768, loss=0.506773054599762
q:-0.07975272834300995, loss=0.4988541603088379
q:0.023601992055773735, loss=0.43127378821372986
q:0.015494290739297867, loss=0.2425265908241272
q:0.05393262952566147, loss=0.45383840799331665
q:0.048948850482702255, loss=0.28938645124435425
q:0.10607253015041351, loss=0.5580621957778931
q:0.11869934946298599, loss=0.3423507809638977
q:0.13250091671943665, loss=0.395752489566803
q:0.22253137826919556, loss=0.39205414056777954
q:0.21984729170799255, loss=0.38969722390174866
q:0.28037193417549133, loss=0.48115724325180054
q:0.2815655469894409, loss=0.5894237756729126
q:0.31347864866256714, loss=0.37424594163894653
q:0.3499940037727356, loss=0.23588182032108307
q:0.35263246297836304, loss=0.6450822353363037
q:0.35153335332870483, loss=0.6975681781768799
q:0.39159131050109863, loss=0.38

KeyboardInterrupt: 

<Figure size 432x288 with 0 Axes>

In [None]:
states, next_states, actions, rewards, dones = agent.sampleExperienceBatch()

In [None]:
states.shape

In [None]:
actions.shape

In [None]:
rewards.shape

In [None]:
dones.shape

In [None]:
print(rewards)