In [1]:
%matplotlib inline
from krazy_gridworld import KrazyGridWorld
import numpy as np
import torch
from model import Model
from utils import ReplayBuffer, get_state, sample_advice, advice_satisfied
from itertools import count
from datetime import datetime
from torch.utils.tensorboard import SummaryWriter

## Hyperparameters

In [2]:
NAME='ACTRCE(-)-KGW' + ' ' + datetime.now().strftime('%Y-%m-%d %H:%M:%S')

max_frames = int(32e6)
save_interval = 100

lr=1e-3
batch_size = 64
epsilon = 0.05
T=25

log_dir = f'/mnt/hdd1/ml/logs/{NAME}'
SAVE_DIR = 'models'

In [3]:
env = KrazyGridWorld()
env.reset()
channel_in, height, width = get_state(env).shape
action_dim = 4

In [4]:
writer = SummaryWriter(log_dir=log_dir)
# writer=None
replay_buffer = ReplayBuffer()

In [5]:
net = Model(lr, height, width, channel_in, action_dim, writer=writer)

## Utility functions

In [6]:
def epsilon_decay(frame_number, eps_init=1.0, eps_end=0.01, decay_len=100000):
    if frame_number > decay_len:
        return eps_end
    else:
        return eps_init * (1-frame_number/decay_len) + eps_end * (frame_number/decay_len)

## Training

In [None]:
frame_number = 0
success = 0
num_episodes = 0
while frame_number < max_frames:
    env.reset()
    advice = sample_advice()
    replay_buffer.new_episode(advice)
    for t in range(T):
            frame_number += 1
            state = get_state(env)
            eps = epsilon_decay(frame_number)
            action = net.select_action(state, advice.split(" "), epsilon=eps)
            
            _, _, done, info = env.step(action)
            at_goal = env.at_goal()
            is_lava = env.is_dead()
            color = info['color']
            
            is_done, satisfied = advice_satisfied(advice, color, at_goal, is_lava)
            
            done = done or (t == T - 1) or is_done
            
            replay_buffer.add(state, action, color, at_goal, is_lava)
            
            if frame_number % 1 == 0:
                loss = net.update(batch_size, replay_buffer)
                if writer is not None and loss is not None:
                    writer.add_scalar("loss", loss, frame_number)
                
            if frame_number % 2 == 0:
                net.update_target_model()
                
            if frame_number % 1000 == 0 and writer is not None:
                writer.add_scalar('success_rate', success/num_episodes, frame_number)
                success = 0
                num_episodes = 0
            
            if done:
                is_initial = replay_buffer.compute_reward(color, at_goal, is_lava)
                if satisfied:
                    success += 1
                num_episodes += 1
                break    
        
#     if episode % save_interval == 0:
#         print(f'model saved on episode: {episode % (10 * save_interval)}')
#         net.save('models', f'episode-{episode % (10 * save_interval)}')
        
#         print(f'best model saved with reward: {total_rewards}')
#         net.save('models', f'best')

In [None]:
# net.load('models', f'best')

In [None]:
# def show_state(env, step=0, info=""):
#     plt.figure(3)
#     plt.clf()
#     plt.imshow(env.get_img_pyplot_obs())
#     plt.title("%s | Step: %d %s" % (env, step, info))
#     plt.axis('off')

#     display.clear_output(wait=True)
#     display.display(plt.gcf())