In [22]:
env = 'unity' # deepmindlab

log_path = './log/' + env + '/mem/50k'
exp_name ='exp'
num_actors = 6

queue_maxsize = 25
lu = 100
lo = 32
traj_no_total = 2000
max_update = 1000000
batch_size = 10

rho_bar = 1.
c_bar = 1.
gamma = 0.99
clip_reward = False

grad_clip_norm = 10.
init_lr = 0.0002
epoch = 1000

policy_loss_c = 1.
v_loss_c = 0.5
entropy_c = 0.00015
model_loss_c = 0.1

action_repeat = 6
action_dim = 9
hidden_dim = 512
act_emb_dim = 128
max_frame = 6e4

latent_dim = 16
draw_step = 8

# memory
code_size = 1024
memory_size = 64
dim_s = 128

map_dec_type = 'pixel' # deconv-complex, deconv_simple
mem_type = 'static'

model_save_path = 'model_saved'

kappa = 1e-3

if env == 'unity':
  map_size = [100., 100.]
  pos_dim = (100 // 20)**2 # 5X5 grids
else:
  map_size = [1000., 1000.]
  pos_dim = 1000 // 10
rot_dim = 360 // 18
debug = False

level_cache_dir = '/tmp/level_cache'


In [23]:
import torch
device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")

# utils

In [24]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init

# In[ ]:


class ResidualBlock(nn.Module):
    
    def __init__(self, in_channels, hid_channels, out_channels,
                 kernel_size_1, stride_1, padding_1,
                 kernel_size_2, stride_2, padding_2):
        super().__init__()
        
        self.main = nn.Sequential(
            Conv2dBlock(in_channels, hid_channels, kernel_size_1, stride_1, padding_1),
            Conv2dBlock(hid_channels, out_channels, kernel_size_2, stride_2, padding_2))
        
        self.skip = Conv2dBlock(in_channels, out_channels, kernel_size_2, stride_2, padding_2)
    
    
    def forward(self, x):
        
        return self.main(x) + self.skip(x)


# In[22]:


class Conv2dBlock(nn.Module):
    
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
        super().__init__()
        
        self.m = conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False)
        self.weight = nn.Parameter(torch.ones(out_channels))
        self.bias = nn.Parameter(torch.zeros(out_channels))
    
    
    def forward(self, x):
        
        x = self.m(x)
        return F.relu(F.group_norm(x, 8, self.weight, self.bias))


# In[ ]:


def conv2d(in_channels, out_channels, kernel_size, stride=1,
           padding=0, dilation=1, groups=1,
           bias=True, padding_mode='zeros',
           weight_init='kaiming'):
    
    m = nn.Conv2d(in_channels, out_channels, kernel_size, stride,
                  padding, dilation, groups, bias, padding_mode)
    
    if weight_init == 'xavier':
        nn.init.xavier_normal_(m.weight)
    else:
        nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
    
    if bias:
        nn.init.zeros_(m.bias)
    
    return m


# In[ ]:


class GaussianWithConv2d(nn.Module):
    
    def __init__(self, c_in, c_out):
        super(GaussianWithConv2d, self).__init__()
        
        self.enc = conv2d(c_in, 2 * c_out, 5, 1, 2, weight_init='xavier')
    
    
    def forward(self, inputs):
        # type: (Tensor) -> Tuple[Tensor, Tensor]
        mu, std = self.enc(inputs).chunk(2, 1)
        std = F.softplus(std) + 1e-5
        return mu, std


# In[13]:


class GatedCNN(nn.Module):
  
  def __init__(self, input_size, hidden_size):
    super().__init__()
    
    self.conv_h = nn.Conv2d(input_size, hidden_size*2, 5, 1, 2)
    # self.conv_h = nn.Conv2d(64, hidden_size*2, 3, 1, 1)
    
  def forward(self, h, x):
    
    h = self.conv_h(torch.cat([x, h], dim=1))
    # h = self.conv_h(x + h)
    
    h1, h2 = torch.chunk(h, 2, dim=1)
    h = torch.tanh(h1) * torch.sigmoid(h2)
    
    return h
    
class AgentCoreBaseVisEncoder(nn.Module):
  def __init__(self, c_in, c_out):
    super().__init__()

    self.conv = nn.Conv2d(c_in, c_out, 3, 1, 1)
    self.pad = nn.ZeroPad2d((0, 1, 0, 1))
    self.max_pool = nn.MaxPool2d(3, 2)
    self.residual = ResidualBlock(c_out, c_out, c_out, 3, 1, 1, 3, 1, 1)

    self.net = nn.Sequential(self.pad, self.conv, self.max_pool, self.residual)

  def forward(self, x):
    return self.net(x)

class AgentCoreDeepVisEncoder(nn.Module):
  def __init__(self):
    super().__init__()

    self.block_1 = AgentCoreBaseVisEncoder(3, 16)
    self.block_2 = AgentCoreBaseVisEncoder(16, 32)
    self.block_3 = AgentCoreBaseVisEncoder(32, 32)
    
    self.fc = nn.Linear(8*8*32, 256)

    self.net = nn.Sequential(self.block_1, self.block_2, self.block_3, nn.Flatten(), nn.ReLU(), self.fc, nn.ReLU())

  def forward(self, x):
    return self.net(x)

def weight_init(m): 
  if isinstance(m, nn.Linear):
    init.normal_(m.weight, std=0.04)

# modules_unity

In [25]:
import torch
import torch.nn as nn
import torch.nn.functional as F
# from utils import *
from torch.distributions import Categorical
import pdb


# In[ ]:
# device = torch.device("cuda:4" if torch.cuda.is_available() else "cpu")

class Policy(nn.Module):
  
  def __init__(self):
    
    super().__init__()
    self.visual_enc = AgentCoreDeepVisEncoder().to(device)
    
    self.agent_core = nn.LSTM(256+act_emb_dim, hidden_dim, batch_first=True).to(device)
    
    
  def init_state(self, B):
    h0 = torch.zeros(1, B, hidden_dim, device=device)
    c0 = torch.zeros(1, B, hidden_dim, device=device)

    return h0, c0

  def forward(self, x, act):

    B, T, C, H, W = x.shape

    x_t_enc = self.visual_enc(x.reshape(B*T, C, H, W)).reshape(B, T, -1)

    x_t_act = torch.cat([x_t_enc, act], dim=-1)

    states_h = []
    states_c = []

    h, c = self.init_state(B) # 1, B, H
    states_h.append(h)
    states_c.append(c)

    for t in range(T):

      o, (h, c) = self.agent_core(x_t_act[:, t:t+1], (h, c))
      
      states_h.append(h)
      states_c.append(c)

    states_h = torch.cat(states_h, dim=0).permute(1,0,2) # B, T+1, H
    states_c = torch.cat(states_c, dim=0).permute(1,0,2)

    return states_h, states_c

# In[ ]:


