In [1]:
import torch
import torch.optim as optim
import numpy as np
import os
import argparse
import time, datetime
import matplotlib; matplotlib.use('Agg')
from src import config, data
from src.checkpoints import CheckpointIO
from collections import defaultdict
import shutil
from tensorboardX import SummaryWriter

In [2]:
writer = SummaryWriter()
cfg = config.load_config('configs/pointcloud/grid.yaml', 'configs/default.yaml')
is_cuda = (torch.cuda.is_available())
device = torch.device("cuda" if is_cuda else "cpu")

In [3]:
t0 = time.time()

# Shorthands
out_dir = cfg['training']['out_dir']
batch_size = cfg['training']['batch_size']
backup_every = cfg['training']['backup_every']
vis_n_outputs = cfg['generation']['vis_n_outputs']

In [4]:
model_selection_metric = cfg['training']['model_selection_metric']
if cfg['training']['model_selection_mode'] == 'maximize':
    model_selection_sign = 1
elif cfg['training']['model_selection_mode'] == 'minimize':
    model_selection_sign = -1
else:
    raise ValueError('model_selection_mode must be '
                     'either maximize or minimize.')

In [5]:
# Output directory
if not os.path.exists(out_dir):
    os.makedirs(out_dir)

shutil.copyfile('configs/pointcloud/grid.yaml', os.path.join(out_dir, 'config.yaml'))

# Dataset
train_dataset = config.get_dataset('train', cfg)
val_dataset = config.get_dataset('val', cfg, return_idx=True)
print(train_dataset[5])
print(train_dataset)
print(len(train_dataset[5]['points.occ']))
print(len(train_dataset))

