In [148]:
import torch
import torch.optim as optim
import numpy as np
import os
import argparse
import time, datetime
import matplotlib; matplotlib.use('Agg')
from src import config, data
from src.checkpoints import CheckpointIO
from collections import defaultdict
import shutil

In [149]:
cfg = config.load_config('configs/pointcloud/grid.yaml', 'configs/default.yaml')
is_cuda = (torch.cuda.is_available())
device = torch.device("cuda" if is_cuda else "cpu")

In [150]:
t0 = time.time()

# Shorthands
out_dir = cfg['training']['out_dir']
batch_size = cfg['training']['batch_size']
backup_every = cfg['training']['backup_every']
vis_n_outputs = cfg['generation']['vis_n_outputs']

In [151]:
model_selection_metric = cfg['training']['model_selection_metric']
if cfg['training']['model_selection_mode'] == 'maximize':
    model_selection_sign = 1
elif cfg['training']['model_selection_mode'] == 'minimize':
    model_selection_sign = -1
else:
    raise ValueError('model_selection_mode must be '
                     'either maximize or minimize.')

In [152]:
# Output directory
if not os.path.exists(out_dir):
    os.makedirs(out_dir)

shutil.copyfile('configs/pointcloud/grid.yaml', os.path.join(out_dir, 'config.yaml'))

# Dataset
train_dataset = config.get_dataset('train', cfg)
val_dataset = config.get_dataset('val', cfg, return_idx=True)
print(train_dataset[5])
print(train_dataset)
print(len(train_dataset[5]['points.occ']))
print(len(train_dataset))