class MapDecoder(nn.Module):
  
  def __init__(self):
    
    super().__init__()
    
    #TODO
 
    if map_dec_type == 'pixel':
      self.dec = nn.Sequential(
        Conv2dBlock(hidden_dim*2, 256*4*4, 1, 1, 0),
        nn.PixelShuffle(4),
        Conv2dBlock(256, 128*4*4, 3, 1, 1),
        nn.PixelShuffle(4),
        Conv2dBlock(128, 64*4*4, 3, 1, 1),
        nn.PixelShuffle(4),
        nn.Conv2d(64, 3, 3, 1, 1),
        nn.Sigmoid()
      ).to(device)
    elif map_dec_type == 'deconv-complex':
      self.dec = nn.Sequential(
          nn.ConvTranspose2d(hidden_dim*2, 256, 4, 4),
          nn.GroupNorm(32, 256),
          nn.ReLU(),
          nn.ConvTranspose2d(256, 128, 4, 4),
          nn.GroupNorm(16, 128),
          nn.ReLU(),
          nn.ConvTranspose2d(128, 64, 4, 4),
          nn.GroupNorm(8, 64),
          nn.ReLU(),
          nn.Conv2d(64, 3, 3, 1, 1)
        ).to(device)
    elif map_dec_type == 'deconv-simple':
      self.dec = nn.Sequential(
          nn.ConvTranspose2d(hidden_dim*2, 64, 4, 4),
          nn.GroupNorm(8, 64),
          nn.ReLU(),
          nn.ConvTranspose2d(64, 32, 4, 4),
          nn.GroupNorm(4, 32),
          nn.ReLU(),
          nn.ConvTranspose2d(32, 16, 4, 4),
          nn.GroupNorm(2, 16),
          nn.ReLU(),
          nn.Conv2d(16, 3, 3, 1, 1)
        ).to(device)
    

  def forward(self, h, map_gt):
    map_pred = self.dec(h.unsqueeze(-1).unsqueeze(-1))
    loss = ((map_pred - map_gt).pow(2).sum(dim=(1,2,3))).mean()
    return map_pred, loss


# In[ ]:


class PositionDecoder(nn.Module):
  
  def __init__(self):
    
    super().__init__()
    
    self.dec = nn.Sequential(
      nn.Linear(hidden_dim*2, pos_dim),
    ).to(device)
    
  def forward(self, h, pos_gt):
    pos_gt_ = pos_gt + 50 # [-50, 50] --> [0, 100]
    pos_gt_ = pos_gt_ // 20 # 5x5 grid
    pos_cls_gt = pos_gt_[:, 0]*5 + pos_gt_[:, 1] # 0~25
    pos_pred = self.dec(h)
    pos_acc = ((pos_pred.argmax(dim=1)).long() == (pos_cls_gt).long()).float().mean()
    loss = F.cross_entropy(pos_pred, pos_cls_gt.long(), reduction='none')
    loss = loss.mean()
    return pos_acc, loss

class RotDecoder(nn.Module):
  
  def __init__(self):
    
    super().__init__()
    
    self.dec = nn.Sequential(
      nn.Linear(hidden_dim*2, rot_dim),
    ).to(device)
    
  def forward(self, h, rot_gt):
    rot_gt_ = (rot_gt % 360.) // 18 # 20 cls
    rot_pred = self.dec(h)
    rot_acc = ((rot_pred.argmax(dim=1)).long() == (rot_gt_).long()).float().mean()

    loss = F.cross_entropy(rot_pred, rot_gt_.long(), reduction='none')
    loss = loss.mean()
    return rot_acc, loss

# In[80]:


class ConvDraw(nn.Module):
  
  def __init__(self):
    
    super().__init__()
    
    self.img_enc = nn.Sequential(
      Conv2dBlock(3*2, 16, 4, 2, 1),
      Conv2dBlock(16, 16, 4, 2, 1),
      Conv2dBlock(16, 64, 4, 2, 1),
    ).to(device)
    
    self.img_dec = nn.Sequential(
      Conv2dBlock(64, 32*2*2, 3, 1, 1),
      nn.PixelShuffle(2),
      Conv2dBlock(32, 32*2*2, 3, 1, 1),
      nn.PixelShuffle(2),
      nn.Conv2d(32, 3*2*2, 3, 1, 1),
      nn.PixelShuffle(2),
    ).to(device)
    
    self.sim_core_enc = nn.Sequential(
      nn.ConvTranspose2d(hidden_dim, 256, 4, 2, 1),
      nn.GroupNorm(32, 256),
      nn.ReLU(),
      nn.ConvTranspose2d(256, 128, 4, 2, 1),
      nn.GroupNorm(16, 128),
      nn.ReLU(),
      nn.ConvTranspose2d(128, 64, 4, 2, 1),
      nn.GroupNorm(8, 64),
      nn.ReLU(),
    ).to(device)
    
    self.num_steps = draw_step

    self.decoder = nn.ModuleList([GatedCNN(64+latent_dim, 64).to(device) for _ in range(self.num_steps)])
    self.posterior = nn.ModuleList([GatedCNN(64*2, 64).to(device) for _ in range(self.num_steps)])
    self.prior = nn.ModuleList([GatedCNN(64*2, 64).to(device) for _ in range(self.num_steps)])
    
    self.gaussian_p = GaussianWithConv2d(64, latent_dim).to(device)
    self.gaussian_q = GaussianWithConv2d(64, latent_dim).to(device)

    
    
    
  def init_states(self, B):
    
    h_p = torch.zeros(B, 64, 8, 8, device=device, dtype=torch.float)
    h_q = torch.zeros(B, 64, 8, 8, device=device, dtype=torch.float)
    h_rec = torch.zeros(B, 64, 8, 8, device=device, dtype=torch.float)
    
    return h_p, h_q, h_rec
  
  def forward(self, sim_core_h, x_t, is_training=True):
    
    B = x_t.shape[0]
    
    h_p, h_q, h_rec = self.init_states(B)
    
    sim_core_h = sim_core_h.unsqueeze(-1).unsqueeze(-1)
    sim_core_h_enc = self.sim_core_enc(sim_core_h)
    
    rec_err = x_t.new_zeros(x_t.shape)
    rec_x = x_t.new_zeros(x_t.shape)
    
    kl_loss = 0.
    
    for i in range(self.num_steps):
      
      if i == 0:
        for _ in range(4):
          h_p = self.prior[i](h_p, sim_core_h_enc)
      else:
        h_p = self.prior[i](h_p, sim_core_h_enc)
      
      # x_t_enc = self.img_enc(x_t)
      
      if is_training:
        rec_err_enc = self.img_enc(torch.cat([x_t, rec_err], dim=1))
        h_q = self.posterior[i](h_q, rec_err_enc)
        
        mu_q, std_q = self.gaussian_q(h_q)
        mu_p, std_p = self.gaussian_p(h_p)
        z = self.sample(mu_q, std_q)
        kl_loss = kl_loss + self.cal_kl(mu_p, std_p, mu_q, std_q)
        
        h_rec = self.decoder[i](h_rec, z)
        x_hat = self.img_dec(h_rec)
        rec_x = rec_x + x_hat
        rec_err = x_t - rec_x

      else:
        
        mu_p, std_p = self.gaussian_p(h_p)
        z = self.sample(mu_p, std_p)
      
        h_rec = self.decoder[i](h_rec, z)
        x_hat = self.img_dec(h_rec)
        rec_x = rec_x + x_hat

      rec_x = rec_x.sigmoid()

    if is_training:

      constraint = self.constraint(x_t, rec_x)
      return rec_x, constraint.mean(), kl_loss.sum(dim=(1,2,3)).mean()

    else:
      return rec_x
        
  def constraint(self, x, rec):
    # kappa is pixel space error threshold
    return torch.sum(torch.pow(rec - x, 2), dim = (1,2,3)) - kappa * 64 * 64 * 3

  def sample(self, mu, std):
    
    noise = torch.empty_like(mu).normal_()
    
    return mu + std * noise
  
  def cal_kl(self, mu_p, std_p, mu_q, std_q):
    
    var_ratio = (std_q / std_p) ** 2
    
    return 0.5 * (((mu_q - mu_p) / std_p) ** 2 + (var_ratio - 1) - var_ratio.log())
      



