In [21]:
# #only run once
# #!pip install nes-py==0.2.6
# !brew update
# !brew install ffmpeg
# !brew install libsm
# !brew install libxext
# !brew install mesa
# !pip install opencv-python
# !pip install gym-super-mario-bros
# !pip install gym

In [2]:
import torch
import torch.nn as nn
import random
import gym
import gym_super_mario_bros
from nes_py.wrappers import JoypadSpace
from gym_super_mario_bros import SuperMarioBrosEnv
from tqdm import tqdm
import pickle 
import gym
import numpy as np
import collections 
import cv2
import matplotlib.pyplot as plt
import time
import datetime

In [3]:
from toolkit.gym_env import *
from toolkit.action_utils import *
from toolkit.marlios_model import *
from toolkit.constants import *

CONSECUTIVE_ACTIONS = 2

%load_ext autoreload
%autoreload 2

In [4]:
def show_state(env, ep=0, info=""):
    plt.figure(3)
    plt.clf()
    plt.imshow(env.render(mode='rgb_array'))
    plt.title("Episode: %d %s" % (ep, info))
    plt.axis('off')

    # display.clear_output(wait=True)
    # display.display(plt.gcf())
    display(plt.gcf(), clear=True)

In [5]:
def make_env(env, actions=ACTION_SPACE):
    env = MaxAndSkipEnv(env, skip=2) # I am testing out fewer fram repetitions for our two actions modelling
    env = ProcessFrame84(env)
    env = ImageToPyTorch(env)
    env = BufferWrapper(env, 4)
    env = ScaledFloatFrame(env)
    return JoypadSpace(env, actions)

def generate_epoch_time_id():
    epoch_time = int(time.time())
    return str(epoch_time)

In [6]:
def save_checkpoint(agent, total_rewards, terminal_info, run_id):
    with open(f"ending_position-{run_id}.pkl", "wb") as f:
        pickle.dump(agent.ending_position, f)
    with open(f"num_in_queue-{run_id}.pkl", "wb") as f:
        pickle.dump(agent.num_in_queue, f)
    with open(f"total_rewards-{run_id}.pkl", "wb") as f:
        pickle.dump(total_rewards, f)
    with open(f"terminal_info-{run_id}.pkl", "wb") as f:
        pickle.dump(terminal_info, f)
    if agent.double_dq:
        torch.save(agent.local_net.state_dict(), f"dq1-{run_id}.pt")
        torch.save(agent.target_net.state_dict(), f"dq2-{run_id}.pt")
    else:
        torch.save(agent.dqn.state_dict(), f"dq-{run_id}.pt")  
    # torch.save(agent.STATE_MEM,  f"STATE_MEM-{run_id}.pt")
    # torch.save(agent.ACTION_MEM, f"ACTION_MEM-{run_id}.pt")
    # torch.save(agent.REWARD_MEM, f"REWARD_MEM-{run_id}.pt")
    # torch.save(agent.STATE2_MEM, f"STATE2_MEM-{run_id}.pt")
    # torch.save(agent.DONE_MEM,   f"DONE_MEM-{run_id}.pt")
    # torch.save(agent.SPACE_MEM,   f"SPACE_MEM-{run_id}.pt")

