In [21]:
import torch
from torch import nn
from torchvision import transforms as T
from PIL import Image
import numpy as np
from pathlib import Path
from collections import deque
import random, datetime, os, copy
device = "cuda" if torch.cuda.is_available() else "cpu"
device

from time import sleep
from tqdm import tqdm

In [22]:
%reload_ext autoreload
%autoreload 2
import gym
from gym.spaces import Box
from gym.wrappers import FrameStack
from nes_py.wrappers import JoypadSpace
import gym_super_mario_bros
import gym_super_mario_bros.actions as JoypadActions

from lib.env_wrappers import EnvWrapperFactory
from agents.ForgetfulAgent import ForgetfulAgent
from lib.MetricLogger import MetricLogger


In [23]:
imageShape = (84, 84)
actionShape = len(JoypadActions.SIMPLE_MOVEMENT)

In [24]:
%reload_ext autoreload
%autoreload 2
env = gym_super_mario_bros.make('SuperMarioBros-v0')
env = JoypadSpace(env, JoypadActions.SIMPLE_MOVEMENT)
env = EnvWrapperFactory.convert(env, shape=imageShape)
state = env.reset()
print(state.shape)

shape before any transformations: (240, 256, 3)
shape after grayscaler: (240, 256)
shape after resizer: (84, 84)
shape after all transformations: (5, 84, 84)
(5, 84, 84)


In [25]:
env.observation_space.shape

(5, 84, 84)

In [26]:
%reload_ext autoreload
%autoreload 2
agent = ForgetfulAgent(state_shape=env.observation_space.shape, action_shape=actionShape, device=device, net_name="CNN84x84")
save_dir = Path("logs") / datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
save_dir.mkdir(parents=True)
logger = MetricLogger(save_dir)
print(agent.name)

ForgetfulAgent-CNN84x84


In [27]:
# try training

# maxStepsPerEpisode = 10_00_0000
learnCount = 0
episodes = 10000
epsRewards = []
with tqdm(range(1, episodes+1)) as tepoch:
    for eps in tepoch:
        tepoch.set_description(f"Episode {eps}")
        state = env.reset()
        epsReward = 0.0
        maxX = 0
        # print(f"starting episode: {eps}")
        # for i in range(maxStepsPerEpisode):
        while True:

            action = agent.getAction(state)
            next_state, reward, done, info = env.step(action)

            #episode stats
            epsReward += reward
            if info["x_pos"] > maxX:
                maxX = info["x_pos"]
                
            

            # add to memory
            agent.cache(state, next_state, action, reward, done)

            q, loss = agent.learn()

            logger.log_step(reward, loss, q)

            state = next_state

            # if q is not None:
            #     learnCount += 1
            #     if learnCount % 1000 == 0:
            #         print(f"q:{q}, loss={loss}")

            if info["flag_get"]:
                print(f"reached a flag")
                print(info)

            if done or info["flag_get"]:
                break

        # print(f"done: {done},\n info: {info}")
        logger.log_episode()
        if eps % 100 == 0:
            logger.record(episode=eps, epsilon=agent.exploration_rate, step=agent.current_step)
            # print(info)
            # agent.save(dir="models/", epoch=eps)
        if eps % 1000 == 0:
            # print(info)
            agent.save(dir="models/", epoch=eps)

        epsRewards.append(epsReward)
        tepoch.set_postfix_str(f"episode reward :{epsReward}, max X: {maxX}")
        # print("Total reward after episode {} is {}".format(eps, epsReward))
        

Episode 534:   5%|▌         | 533/10000 [1:29:09<24:08:04,  9.18s/it, episode reward :2926.0, max X: 3161]

reached a flag
{'coins': 2, 'flag_get': True, 'life': 2, 'score': 1000, 'stage': 1, 'status': 'small', 'time': 205, 'world': 1, 'x_pos': 3161, 'y_pos': 95}


Episode 666:   7%|▋         | 665/10000 [1:48:00<49:23:33, 19.05s/it, episode reward :2631.0, max X: 3161]

