In [5]:
import glob, os
import numpy as np
import math, random

from IPython.display import clear_output
import matplotlib.pyplot as plt
%matplotlib inline

import torch
import torch.nn as nn
import torch.optim as optim
import torch.autograd as autograd 
import torch.nn.functional as F

from geometry.model import Model, combine_observations, get_mesh
from geometry.utils.visualisation import illustrate_points, illustrate_mesh, illustrate_voxels
from geometry.voxel_grid import VoxelGrid

from rl.environment import Environment, CombiningObservationsWrapper
from rl.environment import StepPenaltyRewardWrapper, DepthMapWrapper
from rl.environment import VoxelGridWrapper, VoxelWrapper
from rl.environment import FrameStackWrapper, ActionMaskWrapper
from rl.environment import MeshReconstructionWrapper
from rl.validation import validate
from rl.utils import build_epsilon_func, plot


from rl.dqn import CnnDQN, CnnDQNA, VoxelDQN
from rl.agent import DQNAgent, DDQNAgent
from rl.replay_buffer import DiskReplayBuffer, ReplayBuffer


# !conda install -c conda-forge pyembree
# !conda install -c conda-forge igl
# !pip install Cython
# !pip install gym

In [2]:
device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')

experiment_save_path = "./models/abc-vddqn-occl-rew-fix-rew-fix-fine/"

train_dataset_path = "./data/1kabc/simple/train/"
val_dataset_path = "./data/1kabc/simple/val/"
number_of_view_points = 100

num_stack = 4
reconstruction_depth = 7
grid_size = 64
raycast_resolution = 1024

optim_name = "Adam"
learning_rate = 0.0004
weight_decay = 0.01
buffer_capacity = 100000
epsilon_decay = 10000
batch_size = 256
start_frame = 0
num_frames = 150000

log_interval = 100
save_interval = 500
val_interval = 1000
train_interval = 10
max_novp = 50

In [3]:
env = Environment(models_path=train_dataset_path,
#                   similarity_threshold=similarity_threshold,
                  image_size=raycast_resolution,
                  number_of_view_points=number_of_view_points)
# env = CombiningObservationsWrapper(env)
# env = StepPenaltyRewardWrapper(env, weight=1.0)
# env = DepthMapWrapper(env)

env = MeshReconstructionWrapper(env, reconstruction_depth=8,
                                do_step_reconstruction=False, 
                                scale_factor=8)

env = VoxelGridWrapper(env, grid_size=grid_size)
env = CombiningObservationsWrapper(env)
env = VoxelWrapper(env, occlusion_reward=True)
env = StepPenaltyRewardWrapper(env, weight=1.0)
env = FrameStackWrapper(env, num_stack=num_stack, lz4_compress=False)
env = ActionMaskWrapper(env)



In [7]:
agent = DDQNAgent(env.observation_space.shape, env.action_space.n,
                 device=device, learning_rate=learning_rate, weight_decay=weight_decay)

replay_buffer = DiskReplayBuffer(capacity=buffer_capacity,
                                 overwrite=True,
                                 location="buffer_voxels/",
                                 num_actions=env.action_space.n,
                                 observation_dtype=env.observation_space.dtype,
                                 observation_shape=env.observation_space.shape)

epsilon_by_frame = build_epsilon_func(epsilon_decay=epsilon_decay)

In [6]:
# model = torch.load("./models/abc-vdqn-occl-sgd/last-72000.pt")
# agent.model = model.to(device)
# start_frame = 72000

### Training

1. + Validation
2. + Saver
3. + Save loss plot
4. Config saver
5. Starter

In [None]:
if not os.path.exists(experiment_save_path):
    os.makedirs(experiment_save_path)
else:
#     model = torch.load(os.path.join(experiment_save_path, "last-{}.pt".format(start_frame))).to(device)
    pass


losses, all_rewards, all_nofs = [], [], []
episode_reward = 0
nof_vp = 0
best_metric = number_of_view_points

