In [1]:
import torch
import torch.optim as optim
import numpy as np
import os
import argparse
import time, datetime
import matplotlib; matplotlib.use('Agg')
from src import config, data
from src.checkpoints import CheckpointIO
from collections import defaultdict
import shutil
from tensorboardX import SummaryWriter

In [2]:
writer = SummaryWriter('testrun')
cfg = config.load_config('configs/pointcloud/grid.yaml', 'configs/default.yaml')
is_cuda = (torch.cuda.is_available())
device = torch.device("cuda:1" if is_cuda else "cpu")

In [3]:
print(device)

cuda:1


In [4]:
t0 = time.time()

# Shorthands
out_dir = cfg['training']['out_dir']
batch_size = cfg['training']['batch_size']
backup_every = cfg['training']['backup_every']
vis_n_outputs = cfg['generation']['vis_n_outputs']

In [5]:
# Output directory
if not os.path.exists(out_dir):
    os.makedirs(out_dir)

shutil.copyfile('configs/pointcloud/grid.yaml', os.path.join(out_dir, 'config.yaml'))

# Dataset
train_dataset = config.get_dataset('train', cfg)
val_dataset = config.get_dataset('val', cfg, return_idx=True)

In [6]:
shape = train_dataset[0]

In [7]:
# train_loader = torch.utils.data.DataLoader(
#     train_dataset[5], batch_size=batch_size, num_workers=cfg['training']['n_workers'], shuffle=True,
#     collate_fn=data.collate_remove_none,
#     worker_init_fn=data.worker_init_fn)

# print(len(train_loader))    
# print(train_loader)

# val_loader = torch.utils.data.DataLoader(
#     val_dataset, batch_size=1, num_workers=cfg['training']['n_workers_val'], shuffle=False,
#      collate_fn=data.collate_remove_none,
#      worker_init_fn=data.worker_init_fn)


In [8]:
# code subset
train_dataset = config.get_dataset("train", cfg)
ds = torch.utils.data.Subset(train_dataset, indices= [0]*len(train_dataset))
val_dataset = config.get_dataset("val", cfg)
val_ds = torch.utils.data.Subset(val_dataset, indices= [0]*len(val_dataset))
train_loader = torch.utils.data.DataLoader(
    ds, batch_size=32, num_workers=8, shuffle=True,
    collate_fn=data.collate_remove_none,
    worker_init_fn=data.worker_init_fn)
# val_loader = torch.utils.data.DataLoader(
#     ds, batch_size=1, num_workers=8, shuffle=True,
#     collate_fn=data.collate_remove_none,
#     worker_init_fn=data.worker_init_fn)

In [9]:
# Model
model = config.get_model(cfg, device=device, dataset=train_dataset)

# Generator
generator = config.get_generator(model, cfg, device=device)

# Intialize training
optimizer = optim.Adam(model.parameters(), lr=1e-4)
# optimizer = optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)
trainer = config.get_trainer(model, optimizer, cfg, device=device)

checkpoint_io = CheckpointIO(out_dir, model=model, optimizer=optimizer)
# try:
#     load_dict = checkpoint_io.load('model.pt')
# except FileExistsError:
#     load_dict = dict()
# epoch_it = load_dict.get('epoch_it', 0)
# it = load_dict.get('it', 0)
# metric_val_best = load_dict.get(
#      'loss_val_best', -model_selection_sign * np.inf)

# if metric_val_best == np.inf or metric_val_best == -np.inf:
#     metric_val_best = -model_selection_sign * np.inf
# print('Current best validation metric (%s): %.8f'
#       % (model_selection_metric, metric_val_best))

In [10]:
# Shorthands
print_every = cfg['training']['print_every']
checkpoint_every = cfg['training']['checkpoint_every']
validate_every = cfg['training']['validate_every']
visualize_every = cfg['training']['visualize_every']

# Print model
nparameters = sum(p.numel() for p in model.parameters())
print('Total number of parameters: %d' % nparameters)

