In [1]:
import torch
import torch.optim as optim
import numpy as np
import os
import argparse
import time, datetime
import matplotlib; matplotlib.use('Agg')
from src import config, data
from src.checkpoints import CheckpointIO
from collections import defaultdict
import shutil
from tensorboardX import SummaryWriter

In [2]:
writer = SummaryWriter()
cfg = config.load_config('configs/pointcloud/grid.yaml', 'configs/default.yaml')
is_cuda = (torch.cuda.is_available())
device = torch.device("cuda" if is_cuda else "cpu")

In [3]:
t0 = time.time()

# Shorthands
out_dir = cfg['training']['out_dir']
batch_size = cfg['training']['batch_size']
backup_every = cfg['training']['backup_every']
vis_n_outputs = cfg['generation']['vis_n_outputs']

In [4]:
model_selection_metric = cfg['training']['model_selection_metric']
if cfg['training']['model_selection_mode'] == 'maximize':
    model_selection_sign = 1
elif cfg['training']['model_selection_mode'] == 'minimize':
    model_selection_sign = -1
else:
    raise ValueError('model_selection_mode must be '
                     'either maximize or minimize.')

In [5]:
# Output directory
if not os.path.exists(out_dir):
    os.makedirs(out_dir)

shutil.copyfile('configs/pointcloud/grid.yaml', os.path.join(out_dir, 'config.yaml'))

# Dataset
train_dataset = config.get_dataset('train', cfg)
val_dataset = config.get_dataset('val', cfg, return_idx=True)
print(train_dataset[5])
print(train_dataset)
print(len(train_dataset[5]['points.occ']))
print(len(train_dataset))