state, _, mask = env.reset()
for frame_idx in range(start_frame + 1, num_frames + 1):
    epsilon = epsilon_by_frame(frame_idx)
    action = agent.act(state, mask, epsilon)

    next_state, reward, done, _, mask = env.step(action)
    replay_buffer.push(state, action, reward, next_state, done, mask)

    state = next_state
    episode_reward += reward
    nof_vp += 1

    if done or nof_vp > max_novp:
        # if done: final_reward = env.final_reward()
        print("Frame: ", frame_idx, "Number of View Points: ", nof_vp)
        print()

        state, _, mask = env.reset()
        all_rewards.append(episode_reward)
        all_nofs.append(nof_vp)
        episode_reward = 0
        nof_vp = 0
        
    if frame_idx % train_interval == 0 and frame_idx > batch_size:
        batch = replay_buffer.sample(batch_size)
        state_, action_, reward_, next_state_, done_, mask_ = batch
        loss = agent.compute_td_loss(state_, action_, reward_, next_state_, done_, mask_, frame_idx)
        losses.append(loss)

    if frame_idx % log_interval == 0:
        save_path = os.path.join(experiment_save_path, 'loss.png')
        plot(save_path, frame_idx, all_rewards, all_nofs, losses)

    if frame_idx % save_interval ==  0:
        for f in glob.glob(os.path.join(experiment_save_path, "last-*.pt")): os.remove(f)
        save_path = os.path.join(experiment_save_path,
                                 'last-{}.pt'.format(frame_idx))
        torch.save(agent.model, save_path)

#     if frame_idx % val_interval == 0:
#         reward, hausdorff, novp = validate(agent, models_path=val_dataset_path)
#         print ("Validation metrics: ", reward, hausdorff, novp)
#         if novp < best_metric:
#             best_metric = novp
#             save_path = os.path.join(experiment_save_path,
#                                  'best-{}-{:.2f}.pt'.format(frame_idx, best_metric))
#             torch.save(agent.model, save_path)


Action:  13 (random)
0.45813842482100237 0.3748398120461341 0.8157838530542503
Action:  88 (random)
0.4888305489260143 0.08885091841093551 0.9046347714651858
Action:  72 (random)
0.4953699284009547 0.03865869286629653 0.9432934643314823
Action:  39 (random)
0.4635799522673031 0.004805638615976049 0.9480991029474584
Action:  9 (random)
0.42935560859188543 0.005339598462195672 0.953438701409654
Frame:  5 Number of View Points:  5

Action:  46 (random)
0.39384288747346075 0.02631578947368418 0.503385018563005
Action:  4 (random)
0.46409120280624017 0.11913081458833807 0.6225158331513431
Action:  91 (random)
0.44969075971568356 0.0181262284341559 0.640642061585499
Action:  2 (random)
0.3932428690113542 0.0021838829438742824 0.6428259445293732
Action:  47 (random)
0.4263823502261608 0.2916575671544005 0.9344835116837737
Action:  93 (random)
0.4678759346441429 0.020419305525223863 0.9549028172089976
Frame:  11 Number of View Points:  6

Action:  88 (random)
0.4977518466973557 0.4844596521243

  q_value = F.softmax(q_value)


Action:  11
0.4399195575666164 0.00019269679159839015 0.9484536082474226
Action:  22 (random)
0.4217840982546865 0.0 0.9484536082474226
Action:  51 (random)
0.4419665302018243 0.0 0.9484536082474226
Action:  73 (random)
0.4658119658119658 0.017824453222853864 0.9662780614702765
Frame:  61 Number of View Points:  9

Action:  22 (random)
0.4551033591731266 0.02349869451697123 0.505781424841477
Action:  42 (random)
0.2603359173126615 0.22417008578888475 0.7299515106303618
Action:  96 (random)
0.4418604651162791 0.23013800820589336 0.9600895188362552
Frame:  64 Number of View Points:  3

Action:  80 (random)
0.4161814612722917 0.13965287049399205 0.5780373831775701
Action:  87 (random)
0.4295648123502795 0.14592790387182908 0.7239652870493992
Action:  37 (random)
0.3304498269896194 0.13805073431241655 0.8620160213618158
Action:  88 (random)
0.4189512909236093 0.047530040053404554 0.9095460614152203
Action:  23 (random)
0.42368412297045516 0.05687583444592792 0.9664218958611482
Frame:  69 N

0.49958696305196754 0.0 0.7146427249166976
Action:  86 (random)
0.4993992189846801 0.0 0.7146427249166976
Action:  34 (random)
0.4993992189846801 0.16290262865605332 0.8775453535727509
Action:  8 (random)
0.3587038149594473 0.0 0.8775453535727509
Action:  42 (random)
0.4567813157104236 0.0 0.8775453535727509
Action:  74 (random)
0.5004130369480324 0.12245464642724913 1.0
Frame:  157 Number of View Points:  7

