In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np

from env_model import ModelDyna
from minipackman import multi_env
from env_model import EnvModel
from a2c import ActorCritic, RolloutStorage
from i2a import ImaginationCore, I2A

from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter('./tensorboard/I2A')

In [7]:
num_envs = 8
mode = 'regular'
md = ModelDyna()
envs = multi_env(num_envs=num_envs, mode=mode)
state_shape = envs.observation_space.shape
num_actions = envs.action_space.n
num_rewards = len(md.mode_rewards[mode])

# load Env Model
env_model = EnvModel(envs.observation_space.shape, md.num_pixels, num_rewards)
env_model.load_state_dict(torch.load("./model/env_model_regular_5000", map_location=torch.device('cpu')))
if torch.cuda.is_available() : env_model.cuda()

distil_policy = ActorCritic(envs.observation_space.shape, envs.action_space.n)
distil_optimizer = optim.Adam(distil_policy.parameters())

imagination = ImaginationCore(1, state_shape, num_actions, num_rewards, env_model, distil_policy, full_rollout=False)
actor_critic = I2A(state_shape, num_actions, num_rewards, 256, imagination, full_rollout=False)
if torch.cuda.is_available() : 
  actor_critic.cuda()
  distil_policy.cuda()

In [11]:
# rmsprop
lr = 7e-4
eps = 1e-5
alpha = 0.99
optimizer = optim.RMSprop(actor_critic.parameters(), lr, eps=eps, alpha=alpha)

gamma = 0.99
entropy_coef = 0.01
value_loss_coef = 0.5
max_grad_norm = 0.5
num_steps = 5
num_frames = int(10e5)

rollout = RolloutStorage(num_steps, num_envs, envs.observation_space.shape)
if torch.cuda.is_available(): rollout.cuda()

all_rewards = []
all_losses = []

state = envs.reset()
current_state = torch.FloatTensor(np.float32(state))
if torch.cuda.is_available() : current_state.cuda()

rollout.states[0].copy_(current_state)

episode_rewards = torch.zeros(num_envs, 1)
final_rewards = torch.zeros(num_envs, 1)

for i_update in range(num_frames):
    for step in range(num_steps):
        action = actor_critic.act(current_state)

        next_state, reward, done, _ = envs.step(action.squeeze(1).cpu().data.numpy())

        reward = torch.FloatTensor(reward).unsqueeze(1)
        episode_rewards += reward
        masks = torch.FloatTensor(1-np.array(done)).unsqueeze(1)
        final_rewards *= masks
        final_rewards += (1-masks) * episode_rewards
        episode_rewards *= masks

        if torch.cuda.is_available(): masks = masks.cuda()
        current_state = torch.FloatTensor(np.float32(next_state))
        rollout.insert(step, current_state, action.data, reward, masks)

    with torch.no_grad():
      _, next_value = actor_critic(rollout.states[-1])
    next_value = next_value.data
  
    returns = rollout.get_batch_returns(next_value, gamma)

    logit, action_log_probs, values, entropy = actor_critic.evaluate_actions(
        rollout.states[:-1].view(-1, *state_shape),
        rollout.actions.view(-1, 1)
    )
    
    distil_logit, _, _, _ = distil_policy.evaluate_actions(
        rollout.states[:-1].view(-1, *state_shape),
        rollout.actions.view(-1, 1)
    )
        
    distil_loss = 0.01 * (F.softmax(logit).detach() * F.log_softmax(distil_logit)).sum(1).mean()

    values = values.view(num_steps, num_envs, 1)
    action_log_probs = action_log_probs.view(num_steps, num_envs, 1)
    advantages = returns - values

    value_loss = advantages.pow(2).mean()
    action_loss = -(advantages.data * action_log_probs).mean()

    optimizer.zero_grad()
    loss = value_loss * value_loss_coef + action_loss - entropy * entropy_coef
    loss.backward()
    nn.utils.clip_grad_norm_(actor_critic.parameters(), max_grad_norm)
    optimizer.step()
    
    distil_optimizer.zero_grad()
    distil_loss.backward()
    optimizer.step()
    
    writer.add_scalar('training reward', final_rewards.sum(), i_update)
    writer.add_scalar('training loss', loss.item(), i_update)

    rollout.after_update()

    if i_update % 1000 == 0 or i_update == num_frames - 1:
        print(f'{i_update} th Update :::: Rewards : {final_rewards.sum()} :::: Loss : {loss.item()}')
        torch.save(actor_critic.state_dict(), "./model/i2a_regular_" + str(i_update+1))

writer.close()
envs.close()

  probs = F.softmax(logit)
  imagined_state = F.softmax(imagined_state).max(1)[1].data.cpu()
  imagined_reward = F.softmax(imagined_reward).max(1)[1].data.cpu()
  probs = F.softmax(logit)
  log_probs = F.log_softmax(logit)
  distil_loss = 0.01 * (F.softmax(logit).detach() * F.log_softmax(distil_logit)).sum(1).mean()
  distil_loss = 0.01 * (F.softmax(logit).detach() * F.log_softmax(distil_logit)).sum(1).mean()


0 th Update :::: Rewards : 0.0 :::: Loss : 0.1786385327577591
1000 th Update :::: Rewards : 131.0 :::: Loss : 0.10291830450296402
2000 th Update :::: Rewards : 177.0 :::: Loss : 2.7502641677856445
3000 th Update :::: Rewards : 138.0 :::: Loss : 5.660292148590088
4000 th Update :::: Rewards : 234.0 :::: Loss : 1.693057656288147
5000 th Update :::: Rewards : 128.0 :::: Loss : 2.1137661933898926
6000 th Update :::: Rewards : 119.0 :::: Loss : 4.023800373077393
7000 th Update :::: Rewards : 115.0 :::: Loss : 4.701552391052246
8000 th Update :::: Rewards : 225.0 :::: Loss : 1.4390003681182861
9000 th Update :::: Rewards : 194.0 :::: Loss : 0.06058957800269127
10000 th Update :::: Rewards : 290.0 :::: Loss : 3.297708034515381
11000 th Update :::: Rewards : 210.0 :::: Loss : 0.22788883745670319
12000 th Update :::: Rewards : 264.0 :::: Loss : 28.337942123413086
13000 th Update :::: Rewards : 149.0 :::: Loss : 3.9439284801483154
14000 th Update :::: Rewards : 178.0 :::: Loss : 1.50297403335571