In [60]:
def run(training_mode=True, pretrained=False, lr=0.0001, gamma=0.90, exploration_decay=0.995, exploration_min=0.02, ep_per_stat = 100,
        mario_env='SuperMarioBros-1-1-v0', action_space=TWO_ACTIONS_SET, num_episodes=1000, run_id=None, n_actions=5, consecutiveActions = 2):
   
    run_id = run_id or generate_epoch_time_id()
    fh = open(f'progress-{run_id}.txt', 'a')
    env = gym.make(mario_env)
    env = make_env(env, ACTION_SPACE)
    # observation_space = env.observation_space.shape # not using this anymore

    #todo: add agent params as a setting/create different agents in diff functions to run 

    agent = DQNAgent(
                     action_space=action_space,
                     max_memory_size=30000,
                     batch_size=32,
                     gamma=gamma,
                     lr=lr,
                     dropout=0.,
                     exploration_max=1,
                     exploration_min=exploration_min,
                     exploration_decay=exploration_decay,
                     double_dq=True,
                     pretrained=pretrained,
                     run_id=run_id,
                     n_actions=n_actions)
    
    
    # num_episodes = 10
    env.reset()
    total_rewards = []
    total_info = []
    
    for ep_num in tqdm(range(num_episodes)):
        state = env.reset()[-1] # take the final dimension of shape (4, 84, 84) leaving shape (84, 84) 
        state = torch.Tensor([state]).unsqueeze(0) # converts (1, 84, 84) to (1, 1, 84, 84)
        total_reward = 0
        steps = 0
        while True:
            if not training_mode:
                show_state(env, ep_num)


            two_actions_index = agent.act(state)
            two_actions_vector = agent.cur_action_space[0, two_actions_index]
            two_actions = vec_to_action(two_actions_vector.cpu()) # tuple of actions
            
            if ep_num%5 == 0 and ep_num != 0:
                print(two_actions)
            
            steps += 1
            reward = 0
            info = None
            terminal = False
            for action in two_actions: 
                if not terminal:
                    # compute index into ACTION_SPACE of our action
                    step_action = ACTION_TO_INDEX[action]

                    state_next, cur_reward, terminal, info = env.step(step_action)
                    total_reward += cur_reward
                    reward += cur_reward
                    
            state_next = torch.Tensor([state_next[-1]]).unsqueeze(0)
            reward = torch.tensor([reward]).unsqueeze(0)        
            terminal = torch.tensor([int(terminal)]).unsqueeze(0)
            
            if training_mode:
                agent.remember(state, two_actions_index, reward, state_next, terminal)
                agent.experience_replay()
            
            state = state_next
            if terminal:
                break

        total_info.append(info)
        total_rewards.append(total_reward)

        if training_mode and (ep_num % ep_per_stat) == 0:
            save_checkpoint(agent, total_rewards, total_info, run_id)

        with open(f'total_reward-{run_id}.txt', 'a') as f:
            f.write("Total reward after episode {} is {}\n".format(ep_num + 1, total_rewards[-1]))
            if (ep_num%100 == 0):
                f.write("==================\n")
                f.write("{} current time at episode {}\n".format(datetime.datetime.now(), ep_num+1))
                f.write("==================\n")
            #print("Total reward after episode {} is {}".format(ep_num + 1, total_rewards[-1]))
            num_episodes += 1
    
    if training_mode:
        save_checkpoint(agent, total_rewards, run_id)
    
    env.close()
    fh.close()
    
    if num_episodes > ep_per_stat:
        plt.title("Episodes trained vs. Average Rewards (per 500 eps)")
        plt.plot([0 for _ in range(ep_per_stat)] + 
                 np.convolve(total_rewards, np.ones((ep_per_stat,))/ep_per_stat, mode="valid").tolist())
        plt.show()





In [61]:
run(training_mode=True, pretrained=False, ep_per_stat=100, lr=0.000005, action_space=SIMPLE_MOVEMENT, n_actions=len(SIMPLE_MOVEMENT)+2)

  0%|          | 5/1000 [01:50<4:54:00, 17.73s/it]

(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('A', 'B', 'right'), ('A', 'B', 'right'))
(('left',), ('left',))
(('A', 'B', 'right'), ('A', 'B', 'right'))
(('A', 'B', 'right'), ('A', 'B', 'right'))
(('A', 'B', 'right'), ('A', 'B', 'right'))
(('A', 'B', 'right'), ('A', 'B', 'right'))
(('left',), ('left',))
(('A', 'B', 'right'), ('A', 'B', 'right'))
(('A', 'B', 'right'), ('A', 'B', 'right'))
(('A', 'B', 'right'), ('A', 'B', 'right'))
(('A', 'B', 'right'), ('A', 'B', 'right'))
(('A', 'B', 'right'), ('A', 'B', 'right'))
(('A', 'B', 'right'), ('A', 'B', 'right'))
(('left',), ('left',))
(('A', 'B', 'right'), ('A', 'B', 'right'))
(('A', 'B', 'right'), ('A', 'B', 'right'))
(('A', 'B', 'right'), ('A', 'B', 'right'))
(('A', 'B', 'right'), ('A', 'B', 'right'))
(('left',), ('left',))
(('A', 'B', 'right'), ('A', 'B', 'right'))
(('A', 'down', 'right'), ('A', 'down', 'right'))
(('A', 'B', 'right'), ('A', 'B', 'right'))
(('A', 'B', 'right'), ('A', 'B', 'ri

  1%|          | 6/1000 [01:53<3:33:43, 12.90s/it]

(('B', 'right'), ('B', 'right'))
(('left',), ('left',))
(('left',), ('left',))
(('B', 'right'), ('B', 'right'))
(('A', 'down', 'right'), ('A', 'down', 'right'))


  1%|          | 10/1000 [05:18<12:03:33, 43.85s/it]

(('left',), ('left',))
(('left',), ('left',))
(('A', 'B', 'right'), ('A', 'B', 'right'))
(('A', 'B', 'right'), ('A', 'B', 'right'))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('A', 'B', 'right'), ('A', 'B', 'right'))
(('A', 'B', 'right'), ('A', 'B', 'right'))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))


  2%|▏         | 15/1000 [10:04<15:01:15, 54.90s/it]