Action:  0 (random)
0.48921496019255695 0.02312727272727272 0.5051636363636364
Action:  19 (random)
0.49319570449916683 0.005236363636363595 0.5104
Action:  17 (random)
0.5043047583780781 0.449309090909091 0.959709090909091
Frame:  160 Number of View Points:  3

Action:  88 (random)
0.46398578811369506 0.17366841710427605 0.6462865716429107
Action:  72 (random)
0.4699612403100775 0.0026256564141035055 0.6489122280570142
Action:  46 (random)
0.331233850129199 0.15828957239309838 0.8072018004501126
Action:  40 (random)
0.45591085271317827 0.06451612903225801 0.8717179294823706
Actio

0.2338888888888889 0.00037814331631680353 0.9464927207411609
Action:  8 (random)
0.4616358024691358 0.0 0.9464927207411609
Action:  44 (random)
0.31685185185185183 0.0 0.9464927207411609
Action:  13 (random)
0.5004938271604938 0.0 0.9464927207411609
Action:  53 (random)
0.2601851851851852 0.05275099262620542 0.9992437133673663
Frame:  255 Number of View Points:  7

Action:  15 (random)
0.4361518550474547 0.3814705425151562 0.8318047567231462
Action:  73 (random)
0.405152224824356 0.12451422353489816 0.9563189802580444
Frame:  257 Number of View Points:  2

Action:  79 (random)
0.31854558107167713 0.20357270453733478 0.4946052161486245
Action:  40 (random)
0.2964509394572025 0.004072883172561603 0.4986780993211861
Action:  50 (random)
0.35180352586406866 0.05687745623436946 0.5555555555555556
Action:  16 (random)
0.32422871723498026 0.029010360843158223 0.5845659163987138
Action:  5 (random)
0.30062630480167013 0.049446230796713175 0.634012147195427
Action:  91 (random)
0.32956390628624

0.44548713622943903 0.01671619613670139 0.9559187716691432
Frame:  352 Number of View Points:  6

Action:  5 (random)
0.49979018044481743 0.40572425112843663 0.6119203939269594
Action:  53 (random)
0.25400755350398657 0.22958555601148956 0.841505949938449
Action:  37 (random)
0.4873268988669744 0.002667213787443501 0.8441731637258925
Action:  47 (random)
0.22203105329416703 0.006257693885925342 0.8504308576118178
Action:  32 (random)
0.4996642887117079 0.0002051702913418163 0.8506360279031596
Action:  35 (random)
0.4970205623164079 0.012617972917521536 0.8632540008206812
Action:  51 (random)
0.1907679395719681 0.02421009437833399 0.8874640951990151
Action:  2 (random)
0.48166177087704576 0.00030775543701277996 0.8877718506360279
Action:  87 (random)
0.499832144355854 0.11202297907263026 0.9997948297086582
Frame:  361 Number of View Points:  9

Action:  65 (random)
0.3416645027265645 0.07824751434807209 0.4360197235470051
Action:  79 (random)
0.4331991690470008 0.1552825155605852 0.5913

0.44953496383052016 0.10869759542325907 0.9472602127469384
Action:  78 (random)
0.28839131932483636 0.0004469473496022669 0.9477071600965407
Action:  5 (random)
0.41911815363417154 8.938946992043117e-05 0.9477965495664611
Action:  1 (random)
0.378022735101619 0.0002681684097612935 0.9480647179762224
Action:  93 (random)
0.4370651050637272 0.03620273531777962 0.984267453294002
Frame:  452 Number of View Points:  8

Action:  77 (random)
0.4890513299413924 0.4618455254363164 0.8148904567396955
Action:  37 (random)
0.5065370000644039 0.15066839955440026 0.9655588562940958
Frame:  454 Number of View Points:  2

Action:  7 (random)
0.46322904267981635 0.10227151541004587 0.5807827663431802
Action:  84 (random)
0.46820268661792214 0.35576410365788624 0.9365468700010664
Action:  30 (random)
0.4702006461486142 0.06046710035192493 0.9970139703529913
Frame:  457 Number of View Points:  3

Action:  73 (random)
0.455183746638781 0.05223289994347091 0.5123798756359526
Action:  19 (random)
0.44597151

  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "


