In [1]:
import torch
import torch.optim as optim
import numpy as np
import os
import argparse
import time, datetime
import matplotlib; matplotlib.use('Agg')
from src import config, data
from src.checkpoints import CheckpointIO
from collections import defaultdict
import shutil
from tensorboardX import SummaryWriter

In [2]:
writer = SummaryWriter('testrun')
cfg = config.load_config('configs/pointcloud/grid.yaml', 'configs/default.yaml')
is_cuda = (torch.cuda.is_available())
device = torch.device("cuda:2" if is_cuda else "cpu")

In [3]:
print(device)

cuda:2


In [4]:
t0 = time.time()

# Shorthands
out_dir = cfg['training']['out_dir']
batch_size = cfg['training']['batch_size']
backup_every = cfg['training']['backup_every']
vis_n_outputs = cfg['generation']['vis_n_outputs']

In [5]:
# Output directory
if not os.path.exists(out_dir):
    os.makedirs(out_dir)

shutil.copyfile('configs/pointcloud/grid.yaml', os.path.join(out_dir, 'config.yaml'))

# Dataset
train_dataset = config.get_dataset('train', cfg)
val_dataset = config.get_dataset('val', cfg, return_idx=True)

In [6]:
shape = train_dataset[0]

In [7]:
# train_loader = torch.utils.data.DataLoader(
#     train_dataset[5], batch_size=batch_size, num_workers=cfg['training']['n_workers'], shuffle=True,
#     collate_fn=data.collate_remove_none,
#     worker_init_fn=data.worker_init_fn)

# print(len(train_loader))    
# print(train_loader)

# val_loader = torch.utils.data.DataLoader(
#     val_dataset, batch_size=1, num_workers=cfg['training']['n_workers_val'], shuffle=False,
#      collate_fn=data.collate_remove_none,
#      worker_init_fn=data.worker_init_fn)


In [8]:
# code subset
train_dataset = config.get_dataset("train", cfg)
ds = torch.utils.data.Subset(train_dataset, indices= [0]*len(train_dataset))
val_dataset = config.get_dataset("val", cfg)
val_ds = torch.utils.data.Subset(val_dataset, indices= [0]*len(val_dataset))
train_loader = torch.utils.data.DataLoader(
    ds, batch_size=32, num_workers=8, shuffle=True,
    collate_fn=data.collate_remove_none,
    worker_init_fn=data.worker_init_fn)
# val_loader = torch.utils.data.DataLoader(
#     ds, batch_size=1, num_workers=8, shuffle=True,
#     collate_fn=data.collate_remove_none,
#     worker_init_fn=data.worker_init_fn)

In [9]:
# Model
model = config.get_model(cfg, device=device, dataset=train_dataset)

# Generator
generator = config.get_generator(model, cfg, device=device)

# Intialize training
optimizer = optim.Adam(model.parameters(), lr=1e-4)
# optimizer = optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)
trainer = config.get_trainer(model, optimizer, cfg, device=device)

checkpoint_io = CheckpointIO(out_dir, model=model, optimizer=optimizer)
# try:
#     load_dict = checkpoint_io.load('model.pt')
# except FileExistsError:
#     load_dict = dict()
# epoch_it = load_dict.get('epoch_it', 0)
# it = load_dict.get('it', 0)
# metric_val_best = load_dict.get(
#      'loss_val_best', -model_selection_sign * np.inf)

# if metric_val_best == np.inf or metric_val_best == -np.inf:
#     metric_val_best = -model_selection_sign * np.inf
# print('Current best validation metric (%s): %.8f'
#       % (model_selection_metric, metric_val_best))

In [10]:
# Shorthands
print_every = cfg['training']['print_every']
checkpoint_every = cfg['training']['checkpoint_every']
validate_every = cfg['training']['validate_every']
visualize_every = cfg['training']['visualize_every']

# Print model
nparameters = sum(p.numel() for p in model.parameters())
print('Total number of parameters: %d' % nparameters)