print('output path: ', cfg['training']['out_dir'])

Total number of parameters: 1978275
output path:  out/pointcloud/grid


In [11]:
#batch = next(train_loader.__iter__())

In [12]:
# c = model.encode_inputs(batch.get('inputs').cuda())
# results = model.decode(batch.get('points').cuda(), c).logits

In [13]:
it = 0
epoch_it = 0
batch = next(train_loader.__iter__())
while epoch_it <= 30000:
    epoch_it += 1

    it += 1
    loss = trainer.train_step(batch)
    #logger.add_scalar('train/loss', loss, it)
    writer.add_scalars("metropolis+sigma0.1", {'global_loss':loss[0],
                                        'norm_loss':loss[1],
                                        'gradient_loss':loss[2], 
                                        'ratio_relu':loss[3]}, it)
    
    
    #writer.add_scalar("dotProductReg", loss[0], it)
    # Print output
    if print_every > 0 and (it % print_every) == 0:
        t = datetime.datetime.now()
        print('[Epoch %02d] it=%03d, loss=%.4f, time: %.2fs, %02d:%02d'
                    % (epoch_it, it, loss[0], time.time() - t0, t.hour, t.minute))

    #data_v = next(in data_vis_.__iter__())ist:
    # Visualize output
#     if visualize_every > 0 and (it % visualize_every) == 0:
#         print('Visualizing')
#         for data_vis in data_vis_list:
#             if cfg['generation']['sliding_window']:
#                 out = generator.generate_mesh_sliding(data_vis['data'])
#             else:
#                 out = generator.generate_mesh(data_vis['data'])
#             # Get statistics
#             try:
#                 mesh, stats_dict = out
#             except TypeError:
#                 mesh, stats_dict = out, {}

#             mesh.export(os.path.join(out_dir, 'vis', '{}_{}_{}.off'.format(it, data_vis['category'], data_vis['it'])))


    # Save checkpoint
    if (checkpoint_every > 0 and (it % checkpoint_every) == 0):
        print('Saving checkpoint')
        checkpoint_io.save('metropolis+sigma0.1.pt', epoch_it=epoch_it, it=it)

    # Backup if necessary
    if (backup_every > 0 and (it % backup_every) == 0):
        print('Backup checkpoint')
        checkpoint_io.save('metropolis+sigma0.1Model_%d.pt' % it, epoch_it=epoch_it, it=it)
        
    # Run validation
#     if validate_every > 0 and (it % validate_every) == 0:
#         eval_dict = trainer.evaluate(val_loader)
#         metric_val = eval_dict[model_selection_metric]
#         print('Validation metric (%s): %.4f'
#                 % (model_selection_metric, metric_val))

#         for k, v in eval_dict.items():
#             logger.add_scalar('val/%s' % k, v, it)

#         if model_selection_sign * (metric_val - metric_val_best) > 0:
#             metric_val_best = metric_val
#             print('New best model (loss %.4f)' % metric_val_best)
#             checkpoint_io.save('model_best.pt', epoch_it=epoch_it, it=it,
#                                 loss_val_best=metric_val_best)

    # Exit if necessary
#     if exit_after > 0 and (time.time() - t0) >= exit_after:
#         print('Time limit reached. Exiting.')
#         checkpoint_io.save('model.pt', epoch_it=epoch_it, it=it,
#                             loss_val_best=metric_val_best)

#print(train_loader)

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