# Logger

In [26]:
from torch.utils.tensorboard import SummaryWriter
from enum import Enum
from threading import Thread
from queue import Empty


class SummaryType(Enum):
    SCALAR = 1
    HISTOGRAM = 2
    VIDEO = 3
    IMAGE = 4
    FIGURE = 5
    GRAPH = 6


# Not working asynchronously
class Statistics(Thread):
    """Writes the statistics of the async processes into a tensorboard"""

    def __init__(self, writer_dir, statistics_queue, nb_episodes):

        super(Statistics, self).__init__()

        self.exit = False

        self.stats_queue = statistics_queue
        self.nb_episodes = nb_episodes

        self._writer = SummaryWriter(log_dir=writer_dir)

    def run(self):

        super(Statistics, self).run()

        # Make sure that all the logs are pushed to the tensorboard
        last_step = 0
        while True:

            try:
                summary_type, tag, data = self.stats_queue.get(timeout=1)
            except Empty:
                if self.exit:
                    break
                continue

            # Push the informations
            step = self.nb_episodes.value

            frames = step * 100

            if summary_type == summary_type.SCALAR:
                self._writer.add_scalar(tag=tag, scalar_value=data, global_step=frames)

            elif summary_type == summary_type.HISTOGRAM:
                self._writer.add_histogram(
                    tag=tag, values=data, global_step=frames, bins="tensorflow"
                )

            elif summary_type == summary_type.FIGURE:
                self._writer.add_figure(tag=tag, figure=frames, global_step=step)

            elif summary_type == summary_type.IMAGE:
                if data.dim() > 3:
                    self._writer.add_images(
                        tag=tag, img_tensor=data, global_step=frames, dataformats="NCHW"
                    )
                else:
                    self._writer.add_image(
                        tag=tag, img_tensor=data, global_step=frames, dataformats="CHW"
                    )

            elif summary_type == summary_type.VIDEO:
                self._writer.add_video(tag=tag, vid_tensor=data, global_step=frames, fps=4)

            elif summary_type == summary_type.GRAPH:
                self._writer.add_graph(model=data)

        self._writer.close()

# Learner_StoredData

In [27]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.utils as vutils
# from modules_unity import *
from torch.utils.tensorboard import SummaryWriter
import time
# from env_utils import MetricLogger, plot_traj_step
import os
import pickle
import pdb

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.float

class Learner(object):

  def __init__(
    self,
    id,
    policy,
    train_dl,
    val_dl,
    
    timeout=200
  ):

    self.id = id
    self.policy = policy

    # todo: learning rate

    self.timeout = timeout

    
    self.train_dl = train_dl
    self.val_dl = val_dl

    self.map_decoder = MapDecoder()
    self.pos_decoder = PositionDecoder()
    self.rot_decoder = RotDecoder()
    self.conv_draw = ConvDraw()
    self.sim_core = nn.LSTMCell(act_emb_dim, hidden_dim).to(device)
    self.action_emb_layer = nn.Embedding(6, act_emb_dim).to(device)

    self.params = [*self.policy.parameters(),
      *self.map_decoder.parameters(), *self.pos_decoder.parameters(), *self.rot_decoder.parameters(),
      *self.conv_draw.parameters(), *self.sim_core.parameters(), *self.action_emb_layer.parameters()]

    # self.params = [*self.policy.parameters()]

    self.optimizer = torch.optim.Adam(
      self.params, lr=init_lr, betas=(0, 0.999), eps=3.125e-7
    )
    # self.scheduler = torch.optim.lr_scheduler.LambdaLR(self.optimizer, lambda epoch: 0.95)
    self.timer = MetricLogger()