print('output path: ', cfg['training']['out_dir'])

Total number of parameters: 1978275
output path:  out/pointcloud/grid


In [11]:
#batch = next(train_loader.__iter__())

In [12]:
# c = model.encode_inputs(batch.get('inputs').cuda())
# results = model.decode(batch.get('points').cuda(), c).logits

In [13]:
it = 0
epoch_it = 0
batch = next(train_loader.__iter__())
while epoch_it <= 30000:
    epoch_it += 1

    it += 1
    loss = trainer.train_step(batch)
    #logger.add_scalar('train/loss', loss, it)
    writer.add_scalars("metropolis+sigma10", {'global_loss':loss[0],
                                        'norm_loss':loss[1],
                                        'gradient_loss':loss[2], 
                                        'ratio_relu':loss[3]}, it)
    
    
    #writer.add_scalar("dotProductReg", loss[0], it)
    # Print output
    if print_every > 0 and (it % print_every) == 0:
        t = datetime.datetime.now()
        print('[Epoch %02d] it=%03d, loss=%.4f, time: %.2fs, %02d:%02d'
                    % (epoch_it, it, loss[0], time.time() - t0, t.hour, t.minute))

    #data_v = next(in data_vis_.__iter__())ist:
    # Visualize output
#     if visualize_every > 0 and (it % visualize_every) == 0:
#         print('Visualizing')
#         for data_vis in data_vis_list:
#             if cfg['generation']['sliding_window']:
#                 out = generator.generate_mesh_sliding(data_vis['data'])
#             else:
#                 out = generator.generate_mesh(data_vis['data'])
#             # Get statistics
#             try:
#                 mesh, stats_dict = out
#             except TypeError:
#                 mesh, stats_dict = out, {}

#             mesh.export(os.path.join(out_dir, 'vis', '{}_{}_{}.off'.format(it, data_vis['category'], data_vis['it'])))


    # Save checkpoint
    if (checkpoint_every > 0 and (it % checkpoint_every) == 0):
        print('Saving checkpoint')
        checkpoint_io.save('metropolis+sigma10.pt', epoch_it=epoch_it, it=it)

    # Backup if necessary
    if (backup_every > 0 and (it % backup_every) == 0):
        print('Backup checkpoint')
        checkpoint_io.save('metropolis+sigma10Model_%d.pt' % it, epoch_it=epoch_it, it=it)
        
    # Run validation
#     if validate_every > 0 and (it % validate_every) == 0:
#         eval_dict = trainer.evaluate(val_loader)
#         metric_val = eval_dict[model_selection_metric]
#         print('Validation metric (%s): %.4f'
#                 % (model_selection_metric, metric_val))

#         for k, v in eval_dict.items():
#             logger.add_scalar('val/%s' % k, v, it)

#         if model_selection_sign * (metric_val - metric_val_best) > 0:
#             metric_val_best = metric_val
#             print('New best model (loss %.4f)' % metric_val_best)
#             checkpoint_io.save('model_best.pt', epoch_it=epoch_it, it=it,
#                                 loss_val_best=metric_val_best)

    # Exit if necessary
#     if exit_after > 0 and (time.time() - t0) >= exit_after:
#         print('Time limit reached. Exiting.')
#         checkpoint_io.save('model.pt', epoch_it=epoch_it, it=it,
#                             loss_val_best=metric_val_best)

#print(train_loader)

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


