In [45]:
import torch
import torch.optim as optim
import numpy as np
import os
import argparse
import time, datetime
import matplotlib; matplotlib.use('Agg')
from src import config, data
from src.checkpoints import CheckpointIO
from collections import defaultdict
import shutil

In [46]:
cfg = config.load_config('configs/pointcloud/grid.yaml', 'configs/default.yaml')
is_cuda = (torch.cuda.is_available())
device = torch.device("cuda" if is_cuda else "cpu")

In [47]:
t0 = time.time()

# Shorthands
out_dir = cfg['training']['out_dir']
batch_size = cfg['training']['batch_size']
backup_every = cfg['training']['backup_every']
vis_n_outputs = cfg['generation']['vis_n_outputs']

In [48]:
model_selection_metric = cfg['training']['model_selection_metric']
if cfg['training']['model_selection_mode'] == 'maximize':
    model_selection_sign = 1
elif cfg['training']['model_selection_mode'] == 'minimize':
    model_selection_sign = -1
else:
    raise ValueError('model_selection_mode must be '
                     'either maximize or minimize.')

In [53]:
# Output directory
if not os.path.exists(out_dir):
    os.makedirs(out_dir)

shutil.copyfile('configs/pointcloud/grid.yaml', os.path.join(out_dir, 'config.yaml'))

# Dataset
train_dataset = config.get_dataset('train', cfg)
val_dataset = config.get_dataset('val', cfg, return_idx=True)
print(train_dataset[5])
print(train_dataset)
print(len(train_dataset[5]['points.occ']))
print(len(train_dataset))