#     exp_name = exp_name + '-lr{}-h{}-kappa{}-zdim{}-mapdec_{}-maxf{}'.format(init_lr,
#                                                           hidden_dim,
#                                                           kappa,
#                                                           latent_dim, map_dec_type, max_frame)
#     log_path = os.path.join(log_path, exp_name)
    self.writer = SummaryWriter(log_dir=log_path)

    self.model_save_path = os.path.join(log_path, model_save_path)
    if not os.path.exists(self.model_save_path):
      os.makedirs(self.model_save_path)

    self.model_file_list = os.path.join(self.model_save_path, 'model_list.pkl')
    if not os.path.exists(self.model_file_list):
      with open(self.model_file_list, 'wb') as f:
        model_list = []
        pickle.dump(model_list, f)

  def run(self):
    self.lambd = torch.empty(1, device=device).fill_(1000.)

    self.best_loss = 1000000.

    train_dl = self.train_dl
    val_dl = self.val_dl
    for e in range(epoch):
      self.learn(e, train_dl)
      self.generate(e, val_dl)


  def enable_eval(self):

    self.policy.eval()
    self.map_decoder.eval()
    self.pos_decoder.eval()
    self.rot_decoder.eval()
    self.conv_draw.eval()
    self.sim_core.eval()
    self.action_emb_layer.eval()

  def enable_training(self):

    self.policy.train()
    self.map_decoder.train()
    self.pos_decoder.train()
    self.rot_decoder.train()
    self.conv_draw.train()
    self.sim_core.train()
    self.action_emb_layer.train()

  def visual_test(self, traj):

    self.enable_eval()

    with torch.no_grad():

      agent_view = traj['agent_view'].to(device) # B, T, C, H, W
      map_top_down = traj['map_top_down'].to(device) # B, T, C, H, W
      actions = traj['actions'].to(device) # B, T, 1
      pos = traj['pos'].to(device) # B, T, 2
      rot = traj['rot'].to(device) # B, T

      B, T, C, H, W = map_top_down.shape

      action_0 = actions.new_zeros(B, 1, 1)
      actions = torch.cat([action_0, actions[:, :-1]], dim=1) # B, T+1, 1
      act_emb = self.action_emb_layer(actions.long()).reshape(B, T, act_emb_dim)

      state_h, state_c = self.policy(agent_view, act_emb) # B, T+1, H

      us = torch.cat([state_h[:, 1:-1], state_c[:, 1:-1].tanh()], dim=-1)
      map_pred, map_loss = self.map_decoder(us.reshape(B*(T-1), -1).detach(),
                                            map_top_down[:, :-1].reshape(B*(T-1), C, H, W))

      pos_acc, pos_loss = self.pos_decoder(us.reshape(B*(T-1), -1).detach(), pos[:,:-1].reshape(B*(T-1), -1))
      rot_acc, rot_loss = self.rot_decoder(us.reshape(B*(T-1), -1).detach(), rot[:,:-1].reshape(B*(T-1)))

      agent_view_rec, agent_view_gen, agent_view_gt, constraint, kl_loss, agent_mse = self.sim_core_rollout((state_h,state_c), act_emb.reshape(B, T, -1), agent_view, is_training=False)

      seq_gt, seq_gen, gen_mse = self.sim_core_rollout_seq((state_h,state_c), act_emb.reshape(B, T, -1), agent_view, t_i=50)
      rollout_t = seq_gt.shape[1]

    map_traj_batch = []
    for i in range(B):
      map_traj = plot_traj_step(map_top_down[i, 0].permute(1,2,0).cpu().numpy(), pos[i].cpu().numpy())
      map_traj = torch.tensor(map_traj).float().permute(0, 3, 1, 2)
      map_traj_batch.append(map_traj)

    map_traj_batch = torch.stack(map_traj_batch, dim=0)

    return map_traj_batch, map_pred.reshape(B, T-1, C, H, W), seq_gen, agent_view


  def generate(self, e, val_dl):

    val_log = dict(
      map_loss_total = 0.,
      pos_loss_total = 0.,
      rot_loss_total = 0.,
      kl_loss_total = 0.,
      model_loss_total = 0.,
      pos_acc_total = 0.,
      rot_acc_total = 0.,
      gen_mse_total = 0.,
    )

    self.enable_eval()

    with torch.no_grad():

      for i, traj in enumerate(val_dl):

        agent_view = traj['agent_view'].to(device) # B, T, C, H, W
        map_top_down = traj['map_top_down'].to(device) # B, T, C, H, W
        actions = traj['actions'].to(device) # B, T, 1
        pos = traj['pos'].to(device) # B, T, 2
        rot = traj['rot'].to(device) # B, T

        B, T, C, H, W = map_top_down.shape

        action_0 = actions.new_zeros(B, 1, 1)
        actions = torch.cat([action_0, actions[:, :-1]], dim=1) # B, T+1, 1
        act_emb = self.action_emb_layer(actions.long()).reshape(B, T, act_emb_dim)

        state_h, state_c = self.policy(agent_view, act_emb) # B, T+1, H

        us = torch.cat([state_h[:, 1:-1], state_c[:, 1:-1].tanh()], dim=-1)
        map_pred, map_loss = self.map_decoder(us.reshape(B*(T-1), -1).detach(),
                                              map_top_down[:, :-1].reshape(B*(T-1), C, H, W))

        pos_acc, pos_loss = self.pos_decoder(us.reshape(B*(T-1), -1).detach(), pos[:,:-1].reshape(B*(T-1), -1))
        rot_acc, rot_loss = self.rot_decoder(us.reshape(B*(T-1), -1).detach(), rot[:,:-1].reshape(B*(T-1)))

        agent_view_rec, agent_view_gen, agent_view_gt, constraint, kl_loss, agent_mse = self.sim_core_rollout((state_h,state_c), act_emb.reshape(B, T, -1), agent_view, is_training=False)

        seq_gt, seq_gen, gen_mse = self.sim_core_rollout_seq((state_h,state_c), act_emb.reshape(B, T, -1), agent_view)
        rollout_t = seq_gt.shape[1]

        model_loss = kl_loss + self.lambd * constraint

        val_log['map_loss_total'] += map_loss
        val_log['pos_loss_total'] += pos_loss
        val_log['rot_loss_total'] += rot_loss
        val_log['kl_loss_total'] += kl_loss
        val_log['model_loss_total'] += model_loss
        val_log['pos_acc_total'] += pos_acc
        val_log['rot_acc_total'] += rot_acc
        val_log['gen_mse_total'] += gen_mse

      map_traj = plot_traj_step(map_top_down[0, 0].permute(1,2,0).cpu().numpy(), pos[0].cpu().numpy())
      map_traj = torch.tensor(map_traj).float().permute(0, 3, 1, 2)

    self.writer.add_image("Validation/map_dec",
        vutils.make_grid(torch.cat([map_traj.reshape(T, C, H, W)[:-1],
                                    map_pred.detach().cpu().reshape(B, T-1, C, H, W)[0]], dim=0), nrow=100),
        e)

    self.writer.add_image("Validation/rollout_agent_views",
        vutils.make_grid(torch.cat([seq_gt[0], seq_gen.detach()[0]], dim=0), nrow=rollout_t),
        e)

    total_loss = 0.
    for n, l in val_log.items():
      self.writer.add_scalar("Validation/" + n, l / len(val_dl), e)
      total_loss += l.item()

    total_loss = total_loss / len(val_dl)

    if total_loss < self.best_loss:
      filename = os.path.join(self.model_save_path, 'best_model.pth')
      self.save(e, filename=filename)
      self.best_loss = total_loss

    self.save(e)

  def learn(self, e, train_dl):

    self.enable_training()

    for i, traj in enumerate(train_dl):

      self.update_count = e * len(train_dl) + i

      map_loss = 0.
      pos_loss = 0.
      model_loss = 0.

      traj_batch = []
      start = time.time()
      
      agent_view = traj['agent_view'].to(device) # B, T, C, H, W
      map_top_down = traj['map_top_down'].to(device) # B, T, C, H, W
      actions = traj['actions'].to(device) # B, T, 1
      pos = traj['pos'].to(device) # B, T, 2
      rot = traj['rot'].to(device) # B, T
      B, T, C, H, W = map_top_down.shape

      action_0 = actions.new_zeros(B, 1, 1)
      actions = torch.cat([action_0, actions[:, :-1]], dim=1) # B, T+1, 1
      act_emb = self.action_emb_layer(actions.long()).reshape(B, T, act_emb_dim)

      end = time.time()
      data_proc_time = end - start
      self.timer['data_proc_time'] = data_proc_time
      start = end

      state_h, state_c = self.policy(agent_view, act_emb) # B, T+1, H

      end = time.time()
      agent_core_time = end - start
      self.timer['agent_core_time'] = agent_core_time
      start = end

      us = torch.cat([state_h[:, 1:-1], state_c[:, 1:-1].tanh()], dim=-1)
      map_pred, map_loss = self.map_decoder(us.detach().reshape(B*(T-1), -1),
                                            map_top_down[:, :-1].reshape(B*(T-1), C, H, W))

      pos_acc, pos_loss = self.pos_decoder(us.detach().reshape(B*(T-1), -1), pos[:,:-1].reshape(B*(T-1), -1))
      rot_acc, rot_loss = self.rot_decoder(us.detach().reshape(B*(T-1), -1), rot[:,:-1].reshape(B*(T-1)))

      end = time.time()
      loss_time = end - start
      self.timer['pos_map_time'] = loss_time
      start = end

      if self.update_count % 100 == 0:

        agent_view_rec, agent_view_gen, agent_view_gt, constraint, kl_loss, agent_mse = self.sim_core_rollout((state_h,state_c), act_emb.reshape(B, T, -1), agent_view, is_training=False)

      else:

        agent_view_rec, agent_view_gt, constraint, kl_loss, agent_mse = self.sim_core_rollout((state_h,state_c), act_emb.reshape(B, T, -1), agent_view)

      model_loss = kl_loss + self.lambd * constraint

      with torch.no_grad():
        if self.update_count == 0:
          self.constrain_ma = constraint.clone()
        else:
          self.constrain_ma = 0.99 * self.constrain_ma.detach_() + (1 - 0.99) * constraint

        if self.update_count % 100 == 0:
          self.lambd = self.lambd * torch.clamp(torch.exp(self.constrain_ma), 0.9, 1.1)
          self.lambd = torch.clamp(self.lambd, 0., 1000.)

      end = time.time()
      loss_time = end - start
      self.timer['conv_draw_time'] = loss_time
      start = end

      total_loss = model_loss + map_loss + pos_loss + rot_loss

      self.optimizer.zero_grad()
      total_loss.backward()

      torch.nn.utils.clip_grad_norm_(self.params, grad_clip_norm)

      end = time.time()
      backward_time = end - start
      self.timer['backward_time'] = backward_time
      start = end

      if self.update_count % 500 == 0:
        self.log_grad()

      self.optimizer.step()
      #TODO fix lr scheduler
      lr = poly_lr_scheduler(self.optimizer, init_lr, self.update_count, max_frame)

      end = time.time()
      step_time = end - start
      self.timer['step_time'] = step_time
      start = end

      if self.update_count % 100 == 0:
        losses = dict(
          map_dec_loss = map_loss.detach().item(),
          position_loss = pos_loss.detach().item(),
          rot_loss = rot_loss.detach().item(),
          kl_loss = kl_loss.detach().item(),
          overshoot_loss = constraint.detach().item(),
          model_loss = model_loss.detach().item(),
          total_loss = total_loss.detach().item()
        )
        self.log_loss(losses)

        self.writer.add_scalar("learner/lr", lr, self.update_count)
        self.writer.add_scalar("geco/geco-lambda", self.lambd, self.update_count)
        self.writer.add_scalar("geco/geco-rec-mse", agent_mse.item(), self.update_count)
        self.writer.add_scalar("prediction/pos/batch_rot_acc", rot_acc, self.update_count)
        self.writer.add_scalar("prediction/pos/batch_pos_acc", pos_acc, self.update_count)

        # self.logger_q.put(
        #   ( SummaryType.VIDEO, 
        #     "learner/agent_views/agent_views",
        #     agent_view[:1])
        # )
        # self.logger_q.put(
        #   ( SummaryType.VIDEO, 
        #     "learner/agent_views/map_top_down",
        #     map_top_down[:1])
        # )

        self.writer.add_image("learner/map_dec",
            vutils.make_grid(torch.cat([map_top_down.reshape(B, T, C, H, W)[0, 50:-1],
                                        map_pred.detach().reshape(B, T-1, C, H, W)[0, 50:]], dim=0), nrow=50),
            self.update_count)

        self.writer.add_image("learner/overshooting/agent_views",
            vutils.make_grid(torch.cat([agent_view_gt[:6], agent_view_rec.detach()[:6], agent_view_gen.detach()[:6]], dim=0), nrow=6),
            self.update_count)

        self.writer.flush()

        print("=========== Losses ===========\n")
        for n, l in losses.items():
          print("{}:\t\t {:.6f}".format(n, l))

        print("=========== Timer Measure ===========\n")
        for key in self.timer.values:
          print("{}:\t\t {:.6f}".format(key, self.timer[key].avg))

  def log_grad(self):

    for tag, parm in self.policy.named_parameters():
      self.writer.add_scalar("gradient/policy/" + tag, parm.grad.data.norm(2), self.update_count)

    for tag, parm in self.sim_core.named_parameters():
      self.writer.add_scalar("gradient/sim_core/" + tag, parm.grad.data.norm(2), self.update_count)

    for tag, parm in self.conv_draw.named_parameters():
      self.writer.add_scalar("gradient/conv_draw/" + tag, parm.grad.data.norm(2), self.update_count)

    for tag, parm in self.pos_decoder.named_parameters():
      self.writer.add_scalar("gradient/pos_dec/" + tag, parm.grad.data.norm(2), self.update_count)

    for tag, parm in self.rot_decoder.named_parameters():
      self.writer.add_scalar("gradient/rot_dec/" + tag, parm.grad.data.norm(2), self.update_count)

    for tag, parm in self.map_decoder.named_parameters():
      self.writer.add_scalar("gradient/map_decoder/" + tag, parm.grad.data.norm(2), self.update_count)

  def log_loss(self, losses):

    for n, l in losses.items():
      self.writer.add_scalar("loss/"+n, l, self.update_count)

  def sim_core_rollout_seq(self, bt, actions, agent_view, t_i=None):
    B, T, C, H, W = agent_view.shape
    # sample starting point for trajectories
    if t_i == None:
      t_i = torch.randint(1, lu, (1, ), device=device)
    else:
      t_i = torch.tensor([t_i]).to(device)

    rollout_t = lu - t_i

    h, c = bt
    bt = h[:, t_i]
    bct = c[:, t_i]
    h_sim, c_sim = bt.reshape(B, hidden_dim), bct.reshape(B, hidden_dim)

    h_sim_list = []

    for j in range(rollout_t):
      h_sim, c_sim = self.sim_core(actions[:, t_i+j].reshape(B, act_emb_dim).float(), (h_sim, c_sim))
      h_sim_list.append(h_sim.reshape(B, hidden_dim))

    h_sim = torch.stack(h_sim_list, dim=1).reshape(B*rollout_t, hidden_dim) # B, t, h

    agent_view_gt = agent_view[:, t_i[0]:t_i[0]+rollout_t].reshape(B*rollout_t, C, H, W) # B, t, C, H, W, shift one step

    agent_view_gen = self.conv_draw(h_sim, agent_view_gt, is_training=False)

    with torch.no_grad():
      gen_mse = (agent_view_gen - agent_view_gt).pow(2).sum(dim=(1,2,3)).mean()

    return agent_view_gt.reshape(B, rollout_t, C, H, W), agent_view_gen.reshape(B, rollout_t, C, H, W), gen_mse

  def sim_core_rollout(self, bt, actions, agent_view, is_training=True):

    B, T, C, H, W = agent_view.shape
    # sample starting point for trajectories
    t_i = torch.randint(1, lu, (2, ), device=device)

    # sample points to eval likelihood
    delta_k_i_list = []
    for tt in t_i:
      delta_k_i = torch.randint(1, min(lu-tt+1, lo+1), (6, ), device=device)
      delta_k_i_list.append(delta_k_i)

    rollout_t = torch.stack(delta_k_i_list).max()

    h, c = bt
    bt = h[:, t_i]
    bct = c[:, t_i]
    h_sim, c_sim = bt.reshape(B*2, hidden_dim), bct.reshape(B*2, hidden_dim)

    h_sim_list_1 = []
    h_sim_list_2 = []

    action_sim = self.pad_actions(actions, t_i, rollout_t)

    for j in range(rollout_t):
      h_sim, c_sim = self.sim_core(action_sim[:, t_i+j].reshape(B*2, act_emb_dim).float(), (h_sim, c_sim))
      h1, h2 = h_sim.reshape(B, 2, hidden_dim).chunk(2, dim=1)
      h_sim_list_1.append(h1.reshape(B, hidden_dim))
      h_sim_list_2.append(h2.reshape(B, hidden_dim))

    h_sim_1 = torch.stack(h_sim_list_1, dim=1)[:, delta_k_i_list[0] - 1] # B, 6, h
    h_sim_2 = torch.stack(h_sim_list_2, dim=1)[:, delta_k_i_list[1] - 1]

    agent_view_1 = agent_view[:, t_i[0]:t_i[0]+rollout_t][:, delta_k_i_list[0]-1] # B, 6, C, H, W, shift one step
    agent_view_2 = agent_view[:, t_i[1]:t_i[1]+rollout_t][:, delta_k_i_list[1]-1]


    h_sim = torch.cat([h_sim_1, h_sim_2], dim=1).reshape(B*6*2, hidden_dim)
    agent_view_gt = torch.cat([agent_view_1, agent_view_2], dim=1).reshape(B*6*2, C, H, W)

    agent_view_rec, constraint, kl_loss = self.conv_draw(h_sim, agent_view_gt, is_training=True)

    with torch.no_grad():
      rec_mse = (agent_view_rec - agent_view_gt).pow(2).sum(dim=(1,2,3)).mean()

    if not is_training:

      with torch.no_grad():
        agent_view_gen = self.conv_draw(h_sim, agent_view_gt, is_training=False)

      return agent_view_rec, agent_view_gen, agent_view_gt, constraint, kl_loss, rec_mse

    else:
      return agent_view_rec, agent_view_gt, constraint, kl_loss, rec_mse

  def pad_actions(self, action, start_t, rollout_len):
    pad_len = (start_t + rollout_len).max() - 5
    if pad_len > 0:
      action_pad = action.new_zeros(action.shape[0], pad_len, act_emb_dim)
      action = torch.cat([action, action_pad], dim=1)

    return action

  def load(self, path):
    """ Load model parameters """
    checkpoint = torch.load(path)
    self.policy.load_state_dict(checkpoint["policy"])
    self.sim_core.load_state_dict(checkpoint["sim_core"])
    self.conv_draw.load_state_dict(checkpoint["conv_draw"])
    self.pos_decoder.load_state_dict(checkpoint["pos_decoder"])
    self.rot_decoder.load_state_dict(checkpoint["rot_decoder"])
    self.map_decoder.load_state_dict(checkpoint["map_decoder"])
    self.action_emb_layer.load_state_dict(checkpoint["action_emb_layer"])
    global_step = checkpoint['global_step']

  def save(self, global_step, maxnum=3, filename=None):
    if filename is None:
      filename = os.path.join(self.model_save_path, 'model_{:09}.pth'.format(global_step+1))

      with open(self.model_file_list, 'rb+') as f:
        model_list = pickle.load(f)
        if len(model_list) >= maxnum:
          if os.path.exists(model_list[0]):
            os.remove(model_list[0])
          del model_list[0]
        model_list.append(filename)
      with open(self.model_file_list, 'rb+') as f:
        pickle.dump(model_list, f)

    torch.save(
        {
            "policy": self.policy.state_dict(),
            "map_decoder": self.map_decoder.state_dict(),
            "rot_decoder": self.rot_decoder.state_dict(),
            "pos_decoder": self.pos_decoder.state_dict(),
            "conv_draw": self.conv_draw.state_dict(),
            "sim_core": self.sim_core.state_dict(),
            "action_emb_layer": self.action_emb_layer.state_dict(),
            "global_step": global_step,
        },
        filename,
    )
    print("save model to {}".format(filename))


