In [1]:
import torch
import torch.optim as optim
import numpy as np
import os
import argparse
import time, datetime
import matplotlib; matplotlib.use('Agg')
from src import config, data
from src.checkpoints import CheckpointIO
from collections import defaultdict
import shutil
from tensorboardX import SummaryWriter

In [2]:
writer = SummaryWriter()
cfg = config.load_config('configs/pointcloud/grid.yaml', 'configs/default.yaml')
is_cuda = (torch.cuda.is_available())
device = torch.device("cuda" if is_cuda else "cpu")

In [3]:
t0 = time.time()

# Shorthands
out_dir = cfg['training']['out_dir']
batch_size = cfg['training']['batch_size']
backup_every = cfg['training']['backup_every']
vis_n_outputs = cfg['generation']['vis_n_outputs']

In [4]:
model_selection_metric = cfg['training']['model_selection_metric']
if cfg['training']['model_selection_mode'] == 'maximize':
    model_selection_sign = 1
elif cfg['training']['model_selection_mode'] == 'minimize':
    model_selection_sign = -1
else:
    raise ValueError('model_selection_mode must be '
                     'either maximize or minimize.')

In [5]:
# Output directory
if not os.path.exists(out_dir):
    os.makedirs(out_dir)

shutil.copyfile('configs/pointcloud/grid.yaml', os.path.join(out_dir, 'config.yaml'))

# Dataset
train_dataset = config.get_dataset('train', cfg)
val_dataset = config.get_dataset('val', cfg, return_idx=True)
print(train_dataset[5])
print(train_dataset)
print(len(train_dataset[5]['points.occ']))
print(len(train_dataset))