reached a flag
{'coins': 0, 'flag_get': True, 'life': 0, 'score': 1000, 'stage': 1, 'status': 'small', 'time': 63, 'world': 1, 'x_pos': 3161, 'y_pos': 116}


Episode 753:   8%|▊         | 752/10000 [2:01:11<18:21:56,  7.15s/it, episode reward :3584.0, max X: 3161]

reached a flag
{'coins': 2, 'flag_get': True, 'life': 1, 'score': 1400, 'stage': 1, 'status': 'small', 'time': 269, 'world': 1, 'x_pos': 3161, 'y_pos': 117}


Episode 1001:  10%|█         | 1000/10000 [2:38:05<15:43:52,  6.29s/it, episode reward :2440.0, max X: 1235]

saving model to models/ForgetfulAgent-CNN84x84-checkpoint-Apr-06-2022-1000.pytorch


Episode 1403:  14%|█▍        | 1402/10000 [3:19:27<15:39:02,  6.55s/it, episode reward :2874.0, max X: 3161]

reached a flag
{'coins': 0, 'flag_get': True, 'life': 2, 'score': 600, 'stage': 1, 'status': 'small', 'time': 153, 'world': 1, 'x_pos': 3161, 'y_pos': 120}


Episode 1446:  14%|█▍        | 1445/10000 [3:22:41<14:34:57,  6.14s/it, episode reward :2953.0, max X: 3161]

reached a flag
{'coins': 2, 'flag_get': True, 'life': 2, 'score': 1500, 'stage': 1, 'status': 'small', 'time': 232, 'world': 1, 'x_pos': 3161, 'y_pos': 98}


Episode 1621:  16%|█▌        | 1620/10000 [3:38:34<23:06:16,  9.93s/it, episode reward :2974.0, max X: 3161]

reached a flag
{'coins': 2, 'flag_get': True, 'life': 2, 'score': 1400, 'stage': 1, 'status': 'small', 'time': 253, 'world': 1, 'x_pos': 3161, 'y_pos': 103}


Episode 2001:  20%|██        | 2000/10000 [4:36:24<9:26:01,  4.25s/it, episode reward :2759.0, max X: 1240] 

saving model to models/ForgetfulAgent-CNN84x84-checkpoint-Apr-06-2022-2000.pytorch


Episode 2047:  20%|██        | 2046/10000 [4:40:10<12:51:14,  5.82s/it, episode reward :3865.0, max X: 3161]

reached a flag
{'coins': 2, 'flag_get': True, 'life': 0, 'score': 900, 'stage': 1, 'status': 'small', 'time': 295, 'world': 1, 'x_pos': 3161, 'y_pos': 112}


Episode 3000:  30%|██▉       | 2999/10000 [47:20:57<7:02:26,  3.62s/it, episode reward :1885.0, max X: 1149]       

saving model to models/ForgetfulAgent-CNN84x84-checkpoint-Apr-08-2022-3000.pytorch


Episode 3879:  39%|███▉      | 3878/10000 [48:26:54<9:58:04,  5.86s/it, episode reward :2967.0, max X: 3161] 

reached a flag
{'coins': 2, 'flag_get': True, 'life': 2, 'score': 1300, 'stage': 1, 'status': 'small', 'time': 246, 'world': 1, 'x_pos': 3161, 'y_pos': 116}


Episode 4001:  40%|████      | 4000/10000 [48:37:33<8:22:54,  5.03s/it, episode reward :1524.0, max X: 1142] 

saving model to models/ForgetfulAgent-CNN84x84-checkpoint-Apr-08-2022-4000.pytorch


Episode 4740:  47%|████▋     | 4739/10000 [49:37:25<12:42:23,  8.69s/it, episode reward :2886.0, max X: 3161]

reached a flag
{'coins': 2, 'flag_get': True, 'life': 2, 'score': 400, 'stage': 1, 'status': 'small', 'time': 165, 'world': 1, 'x_pos': 3161, 'y_pos': 106}