Action:  95 (random)
0.4277647193716677 0.3375855188141391 0.935005701254276
Action:  19 (random)
0.4023390471822508 0.0010689851767388125 0.9360746864310148
Action:  62 (random)
0.43435762196869804 0.012685290763968182 0.948759977194983
Action:  16 (random)
0.48007796823940835 0.025014253135689835 0.9737742303306728
Frame:  504 Number of View Points:  5

Action:  13 (random)
0.49992003838157684 0.07144613871126415 0.49495818986719137
Action:  88 (random)
0.4996401727170958 0.48536645351697 0.9803246433841614
Frame:  506 Number of View Points:  2

Action:  73 (random)
0.3862988196149885 0.1589419926768163 0.5599826556176527
Action:  42 (random)
0.3417373049195912 0.14391019464251298 0.7038928502601657
Action:  63 (random)
0.3205260239245821 0.01753709770668732 0.721429947966853
Action:  82 (random)
0.3025825873405688 0.09149161688186547 0.8129215648487185
Action:  88 (random)
0.39984552008238927 0.0049624205049142445 0.8178839853536327
Action:  81 (random)
0.39406242573080885 0.0038543

0.4436281279345965 6.772773450725023e-05 0.8945479173721639
Action:  13 (random)
0.3587526929238248 6.772773450725023e-05 0.8946156451066711
Action:  32 (random)
0.4165055515660388 0.0 0.8946156451066711
Action:  62 (random)
0.4522178644423576 0.05838130714527601 0.9529969522519471
Frame:  603 Number of View Points:  11

Action:  33 (random)
0.4966230301008922 0.4262334828464056 0.6457752454843011
Action:  22 (random)
0.4958725923455349 0.017577888228876293 0.6633531337131774
Action:  30 (random)
0.4954973734678563 0.0008485877076008785 0.6642017214207783
Action:  53 (random)
0.2564412574001501 0.22208752576069823 0.8862892471814765
Action:  66 (random)
0.4927457683648795 0.08158564674506008 0.9678748939265366
Frame:  608 Number of View Points:  5

Action:  9 (random)
0.5004140630881578 0.16624930439621588 0.6749675384900761
Action:  72 (random)
0.5004517051870813 0.0 0.6749675384900761
Action:  19 (random)
0.4995482948129188 0.16351326284548318 0.8384808013355592
Action:  34 (random)


0.5007067970314525 0.03410997204100652 0.8067101584342964
Action:  73 (random)
0.5003533985157262 0.0031686859273065693 0.809878844361603
Action:  98 (random)
0.4960537165743904 0.0014911463187325946 0.8113699906803356
Action:  84 (random)
0.49634821533749557 0.002423112767940272 0.8137931034482758
Action:  26 (random)
0.5016491930733891 0.1832246039142591 0.9970177073625349
Frame:  703 Number of View Points:  9

Action:  41 (random)
0.42478786320390843 0.14904386951631043 0.46147356580427445
Action:  5 (random)
0.4764721007971201 0.05821147356580425 0.5196850393700787
Action:  54 (random)
0.38313191051684237 0.0655230596175479 0.5852080989876266
Action:  9 (random)
0.4543584469015171 0.16085489313835766 0.7460629921259843
Action:  21 (random)
0.4429159166880946 0.003655793025871712 0.749718785151856
Action:  94 (random)
0.4768578040627411 0.15466816647919013 0.9043869516310461
Action:  19 (random)
0.4502442787348933 0.02474690663667045 0.9291338582677166
Action:  86 (random)
0.4706865

0.4410574842412268 0.002523240371845925 0.951394422310757
Frame:  798 Number of View Points:  8

Action:  82 (random)
0.27062717084999016 0.06602893140457303 0.30759060507077307
Action:  70 (random)
0.3047381872992988 0.2530720174210608 0.5606626224918339
Action:  99 (random)
0.313749262730192 0.19272048530097996 0.7533831077928138
Action:  21 (random)
0.38816436201585947 0.11829211385907601 0.8716752216518898
Action:  78
0.26656399501933287 0.001166588894073728 0.8728418105459635
Action:  54 (random)
0.39520938462546695 0.0024109503810857413 0.8752527609270493
Action:  95 (random)
0.34425584900714334 0.060273759527142645 0.9355265204541919
Action:  40 (random)
0.27986761911003344 0.0038886296469123893 0.9394151501011043
Action:  5 (random)
0.3180745789370208 0.002644268159900487 0.9420594182610048
Action:  59 (random)
0.40844747362212463 0.0012443614870120134 0.9433037797480168
Action:  91 (random)
0.30349302051248445 0.0 0.9433037797480168
Action:  19 (random)
0.30273936693099157 0.0