(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',),

  2%|▏         | 16/1000 [11:01<15:10:11, 55.50s/it]

(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))


  2%|▏         | 20/1000 [14:42<15:03:54, 55.34s/it]

(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',),

  2%|▏         | 21/1000 [15:39<15:08:03, 55.65s/it]

(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))


  2%|▎         | 25/1000 [19:24<15:07:06, 55.82s/it]

(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',), ('left',))
(('left',),

  3%|▎         | 27/1000 [21:40<13:00:49, 48.15s/it]


KeyboardInterrupt: 

# Debugging

In [54]:
training_mode=True
pretrained=False
lr=0.000005
gamma=0.90
exploration_decay=0.995
exploration_min=0.02
ep_per_stat = 100
mario_env='SuperMarioBros-1-1-v0'
action_space=SIMPLE_MOVEMENT
num_episodes=1000
run_id=None
n_actions=len(SIMPLE_MOVEMENT) + 2
consecutiveActions = 2

run_id = run_id or generate_epoch_time_id()
fh = open(f'progress-{run_id}.txt', 'a')
env = gym.make(mario_env)
#env = gym_super_mario_bros.make('SuperMarioBros-v0')

#env = make_env(env)  # Wraps the environment so that frames are grayscale 
#env = SuperMarioBrosEnv()
env = make_env(env, ACTION_SPACE)
# observation_space = env.observation_space.shape # not using this anymore

#todo: add agent params as a setting/create different agents in diff functions to run 

agent = DQNAgent(
                    action_space=action_space,
                    max_memory_size=30000,
                    batch_size=32,
                    gamma=gamma,
                    lr=lr,
                    dropout=0.,
                    exploration_max=1,
                    exploration_min=exploration_min,
                    exploration_decay=exploration_decay,
                    double_dq=True,
                    pretrained=pretrained,
                    run_id=run_id,
                    n_actions=n_actions)


# num_episodes = 10
env.reset()
total_rewards = []
total_info = []

for ep_num in tqdm(range(num_episodes)):
    state = env.reset()[-1] # take the final dimension of shape (4, 84, 84) leaving shape (84, 84) 
    state = torch.Tensor([state]).unsqueeze(0) # converts (1, 84, 84) to (1, 1, 84, 84)
    total_reward = 0
    steps = 0
    while True:
        if not training_mode:
            show_state(env, ep_num)


        two_actions_index = agent.act(state)
        # print()
        # print(two_actions_index.float())
        two_actions_vector = agent.cur_action_space[0, two_actions_index]
        # print(two_actions_vector)

        two_actions = vec_to_action(two_actions_vector.cpu()) # tuple of actions
        # print(two_actions)

        steps += 1
        reward = 0
        info = None
        terminal = False
        for action in two_actions: 
            if not terminal:
                # compute index into ACTION_SPACE of our action
                step_action = ACTION_TO_INDEX[action]

                state_next, cur_reward, terminal, info = env.step(step_action)
                total_reward += cur_reward
                reward += cur_reward
                state_next = torch.Tensor([state_next[-1]]).unsqueeze(0)

        reward = torch.tensor([reward]).unsqueeze(0)        
        terminal = torch.tensor([int(terminal)]).unsqueeze(0)
        
        if training_mode:
            agent.remember(state, two_actions_index, reward, state_next, terminal)
            agent.experience_replay()
        
        state = state_next
        if terminal:
            break

    total_info.append(info)
    total_rewards.append(total_reward)

    if training_mode and (ep_num % ep_per_stat) == 0:
        save_checkpoint(agent, total_rewards, total_info, run_id)

    with open(f'total_reward-{run_id}.txt', 'a') as f:
        f.write("Total reward after episode {} is {}\n".format(ep_num + 1, total_rewards[-1]))
        if (ep_num%100 == 0):
            f.write("==================\n")
            f.write("{} current time at episode {}\n".format(datetime.datetime.now(), ep_num+1))
            f.write("==================\n")
        #print("Total reward after episode {} is {}".format(ep_num + 1, total_rewards[-1]))
        num_episodes += 1

if training_mode:
    save_checkpoint(agent, total_rewards, run_id)

env.close()
fh.close()

if num_episodes > ep_per_stat:
    plt.title("Episodes trained vs. Average Rewards (per 500 eps)")
    plt.plot([0 for _ in range(ep_per_stat)] + 
                np.convolve(total_rewards, np.ones((ep_per_stat,))/ep_per_stat, mode="valid").tolist())
    plt.show()



  logger.warn(
  deprecation(
  deprecation(
  logger.deprecation(
  if not isinstance(done, (bool, np.bool8)):
  0%|          | 2/1000 [02:35<21:30:12, 77.57s/it]


KeyboardInterrupt: 

In [61]:
x = env.reset()[-1] # take the final dimension of shape (4, 84, 84) leaving shape (84, 84) 
x = torch.Tensor([x]).unsqueeze(0).to(agent.device) # converts (1, 84, 84) to (1, 1, 84, 84)
# print(state.shape)

conv_out = agent.local_net.conv(x).view(x.size()[0], -1)
batched_conv_out = conv_out.reshape(conv_out.shape[0], 1, conv_out.shape[-1]).repeat(1, agent.cur_action_space.shape[-2], 1)
batched_actions = torch.cat((batched_conv_out, agent.cur_action_space), dim=2)
intermediate = agent.local_net.fc(batched_actions)
out =  torch.flatten(intermediate, start_dim=1)


In [76]:
# print(intermediate)

double = torch.cat((intermediate, intermediate + 1))
out = torch.flatten(double, start_dim=1)

In [82]:
STATE, ACTION, REWARD, STATE2, DONE, SPACE = agent.recall()
STATE = STATE.to(agent.device)
ACTION = ACTION.to(agent.device)
REWARD = REWARD.to(agent.device)
STATE2 = STATE2.to(agent.device)
SPACE = SPACE.to(agent.device)
DONE = DONE.to(agent.device)

agent.optimizer.zero_grad()
# Double Q-Learning target is Q*(S, A) <- r + γ max_a Q_target(S', a)

target = REWARD + torch.mul((agent.gamma * 
                            agent.target_net(STATE2, SPACE).max(1).values.unsqueeze(1)), 
                            1 - DONE)

current = agent.local_net(STATE, SPACE).gather(1, ACTION.long()) # Local net approximation of Q-value


loss = agent.l1(current, target) # maybe we can play with some L2 loss 
loss.backward() # Compute gradients
agent.optimizer.step() # Backpropagate error

# agent.cur_action_space = torch.from_numpy(agent.subsample_actions(agent.n_actions)).to(torch.float32).to(agent.device)
# I am disabling this here for my testing, but also think we should add it to the run loop for testing til we are sure it works, idk

agent.exploration_rate *= agent.exploration_decay

# Makes sure that exploration rate is always at least 'exploration min'
agent.exploration_rate = max(agent.exploration_rate, agent.exploration_min)

In [93]:
agent.cur_action_space

tensor([[[0., 0., 0., 1., 0., 0., 0., 0., 1., 0.],
         [0., 1., 0., 0., 1., 0., 1., 0., 0., 1.],
         [0., 0., 0., 0., 1., 0., 0., 0., 0., 1.],
         [1., 1., 0., 0., 1., 1., 1., 0., 0., 1.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [1., 0., 0., 0., 1., 1., 0., 0., 0., 1.],
         [1., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
         [0., 1., 1., 0., 1., 1., 1., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]], device='mps:0')

In [91]:
agent.local_net(STATE, SPACE).gather(1, ACTION.long())

tensor([[-0.3566],
        [-0.3520],
        [ 9.3666],
        [-0.3626],
        [-0.3566],
        [-0.3566],
        [-0.4712],
        [-0.6427],
        [-0.3626],
        [-0.3396],
        [-0.3520],
        [-0.3626],
        [-0.3520],
        [-0.3520],
        [-0.3626],
        [-0.5454],
        [-0.3520],
        [-0.3520],
        [-0.5498],
        [-0.3566],
        [-0.4670],
        [-0.4699],
        [-0.3626],
        [-0.3566],
        [-0.3566],
        [-0.3520],
        [-0.4642],
        [-0.3520],
        [-0.3566],
        [-0.4632],
        [-0.4632],
        [-0.3626]], device='mps:0', grad_fn=<GatherBackward0>)

In [89]:
agent.cur_action_space.shape

torch.Size([1, 9, 10])

In [81]:
print(torch.argmax(out, dim=1))

tensor([0, 0], device='mps:0')


# Checking what the model was doing

In [49]:
training_mode=False
pretrained=True
lr=0.00025
gamma=0.90
exploration_decay=0.995
exploration_min=0.02
ep_per_stat = 100
mario_env='SuperMarioBros-1-1-v0'
action_space=SIMPLE_MOVEMENT
num_episodes=1000
# run_id='1681699251'
n_actions=len(SIMPLE_MOVEMENT) + 2
consecutiveActions = 2

run_id = run_id or generate_epoch_time_id() 
fh = open(f'progress-{run_id}.txt', 'a')
env = gym.make(mario_env)
#env = gym_super_mario_bros.make('SuperMarioBros-v0')

#env = make_env(env)  # Wraps the environment so that frames are grayscale 
#env = SuperMarioBrosEnv()
env = make_env(env, ACTION_SPACE)
# observation_space = env.observation_space.shape # not using this anymore


#todo: add agent params as a setting/create different agents in diff functions to run 

agent = DQNAgent(
                    action_space=action_space,
                    max_memory_size=30000,
                    batch_size=32,
                    gamma=gamma,
                    lr=lr,
                    dropout=0.,
                    exploration_max=.02,
                    exploration_min=exploration_min,
                    exploration_decay=exploration_decay,
                    double_dq=True,
                    pretrained=pretrained,
                    run_id=run_id,
                    n_actions=n_actions)


# num_episodes = 10
env.reset()
total_rewards = []
total_info = []

for ep_num in tqdm(range(num_episodes)):
    state = env.reset()[-1] # take the final dimension of shape (4, 84, 84) leaving shape (84, 84) 
    state = torch.Tensor([state]).unsqueeze(0) # converts (1, 84, 84) to (1, 1, 84, 84)
    total_reward = 0
    steps = 0
    while True:
        # if not training_mode:
            # show_state(env, ep_num)


        two_actions_index = agent.act(state)

        print(agent.local_net(state.to(agent.device), agent.cur_action_space))
        print(torch.argmax(agent.local_net(state.to(agent.device), agent.cur_action_space)))
 
        two_actions_vector = agent.cur_action_space[0, two_actions_index]

        two_actions = vec_to_action(two_actions_vector.cpu()) # tuple of actions
        print(two_actions)
        print()

        steps += 1
        reward = 0
        info = None
        terminal = False
        for action in two_actions: 
            if not terminal:
                # compute index into ACTION_SPACE of our action
                step_action = ACTION_TO_INDEX[action]

                state_next, cur_reward, terminal, info = env.step(step_action)
                total_reward += cur_reward
                reward += cur_reward

        state_next = torch.Tensor([state_next[-1]]).unsqueeze(0)
        reward = torch.tensor([reward]).unsqueeze(0)        
        terminal = torch.tensor([int(terminal)]).unsqueeze(0)
        
        if training_mode:
            agent.remember(state, two_actions_index, reward, state_next, terminal)
            agent.experience_replay()
        
        state = state_next
        if terminal:
            break

    total_info.append(info)
    total_rewards.append(total_reward)

    if training_mode and (ep_num % ep_per_stat) == 0:
        save_checkpoint(agent, total_rewards, total_info, run_id)

    with open(f'total_reward-{run_id}.txt', 'a') as f:
        f.write("Total reward after episode {} is {}\n".format(ep_num + 1, total_rewards[-1]))
        if (ep_num%100 == 0):
            f.write("==================\n")
            f.write("{} current time at episode {}\n".format(datetime.datetime.now(), ep_num+1))
            f.write("==================\n")
        #print("Total reward after episode {} is {}".format(ep_num + 1, total_rewards[-1]))
        num_episodes += 1

if training_mode:
    save_checkpoint(agent, total_rewards, run_id)

env.close()
fh.close()

if num_episodes > ep_per_stat:
    plt.title("Episodes trained vs. Average Rewards (per 500 eps)")
    plt.plot([0 for _ in range(ep_per_stat)] + 
                np.convolve(total_rewards, np.ones((ep_per_stat,))/ep_per_stat, mode="valid").tolist())
    plt.show()



  logger.warn(
  deprecation(
  deprecation(
  logger.deprecation(
  if not isinstance(done, (bool, np.bool8)):


tensor([[8.1580, 8.1779, 8.2110, 8.1554, 8.1368, 8.1752, 8.2084, 8.2065, 8.1580]],
       device='mps:0', grad_fn=<ReshapeAliasBackward0>)
tensor(2, device='mps:0')
(('B', 'right'), ('B', 'right'))

tensor([[10.8588, 10.8799, 10.9129, 10.8557, 10.8383, 10.8767, 10.9101, 10.9081,
         10.8588]], device='mps:0', grad_fn=<ReshapeAliasBackward0>)
tensor(2, device='mps:0')
(('B', 'right'), ('B', 'right'))

tensor([[11.3143, 11.3354, 11.3684, 11.3114, 11.2947, 11.3325, 11.3655, 11.3646,
         11.3143]], device='mps:0', grad_fn=<ReshapeAliasBackward0>)
tensor(2, device='mps:0')
(('B', 'right'), ('B', 'right'))

tensor([[12.9228, 12.9446, 12.9784, 12.9183, 12.9036, 12.9403, 12.9740, 12.9715,
         12.9228]], device='mps:0', grad_fn=<ReshapeAliasBackward0>)
tensor(2, device='mps:0')
(('B', 'right'), ('B', 'right'))

tensor([[15.6140, 15.6311, 15.6616, 15.6063, 15.5928, 15.6242, 15.6552, 15.6542,
         15.6140]], device='mps:0', grad_fn=<ReshapeAliasBackward0>)
tensor(2, device='mps

  0%|          | 1/1000 [00:00<09:17,  1.79it/s]

tensor([[34.7344, 34.7477, 34.7729, 34.7244, 34.7225, 34.7379, 34.7655, 34.7645,
         34.7344]], device='mps:0', grad_fn=<ReshapeAliasBackward0>)
tensor(2, device='mps:0')
(('B', 'right'), ('B', 'right'))

tensor([[25.5011, 25.5182, 25.5517, 25.4957, 25.4889, 25.5128, 25.5465, 25.5418,
         25.5011]], device='mps:0', grad_fn=<ReshapeAliasBackward0>)
tensor(2, device='mps:0')
(('B', 'right'), ('B', 'right'))

tensor([[21.9244, 21.9380, 21.9666, 21.9198, 21.9048, 21.9345, 21.9640, 21.9667,
         21.9244]], device='mps:0', grad_fn=<ReshapeAliasBackward0>)
tensor(7, device='mps:0')
(('A', 'down', 'right'), ('A', 'B'))

tensor([[20.1326, 20.1479, 20.1810, 20.1278, 20.1141, 20.1431, 20.1762, 20.1757,
         20.1326]], device='mps:0', grad_fn=<ReshapeAliasBackward0>)
tensor(2, device='mps:0')
(('B', 'right'), ('B', 'right'))

tensor([[13.4473, 13.4646, 13.4976, 13.4408, 13.4359, 13.4584, 13.4911, 13.4879,
         13.4473]], device='mps:0', grad_fn=<ReshapeAliasBackward0>)
tensor

  0%|          | 2/1000 [00:01<10:05,  1.65it/s]

tensor(7, device='mps:0')
(('A', 'down', 'right'), ('A', 'B'))

tensor([[34.7344, 34.7477, 34.7729, 34.7244, 34.7225, 34.7379, 34.7655, 34.7645,
         34.7344]], device='mps:0', grad_fn=<ReshapeAliasBackward0>)
tensor(2, device='mps:0')
(('B', 'right'), ('B', 'right'))

tensor([[25.5011, 25.5182, 25.5517, 25.4957, 25.4889, 25.5128, 25.5465, 25.5418,
         25.5011]], device='mps:0', grad_fn=<ReshapeAliasBackward0>)
tensor(2, device='mps:0')
(('B', 'right'), ('B', 'right'))

tensor([[21.9244, 21.9380, 21.9666, 21.9198, 21.9048, 21.9345, 21.9640, 21.9667,
         21.9244]], device='mps:0', grad_fn=<ReshapeAliasBackward0>)
tensor(7, device='mps:0')
(('A', 'down', 'right'), ('A', 'B'))

tensor([[20.1326, 20.1479, 20.1810, 20.1278, 20.1141, 20.1431, 20.1762, 20.1757,
         20.1326]], device='mps:0', grad_fn=<ReshapeAliasBackward0>)
tensor(2, device='mps:0')
(('B', 'right'), ('B', 'right'))

tensor([[13.4473, 13.4646, 13.4976, 13.4408, 13.4359, 13.4584, 13.4911, 13.4879,
         13

  0%|          | 3/1000 [00:01<09:58,  1.67it/s]

tensor([[25.9901, 26.0084, 26.0380, 25.9843, 25.9734, 26.0025, 26.0321, 26.0270,
         25.9901]], device='mps:0', grad_fn=<ReshapeAliasBackward0>)
tensor(2, device='mps:0')
(('B', 'right'), ('B', 'right'))

tensor([[24.6113, 24.6278, 24.6607, 24.6040, 24.5981, 24.6203, 24.6519, 24.6523,
         24.6113]], device='mps:0', grad_fn=<ReshapeAliasBackward0>)
tensor(2, device='mps:0')
(('B', 'right'), ('B', 'right'))

tensor([[22.4212, 22.4415, 22.4728, 22.4162, 22.4050, 22.4366, 22.4674, 22.4636,
         22.4212]], device='mps:0', grad_fn=<ReshapeAliasBackward0>)
tensor(2, device='mps:0')
(('B', 'right'), ('B', 'right'))

tensor([[17.7035, 17.7231, 17.7580, 17.6980, 17.6920, 17.7169, 17.7509, 17.7511,
         17.7035]], device='mps:0', grad_fn=<ReshapeAliasBackward0>)
tensor(2, device='mps:0')
(('B', 'right'), ('B', 'right'))

tensor([[14.8877, 14.9079, 14.9444, 14.8835, 14.8692, 14.9037, 14.9404, 14.9335,
         14.8877]], device='mps:0', grad_fn=<ReshapeAliasBackward0>)
tensor(2, 

  0%|          | 4/1000 [00:02<09:49,  1.69it/s]

tensor([[13.4473, 13.4646, 13.4976, 13.4408, 13.4359, 13.4584, 13.4911, 13.4879,
         13.4473]], device='mps:0', grad_fn=<ReshapeAliasBackward0>)
tensor(2, device='mps:0')
(('B', 'right'), ('B', 'right'))

tensor([[5.1130, 5.1307, 5.1622, 5.1030, 5.0999, 5.1213, 5.1521, 5.1497, 5.1130]],
       device='mps:0', grad_fn=<ReshapeAliasBackward0>)
tensor(2, device='mps:0')
(('B', 'right'), ('B', 'right'))

tensor([[-3.8237, -3.8095, -3.7820, -3.8380, -3.8333, -3.8233, -3.7964, -3.8021,
         -3.8237]], device='mps:0', grad_fn=<ReshapeAliasBackward0>)
tensor(2, device='mps:0')
(('B', 'right'), ('B', 'right'))

tensor([[8.1580, 8.1779, 8.2110, 8.1554, 8.1368, 8.1752, 8.2084, 8.2065, 8.1580]],
       device='mps:0', grad_fn=<ReshapeAliasBackward0>)
tensor(2, device='mps:0')
(('B', 'right'), ('B', 'right'))

tensor([[10.8588, 10.8799, 10.9129, 10.8557, 10.8383, 10.8767, 10.9101, 10.9081,
         10.8588]], device='mps:0', grad_fn=<ReshapeAliasBackward0>)
tensor(2, device='mps:0')
(('B',

  0%|          | 5/1000 [00:02<09:42,  1.71it/s]

tensor([[-3.8237, -3.8095, -3.7820, -3.8380, -3.8333, -3.8233, -3.7964, -3.8021,
         -3.8237]], device='mps:0', grad_fn=<ReshapeAliasBackward0>)
tensor(2, device='mps:0')
(('B', 'right'), ('B', 'right'))

tensor([[8.1580, 8.1779, 8.2110, 8.1554, 8.1368, 8.1752, 8.2084, 8.2065, 8.1580]],
       device='mps:0', grad_fn=<ReshapeAliasBackward0>)
tensor(2, device='mps:0')
(('B', 'right'), ('B', 'right'))

tensor([[10.8588, 10.8799, 10.9129, 10.8557, 10.8383, 10.8767, 10.9101, 10.9081,
         10.8588]], device='mps:0', grad_fn=<ReshapeAliasBackward0>)
tensor(2, device='mps:0')
(('B', 'right'), ('B', 'right'))

tensor([[11.3143, 11.3354, 11.3684, 11.3114, 11.2947, 11.3325, 11.3655, 11.3646,
         11.3143]], device='mps:0', grad_fn=<ReshapeAliasBackward0>)
tensor(2, device='mps:0')
(('B', 'right'), ('B', 'right'))

tensor([[12.9228, 12.9446, 12.9784, 12.9183, 12.9036, 12.9403, 12.9740, 12.9715,
         12.9228]], device='mps:0', grad_fn=<ReshapeAliasBackward0>)
tensor(2, device='mps

  1%|          | 6/1000 [00:03<09:23,  1.76it/s]

tensor([[25.5011, 25.5182, 25.5517, 25.4957, 25.4889, 25.5128, 25.5465, 25.5418,
         25.5011]], device='mps:0', grad_fn=<ReshapeAliasBackward0>)
tensor(2, device='mps:0')
(('B', 'right'), ('B', 'right'))

tensor([[21.9244, 21.9380, 21.9666, 21.9198, 21.9048, 21.9345, 21.9640, 21.9667,
         21.9244]], device='mps:0', grad_fn=<ReshapeAliasBackward0>)
tensor(7, device='mps:0')
(('A', 'down', 'right'), ('A', 'B'))

tensor([[20.1326, 20.1479, 20.1810, 20.1278, 20.1141, 20.1431, 20.1762, 20.1757,
         20.1326]], device='mps:0', grad_fn=<ReshapeAliasBackward0>)
tensor(2, device='mps:0')
(('B', 'right'), ('B', 'right'))

tensor([[13.4473, 13.4646, 13.4976, 13.4408, 13.4359, 13.4584, 13.4911, 13.4879,
         13.4473]], device='mps:0', grad_fn=<ReshapeAliasBackward0>)
tensor(2, device='mps:0')
(('B', 'right'), ('B', 'right'))

tensor([[5.1130, 5.1307, 5.1622, 5.1030, 5.0999, 5.1213, 5.1521, 5.1497, 5.1130]],
       device='mps:0', grad_fn=<ReshapeAliasBackward0>)
tensor(2, device=

  1%|          | 6/1000 [00:07<20:18,  1.23s/it]

tensor([[5.8660, 5.8787, 5.9105, 5.8587, 5.8462, 5.8714, 5.9029, 5.9054, 5.8660]],
       device='mps:0', grad_fn=<ReshapeAliasBackward0>)
tensor(2, device='mps:0')
(('B', 'right'), ('B', 'right'))

tensor([[6.4423, 6.4582, 6.4896, 6.4351, 6.4201, 6.4509, 6.4824, 6.4838, 6.4423]],
       device='mps:0', grad_fn=<ReshapeAliasBackward0>)
tensor(2, device='mps:0')
(('B', 'right'), ('B', 'right'))






KeyboardInterrupt: 

In [94]:
for i in range(agent.n_actions):
    two_actions_vector = agent.cur_action_space[0, i]
    two_actions = vec_to_action(two_actions_vector.cpu()) # tuple of actions
    print(two_actions)



(('left',), ('left',))
(('B', 'right'), ('B', 'right'))
(('right',), ('right',))
(('A', 'B', 'right'), ('A', 'B', 'right'))
(('NOOP',), ('NOOP',))
(('A', 'right'), ('A', 'right'))
(('A',), ('A',))
(('B', 'down', 'right'), ('A', 'B'))
(('NOOP',), ('NOOP',))