def poly_lr_scheduler(optimizer, init_lr, num_frame, max_frame, lr_decay_iter=1,
                      power=1.):
  """Polynomial decay of learning rate
    :param init_lr is base learning rate
    :param iter is a current iteration
    :param lr_decay_iter how frequently decay occurs, default is 1
    :param max_iter is number of maximum iterations
    :param power is a polymomial power

  """
  if num_frame % lr_decay_iter or num_frame > max_frame:
    return optimizer

  ratio = min(2/3., num_frame/max_frame)
  lr = init_lr*(1 - ratio)**power
  for param_group in optimizer.param_groups:
    param_group['lr'] = lr
  
  return lr

# Env Utils

In [28]:
from typing import List
import torch
import torch.multiprocessing as mp
import numpy as np
import os.path
import shutil
from PIL import Image, ImageDraw
from collections import deque, defaultdict, namedtuple
import pdb

# device = torch.device("cuda:4" if torch.cuda.is_available() else "cpu")

class Counter:
  def __init__(self, init_val: int = 0):
    self._val = mp.RawValue("i", init_val)
    self._lock = mp.Lock()

  def increment(self):
    with self._lock:
      self._val.value += 1

  @property
  def value(self):
    with self._lock:
      return self._val.value

class Timer():
  def __init__(self, maxsize=20):
    self.values = deque(maxlen=maxsize)
    self.count = 0.
    self.sum = 0.

  def update(self, value):
    self.values.append(value)
    self.sum += value
    self.count += 1

  @property
  def avg(self):
    return np.mean(self.values)

  @property
  def global_avg(self):
    return self.sum / self.count

