In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter('./tensorboard/ENV')

from minipackman import multi_env
from env_model import EnvModel, ModelDyna
from a2c import ActorCritic

In [None]:
num_envs = 8
mode = 'regular'
envs = multi_env(num_envs=num_envs, mode=mode)
state_shape = envs.observation_space.shape
num_actions = envs.action_space.n

md = ModelDyna()
env_model = EnvModel(envs.observation_space.shape, md.num_pixels, len(md.mode_rewards["regular"]))
actor_critic = ActorCritic(envs.observation_space.shape, envs.action_space.n)

In [None]:
# load pretrained(imperfect) a2c model
actor_critic.load_state_dict(torch.load('./model/a2c_regular_150000', map_location=torch.device('cpu')))

# policy hat_pi 
def get_action(state):
    if state.ndim == 4:
        state = torch.FloatTensor(np.float32(state))
    else:
        state = torch.FloatTensor(np.float32(state)).unsqueeze(0)
    
    with torch.no_grad():
      action = actor_critic.act(state)
    action = action.data.cpu().squeeze(1).numpy()
    return action


def play_games(envs, frames):
    states = envs.reset()
    
    for frame_idx in range(frames):
        actions = get_action(states)
        next_states, rewards, dones, _ = envs.step(actions)
        
        yield frame_idx, states, actions, rewards, next_states, dones
        
        states = next_states

In [None]:
# train
reward_coef = 0.1
num_updates = 5000

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(env_model.parameters())

for frame_idx, states, actions, rewards, next_states, dones in play_games(envs, num_updates):
    states = torch.FloatTensor(states)
    actions = torch.LongTensor(actions)

    batch_size = states.size(0)
    
    onehot_actions = torch.zeros(batch_size, num_actions, *state_shape[1:])
    onehot_actions[range(batch_size), actions] = 1
    inputs = torch.cat([states, onehot_actions], 1)
    
    if torch.cuda.is_available(): inputs = inputs.cuda()

    imagined_state, imagined_reward = env_model(inputs)

    target_state = md.pix_to_target(next_states) # from model's dynamic
    target_state = torch.LongTensor(target_state)
    
    target_reward = md.rewards_to_target(mode, rewards)
    target_reward = torch.LongTensor(target_reward)

    # l_model(auxilary loss) : model's dynamic ~ our env model's dynamic
    optimizer.zero_grad()
    image_loss = criterion(imagined_state, target_state)
    reward_loss = criterion(imagined_reward, target_reward)
    loss = image_loss + reward_coef * reward_loss
    loss.backward()
    optimizer.step()
    
    # log
    writer.add_scalar('training reward', rewards.sum(), frame_idx)
    writer.add_scalar('training loss', loss.item(), frame_idx)
    if frame_idx % 1000 == 0 or frame_idx == num_updates - 1:
        print(f'frame_idx : {frame_idx} :::: rewards: {rewards.sum()} :::: losses: {loss.item()}')
        torch.save(env_model.state_dict(), './model/env_model_' + mode + '_'+ str(frame_idx+1))