{'points': array([[-0.04952807,  0.4819259 ,  0.01673259],
       [-0.03787785,  0.41909906,  0.20245078],
       [-0.5370819 , -0.35882676,  0.12334375],
       ...,
       [-0.5018162 ,  0.49840465,  0.43694407],
       [-0.48304105, -0.17943507, -0.28516206],
       [ 0.02512551, -0.42349476, -0.3951674 ]], dtype=float32), 'points.occ': array([0., 0., 0., ..., 0., 0., 0.], dtype=float32), 'inputs': array([[ 0.15783691,  0.04495239,  0.34887695],
       [ 0.1274414 ,  0.31152344,  0.2211914 ],
       [ 0.16784668,  0.30493164,  0.0035305 ],
       ...,
       [-0.14587402,  0.34545898,  0.03475952],
       [-0.12158203,  0.34716797,  0.3635254 ],
       [ 0.2548828 ,  0.34228516,  0.23010254]], dtype=float32), 'inputs.normals': array([[ 0.15515137,  0.17016602, -0.97314453],
       [-0.6464844 , -0.7626953 ,  0.01418304],
       [ 0.64746094, -0.76123047, -0.03527832],
       ...,
       [-0.07946777,  0.9921875 ,  0.09643555],
       [-0.1697998 ,  0.98095703,  0.09606934],
       [

In [70]:
train_loader = torch.utils.data.DataLoader(
    train_dataset[5], batch_size=batch_size, num_workers=cfg['training']['n_workers'], shuffle=True,
    collate_fn=data.collate_remove_none,
    worker_init_fn=data.worker_init_fn)

print(len(train_loader))    
print(train_loader)

val_loader = torch.utils.data.DataLoader(
        val_dataset, batch_size=1, num_workers=cfg['training']['n_workers_val'], shuffle=False,
    collate_fn=data.collate_remove_none,
    worker_init_fn=data.worker_init_fn)


1
<torch.utils.data.dataloader.DataLoader object at 0x7fd76d6d0ac8>


In [None]:
#Amine code subset

train_dataset = config.get_dataset(‘val’, cfg)
ds = torch.utils.data.Subset(train_dataset, indices= [0]*len(train_dataset))
val_dataset = config.get_dataset(‘val’, cfg)
#val_ds = torch.utils.data.Subset(val_dataset, indices= [0]*len(val_dataset))
train_loader = torch.utils.data.DataLoader(
    ds, batch_size=16, num_workers=8, shuffle=True,
    collate_fn=data.collate_remove_none,
    worker_init_fn=data.worker_init_fn)
val_loader = torch.utils.data.DataLoader(
    ds, batch_size=1, num_workers=8, shuffle=True,
    collate_fn=data.collate_remove_none,
    worker_init_fn=data.worker_init_fn)

In [59]:
# For visualizations
vis_loader = torch.utils.data.DataLoader(
    val_dataset, batch_size=1, shuffle=False,
    collate_fn=data.collate_remove_none,
    worker_init_fn=data.worker_init_fn)
model_counter = defaultdict(int)
data_vis_list = []

# Build a data dictionary for visualization
iterator = iter(vis_loader)
for i in range(len(vis_loader)):
    data_vis = next(iterator)
    idx = data_vis['idx'].item()
    model_dict = val_dataset.get_model_dict(idx)
    category_id = model_dict.get('category', 'n/a')
    category_name = val_dataset.metadata[category_id].get('name', 'n/a')
    category_name = category_name.split(',')[0]
    if category_name == 'n/a':
        category_name = category_id

    c_it = model_counter[category_id]
    if c_it < vis_n_outputs:
        data_vis_list.append({'category': category_name, 'it': c_it, 'data': data_vis})

    model_counter[category_id] += 1


In [60]:
# Model
model = config.get_model(cfg, device=device, dataset=train_dataset)

# Generator
generator = config.get_generator(model, cfg, device=device)

# Intialize training
optimizer = optim.Adam(model.parameters(), lr=1e-4)
# optimizer = optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)
trainer = config.get_trainer(model, optimizer, cfg, device=device)

checkpoint_io = CheckpointIO(out_dir, model=model, optimizer=optimizer)
try:
    load_dict = checkpoint_io.load('model.pt')
except FileExistsError:
    load_dict = dict()
epoch_it = load_dict.get('epoch_it', 0)
it = load_dict.get('it', 0)
metric_val_best = load_dict.get(
    'loss_val_best', -model_selection_sign * np.inf)

if metric_val_best == np.inf or metric_val_best == -np.inf:
    metric_val_best = -model_selection_sign * np.inf
print('Current best validation metric (%s): %.8f'
      % (model_selection_metric, metric_val_best))

out/pointcloud/grid/model.pt
=> Loading checkpoint from local file...
Current best validation metric (iou): 0.04091307


In [61]:
# Shorthands
print_every = cfg['training']['print_every']
checkpoint_every = cfg['training']['checkpoint_every']
validate_every = cfg['training']['validate_every']
visualize_every = cfg['training']['visualize_every']

# Print model
nparameters = sum(p.numel() for p in model.parameters())
print('Total number of parameters: %d' % nparameters)

print('output path: ', cfg['training']['out_dir'])

Total number of parameters: 1978209
output path:  out/pointcloud/grid


In [71]:
batch = next(train_loader.__iter__())
while epoch_it<10000:
    epoch_it += 1

    it += 1
    loss = trainer.train_step(batch)
    #logger.add_scalar('train/loss', loss, it)

    # Print output
    if print_every > 0 and (it % print_every) == 0:
        t = datetime.datetime.now()
        print('[Epoch %02d] it=%03d, loss=%.4f, time: %.2fs, %02d:%02d'
                    % (epoch_it, it, loss, time.time() - t0, t.hour, t.minute))

    #data_v = next(in data_vis_.__iter__())ist:
    # Visualize output
    if visualize_every > 0 and (it % visualize_every) == 0:
        print('Visualizing')
        for data_vis in data_vis_list:
            if cfg['generation']['sliding_window']:
                out = generator.generate_mesh_sliding(data_vis['data'])    
            else:
                out = generator.generate_mesh(data_vis['data'])
            # Get statistics
            try:
                mesh, stats_dict = out
            except TypeError:
                mesh, stats_dict = out, {}

            mesh.export(os.path.join(out_dir, 'vis', '{}_{}_{}.off'.format(it, data_vis['category'], data_vis['it'])))


    # Save checkpoint
    if (checkpoint_every > 0 and (it % checkpoint_every) == 0):
        print('Saving checkpoint')
        checkpoint_io.save('model.pt', epoch_it=epoch_it, it=it,
                            loss_val_best=metric_val_best)

    # Backup if necessary
    if (backup_every > 0 and (it % backup_every) == 0):
        print('Backup checkpoint')
        checkpoint_io.save('model_%d.pt' % it, epoch_it=epoch_it, it=it,
                            loss_val_best=metric_val_best)
    # Run validation
    if validate_every > 0 and (it % validate_every) == 0:
        eval_dict = trainer.evaluate(val_loader)
        metric_val = eval_dict[model_selection_metric]
        print('Validation metric (%s): %.4f'
                % (model_selection_metric, metric_val))

#         for k, v in eval_dict.items():
#             logger.add_scalar('val/%s' % k, v, it)

        if model_selection_sign * (metric_val - metric_val_best) > 0:
            metric_val_best = metric_val
            print('New best model (loss %.4f)' % metric_val_best)
            checkpoint_io.save('model_best.pt', epoch_it=epoch_it, it=it,
                                loss_val_best=metric_val_best)

    # Exit if necessary
#     if exit_after > 0 and (time.time() - t0) >= exit_after:
#         print('Time limit reached. Exiting.')
#         checkpoint_io.save('model.pt', epoch_it=epoch_it, it=it,
#                             loss_val_best=metric_val_best)

#print(train_loader)

KeyError: Caught KeyError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/nizar/miniconda3/envs/conv_onet/lib/python3.6/site-packages/torch/utils/data/_utils/worker.py", line 287, in _worker_loop
    data = fetcher.fetch(index)
  File "/home/nizar/miniconda3/envs/conv_onet/lib/python3.6/site-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/nizar/miniconda3/envs/conv_onet/lib/python3.6/site-packages/torch/utils/data/_utils/fetch.py", line 44, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
KeyError: 3


In [40]:
def ascent(p, model, num_steps):
    for t in range(num_steps):
        p += alpha*model(p)/2 + torch.randn(p.shape)*torch.sqrt(alpha)
        return p 
    
     

config.yaml    model_2600.pt  model_3500.pt  model_4400.pt  model_5300.pt
model_1800.pt  model_2700.pt  model_3600.pt  model_4500.pt  model_5400.pt
model_1900.pt  model_2800.pt  model_3700.pt  model_4600.pt  model_5500.pt
model_2000.pt  model_2900.pt  model_3800.pt  model_4700.pt  model_5600.pt
model_2100.pt  model_3000.pt  model_3900.pt  model_4800.pt  model_5700.pt
model_2200.pt  model_3100.pt  model_4000.pt  model_4900.pt  model_best.pt
model_2300.pt  model_3200.pt  model_4100.pt  model_5000.pt  model.pt
model_2400.pt  model_3300.pt  model_4200.pt  model_5100.pt  [0m[01;34mvis[0m/
model_2500.pt  model_3400.pt  model_4300.pt  model_5200.pt