{'points': array([[-0.09609246, -0.16188681,  0.01934641],
       [-0.11124776,  0.1005385 ,  0.15549001],
       [ 0.38469836,  0.4537726 ,  0.08635703],
       ...,
       [-0.3643776 ,  0.06455163,  0.5312809 ],
       [-0.42909572,  0.4304115 , -0.08003872],
       [ 0.38643727,  0.00438961, -0.17688385]], dtype=float32), 'points.occ': array([0., 0., 0., ..., 0., 0., 0.], dtype=float32), 'inputs': array([[ 0.11865234,  0.34814453, -0.09179688],
       [ 0.21643066,  0.34594727,  0.16113281],
       [ 0.22253418,  0.32128906, -0.3371582 ],
       ...,
       [-0.16503906,  0.29418945,  0.17651367],
       [-0.21936035,  0.32104492,  0.31713867],
       [-0.03616333,  0.32836914,  0.20617676]], dtype=float32), 'inputs.normals': array([[-0.06939697,  0.97314453,  0.2199707 ],
       [ 0.13537598,  0.98583984, -0.10125732],
       [ 0.05429077, -0.9946289 , -0.08917236],
       ...,
       [ 0.8417969 , -0.5361328 ,  0.06677246],
       [ 0.02293396, -0.9995117 ,  0.00432205],
       [

In [6]:
shape = train_dataset[0]

In [7]:
# train_loader = torch.utils.data.DataLoader(
#     train_dataset[5], batch_size=batch_size, num_workers=cfg['training']['n_workers'], shuffle=True,
#     collate_fn=data.collate_remove_none,
#     worker_init_fn=data.worker_init_fn)

# print(len(train_loader))    
# print(train_loader)

# val_loader = torch.utils.data.DataLoader(
#     val_dataset, batch_size=1, num_workers=cfg['training']['n_workers_val'], shuffle=False,
#      collate_fn=data.collate_remove_none,
#      worker_init_fn=data.worker_init_fn)


In [8]:
#Amine code subset

train_dataset = config.get_dataset("train", cfg)
ds = torch.utils.data.Subset(train_dataset, indices= [0]*len(train_dataset))
val_dataset = config.get_dataset("val", cfg)
val_ds = torch.utils.data.Subset(val_dataset, indices= [0]*len(val_dataset))
train_loader = torch.utils.data.DataLoader(
    ds, batch_size=32, num_workers=8, shuffle=True,
    collate_fn=data.collate_remove_none,
    worker_init_fn=data.worker_init_fn)
# val_loader = torch.utils.data.DataLoader(
#     ds, batch_size=1, num_workers=8, shuffle=True,
#     collate_fn=data.collate_remove_none,
#     worker_init_fn=data.worker_init_fn)

In [9]:
# For visualizations
vis_loader = torch.utils.data.DataLoader(
    val_dataset, batch_size=1, shuffle=False,
    collate_fn=data.collate_remove_none,
    worker_init_fn=data.worker_init_fn)
model_counter = defaultdict(int)
data_vis_list = []

# Build a data dictionary for visualization
iterator = iter(vis_loader)
for i in range(len(vis_loader)):
    data_vis = next(iterator)
    idx = data_vis['idx'].item()
    model_dict = val_dataset.get_model_dict(idx)
    category_id = model_dict.get('category', 'n/a')
    category_name = val_dataset.metadata[category_id].get('name', 'n/a')
    category_name = category_name.split(',')[0]
    if category_name == 'n/a':
        category_name = category_id

    c_it = model_counter[category_id]
    if c_it < vis_n_outputs:
        data_vis_list.append({'category': category_name, 'it': c_it, 'data': data_vis})

    model_counter[category_id] += 1


KeyError: 'idx'

In [10]:
# Model
model = config.get_model(cfg, device=device, dataset=train_dataset)

# Generator
generator = config.get_generator(model, cfg, device=device)

# Intialize training
optimizer = optim.Adam(model.parameters(), lr=1e-4)
# optimizer = optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)
trainer = config.get_trainer(model, optimizer, cfg, device=device)

checkpoint_io = CheckpointIO(out_dir, model=model, optimizer=optimizer)
# try:
#     load_dict = checkpoint_io.load('model.pt')
# except FileExistsError:
#     load_dict = dict()
# epoch_it = load_dict.get('epoch_it', 0)
# it = load_dict.get('it', 0)
# metric_val_best = load_dict.get(
#      'loss_val_best', -model_selection_sign * np.inf)

# if metric_val_best == np.inf or metric_val_best == -np.inf:
#     metric_val_best = -model_selection_sign * np.inf
# print('Current best validation metric (%s): %.8f'
#       % (model_selection_metric, metric_val_best))

In [11]:
# Shorthands
print_every = cfg['training']['print_every']
checkpoint_every = cfg['training']['checkpoint_every']
validate_every = cfg['training']['validate_every']
visualize_every = cfg['training']['visualize_every']

# Print model
nparameters = sum(p.numel() for p in model.parameters())
print('Total number of parameters: %d' % nparameters)

print('output path: ', cfg['training']['out_dir'])

Total number of parameters: 1978275
output path:  out/pointcloud/grid


In [12]:
batch = next(train_loader.__iter__())

In [13]:
c = model.encode_inputs(batch.get('inputs').cuda())
results = model.decode(batch.get('points').cuda(), c).logits

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


In [14]:
it = 0
epoch_it = 0
batch = next(train_loader.__iter__())
while epoch_it<5000:
    epoch_it += 1

    it += 1
    loss = trainer.train_step(batch)
    #logger.add_scalar('train/loss', loss, it)
    writer.add_scalar("surfaceCurl", loss, it)

    # Print output
    if print_every > 0 and (it % print_every) == 0:
        t = datetime.datetime.now()
        print('[Epoch %02d] it=%03d, loss=%.4f, time: %.2fs, %02d:%02d'
                    % (epoch_it, it, loss, time.time() - t0, t.hour, t.minute))

    #data_v = next(in data_vis_.__iter__())ist:
    # Visualize output
    if visualize_every > 0 and (it % visualize_every) == 0:
        print('Visualizing')
        for data_vis in data_vis_list:
            if cfg['generation']['sliding_window']:
                out = generator.generate_mesh_sliding(data_vis['data'])
            else:
                out = generator.generate_mesh(data_vis['data'])
            # Get statistics
            try:
                mesh, stats_dict = out
            except TypeError:
                mesh, stats_dict = out, {}

            mesh.export(os.path.join(out_dir, 'vis', '{}_{}_{}.off'.format(it, data_vis['category'], data_vis['it'])))


    # Save checkpoint
    if (checkpoint_every > 0 and (it % checkpoint_every) == 0):
        print('Saving checkpoint')
        checkpoint_io.save('surfaceCurlmodel.pt', epoch_it=epoch_it, it=it)

    # Backup if necessary
    if (backup_every > 0 and (it % backup_every) == 0):
        print('Backup checkpoint')
        checkpoint_io.save('surfaceCurlmodel_%d.pt' % it, epoch_it=epoch_it, it=it)
        
    # Run validation
#     if validate_every > 0 and (it % validate_every) == 0:
#         eval_dict = trainer.evaluate(val_loader)
#         metric_val = eval_dict[model_selection_metric]
#         print('Validation metric (%s): %.4f'
#                 % (model_selection_metric, metric_val))

#         for k, v in eval_dict.items():
#             logger.add_scalar('val/%s' % k, v, it)

#         if model_selection_sign * (metric_val - metric_val_best) > 0:
#             metric_val_best = metric_val
#             print('New best model (loss %.4f)' % metric_val_best)
#             checkpoint_io.save('model_best.pt', epoch_it=epoch_it, it=it,
#                                 loss_val_best=metric_val_best)

    # Exit if necessary
#     if exit_after > 0 and (time.time() - t0) >= exit_after:
#         print('Time limit reached. Exiting.')
#         checkpoint_io.save('model.pt', epoch_it=epoch_it, it=it,
#                             loss_val_best=metric_val_best)

#print(train_loader)

[Epoch 100] it=100, loss=1.7430, time: 56.60s, 10:34
[Epoch 200] it=200, loss=1.0153, time: 91.88s, 10:34
Visualizing
Saving checkpoint
[Epoch 300] it=300, loss=0.7830, time: 127.30s, 10:35
[Epoch 400] it=400, loss=0.6436, time: 163.29s, 10:35
Visualizing
Saving checkpoint
[Epoch 500] it=500, loss=0.5604, time: 210.16s, 10:36
[Epoch 600] it=600, loss=0.4929, time: 257.69s, 10:37
Visualizing
Saving checkpoint
[Epoch 700] it=700, loss=0.4439, time: 305.37s, 10:38
[Epoch 800] it=800, loss=0.4071, time: 352.23s, 10:39
Visualizing
Saving checkpoint
[Epoch 900] it=900, loss=0.3737, time: 399.82s, 10:39
[Epoch 1000] it=1000, loss=0.3467, time: 446.65s, 10:40
Visualizing
Saving checkpoint
[Epoch 1100] it=1100, loss=0.3224, time: 493.89s, 10:41
[Epoch 1200] it=1200, loss=0.3042, time: 540.67s, 10:42
Visualizing
Saving checkpoint
[Epoch 1300] it=1300, loss=0.2855, time: 587.33s, 10:43
[Epoch 1400] it=1400, loss=0.2726, time: 634.03s, 10:43
Visualizing
Saving checkpoint
[Epoch 1500] it=1500, loss

In [14]:
from glob import glob

In [15]:
glob('data/ShapeNet/*/*/')

['data/ShapeNet/04090263/92431d034edd34c760c81723f0d4ce20/',
 'data/ShapeNet/04090263/d326ce10d768da152c3271e911ffe19/',
 'data/ShapeNet/04090263/782f3821e6b638d6fb6bde4b7e6c6613/',
 'data/ShapeNet/04090263/9397ae7d40c327044da9f09deacee7d4/',
 'data/ShapeNet/04090263/98572b8a17031500c2c44977d8755d41/',
 'data/ShapeNet/04090263/b28220a981c84b25427e47767269c4b/',
 'data/ShapeNet/04090263/36299a0fd2aebb5b1cb4c4614a9a037e/',
 'data/ShapeNet/04090263/6cff6f4bd6a5d73e8411da876c84603f/',
 'data/ShapeNet/04090263/487330fd2ba7d55f97020a1f4453e3a4/',
 'data/ShapeNet/04090263/b449f16a0cbad90088be2a30dd556a09/',
 'data/ShapeNet/04090263/27257aee4b0f91b1a16c70da5e24216f/',
 'data/ShapeNet/04090263/f5ab909cc5813c7ebe8eb764bcb3c31e/',
 'data/ShapeNet/04090263/7ad1c8369ecca95ffb5c1b0f759e2bc1/',
 'data/ShapeNet/04090263/dcf13ca83d9580bd44c069e8827241aa/',
 'data/ShapeNet/04090263/10e60e0eb0d7915c8de11d571206924/',
 'data/ShapeNet/04090263/561430988b5af11bd04b05b0f20a897b/',
 'data/ShapeNet/04090263/98

In [6]:
test['dist'].shape

(200,)

In [8]:
import kdtree

In [9]:
hh=kdtree.KDTree(test)

AttributeError: module 'kdtree' has no attribute 'KDTree'