{'points': array([[-0.49517384,  0.28118938, -0.5330159 ],
       [-0.32216063, -0.4197425 , -0.5473141 ],
       [ 0.47376126,  0.35932037, -0.17588124],
       ...,
       [-0.29364073,  0.14965375, -0.4432282 ],
       [-0.33884   , -0.49018496,  0.35189503],
       [-0.16472606, -0.05514001,  0.33647585]], dtype=float32), 'points.occ': array([0., 0., 0., ..., 0., 0., 0.], dtype=float32), 'inputs': array([[-0.16711426,  0.30029297, -0.37597656],
       [-0.10113525,  0.34838867, -0.2088623 ],
       [ 0.02238464,  0.31884766, -0.3930664 ],
       ...,
       [ 0.22631836,  0.32177734, -0.4736328 ],
       [-0.15307617,  0.32641602,  0.5019531 ],
       [-0.16296387,  0.07574463,  0.3491211 ]], dtype=float32), 'inputs.normals': array([[-0.04034424, -0.46826172, -0.8828125 ],
       [ 0.17578125,  0.93847656, -0.2980957 ],
       [-0.03451538, -0.9995117 , -0.01184845],
       ...,
       [-0.03689575, -0.99902344, -0.01186371],
       [-0.03012085, -0.88671875,  0.46166992],
       [

In [6]:
shape = train_dataset[0]

In [7]:
# train_loader = torch.utils.data.DataLoader(
#     train_dataset[5], batch_size=batch_size, num_workers=cfg['training']['n_workers'], shuffle=True,
#     collate_fn=data.collate_remove_none,
#     worker_init_fn=data.worker_init_fn)

# print(len(train_loader))    
# print(train_loader)

# val_loader = torch.utils.data.DataLoader(
#     val_dataset, batch_size=1, num_workers=cfg['training']['n_workers_val'], shuffle=False,
#      collate_fn=data.collate_remove_none,
#      worker_init_fn=data.worker_init_fn)


In [8]:
#Amine code subset

train_dataset = config.get_dataset("train", cfg)
ds = torch.utils.data.Subset(train_dataset, indices= [0]*len(train_dataset))
val_dataset = config.get_dataset("val", cfg)
val_ds = torch.utils.data.Subset(val_dataset, indices= [0]*len(val_dataset))
train_loader = torch.utils.data.DataLoader(
    ds, batch_size=32, num_workers=8, shuffle=True,
    collate_fn=data.collate_remove_none,
    worker_init_fn=data.worker_init_fn)
# val_loader = torch.utils.data.DataLoader(
#     ds, batch_size=1, num_workers=8, shuffle=True,
#     collate_fn=data.collate_remove_none,
#     worker_init_fn=data.worker_init_fn)

In [9]:
# For visualizations
vis_loader = torch.utils.data.DataLoader(
    val_dataset, batch_size=1, shuffle=False,
    collate_fn=data.collate_remove_none,
    worker_init_fn=data.worker_init_fn)
model_counter = defaultdict(int)
data_vis_list = []

# Build a data dictionary for visualization
iterator = iter(vis_loader)
for i in range(len(vis_loader)):
    data_vis = next(iterator)
    idx = data_vis['idx'].item()
    model_dict = val_dataset.get_model_dict(idx)
    category_id = model_dict.get('category', 'n/a')
    category_name = val_dataset.metadata[category_id].get('name', 'n/a')
    category_name = category_name.split(',')[0]
    if category_name == 'n/a':
        category_name = category_id

    c_it = model_counter[category_id]
    if c_it < vis_n_outputs:
        data_vis_list.append({'category': category_name, 'it': c_it, 'data': data_vis})

    model_counter[category_id] += 1


KeyError: 'idx'

In [10]:
# Model
model = config.get_model(cfg, device=device, dataset=train_dataset)

# Generator
generator = config.get_generator(model, cfg, device=device)

# Intialize training
optimizer = optim.Adam(model.parameters(), lr=1e-4)
# optimizer = optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)
trainer = config.get_trainer(model, optimizer, cfg, device=device)

checkpoint_io = CheckpointIO(out_dir, model=model, optimizer=optimizer)
# try:
#     load_dict = checkpoint_io.load('model.pt')
# except FileExistsError:
#     load_dict = dict()
# epoch_it = load_dict.get('epoch_it', 0)
# it = load_dict.get('it', 0)
# metric_val_best = load_dict.get(
#      'loss_val_best', -model_selection_sign * np.inf)

# if metric_val_best == np.inf or metric_val_best == -np.inf:
#     metric_val_best = -model_selection_sign * np.inf
# print('Current best validation metric (%s): %.8f'
#       % (model_selection_metric, metric_val_best))

In [11]:
# Shorthands
print_every = cfg['training']['print_every']
checkpoint_every = cfg['training']['checkpoint_every']
validate_every = cfg['training']['validate_every']
visualize_every = cfg['training']['visualize_every']

# Print model
nparameters = sum(p.numel() for p in model.parameters())
print('Total number of parameters: %d' % nparameters)

print('output path: ', cfg['training']['out_dir'])

Total number of parameters: 1978275
output path:  out/pointcloud/grid


In [12]:
batch = next(train_loader.__iter__())

In [13]:
c = model.encode_inputs(batch.get('inputs').cuda())
results = model.decode(batch.get('points').cuda(), c).logits

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


In [14]:
it = 0
epoch_it = 0
batch = next(train_loader.__iter__())
while epoch_it<5000:
    epoch_it += 1

    it += 1
    loss = trainer.train_step(batch)
    #logger.add_scalar('train/loss', loss, it)
    writer.add_scalar("NormLoss", loss, it)

    # Print output
    if print_every > 0 and (it % print_every) == 0:
        t = datetime.datetime.now()
        print('[Epoch %02d] it=%03d, loss=%.4f, time: %.2fs, %02d:%02d'
                    % (epoch_it, it, loss, time.time() - t0, t.hour, t.minute))

    #data_v = next(in data_vis_.__iter__())ist:
    # Visualize output
    if visualize_every > 0 and (it % visualize_every) == 0:
        print('Visualizing')
        for data_vis in data_vis_list:
            if cfg['generation']['sliding_window']:
                out = generator.generate_mesh_sliding(data_vis['data'])
            else:
                out = generator.generate_mesh(data_vis['data'])
            # Get statistics
            try:
                mesh, stats_dict = out
            except TypeError:
                mesh, stats_dict = out, {}

            mesh.export(os.path.join(out_dir, 'vis', '{}_{}_{}.off'.format(it, data_vis['category'], data_vis['it'])))


    # Save checkpoint
    if (checkpoint_every > 0 and (it % checkpoint_every) == 0):
        print('Saving checkpoint')
        checkpoint_io.save('modelNorm.pt', epoch_it=epoch_it, it=it)

    # Backup if necessary
    if (backup_every > 0 and (it % backup_every) == 0):
        print('Backup checkpoint')
        checkpoint_io.save('modelNorm_%d.pt' % it, epoch_it=epoch_it, it=it)
        
    # Run validation
#     if validate_every > 0 and (it % validate_every) == 0:
#         eval_dict = trainer.evaluate(val_loader)
#         metric_val = eval_dict[model_selection_metric]
#         print('Validation metric (%s): %.4f'
#                 % (model_selection_metric, metric_val))

#         for k, v in eval_dict.items():
#             logger.add_scalar('val/%s' % k, v, it)

#         if model_selection_sign * (metric_val - metric_val_best) > 0:
#             metric_val_best = metric_val
#             print('New best model (loss %.4f)' % metric_val_best)
#             checkpoint_io.save('model_best.pt', epoch_it=epoch_it, it=it,
#                                 loss_val_best=metric_val_best)

    # Exit if necessary
#     if exit_after > 0 and (time.time() - t0) >= exit_after:
#         print('Time limit reached. Exiting.')
#         checkpoint_io.save('model.pt', epoch_it=epoch_it, it=it,
#                             loss_val_best=metric_val_best)

#print(train_loader)

[Epoch 100] it=100, loss=0.0005, time: 70.67s, 23:47
[Epoch 200] it=200, loss=0.0001, time: 118.43s, 23:47
Visualizing
Saving checkpoint
[Epoch 300] it=300, loss=0.0000, time: 165.21s, 23:48
[Epoch 400] it=400, loss=0.0000, time: 211.91s, 23:49
Visualizing
Saving checkpoint
[Epoch 500] it=500, loss=0.0000, time: 258.69s, 23:50
[Epoch 600] it=600, loss=0.0000, time: 305.39s, 23:50
Visualizing
Saving checkpoint
[Epoch 700] it=700, loss=0.0000, time: 352.14s, 23:51
[Epoch 800] it=800, loss=0.0000, time: 398.85s, 23:52
Visualizing
Saving checkpoint
[Epoch 900] it=900, loss=0.0000, time: 446.68s, 23:53
[Epoch 1000] it=1000, loss=0.0000, time: 494.22s, 23:54
Visualizing
Saving checkpoint
[Epoch 1100] it=1100, loss=0.0000, time: 540.95s, 23:54
[Epoch 1200] it=1200, loss=0.0000, time: 587.66s, 23:55
Visualizing
Saving checkpoint
[Epoch 1300] it=1300, loss=0.0000, time: 635.52s, 23:56
[Epoch 1400] it=1400, loss=0.0000, time: 682.92s, 23:57
Visualizing
Saving checkpoint
[Epoch 1500] it=1500, los