Episode 4861:  49%|████▊     | 4860/10000 [49:45:58<8:05:00,  5.66s/it, episode reward :3267.0, max X: 3161] 

reached a flag
{'coins': 1, 'flag_get': True, 'life': 1, 'score': 700, 'stage': 1, 'status': 'small', 'time': 295, 'world': 1, 'x_pos': 3161, 'y_pos': 120}


Episode 4922:  49%|████▉     | 4921/10000 [49:50:54<7:25:16,  5.26s/it, episode reward :3024.0, max X: 3161] 

reached a flag
{'coins': 4, 'flag_get': True, 'life': 2, 'score': 1800, 'stage': 1, 'status': 'small', 'time': 303, 'world': 1, 'x_pos': 3161, 'y_pos': 130}


Episode 5001:  50%|█████     | 5000/10000 [49:57:20<5:45:25,  4.15s/it, episode reward :1573.0, max X: 1516] 

saving model to models/ForgetfulAgent-CNN84x84-checkpoint-Apr-08-2022-5000.pytorch


Episode 5494:  55%|█████▍    | 5493/10000 [50:34:56<8:40:12,  6.93s/it, episode reward :3704.0, max X: 3161] 

reached a flag
{'coins': 0, 'flag_get': True, 'life': 0, 'score': 300, 'stage': 1, 'status': 'small', 'time': 329, 'world': 1, 'x_pos': 3161, 'y_pos': 144}


Episode 5523:  55%|█████▌    | 5522/10000 [50:37:51<9:59:49,  8.04s/it, episode reward :3604.0, max X: 3161] 

reached a flag
{'coins': 2, 'flag_get': True, 'life': 0, 'score': 2000, 'stage': 1, 'status': 'small', 'time': 327, 'world': 1, 'x_pos': 3161, 'y_pos': 137}


Episode 5568:  56%|█████▌    | 5567/10000 [50:41:24<7:10:08,  5.82s/it, episode reward :3882.0, max X: 3161]

reached a flag
{'coins': 0, 'flag_get': True, 'life': 0, 'score': 300, 'stage': 1, 'status': 'small', 'time': 284, 'world': 1, 'x_pos': 3161, 'y_pos': 133}


Episode 5652:  57%|█████▋    | 5651/10000 [50:48:30<6:35:55,  5.46s/it, episode reward :3631.0, max X: 3161]

reached a flag
{'coins': 1, 'flag_get': True, 'life': 1, 'score': 900, 'stage': 1, 'status': 'small', 'time': 299, 'world': 1, 'x_pos': 3161, 'y_pos': 201}


Episode 5847:  58%|█████▊    | 5846/10000 [51:03:11<7:17:06,  6.31s/it, episode reward :4673.0, max X: 3161] 

reached a flag
{'coins': 2, 'flag_get': True, 'life': 0, 'score': 1100, 'stage': 1, 'status': 'small', 'time': 352, 'world': 1, 'x_pos': 3161, 'y_pos': 104}


Episode 6001:  60%|██████    | 6000/10000 [51:15:57<6:00:33,  5.41s/it, episode reward :1875.0, max X: 1133] 

saving model to models/ForgetfulAgent-CNN84x84-checkpoint-Apr-08-2022-6000.pytorch


Episode 6202:  62%|██████▏   | 6201/10000 [51:31:14<6:27:28,  6.12s/it, episode reward :3280.0, max X: 3161]

reached a flag
{'coins': 2, 'flag_get': True, 'life': 1, 'score': 900, 'stage': 1, 'status': 'small', 'time': 319, 'world': 1, 'x_pos': 3161, 'y_pos': 122}


Episode 7001:  70%|███████   | 7000/10000 [52:35:39<3:44:59,  4.50s/it, episode reward :1475.0, max X: 692]  

saving model to models/ForgetfulAgent-CNN84x84-checkpoint-Apr-08-2022-7000.pytorch


Episode 7888:  79%|███████▉  | 7887/10000 [53:43:12<4:54:33,  8.36s/it, episode reward :2865.0, max X: 3161]