class MetricLogger():
  def __init__(self):
    self.values = defaultdict(Timer)

  def update(self, **kargs):
    for key, value in kargs.items():
      self.values[key].update(value)

  def __getitem__(self, key):
    return self.values[key]

  def __setitem__(self, key, value):
    self.values[key].update(value)

def dec_agent_pos_rot(pos_rot_vec):
  pos = pos_rot_vec[:3]
  rot = pos_rot_vec[3:]
  return pos, rot

def dec_building_pos_scale(pos_scale_vec):
  pos = pos_scale_vec[:27].reshape(9, 3)
  scale = pos_scale_vec[27:54].reshape(9, 3)
  
  return pos, scale

def comp_min_distance_angle(agent_pos, building_pos, building_scale):
  """
  agent_pos: (3,)
  building_pos: (9, 3)
  building_scale: (9, 3)
  
  compute agent and building distance on x-z plane.
  """
  
  num_building = np.sum(np.sum(building_pos, axis=1) != 0)
  
  dis = (agent_pos - building_pos)**2
  dis = np.sqrt(dis[:, 0] + dis[:, 2])
  
  building_radius = (building_scale / 2.)**2
  building_radius = np.sqrt(building_radius[:, 0] + building_radius[:, 2])
  
  dis = dis - building_radius + 5. # distance to the building surface
  
  dis_min_idx = np.argmin(dis[:num_building])
  
  min_building_pos = building_pos[dis_min_idx]
  dis_min = dis[dis_min_idx]
  
  x_offset = min_building_pos[0] - agent_pos[0]
  z_offset = min_building_pos[2] - agent_pos[2]
  cos_theta = x_offset / (np.sqrt(x_offset**2 + z_offset**2) + 1e-5)
  sin_theta = z_offset / (np.sqrt(x_offset**2 + z_offset**2) + 1e-5)
  
  theta = np.degrees(np.arcsin(sin_theta))
  if np.isnan(theta):
    pdb.set_trace()
  if cos_theta < 0.:
    theta = 180. - theta
    
  theta = theta % 360.
  return dis_min, theta

