# Training Reward Predictor for Doom 

In [1]:
from utils.env import make_doom_env
from agents.doom_ppo_agent import DoomPpoAgent
from reward_predictors.doom_reward_predictor import DoomRewardPredictor
import gym
from datetime import datetime
from utils.memory import Memory
from utils.time import current_timestamp_ms
from torch.utils.tensorboard import SummaryWriter

# Creating Environment
num_envs = 4
envs = gym.vector.SyncVectorEnv([ make_doom_env(level_config_path='vizdoom/scenarios/basic.cfg', render=False) for i in range(num_envs)])

# Setting up agent
agent = DoomPpoAgent(envs.single_observation_space, 
                     envs.single_action_space)
                     #models_path='./models/doom_ppo_agent/training_run_2023_07_07_02_24_27/checkpoint_step_292864')
reward_predictor = DoomRewardPredictor(envs.single_observation_space.shape, 
                                       1, 
                                       hidden_size=512, 
                                       learning_rate=0.0001)

# Setting up agent training config
global_step = 0
start_datetime = datetime.now()
start_time = start_datetime.timestamp()
num_steps = 256
num_mini_batches = 32
num_training_epochs=10
batch_size = int(num_envs * num_steps)
mini_batch_size = batch_size // num_mini_batches
total_timesteps = 50000
num_updates = total_timesteps // batch_size
memory = Memory(agent.device, num_steps, num_envs, envs.single_observation_space.shape, envs.single_action_space.shape)

# Setting up debugging for Tensorboard
tensorboard_writer = SummaryWriter(f"logs/doom_basic_level/reward_predictor_training_{current_timestamp_ms()}")

In [2]:
import time
import torch
import numpy as np

observation = torch.Tensor(envs.reset()).to(agent.device)
done = torch.zeros(num_envs).to(agent.device)
best_avg_loss = float('+inf')

for update in range(1, num_updates + 1):

    for step in range(0, num_steps):
        global_step += num_envs

        # Getting next action and it's value
        with torch.no_grad():
            action, log_prob, _, value = agent.get_optimal_action_and_value(observation)
            value = value.flatten()

        observation_, reward, done_, info = envs.step(action.cpu().numpy())

        # Saving experience in memory
        memory.remember(
            step=step, 
            observation= observation,
            action=action,
            value=value,
            log_prob=log_prob,
            reward=torch.tensor(np.array(reward, dtype=np.float32)).to(agent.device).view(-1),
            done=done
        )

        # Saving new observation and done status for next step
        observation = torch.Tensor(observation_).to(agent.device) 
        done =  torch.Tensor(done_).to(agent.device)

        for item in info:
            if "episode" in item.keys():
                print(f"global_step={global_step}, episodic_return={item['episode']['r']}")
                break
    
    # Training reward predictor
    training_results = reward_predictor.train(memory.observations, memory.actions, memory.rewards, batch_size)

    print(f"Avg Training Loss: {training_results['avg_loss']}")

    # Logging to tensorboard
    tensorboard_writer.add_scalar("charts/avg_loss", training_results["avg_loss"], global_step)

    # Saving the model if current best average loss is beat
    if training_results["avg_loss"] < best_avg_loss:
        reward_predictor.save_models(f"./models/doom_reward_predictor/training_run_{start_datetime.strftime('%Y_%m_%d_%H_%M_%S')}/checkpoint_step_{global_step}")
        
        # Saving new best average loss
        best_avg_loss = training_results["avg_loss"]


    

global_step=8, episodic_return=95.0
global_step=16, episodic_return=87.0
global_step=20, episodic_return=91.0
global_step=32, episodic_return=71.0
global_step=36, episodic_return=83.0
global_step=44, episodic_return=79.0
global_step=52, episodic_return=83.0
global_step=64, episodic_return=91.0
global_step=100, episodic_return=63.0
global_step=108, episodic_return=95.0
global_step=192, episodic_return=-47.0
global_step=200, episodic_return=-70.0
global_step=264, episodic_return=-95.0
global_step=272, episodic_return=26.0
global_step=284, episodic_return=91.0
global_step=300, episodic_return=-380.0
global_step=308, episodic_return=95.0
global_step=320, episodic_return=91.0
global_step=360, episodic_return=63.0
global_step=384, episodic_return=-43.0
global_step=468, episodic_return=-121.0
global_step=492, episodic_return=-370.0
global_step=568, episodic_return=10.0
global_step=660, episodic_return=-370.0
global_step=684, episodic_return=-360.0
global_step=736, episodic_return=39.0
global_

KeyboardInterrupt: 