[Epoch 100] it=100, loss=0.5443, time: 73.70s, 11:00
[Epoch 200] it=200, loss=0.6636, time: 133.53s, 11:01
Saving checkpoint
[Epoch 300] it=300, loss=0.8100, time: 193.87s, 11:02
[Epoch 400] it=400, loss=0.8754, time: 253.91s, 11:03
Saving checkpoint
[Epoch 500] it=500, loss=0.9108, time: 313.96s, 11:04
[Epoch 600] it=600, loss=0.9362, time: 373.96s, 11:05
Saving checkpoint
[Epoch 700] it=700, loss=0.9575, time: 434.49s, 11:06
[Epoch 800] it=800, loss=0.9777, time: 494.35s, 11:07
Saving checkpoint
[Epoch 900] it=900, loss=0.9811, time: 554.37s, 11:08
[Epoch 1000] it=1000, loss=0.9956, time: 614.31s, 11:09
Saving checkpoint
[Epoch 1100] it=1100, loss=0.9966, time: 675.08s, 11:10
[Epoch 1200] it=1200, loss=0.9936, time: 735.82s, 11:11
Saving checkpoint
[Epoch 1300] it=1300, loss=0.9899, time: 796.20s, 11:12
[Epoch 1400] it=1400, loss=0.9870, time: 855.24s, 11:13
Saving checkpoint
[Epoch 1500] it=1500, loss=0.9854, time: 914.35s, 11:14
[Epoch 1600] it=1600, loss=0.9816, time: 973.43s, 11:

[Epoch 12500] it=12500, loss=0.9962, time: 7378.59s, 13:02
[Epoch 12600] it=12600, loss=0.9952, time: 7432.52s, 13:03
Saving checkpoint
[Epoch 12700] it=12700, loss=0.9958, time: 7489.95s, 13:04
[Epoch 12800] it=12800, loss=0.9956, time: 7548.75s, 13:05
Saving checkpoint
[Epoch 12900] it=12900, loss=0.9962, time: 7607.76s, 13:06
[Epoch 13000] it=13000, loss=0.9952, time: 7666.69s, 13:07
Saving checkpoint
[Epoch 13100] it=13100, loss=0.9959, time: 7725.71s, 13:08
[Epoch 13200] it=13200, loss=0.9992, time: 7784.66s, 13:09
Saving checkpoint
[Epoch 13300] it=13300, loss=0.9958, time: 7843.76s, 13:10
[Epoch 13400] it=13400, loss=0.9965, time: 7902.74s, 13:11
Saving checkpoint
[Epoch 13500] it=13500, loss=0.9953, time: 7961.71s, 13:12
[Epoch 13600] it=13600, loss=0.9963, time: 8020.69s, 13:13
Saving checkpoint
[Epoch 13700] it=13700, loss=0.9276, time: 8079.70s, 13:14
[Epoch 13800] it=13800, loss=0.9355, time: 8138.62s, 13:15
Saving checkpoint
[Epoch 13900] it=13900, loss=0.9574, time: 8197.

[Epoch 24400] it=24400, loss=0.9988, time: 14396.99s, 14:59
Saving checkpoint
[Epoch 24500] it=24500, loss=0.9985, time: 14456.13s, 15:00
[Epoch 24600] it=24600, loss=0.9983, time: 14516.42s, 15:01
Saving checkpoint
[Epoch 24700] it=24700, loss=0.9988, time: 14577.31s, 15:02
[Epoch 24800] it=24800, loss=0.9987, time: 14638.37s, 15:03
Saving checkpoint
[Epoch 24900] it=24900, loss=0.9992, time: 14699.48s, 15:04
[Epoch 25000] it=25000, loss=0.9983, time: 14760.52s, 15:05
Saving checkpoint
[Epoch 25100] it=25100, loss=0.9988, time: 14821.71s, 15:06
[Epoch 25200] it=25200, loss=0.9987, time: 14882.71s, 15:07
Saving checkpoint
[Epoch 25300] it=25300, loss=0.9991, time: 14941.51s, 15:08
[Epoch 25400] it=25400, loss=0.9988, time: 15000.22s, 15:09
Saving checkpoint
[Epoch 25500] it=25500, loss=0.9986, time: 15059.24s, 15:10
[Epoch 25600] it=25600, loss=0.9989, time: 15118.25s, 15:11
Saving checkpoint
[Epoch 25700] it=25700, loss=0.9985, time: 15177.31s, 15:12
[Epoch 25800] it=25800, loss=0.998