In [1]:
def compute_metrics(env, iter_cnt=10, max_iter=30):
    rewards, final_rewards, novp = [], [], []
    for _ in range(iter_cnt):
        state, action, mask = env.reset()
        episode_reward = 0.0
        for t in range(max_iter):
#             action = model.act(state, mask, epsilon=0.0)
            

            s   = autograd.Variable(torch.FloatTensor(np.float32(state)).unsqueeze(0), volatile=True).cuda()
            m = autograd.Variable(torch.FloatTensor(np.float32(mask)), volatile=True).cuda()
            q_value = model.forward(s)
            q_value *= m
            a = torch.nn.functional.softmax(q_value)
            
            action = a.argmax().item()
            print(action)

            state, reward, done, info, mask = env.step(action)
            # print("REWARD: ", reward)
#             env.render(action, state)
            episode_reward += reward

            if done:
                break

        final_reward = 0
        # final_reward = env.final_reward()
        # episode_reward += 1.0 / final_reward
        rewards.append(episode_reward)
        final_rewards.append(final_reward)
        novp.append(t + 1)
    return np.mean(rewards), np.mean(final_rewards), np.mean(novp)

In [2]:
agent_func = lambda s : model.act(s, epsilon=0.0)
result = compute_metrics(env)

NameError: name 'env' is not defined

In [56]:
result

(-6.88126904400267, 0.0, 12.2)

In [151]:
print((6 * 256 * 256 * 8) / 1e9 * 3000, 'G')
print((10 * 512 * 512 * 8) / 1e9 * 5000, 'G')
print(1e6 * 10 * 512 * 512 * 4 / 2 ** 40, 'T')
print(1e5 * 64 * 64 * 64 * 4 * 4 / 2 ** 40, 'T')

9.437184 G
104.8576 G
9.5367431640625 T
0.3814697265625 T


In [None]:
TODO Today:
    1. + check DQN with A, S as input -> send Sergey
    2. + Smaller depth_maps (6 * 256 * 256)
    3. + Illustrate reward as area + novp
    4. + Distributed Buffer
    5. + Buffer on hard disk (memmap)
    6. Floats to int
    7. GPU raycasting
    8. * Greedy algo
    9. * Voxels
    10. CNN different shapes input
    11. Preprocessed depth_maps
    
    10. Overfit experiments 
    11. ABC experiment
    
    12. Experiment radius = 1.0
    13. Experiments with smaller reward fine
    14. Experiment with different DQN-s
    15. ABC - write random model reading
    

Ideas:
    1. Add fine for same view_point
    
Big Experiments:
    1. Voxels
    2. PointNet
    3. Context (N first view_points)
    4. Meta-learning, подавать от простой к сложной детальке

In [15]:
difficult = [
    "00020107_b27a1602d1d44a3d89140ce4_007.obj",
    "00010095_5ae1ee45b583467fa009adc4_006.obj",
    "00010163_ccef4063b69f428e91b498c9_008.obj",
    "00010145_77759770d8cd48af80775d86_002.obj",
    "00010153_556de37e0a7447fcbfbdfd22_000.obj",
    "00010162_ccef4063b69f428e91b498c9_007.obj",
    "00020074_37170a1ba80747f1a1478985_000.obj",
    "00020077_bf146f0c5dee4199be920a21_000.obj",
    "00020080_7a689565e1e0481ca3ad4a6f_000.obj",
    "00010164_ccef4063b69f428e91b498c9_009.obj",
    "00010179_f91d806ac1e34ea1b14e23be_000.obj",
    "00020095_842a932142a9431784488344_000.obj",
    "00020097_e24ecc9c647f4bd1832bfb1d_000.obj",
    "00020106_b27a1602d1d44a3d89140ce4_006.obj",
    "00020141_b27a1602d1d44a3d89140ce4_041.obj",
    "00020155_b27a1602d1d44a3d89140ce4_055.obj",
    "00020186_b27a1602d1d44a3d89140ce4_086.obj",
    "00020202_b27a1602d1d44a3d89140ce4_102.obj",
    "00020203_b27a1602d1d44a3d89140ce4_103.obj",
    "00020213_1f65839d7f6c42bf8c2b3391_000.obj"]