[Epoch 100] it=100, loss=-0.0681, time: 74.78s, 11:01
[Epoch 200] it=200, loss=0.1831, time: 134.57s, 11:02
Saving checkpoint
[Epoch 300] it=300, loss=0.0026, time: 194.43s, 11:03
[Epoch 400] it=400, loss=-0.0886, time: 254.85s, 11:04
Saving checkpoint
[Epoch 500] it=500, loss=-0.1231, time: 314.86s, 11:05
[Epoch 600] it=600, loss=-0.0916, time: 374.26s, 11:06
Saving checkpoint
[Epoch 700] it=700, loss=-0.0741, time: 434.08s, 11:07
[Epoch 800] it=800, loss=-0.0266, time: 495.15s, 11:08
Saving checkpoint
[Epoch 900] it=900, loss=-0.0379, time: 555.69s, 11:09
[Epoch 1000] it=1000, loss=-0.0396, time: 615.41s, 11:10
Saving checkpoint
[Epoch 1100] it=1100, loss=0.0009, time: 675.07s, 11:11
[Epoch 1200] it=1200, loss=0.0381, time: 734.60s, 11:12
Saving checkpoint
[Epoch 1300] it=1300, loss=0.0252, time: 794.42s, 11:13
[Epoch 1400] it=1400, loss=0.0394, time: 854.14s, 11:14
Saving checkpoint
[Epoch 1500] it=1500, loss=0.0699, time: 913.93s, 11:15
[Epoch 1600] it=1600, loss=0.0811, time: 973.

[Epoch 12500] it=12500, loss=0.0875, time: 7490.60s, 13:05
[Epoch 12600] it=12600, loss=0.0964, time: 7550.27s, 13:06
Saving checkpoint
[Epoch 12700] it=12700, loss=0.0679, time: 7610.09s, 13:07
[Epoch 12800] it=12800, loss=0.0797, time: 7670.28s, 13:08
Saving checkpoint
[Epoch 12900] it=12900, loss=0.0171, time: 7730.17s, 13:09
[Epoch 13000] it=13000, loss=0.1010, time: 7790.33s, 13:10
Saving checkpoint
[Epoch 13100] it=13100, loss=0.1006, time: 7849.78s, 13:11
[Epoch 13200] it=13200, loss=0.1036, time: 7909.23s, 13:12
Saving checkpoint
[Epoch 13300] it=13300, loss=0.2299, time: 7968.88s, 13:13
[Epoch 13400] it=13400, loss=0.1699, time: 8030.38s, 13:14
Saving checkpoint
[Epoch 13500] it=13500, loss=0.1248, time: 8092.03s, 13:15
[Epoch 13600] it=13600, loss=0.1257, time: 8153.65s, 13:16
Saving checkpoint
[Epoch 13700] it=13700, loss=0.0455, time: 8214.64s, 13:17
[Epoch 13800] it=13800, loss=0.1124, time: 8274.22s, 13:18
Saving checkpoint
[Epoch 13900] it=13900, loss=0.0829, time: 8335.

[Epoch 24400] it=24400, loss=0.0423, time: 14677.52s, 15:05
Saving checkpoint
[Epoch 24500] it=24500, loss=0.0779, time: 14737.19s, 15:06
[Epoch 24600] it=24600, loss=0.0600, time: 14796.85s, 15:07
Saving checkpoint
[Epoch 24700] it=24700, loss=0.0235, time: 14856.71s, 15:08
[Epoch 24800] it=24800, loss=0.0581, time: 14915.46s, 15:09
Saving checkpoint
[Epoch 24900] it=24900, loss=0.0671, time: 14974.90s, 15:10
[Epoch 25000] it=25000, loss=0.0383, time: 15034.31s, 15:11
Saving checkpoint
[Epoch 25100] it=25100, loss=0.0086, time: 15093.81s, 15:12
[Epoch 25200] it=25200, loss=0.0851, time: 15153.24s, 15:13
Saving checkpoint
[Epoch 25300] it=25300, loss=0.0575, time: 15212.76s, 15:14
[Epoch 25400] it=25400, loss=0.0889, time: 15272.20s, 15:15
Saving checkpoint
[Epoch 25500] it=25500, loss=0.0729, time: 15331.73s, 15:16
[Epoch 25600] it=25600, loss=0.0459, time: 15390.88s, 15:17
Saving checkpoint
[Epoch 25700] it=25700, loss=0.0799, time: 15452.34s, 15:18
[Epoch 25800] it=25800, loss=-0.05