{'points': array([[ 0.06367018, -0.23341139, -0.35384396],
       [ 0.06717573,  0.40972158, -0.5414992 ],
       [-0.48746377, -0.18413779, -0.45706937],
       ...,
       [ 0.2570478 ,  0.07432955,  0.41330966],
       [-0.28128406, -0.3550816 , -0.46653956],
       [ 0.38124532,  0.01783335, -0.27706045]], dtype=float32), 'points.occ': array([1., 0., 0., ..., 0., 0., 0.], dtype=float32), 'inputs': array([[ 0.12231445,  0.31811523,  0.04150391],
       [ 0.23632812,  0.3215332 ,  0.32861328],
       [ 0.11230469, -0.22375488, -0.37231445],
       ...,
       [ 0.1940918 ,  0.32080078, -0.21105957],
       [-0.1538086 , -0.02494812,  0.37280273],
       [ 0.07141113,  0.32006836,  0.3984375 ]], dtype=float32), 'inputs.normals': array([[-0.62597656, -0.7753906 ,  0.0836792 ],
       [ 0.00189972, -1.        , -0.01567078],
       [ 0.04974365, -0.15429688, -0.9868164 ],
       ...,
       [ 0.08111572, -0.99609375, -0.03234863],
       [ 0.8754883 , -0.15881348,  0.45629883],
       [

In [6]:
shape = train_dataset[0]

In [7]:
# train_loader = torch.utils.data.DataLoader(
#     train_dataset[5], batch_size=batch_size, num_workers=cfg['training']['n_workers'], shuffle=True,
#     collate_fn=data.collate_remove_none,
#     worker_init_fn=data.worker_init_fn)

# print(len(train_loader))    
# print(train_loader)

# val_loader = torch.utils.data.DataLoader(
#     val_dataset, batch_size=1, num_workers=cfg['training']['n_workers_val'], shuffle=False,
#      collate_fn=data.collate_remove_none,
#      worker_init_fn=data.worker_init_fn)


In [8]:
#Amine code subset

train_dataset = config.get_dataset("train", cfg)
ds = torch.utils.data.Subset(train_dataset, indices= [0]*len(train_dataset))
val_dataset = config.get_dataset("val", cfg)
val_ds = torch.utils.data.Subset(val_dataset, indices= [0]*len(val_dataset))
train_loader = torch.utils.data.DataLoader(
    ds, batch_size=32, num_workers=8, shuffle=True,
    collate_fn=data.collate_remove_none,
    worker_init_fn=data.worker_init_fn)
# val_loader = torch.utils.data.DataLoader(
#     ds, batch_size=1, num_workers=8, shuffle=True,
#     collate_fn=data.collate_remove_none,
#     worker_init_fn=data.worker_init_fn)

In [9]:
# For visualizations
vis_loader = torch.utils.data.DataLoader(
    val_dataset, batch_size=1, shuffle=False,
    collate_fn=data.collate_remove_none,
    worker_init_fn=data.worker_init_fn)
model_counter = defaultdict(int)
data_vis_list = []

# Build a data dictionary for visualization
iterator = iter(vis_loader)
for i in range(len(vis_loader)):
    data_vis = next(iterator)
    idx = data_vis['idx'].item()
    model_dict = val_dataset.get_model_dict(idx)
    category_id = model_dict.get('category', 'n/a')
    category_name = val_dataset.metadata[category_id].get('name', 'n/a')
    category_name = category_name.split(',')[0]
    if category_name == 'n/a':
        category_name = category_id

    c_it = model_counter[category_id]
    if c_it < vis_n_outputs:
        data_vis_list.append({'category': category_name, 'it': c_it, 'data': data_vis})

    model_counter[category_id] += 1


KeyError: 'idx'

In [10]:
# Model
model = config.get_model(cfg, device=device, dataset=train_dataset)

# Generator
generator = config.get_generator(model, cfg, device=device)

# Intialize training
optimizer = optim.Adam(model.parameters(), lr=1e-4)
# optimizer = optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)
trainer = config.get_trainer(model, optimizer, cfg, device=device)

checkpoint_io = CheckpointIO(out_dir, model=model, optimizer=optimizer)
# try:
#     load_dict = checkpoint_io.load('model.pt')
# except FileExistsError:
#     load_dict = dict()
# epoch_it = load_dict.get('epoch_it', 0)
# it = load_dict.get('it', 0)
# metric_val_best = load_dict.get(
#      'loss_val_best', -model_selection_sign * np.inf)

# if metric_val_best == np.inf or metric_val_best == -np.inf:
#     metric_val_best = -model_selection_sign * np.inf
# print('Current best validation metric (%s): %.8f'
#       % (model_selection_metric, metric_val_best))

In [11]:
# Shorthands
print_every = cfg['training']['print_every']
checkpoint_every = cfg['training']['checkpoint_every']
validate_every = cfg['training']['validate_every']
visualize_every = cfg['training']['visualize_every']

# Print model
nparameters = sum(p.numel() for p in model.parameters())
print('Total number of parameters: %d' % nparameters)

print('output path: ', cfg['training']['out_dir'])

Total number of parameters: 1978275
output path:  out/pointcloud/grid


In [12]:
batch = next(train_loader.__iter__())

In [13]:
c = model.encode_inputs(batch.get('inputs').cuda())
results = model.decode(batch.get('points').cuda(), c).logits

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


In [14]:
it = 0
epoch_it = 0
batch = next(train_loader.__iter__())
while epoch_it<5000:
    epoch_it += 1

    it += 1
    loss = trainer.train_step(batch)
    #logger.add_scalar('train/loss', loss, it)
    writer.add_scalar("basicModel", loss, it)

    # Print output
    if print_every > 0 and (it % print_every) == 0:
        t = datetime.datetime.now()
        print('[Epoch %02d] it=%03d, loss=%.4f, time: %.2fs, %02d:%02d'
                    % (epoch_it, it, loss, time.time() - t0, t.hour, t.minute))

    #data_v = next(in data_vis_.__iter__())ist:
    # Visualize output
    if visualize_every > 0 and (it % visualize_every) == 0:
        print('Visualizing')
        for data_vis in data_vis_list:
            if cfg['generation']['sliding_window']:
                out = generator.generate_mesh_sliding(data_vis['data'])
            else:
                out = generator.generate_mesh(data_vis['data'])
            # Get statistics
            try:
                mesh, stats_dict = out
            except TypeError:
                mesh, stats_dict = out, {}

            mesh.export(os.path.join(out_dir, 'vis', '{}_{}_{}.off'.format(it, data_vis['category'], data_vis['it'])))


    # Save checkpoint
    if (checkpoint_every > 0 and (it % checkpoint_every) == 0):
        print('Saving checkpoint')
        checkpoint_io.save('basicModel.pt', epoch_it=epoch_it, it=it)

    # Backup if necessary
    if (backup_every > 0 and (it % backup_every) == 0):
        print('Backup checkpoint')
        checkpoint_io.save('basicModel_%d.pt' % it, epoch_it=epoch_it, it=it)
        
    # Run validation
#     if validate_every > 0 and (it % validate_every) == 0:
#         eval_dict = trainer.evaluate(val_loader)
#         metric_val = eval_dict[model_selection_metric]
#         print('Validation metric (%s): %.4f'
#                 % (model_selection_metric, metric_val))

#         for k, v in eval_dict.items():
#             logger.add_scalar('val/%s' % k, v, it)

#         if model_selection_sign * (metric_val - metric_val_best) > 0:
#             metric_val_best = metric_val
#             print('New best model (loss %.4f)' % metric_val_best)
#             checkpoint_io.save('model_best.pt', epoch_it=epoch_it, it=it,
#                                 loss_val_best=metric_val_best)

    # Exit if necessary
#     if exit_after > 0 and (time.time() - t0) >= exit_after:
#         print('Time limit reached. Exiting.')
#         checkpoint_io.save('model.pt', epoch_it=epoch_it, it=it,
#                             loss_val_best=metric_val_best)

#print(train_loader)

[Epoch 100] it=100, loss=2.7975, time: 147.94s, 19:05
[Epoch 200] it=200, loss=1.6242, time: 218.73s, 19:06
Visualizing
Saving checkpoint
[Epoch 300] it=300, loss=1.1940, time: 290.05s, 19:07
[Epoch 400] it=400, loss=1.0181, time: 361.30s, 19:09
Visualizing
Saving checkpoint
[Epoch 500] it=500, loss=0.8496, time: 450.09s, 19:10
[Epoch 600] it=600, loss=0.7280, time: 568.99s, 19:12
Visualizing
Saving checkpoint
[Epoch 700] it=700, loss=0.6575, time: 677.83s, 19:14
[Epoch 800] it=800, loss=0.5833, time: 773.04s, 19:16
Visualizing
Saving checkpoint
[Epoch 900] it=900, loss=0.5414, time: 868.30s, 19:17
[Epoch 1000] it=1000, loss=0.4948, time: 963.39s, 19:19
Visualizing
Saving checkpoint
[Epoch 1100] it=1100, loss=0.4606, time: 1057.99s, 19:20
[Epoch 1200] it=1200, loss=0.4359, time: 1154.45s, 19:22
Visualizing
Saving checkpoint
[Epoch 1300] it=1300, loss=0.4164, time: 1251.00s, 19:23
[Epoch 1400] it=1400, loss=0.3938, time: 1347.01s, 19:25
Visualizing
Saving checkpoint
[Epoch 1500] it=1500