reached a flag
{'coins': 2, 'flag_get': True, 'life': 2, 'score': 1200, 'stage': 1, 'status': 'small', 'time': 150, 'world': 1, 'x_pos': 3161, 'y_pos': 102}


Episode 8001:  80%|████████  | 8000/10000 [53:52:13<2:27:48,  4.43s/it, episode reward :1835.0, max X: 723] 

saving model to models/ForgetfulAgent-CNN84x84-checkpoint-Apr-08-2022-8000.pytorch


Episode 8316:  83%|████████▎ | 8315/10000 [54:14:44<2:26:35,  5.22s/it, episode reward :3622.0, max X: 3161]

reached a flag
{'coins': 2, 'flag_get': True, 'life': 1, 'score': 700, 'stage': 1, 'status': 'small', 'time': 309, 'world': 1, 'x_pos': 3161, 'y_pos': 122}


Episode 8539:  85%|████████▌ | 8538/10000 [54:30:52<2:37:39,  6.47s/it, episode reward :3842.0, max X: 3161]

reached a flag
{'coins': 1, 'flag_get': True, 'life': 0, 'score': 1600, 'stage': 1, 'status': 'small', 'time': 260, 'world': 1, 'x_pos': 3161, 'y_pos': 133}


Episode 9001:  90%|█████████ | 9000/10000 [55:08:32<1:19:35,  4.78s/it, episode reward :1792.0, max X: 678] 

saving model to models/ForgetfulAgent-CNN84x84-checkpoint-Apr-08-2022-9000.pytorch


Episode 9078:  91%|█████████ | 9077/10000 [55:14:12<1:56:31,  7.57s/it, episode reward :2910.0, max X: 3161]

reached a flag
{'coins': 1, 'flag_get': True, 'life': 2, 'score': 800, 'stage': 1, 'status': 'small', 'time': 189, 'world': 1, 'x_pos': 3161, 'y_pos': 97}


Episode 9348:  93%|█████████▎| 9347/10000 [55:36:56<1:01:32,  5.65s/it, episode reward :3010.0, max X: 3161]

reached a flag
{'coins': 2, 'flag_get': True, 'life': 2, 'score': 900, 'stage': 1, 'status': 'small', 'time': 295, 'world': 1, 'x_pos': 3161, 'y_pos': 107}


Episode 9656:  97%|█████████▋| 9655/10000 [55:59:51<37:56,  6.60s/it, episode reward :4266.0, max X: 3161]  

reached a flag
{'coins': 1, 'flag_get': True, 'life': 0, 'score': 900, 'stage': 1, 'status': 'small', 'time': 266, 'world': 1, 'x_pos': 3161, 'y_pos': 108}


Episode 10000: 100%|██████████| 10000/10000 [56:26:14<00:00, 20.32s/it, episode reward :1529.0, max X: 1122]

saving model to models/ForgetfulAgent-CNN84x84-checkpoint-Apr-08-2022-10000.pytorch





<Figure size 432x288 with 0 Axes>

In [20]:

env = gym_super_mario_bros.make('SuperMarioBros-v0')
env = JoypadSpace(env, JoypadActions.SIMPLE_MOVEMENT)
env = EnvWrapperFactory.convert(env, shape=imageShape)


done = True
count = 0 
for step in range(10000):
    if done:
        count += 1
        if count > 2:
            break
        state = env.reset()
    # state, reward, done, info = env.step(agent.getBestAction(state))
    state, reward, done, info = env.step(agent.getAction(state))
    # state, reward, done, info = env.step(agent._exploit(state))
    env.render()
env.close()

shape before any transformations: (240, 256, 3)
shape after grayscaler: (240, 256)
shape after resizer: (84, 84)
shape after all transformations: (5, 84, 84)


In [11]:
agent.save(dir="models/", epoch=eps)

saving model to models/ForgetfulAgent-CNN84x84-checkpoint-9999.pytorch


In [15]:
print(agent.exploration_rate)

0.12048704660553211