{'points': array([[-0.1177751 , -0.25705388,  0.4742397 ],
       [-0.36065817,  0.0323042 ,  0.34738368],
       [ 0.14904504, -0.24395354,  0.40379843],
       ...,
       [-0.40536687, -0.25214607, -0.11462478],
       [ 0.2797244 , -0.26814985,  0.12606093],
       [ 0.47204673,  0.05703427,  0.00278376]], dtype=float32), 'points.occ': array([0., 0., 0., ..., 0., 0., 0.], dtype=float32), 'inputs': array([[-0.2614746 ,  0.33984375,  0.30664062],
       [-0.02752686,  0.3491211 , -0.28466797],
       [-0.15917969,  0.16491699,  0.35229492],
       ...,
       [-0.16113281,  0.2980957 , -0.29541016],
       [ 0.16699219,  0.27392578,  0.37548828],
       [ 0.25048828,  0.34375   ,  0.4104004 ]], dtype=float32), 'inputs.normals': array([[-0.26831055,  0.9511719 ,  0.15124512],
       [-0.140625  ,  0.9536133 ,  0.26611328],
       [ 0.66748047,  0.03567505, -0.74365234],
       ...,
       [ 0.8417969 , -0.53515625, -0.06884766],
       [ 0.5527344 , -0.01221466,  0.8334961 ],
       [

In [153]:
# train_loader = torch.utils.data.DataLoader(
#     train_dataset[5], batch_size=batch_size, num_workers=cfg['training']['n_workers'], shuffle=True,
#     collate_fn=data.collate_remove_none,
#     worker_init_fn=data.worker_init_fn)

# print(len(train_loader))    
# print(train_loader)

# val_loader = torch.utils.data.DataLoader(
#     val_dataset, batch_size=1, num_workers=cfg['training']['n_workers_val'], shuffle=False,
#      collate_fn=data.collate_remove_none,
#      worker_init_fn=data.worker_init_fn)


In [154]:
#Amine code subset

train_dataset = config.get_dataset("train", cfg)
ds = torch.utils.data.Subset(train_dataset, indices= [0]*len(train_dataset))
val_dataset = config.get_dataset("val", cfg)
val_ds = torch.utils.data.Subset(val_dataset, indices= [0]*len(val_dataset))
train_loader = torch.utils.data.DataLoader(
    ds, batch_size=32, num_workers=8, shuffle=True,
    collate_fn=data.collate_remove_none,
    worker_init_fn=data.worker_init_fn)
# val_loader = torch.utils.data.DataLoader(
#     ds, batch_size=1, num_workers=8, shuffle=True,
#     collate_fn=data.collate_remove_none,
#     worker_init_fn=data.worker_init_fn)

In [155]:
# For visualizations
vis_loader = torch.utils.data.DataLoader(
    val_dataset, batch_size=1, shuffle=False,
    collate_fn=data.collate_remove_none,
    worker_init_fn=data.worker_init_fn)
model_counter = defaultdict(int)
data_vis_list = []

# Build a data dictionary for visualization
iterator = iter(vis_loader)
for i in range(len(vis_loader)):
    data_vis = next(iterator)
    idx = data_vis['idx'].item()
    model_dict = val_dataset.get_model_dict(idx)
    category_id = model_dict.get('category', 'n/a')
    category_name = val_dataset.metadata[category_id].get('name', 'n/a')
    category_name = category_name.split(',')[0]
    if category_name == 'n/a':
        category_name = category_id

    c_it = model_counter[category_id]
    if c_it < vis_n_outputs:
        data_vis_list.append({'category': category_name, 'it': c_it, 'data': data_vis})

    model_counter[category_id] += 1


KeyError: 'idx'

In [156]:
# Model
model = config.get_model(cfg, device=device, dataset=train_dataset)

# Generator
generator = config.get_generator(model, cfg, device=device)

# Intialize training
optimizer = optim.Adam(model.parameters(), lr=1e-4)
# optimizer = optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)
trainer = config.get_trainer(model, optimizer, cfg, device=device)

# checkpoint_io = CheckpointIO(out_dir, model=model, optimizer=optimizer)
# try:
#     load_dict = checkpoint_io.load('model.pt')
# except FileExistsError:
#     load_dict = dict()
# epoch_it = load_dict.get('epoch_it', 0)
# it = load_dict.get('it', 0)
# metric_val_best = load_dict.get(
#     'loss_val_best', -model_selection_sign * np.inf)

# if metric_val_best == np.inf or metric_val_best == -np.inf:
#     metric_val_best = -model_selection_sign * np.inf
# print('Current best validation metric (%s): %.8f'
#       % (model_selection_metric, metric_val_best))

In [157]:
# Shorthands
print_every = cfg['training']['print_every']
checkpoint_every = cfg['training']['checkpoint_every']
validate_every = cfg['training']['validate_every']
visualize_every = cfg['training']['visualize_every']

# Print model
nparameters = sum(p.numel() for p in model.parameters())
print('Total number of parameters: %d' % nparameters)

print('output path: ', cfg['training']['out_dir'])

Total number of parameters: 1978209
output path:  out/pointcloud/grid


In [158]:
epoch_it = 0
batch = next(train_loader.__iter__())
while epoch_it<10000:
    epoch_it += 1

    it += 1
    loss = trainer.train_step(batch)
    #logger.add_scalar('train/loss', loss, it)

    # Print output
    if print_every > 0 and (it % print_every) == 0:
        t = datetime.datetime.now()
        print('[Epoch %02d] it=%03d, loss=%.4f, time: %.2fs, %02d:%02d'
                    % (epoch_it, it, loss, time.time() - t0, t.hour, t.minute))

    #data_v = next(in data_vis_.__iter__())ist:
    # Visualize output
    if visualize_every > 0 and (it % visualize_every) == 0:
        print('Visualizing')
        for data_vis in data_vis_list:
            if cfg['generation']['sliding_window']:
                out = generator.generate_mesh_sliding(data_vis['data'])    
            else:
                out = generator.generate_mesh(data_vis['data'])
            # Get statistics
            try:
                mesh, stats_dict = out
            except TypeError:
                mesh, stats_dict = out, {}

            mesh.export(os.path.join(out_dir, 'vis', '{}_{}_{}.off'.format(it, data_vis['category'], data_vis['it'])))


    # Save checkpoint
    if (checkpoint_every > 0 and (it % checkpoint_every) == 0):
        print('Saving checkpoint')
        checkpoint_io.save('model.pt', epoch_it=epoch_it, it=it,
                            loss_val_best=metric_val_best)

    # Backup if necessary
#     if (backup_every > 0 and (it % backup_every) == 0):
#         print('Backup checkpoint')
#         checkpoint_io.save('model_%d.pt' % it, epoch_it=epoch_it, it=it,
#                             loss_val_best=metric_val_best)
    # Run validation
    if validate_every > 0 and (it % validate_every) == 0:
        eval_dict = trainer.evaluate(val_loader)
        metric_val = eval_dict[model_selection_metric]
        print('Validation metric (%s): %.4f'
                % (model_selection_metric, metric_val))

#         for k, v in eval_dict.items():
#             logger.add_scalar('val/%s' % k, v, it)

#         if model_selection_sign * (metric_val - metric_val_best) > 0:
#             metric_val_best = metric_val
#             print('New best model (loss %.4f)' % metric_val_best)
#             checkpoint_io.save('model_best.pt', epoch_it=epoch_it, it=it,
#                                 loss_val_best=metric_val_best)

    # Exit if necessary
#     if exit_after > 0 and (time.time() - t0) >= exit_after:
#         print('Time limit reached. Exiting.')
#         checkpoint_io.save('model.pt', epoch_it=epoch_it, it=it,
#                             loss_val_best=metric_val_best)

#print(train_loader)

[Epoch 32] it=14000, loss=29.4069, time: 19.42s, 11:26
Visualizing
Saving checkpoint
[Epoch 132] it=14100, loss=6.0227, time: 33.22s, 11:26
[Epoch 232] it=14200, loss=2.0390, time: 46.91s, 11:26
Visualizing
Saving checkpoint
[Epoch 332] it=14300, loss=1.3815, time: 60.87s, 11:27
[Epoch 432] it=14400, loss=1.0686, time: 74.69s, 11:27
Visualizing
Saving checkpoint
[Epoch 532] it=14500, loss=0.8570, time: 88.56s, 11:27
[Epoch 632] it=14600, loss=0.7152, time: 102.35s, 11:27
Visualizing
Saving checkpoint
[Epoch 732] it=14700, loss=0.6096, time: 116.27s, 11:28
[Epoch 832] it=14800, loss=0.5215, time: 130.09s, 11:28
Visualizing
Saving checkpoint
[Epoch 932] it=14900, loss=0.4509, time: 143.94s, 11:28
[Epoch 1032] it=15000, loss=0.3903, time: 157.64s, 11:28
Visualizing
Saving checkpoint
[Epoch 1132] it=15100, loss=0.3326, time: 171.51s, 11:29
[Epoch 1232] it=15200, loss=0.2959, time: 185.30s, 11:29
Visualizing
Saving checkpoint
[Epoch 1332] it=15300, loss=0.2603, time: 199.17s, 11:29
[Epoch 1


  0%|          | 0/850 [00:00<?, ?it/s][A

[Epoch 2032] it=16000, loss=0.1190, time: 295.62s, 11:31
Visualizing
Saving checkpoint



  0%|          | 1/850 [00:00<06:21,  2.23it/s][A
  0%|          | 4/850 [00:00<04:35,  3.07it/s][A
  1%|          | 8/850 [00:00<03:19,  4.21it/s][A
  1%|▏         | 11/850 [00:00<02:27,  5.68it/s][A
  2%|▏         | 15/850 [00:00<01:51,  7.52it/s][A
  2%|▏         | 18/850 [00:01<01:25,  9.69it/s][A
  2%|▏         | 21/850 [00:01<01:08, 12.14it/s][A
  3%|▎         | 24/850 [00:01<00:55, 14.77it/s][A
  3%|▎         | 27/850 [00:01<00:47, 17.42it/s][A
  4%|▎         | 30/850 [00:01<00:41, 19.87it/s][A
  4%|▍         | 33/850 [00:01<00:36, 22.09it/s][A
  4%|▍         | 36/850 [00:01<00:33, 23.97it/s][A
  5%|▍         | 40/850 [00:01<00:31, 25.65it/s][A
  5%|▌         | 43/850 [00:01<00:30, 26.79it/s][A
  5%|▌         | 46/850 [00:01<00:29, 27.56it/s][A
  6%|▌         | 49/850 [00:02<00:28, 28.17it/s][A
  6%|▌         | 52/850 [00:02<00:27, 28.63it/s][A
  6%|▋         | 55/850 [00:02<00:27, 29.03it/s][A
  7%|▋         | 58/850 [00:02<00:28, 27.71it/s][A
  7%|▋        

 63%|██████▎   | 535/850 [00:19<00:10, 30.08it/s][A
 63%|██████▎   | 539/850 [00:19<00:10, 30.49it/s][A
 64%|██████▍   | 543/850 [00:19<00:10, 30.68it/s][A
 64%|██████▍   | 547/850 [00:19<00:10, 27.78it/s][A
 65%|██████▍   | 550/850 [00:19<00:11, 26.81it/s][A
 65%|██████▌   | 554/850 [00:19<00:10, 28.15it/s][A
 66%|██████▌   | 558/850 [00:19<00:10, 29.07it/s][A
 66%|██████▌   | 562/850 [00:20<00:09, 29.77it/s][A
 67%|██████▋   | 566/850 [00:20<00:09, 30.20it/s][A
 67%|██████▋   | 570/850 [00:20<00:09, 30.46it/s][A
 68%|██████▊   | 574/850 [00:20<00:08, 30.95it/s][A
 68%|██████▊   | 578/850 [00:20<00:09, 29.43it/s][A
 68%|██████▊   | 581/850 [00:20<00:09, 27.54it/s][A
 69%|██████▉   | 585/850 [00:20<00:09, 28.85it/s][A
 69%|██████▉   | 589/850 [00:21<00:08, 29.46it/s][A
 70%|██████▉   | 593/850 [00:21<00:08, 30.16it/s][A
 70%|███████   | 597/850 [00:21<00:08, 30.93it/s][A
 71%|███████   | 601/850 [00:21<00:08, 28.11it/s][A
 71%|███████   | 604/850 [00:21<00:09, 27.05it

Validation metric (iou): 0.0409
[Epoch 2132] it=16100, loss=0.1063, time: 340.36s, 11:31
[Epoch 2232] it=16200, loss=0.0957, time: 354.95s, 11:32
Visualizing
Saving checkpoint
[Epoch 2332] it=16300, loss=0.0848, time: 369.53s, 11:32
[Epoch 2432] it=16400, loss=0.0747, time: 383.30s, 11:32
Visualizing
Saving checkpoint
[Epoch 2532] it=16500, loss=0.0650, time: 397.16s, 11:32
[Epoch 2632] it=16600, loss=0.0619, time: 410.92s, 11:33
Visualizing
Saving checkpoint
[Epoch 2732] it=16700, loss=0.0655, time: 424.80s, 11:33
[Epoch 2832] it=16800, loss=0.0599, time: 438.60s, 11:33
Visualizing
Saving checkpoint
[Epoch 2932] it=16900, loss=0.0478, time: 452.55s, 11:33
[Epoch 3032] it=17000, loss=0.0359, time: 466.34s, 11:33
Visualizing
Saving checkpoint
[Epoch 3132] it=17100, loss=0.0253, time: 480.24s, 11:34
[Epoch 3232] it=17200, loss=0.0172, time: 494.03s, 11:34
Visualizing
Saving checkpoint
[Epoch 3332] it=17300, loss=0.0138, time: 507.93s, 11:34
[Epoch 3432] it=17400, loss=0.0121, time: 521.6


  0%|          | 0/850 [00:00<?, ?it/s][A

[Epoch 6032] it=20000, loss=0.0015, time: 883.23s, 11:40
Visualizing
Saving checkpoint



  0%|          | 1/850 [00:00<07:01,  2.01it/s][A
  0%|          | 4/850 [00:00<05:03,  2.79it/s][A
  1%|          | 8/850 [00:00<03:39,  3.84it/s][A
  1%|▏         | 12/850 [00:00<02:40,  5.22it/s][A
  2%|▏         | 16/850 [00:00<01:59,  6.96it/s][A
  2%|▏         | 20/850 [00:01<01:31,  9.08it/s][A
  3%|▎         | 24/850 [00:01<01:11, 11.52it/s][A
  3%|▎         | 28/850 [00:01<00:57, 14.22it/s][A
  4%|▍         | 32/850 [00:01<00:48, 17.00it/s][A
  4%|▍         | 36/850 [00:01<00:41, 19.64it/s][A
  5%|▍         | 39/850 [00:01<00:37, 21.85it/s][A
  5%|▌         | 43/850 [00:01<00:33, 23.93it/s][A
  6%|▌         | 47/850 [00:01<00:31, 25.64it/s][A
  6%|▌         | 51/850 [00:02<00:29, 26.93it/s][A
  6%|▋         | 55/850 [00:02<00:28, 27.93it/s][A
  7%|▋         | 59/850 [00:02<00:27, 28.76it/s][A
  7%|▋         | 63/850 [00:02<00:26, 29.55it/s][A
  8%|▊         | 67/850 [00:02<00:26, 30.02it/s][A
  8%|▊         | 71/850 [00:02<00:25, 30.28it/s][A
  9%|▉        

 72%|███████▏  | 616/850 [00:20<00:07, 31.23it/s][A
 73%|███████▎  | 620/850 [00:20<00:07, 31.17it/s][A
 73%|███████▎  | 624/850 [00:20<00:07, 31.35it/s][A
 74%|███████▍  | 628/850 [00:20<00:07, 31.31it/s][A
 74%|███████▍  | 632/850 [00:20<00:06, 31.27it/s][A
 75%|███████▍  | 636/850 [00:21<00:06, 31.44it/s][A
 75%|███████▌  | 640/850 [00:21<00:06, 31.42it/s][A
 76%|███████▌  | 644/850 [00:21<00:06, 31.29it/s][A
 76%|███████▌  | 648/850 [00:21<00:06, 31.39it/s][A
 77%|███████▋  | 652/850 [00:21<00:06, 31.33it/s][A
 77%|███████▋  | 656/850 [00:21<00:06, 31.21it/s][A
 78%|███████▊  | 660/850 [00:21<00:06, 31.41it/s][A
 78%|███████▊  | 664/850 [00:21<00:05, 31.32it/s][A
 79%|███████▊  | 668/850 [00:22<00:05, 31.00it/s][A
 79%|███████▉  | 672/850 [00:22<00:05, 31.11it/s][A
 80%|███████▉  | 676/850 [00:22<00:05, 31.32it/s][A
 80%|████████  | 680/850 [00:22<00:05, 31.31it/s][A
 80%|████████  | 684/850 [00:22<00:05, 31.11it/s][A
 81%|████████  | 688/850 [00:22<00:05, 31.36it

Validation metric (iou): 0.0409
[Epoch 6132] it=20100, loss=0.0013, time: 925.35s, 11:41
[Epoch 6232] it=20200, loss=0.0012, time: 939.15s, 11:41
Visualizing
Saving checkpoint
[Epoch 6332] it=20300, loss=0.0011, time: 953.09s, 11:42
[Epoch 6432] it=20400, loss=0.0010, time: 966.94s, 11:42
Visualizing
Saving checkpoint
[Epoch 6532] it=20500, loss=0.0009, time: 980.90s, 11:42
[Epoch 6632] it=20600, loss=0.0008, time: 994.76s, 11:42
Visualizing
Saving checkpoint
[Epoch 6732] it=20700, loss=0.0007, time: 1008.72s, 11:42
[Epoch 6832] it=20800, loss=0.0006, time: 1022.55s, 11:43
Visualizing
Saving checkpoint
[Epoch 6932] it=20900, loss=0.0006, time: 1036.62s, 11:43
[Epoch 7032] it=21000, loss=0.0005, time: 1050.50s, 11:43
Visualizing
Saving checkpoint
[Epoch 7132] it=21100, loss=0.0005, time: 1064.47s, 11:43
[Epoch 7232] it=21200, loss=0.0005, time: 1078.42s, 11:44
Visualizing
Saving checkpoint
[Epoch 7332] it=21300, loss=0.0004, time: 1092.50s, 11:44
[Epoch 7432] it=21400, loss=0.0004, time

In [131]:
def ascent(p, model, num_steps):
    for t in range(num_steps):
        p += alpha*model(p)/2 + torch.randn(p.shape)*torch.sqrt(alpha)
        return p 
    
     