def dec_top_down_map(building_pos_scale_color):
  pos = building_pos_scale_color[:27].reshape(9, 3)
  scale = building_pos_scale_color[27:54].reshape(9, 3)
  color = building_pos_scale_color[54:].reshape(9, 3)
  
  num_building = np.sum(np.sum(pos, axis=1) != 0)

  height = scale[:num_building, 1]
  
  top_down_map = np.zeros((num_building, 64, 64, 3))
  
  for i in range(num_building):
    pos_i_x = (pos[i, 0] + 50) / 100 * 64
    pos_i_z = (pos[i, 2] + 50) / 100 * 64
    
    scale_i_x = scale[i, 0] / 100 * 64
    scale_i_z = scale[i, 2] / 100 * 64
    
    top_down_map[i][64 - int(pos_i_z + scale_i_z // 2) : 64 - int(pos_i_z - scale_i_z // 2 + 1),
                   int(pos_i_x - scale_i_x // 2) : int(pos_i_x + scale_i_x // 2 + 1)] = color[i]
    
  sort_idx = np.argsort(height)
  maps = np.zeros((64, 64, 3))

  for idx in sort_idx:
    
    overlap = (maps > 0.) * (top_down_map[idx] > 0.)
    
    maps = maps * (1. - overlap) + top_down_map[idx]
    
  return maps

def plot_traj(map_top_down, agent_pos):
  """
  agent_pos: T, 2
  """
  
  agent_pos = np.array(agent_pos) + 50
  
  im = Image.fromarray((map_top_down * 255).astype(np.uint8))
  draw = ImageDraw.Draw(im)
  num = agent_pos.shape[0]
  start = agent_pos[0]
  for i in range(1, num):
    end = agent_pos[i]
    draw.line(
      (int(start[0] / 100 * 64), int((100 - start[1]-1) / 100 * 64),
        int(end[0] / 100 * 64), int((100 - end[1]-1) / 100 * 64)),
      fill='red',
      width = 1,
    )
    start = end
    
  return np.asarray(im, dtype='int32')

def plot_traj_step(map_top_down, agent_pos):
  """
  agent_pos: T, 2
  map_top_down: H, W, C
  """
  
  agent_pos = np.array(agent_pos) + 50
  
  im = Image.fromarray((map_top_down * 255).astype(np.uint8))
  draw = ImageDraw.Draw(im)
  num = agent_pos.shape[0]
  start = agent_pos[0]
  map_traj_list = []
  for i in range(num):
    end = agent_pos[i]
    draw.line(
      (int(start[0] / 100 * 64), int((100 - start[1]-1) / 100 * 64),
        int(end[0] / 100 * 64), int((100 - end[1]-1) / 100 * 64)),
      fill='red',
      width = 1,
    )
    start = end
    map_traj_list.append(np.asarray(im, dtype='int32'))

  map_traj = np.array(map_traj_list).astype(np.float) / 255. # B, H, W, C
    
  return map_traj

def agent_building_angle(agent_pos, agent_rot, building_pos, building_scale):
  """
  agent_rot: scalar
  agent_pos: (3,)
  building_pos: (9, 3)
  building_scale: (9, 3)
  
  compute agent and building distance on x-z plane.
  """
  
  num_building = np.sum(np.sum(building_pos, axis=1) != 0)
  
  x_offset = building_pos[:num_building, 0] - agent_pos[0]
  z_offset = building_pos[:num_building, 2] - agent_pos[2]
  cos_theta = x_offset / (np.sqrt(x_offset**2 + z_offset**2) + 1e-5)
  sin_theta = z_offset / (np.sqrt(x_offset**2 + z_offset**2) + 1e-5)
  
  theta = np.degrees(np.arcsin(sin_theta))
  if np.isnan(theta).any():
    pdb.set_trace()
  for i in range(num_building):
    if cos_theta[i] < 0.:
      theta[i] = 180. - theta[i]
    
  theta = theta % 360.
  
  agent_rot_ = (-agent_rot + 90) % 360 # opposite rotation direction, agent facing y+ is 0 rotation
  angle_diff = abs(agent_rot_ - theta)

  return theta, angle_diff

# Random city dataset

In [29]:
from torch.utils.data import Dataset
from glob import glob
import gzip
import pickle

class RandomCityDataset(Dataset):
  def __init__(self, split='training'):
    traj_files = sorted(glob('./Datasets/50k/act*'))
    if split == 'training':
      start_idx = 0
      end_idx = 50
    elif split == 'val':
      start_idx = 50
      end_idx = 100
    elif split == 'testing':
      start_idx = 100
      end_idx = 150

    self.traj_files = traj_files[start_idx: end_idx]

  def __len__(self):
    return len(self.traj_files)

  def __getitem__(self, idx):
    traj_file = self.traj_files[idx]

    with gzip.open(traj_file, 'r') as f:
      traj = pickle.load(f)

    agent_view = traj.agent_view
    map_top_down = traj.map_top_down
    actions = traj.a
    pos = traj.pos
    rot = traj.rot

    sample = {
      'agent_view': agent_view,
      'map_top_down': map_top_down,
      'actions': actions,
      'pos': pos,
      'rot': rot
    }
    return sample

# train_unity_stored_data

In [30]:
Trajectory = namedtuple(
    "Trajectory",
    [
        "agent_view",
        "map_top_down",
        "a",
        "pos",
        "rot",
    ],
)

In [31]:
import torch
# from learner_memory import Learner
# from env_utils import *
# from modules_memory import *
# from unity_dataset import RandomCityDataset
from torch.utils.data import DataLoader
# from config import cfg_argparser
# import argparse
import os
import pdb

if __name__ == "__main__":

  log_path = os.path.join(log_path, exp_name + '-lr{}'.format(init_lr))
  if not os.path.exists:
    os.mkdirs(log_path)

  train_ds = RandomCityDataset()
  train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=8)
  print("data_loader length: {}".format(len(train_dl)))

  val_ds = RandomCityDataset(split='val')
  val_dl = DataLoader(val_ds, batch_size=batch_size, shuffle=True, num_workers=8)
  print("data_loader length: {}".format(len(val_dl)))
    
  policy = Policy()
  learner = Learner(1, policy, train_dl, val_dl)
  learner.run()

data_loader length: 5
data_loader length: 5

map_dec_loss:		 971.870300
position_loss:		 3.231124
rot_loss:		 3.020029
kl_loss:		 420.614136
overshoot_loss:		 829.927979
model_loss:		 830348.625000
total_loss:		 831326.750000

data_proc_time:		 0.027853
agent_core_time:		 0.100649
pos_map_time:		 0.025995
conv_draw_time:		 0.285321
backward_time:		 0.463342
step_time:		 0.062497
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/best_model.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000001.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/best_model.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000002.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/best_model.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000003.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/best_model.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000004.pth
save m

save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000059.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/best_model.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000060.pth

map_dec_loss:		 181.363220
position_loss:		 2.716399
rot_loss:		 2.954719
kl_loss:		 1555.620728
overshoot_loss:		 56.015743
model_loss:		 57571.363281
total_loss:		 57758.394531

data_proc_time:		 0.019716
agent_core_time:		 0.044544
pos_map_time:		 0.003018
conv_draw_time:		 0.131515
backward_time:		 0.481531
step_time:		 0.021028
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000061.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000062.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000063.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000064.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000065.pth
save model to ./log/unity/mem/50k/exp

save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000129.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000130.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/best_model.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000131.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000132.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000133.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/best_model.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000134.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000135.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000136.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/best_model.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000137.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_00

save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000201.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/best_model.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000202.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000203.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000204.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000205.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000206.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000207.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/best_model.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000208.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000209.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000210.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/mod

save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000280.pth

map_dec_loss:		 128.262268
position_loss:		 2.155519
rot_loss:		 2.359056
kl_loss:		 1129.895630
overshoot_loss:		 13.312765
model_loss:		 14442.660156
total_loss:		 14575.437500

data_proc_time:		 0.020278
agent_core_time:		 0.046138
pos_map_time:		 0.003298
conv_draw_time:		 0.131290
backward_time:		 0.479063
step_time:		 0.021150
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000281.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000282.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000283.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000284.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000285.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000286.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000287.pth
save model to ./log/unity/mem/50


map_dec_loss:		 83.421837
position_loss:		 1.919315
rot_loss:		 2.258483
kl_loss:		 1159.153687
overshoot_loss:		 11.660733
model_loss:		 12819.886719
total_loss:		 12907.486328

data_proc_time:		 0.019840
agent_core_time:		 0.045072
pos_map_time:		 0.003086
conv_draw_time:		 0.132299
backward_time:		 0.484500
step_time:		 0.022494
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000361.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000362.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/best_model.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000363.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000364.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000365.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000366.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000367.pth
save model to ./log/unity/mem/50k/exp-

save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000438.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000439.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000440.pth

map_dec_loss:		 43.704590
position_loss:		 1.524694
rot_loss:		 1.698326
kl_loss:		 1222.848511
overshoot_loss:		 8.190513
model_loss:		 9413.361328
total_loss:		 9460.289062

data_proc_time:		 0.019350
agent_core_time:		 0.045476
pos_map_time:		 0.003068
conv_draw_time:		 0.131312
backward_time:		 0.485810
step_time:		 0.020636
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000441.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000442.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000443.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000444.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000445.pth
save model to ./log/unity/mem/50k/ex

save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000520.pth

map_dec_loss:		 22.649168
position_loss:		 1.432682
rot_loss:		 1.606212
kl_loss:		 1050.640869
overshoot_loss:		 6.105153
model_loss:		 7155.793945
total_loss:		 7181.481934

data_proc_time:		 0.019936
agent_core_time:		 0.045541
pos_map_time:		 0.003115
conv_draw_time:		 0.132498
backward_time:		 0.483903
step_time:		 0.021730
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000521.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000522.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/best_model.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000523.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000524.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000525.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000526.pth
save model to ./log/unity/mem/50k/exp-lr0

save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000600.pth

map_dec_loss:		 14.167305
position_loss:		 1.055043
rot_loss:		 1.239987
kl_loss:		 1066.264160
overshoot_loss:		 6.085365
model_loss:		 7151.629395
total_loss:		 7168.091797

data_proc_time:		 0.020201
agent_core_time:		 0.045386
pos_map_time:		 0.003139
conv_draw_time:		 0.133338
backward_time:		 0.483004
step_time:		 0.022717
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000601.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000602.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000603.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000604.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000605.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000606.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000607.pth
save model to ./log/unity/mem/50k/ex

save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000681.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000682.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000683.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000684.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000685.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000686.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000687.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000688.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000689.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000690.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000691.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000692.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model

save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000762.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000763.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000764.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000765.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000766.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000767.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000768.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000769.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000770.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000771.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000772.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000773.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model

save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000843.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000844.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000845.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000846.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000847.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000848.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000849.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000850.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000851.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000852.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000853.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000854.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model

save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000927.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000928.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000929.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000930.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000931.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000932.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000933.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000934.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000935.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000936.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000937.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model_saved/model_000000938.pth
save model to ./log/unity/mem/50k/exp-lr0.0002/model

In [117]:
env.close()

AttributeError: 'str' object has no attribute 'close'

In [118]:
display.stop()

$DISPLAY was already unset.


<pyvirtualdisplay.display.Display at 0x7f8c30147d90>

# Config file

In [40]:
env.close()

AttributeError: 'str' object has no attribute 'close'

In [41]:
display.stop()

$DISPLAY was already unset.


<pyvirtualdisplay.display.Display at 0x7f8c3673aed0>

In [68]:
import pickle
import gzip


# with open('./Datasets/50k/actor-26-traj-1.pkl', 'rb') as f:
#     data = pickle.load()

import pickle

# import class_def
# from class_def import Foo # Import Foo into main_module's namespace explicitly

if __name__=='__main__':
    with gzip.open('./Datasets/50k/actor-26-traj-1.pkl', 'r') as f:
          traj = pickle.load(f)

AttributeError: Can't get attribute 'Trajectory' on <module '__main__'>