In [1]:
import glob, os
import numpy as np
import math, random

from IPython.display import clear_output
import matplotlib.pyplot as plt
%matplotlib inline

import torch
import torch.nn as nn
import torch.optim as optim
import torch.autograd as autograd 
import torch.nn.functional as F

from geometry.model import Model, combine_observations, get_mesh
from geometry.utils.visualisation import illustrate_points, illustrate_mesh, illustrate_voxels
from geometry.voxel_grid import VoxelGrid

from rl.environment import Environment, CombiningObservationsWrapper
from rl.environment import StepPenaltyRewardWrapper, DepthMapWrapper
from rl.environment import VoxelGridWrapper, VoxelWrapper
from rl.environment import FrameStackWrapper, ActionMaskWrapper
from rl.environment import MeshReconstructionWrapper
from rl.validation import validate
from rl.utils import build_epsilon_func, plot


from rl.dqn import CnnDQN, CnnDQNA, VoxelDQN
from rl.agent import DQNAgent, DDQNAgent
from rl.replay_buffer import DiskReplayBuffer, ReplayBuffer


# !conda install -c conda-forge pyembree
# !conda install -c conda-forge igl
# !pip install Cython
# !pip install gym

In [2]:
device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')

experiment_save_path = "./models/abc-vddqn-final_rec_only/"

train_dataset_path = "./data/1kabc/simple/train/"
val_dataset_path = "./data/1kabc/simple/val/"
number_of_view_points = 100

num_stack = 4
reconstruction_depth = 7
grid_size = 64
raycast_resolution = 1024

optim_name = "Adam"
learning_rate = 0.0004
weight_decay = 0.01
buffer_capacity = 100000
epsilon_decay = 10000
batch_size = 256
start_frame = 0
num_frames = 150000

log_interval = 100
save_interval = 500
val_interval = 1000
train_interval = 10
max_novp = 50

use_depth_observations = False

In [3]:
env = Environment(models_path=train_dataset_path,
                  image_size=raycast_resolution,
                  number_of_view_points=number_of_view_points)

if use_depth_observations:
    env = CombiningObservationsWrapper(env)
    env = StepPenaltyRewardWrapper(env, weight=1.0)
    env = DepthMapWrapper(env)
else:
    env = MeshReconstructionWrapper(env, reconstruction_depth=7,
                                    do_step_reconstruction=False, 
                                    scale_factor=8)

    env = VoxelGridWrapper(env, grid_size=grid_size)
    env = CombiningObservationsWrapper(env)
    env = VoxelWrapper(env, occlusion_reward=False)
    env = StepPenaltyRewardWrapper(env, weight=1.0)
    env = FrameStackWrapper(env, num_stack=num_stack, lz4_compress=False)
    env = ActionMaskWrapper(env)



In [5]:
prev_ckpt = "./models/abc-vddqn-final_rec_only/last-3000.pt"


agent = DDQNAgent(env.observation_space.shape, env.action_space.n, prev_ckpt=prev_ckpt,
                  device=device, learning_rate=learning_rate, weight_decay=weight_decay)

replay_buffer = DiskReplayBuffer(capacity=buffer_capacity,
                                 overwrite=True,
                                 location="buffer_voxels/",
                                 num_actions=env.action_space.n,
                                 observation_dtype=env.observation_space.dtype,
                                 observation_shape=env.observation_space.shape)

epsilon_by_frame = build_epsilon_func(epsilon_decay=epsilon_decay)

### Training

In [None]:
if not os.path.exists(experiment_save_path):
    os.makedirs(experiment_save_path)
else:
    pass


losses, all_rewards, all_nofs = [], [], []
episode_reward = 0
nof_vp = 0
best_metric = number_of_view_points

state, _, mask = env.reset()
for frame_idx in range(start_frame + 1, num_frames + 1):
    epsilon = epsilon_by_frame(frame_idx)
    action = agent.act(state, mask, epsilon)

    next_state, reward, done, _, mask = env.step(action)
    replay_buffer.push(state, action, reward, next_state, done, mask)

    state = next_state
    episode_reward += reward
    nof_vp += 1

    if done or nof_vp > max_novp:
        final_reward = env.final_reward()
        episode_reward += final_reward
        print("Frame: ", frame_idx, "Number of View Points: ", nof_vp, final_reward)
        print()

        state, _, mask = env.reset()
        all_rewards.append(episode_reward)
        all_nofs.append(nof_vp)
        episode_reward = 0
        nof_vp = 0
        
    if frame_idx % train_interval == 0 and frame_idx > batch_size:
        batch = replay_buffer.sample(batch_size)
        state_, action_, reward_, next_state_, done_, mask_ = batch
        loss = agent.compute_td_loss(state_, action_, reward_, next_state_, done_, mask_, frame_idx)
        losses.append(loss)

    if frame_idx % log_interval == 0:
        save_path = os.path.join(experiment_save_path, 'loss.png')
        plot(save_path, frame_idx, all_rewards, all_nofs, losses)

    if frame_idx % save_interval ==  0:
        for f in glob.glob(os.path.join(experiment_save_path, "last-*.pt")): os.remove(f)
        save_path = os.path.join(experiment_save_path,
                                 'last-{}.pt'.format(frame_idx))
        torch.save(agent.model, save_path)

    if frame_idx % val_interval == 0:
        reward, hausdorff, novp = validate(agent, models_path=val_dataset_path)
        print ("Validation metrics: ", reward, hausdorff, novp)
        if novp < best_metric:
            best_metric = novp
            save_path = os.path.join(experiment_save_path,
                                 'best-{}-{:.2f}.pt'.format(frame_idx, best_metric))
            torch.save(agent.model, save_path)
