In [13]:
import numpy as np
import torch
from im2mesh.utils.libmise.mise import  MISE
from im2mesh.utils.libmcubes.mcubes import marching_cubes
import trimesh
import os

#model
from im2mesh import config
import argparse
from tqdm import tqdm
from im2mesh.checkpoints import CheckpointIO



In [28]:
cfg = config.load_config( 'configs/demo.yaml', 'configs/default.yaml') 
device = torch.device("cuda")


In [30]:
out_dir = cfg['training']['out_dir']
generation_dir = os.path.join(out_dir, cfg['generation']['generation_dir'])
out_time_file = os.path.join(generation_dir, 'time_generation_full.pkl')
out_time_file_class = os.path.join(generation_dir, 'time_generation.pkl')

batch_size = cfg['generation']['batch_size']
input_type = cfg['data']['input_type']
vis_n_outputs = cfg['generation']['vis_n_outputs']
if vis_n_outputs is None:
    vis_n_outputs = -1

# Dataset
dataset = config.get_dataset('test', cfg, return_idx=True)

 
model = config.get_model(cfg, device=device, dataset=dataset)

checkpoint_io = CheckpointIO(out_dir, model=model)
checkpoint_io.load(cfg['test']['model_file'])

# Generator
generator = config.get_generator(model, cfg, device=device)



https://s3.eu-central-1.amazonaws.com/avg-projects/occupancy_networks/models/onet_img2mesh_3-f786b04a.pt
=> Loading checkpoint from url...


In [31]:
test_loader = torch.utils.data.DataLoader( dataset, batch_size=1, num_workers=0, shuffle=False)
model.eval()



OccupancyNetwork(
  (decoder): DecoderCBatchNorm(
    (fc_p): Conv1d(3, 256, kernel_size=(1,), stride=(1,))
    (block0): CResnetBlockConv1d(
      (bn_0): CBatchNorm1d(
        (conv_gamma): Conv1d(256, 256, kernel_size=(1,), stride=(1,))
        (conv_beta): Conv1d(256, 256, kernel_size=(1,), stride=(1,))
        (bn): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True)
      )
      (bn_1): CBatchNorm1d(
        (conv_gamma): Conv1d(256, 256, kernel_size=(1,), stride=(1,))
        (conv_beta): Conv1d(256, 256, kernel_size=(1,), stride=(1,))
        (bn): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True)
      )
      (fc_0): Conv1d(256, 256, kernel_size=(1,), stride=(1,))
      (fc_1): Conv1d(256, 256, kernel_size=(1,), stride=(1,))
      (actvn): ReLU()
    )
    (block1): CResnetBlockConv1d(
      (bn_0): CBatchNorm1d(
        (conv_gamma): Conv1d(256, 256, kernel_size=(1,), stride=(1,))
        (conv_beta): Conv1d(256,

In [32]:
padding = 0.1
threshold_g = 0.2

In [3]:
def make_3d_grid(bb_min, bb_max, shape):
    ''' Makes a 3D grid.

    Args:
        bb_min (tuple): bounding box minimum
        bb_max (tuple): bounding box maximum
        shape (tuple): output shape
    '''
    size = shape[0] * shape[1] * shape[2]

    pxs = torch.linspace(bb_min[0], bb_max[0], shape[0])
    pys = torch.linspace(bb_min[1], bb_max[1], shape[1])
    pzs = torch.linspace(bb_min[2], bb_max[2], shape[2])

    pxs = pxs.view(-1, 1, 1).expand(*shape).contiguous().view(size)
    pys = pys.view(1, -1, 1).expand(*shape).contiguous().view(size)
    pzs = pzs.view(1, 1, -1).expand(*shape).contiguous().view(size)
    p = torch.stack([pxs, pys, pzs], dim=1)

    return p

In [35]:
def extract_mesh(occ_hat):
    n_x, n_y, n_z = occ_hat.shape
    box_size = 1 + padding
    threshold = np.log( threshold_g) - np.log(1. - threshold_g)
    
    occ_hat_padded = np.pad(occ_hat, 1, 'constant', constant_values=-1e6)
    
    vertices, triangles = marching_cubes(occ_hat_padded, threshold)
  
    vertices -= 0.5
    # Undo padding
    vertices -= 1
    # Normalize to bounding box
    vertices /= np.array([n_x-1, n_y-1, n_z-1])
    vertices = box_size * (vertices - 0.5)
    
    normals = None

    # Create mesh
    mesh = trimesh.Trimesh(vertices, triangles, vertex_normals=normals,process=False)

    return mesh

In [68]:
def eval_points(p, z, c=None  ):
    points_batch_size=100000
    p_split = torch.split(p, points_batch_size)
    occ_hats = []

    for pi in p_split:
        pi = pi.unsqueeze(0).to(device)
        with torch.no_grad():
            occ_hat = model.decode(pi,z,c).logits

        occ_hats.append(occ_hat.squeeze(0).detach().cpu())

    occ_hat = torch.cat(occ_hats, dim=0)

    return occ_hat

In [69]:
def get_mesh(occ, points,threshold = 0.5,padding=0.1,resolution0=32,upsampling_steps=2):
    model.eval()
    threshold = np.log(threshold_g) - np.log(1. - threshold_g)
    
    nx = 32
    pointsf = 2 * make_3d_grid((-0.5,)*3, (0.5,)*3, (nx,)*3    )
    for i,data in enumerate(tqdm(test_loader)):
        test_sample = data
        break
    
    inputs = test_sample.get('inputs', torch.empty(1, 0)).to(device)
    
#     points = mesh_extractor.query()
#     value_grid = occ.reshape(nx, nx, nx)
    z = model.get_z_from_prior((1,), sample=False).to(device)
    c = model.encode_inputs(inputs)
    values = eval_points(pointsf, z,c ).cpu().numpy()
    value_grid = values.reshape(nx, nx, nx)
     
    mesh = extract_mesh(value_grid )

    return mesh

In [70]:
occ_file = '../project-noisypixel/sample_data/points/occupancies.npy'
points_file = '../project-noisypixel/sample_data/points/points.npy'

points = np.load(points_file)

occ = np.load(occ_file)
occ = np.unpackbits(occ)

# idx = np.random.choice(np.arange(100000), 32768, replace=False)
# occ_sample = occ[idx]
# points_sample = points[idx]



mesh = get_mesh(occ_sample,points_sample)


mesh_out_file = os.path.join('./', '%s.off' % 'onet')
mesh.export(mesh_out_file)

  0%|          | 0/9 [00:00<?, ?it/s]


'OFF\n470 936 0\n-0.2666191939 -0.0532258065 -0.0887096774\n-0.2661290323 -0.0573409627 -0.0887096774\n-0.2661290323 -0.0532258065 -0.0929659685\n-0.2673791842 -0.0532258065 -0.0532258065\n-0.2661290323 -0.0636678603 -0.0532258065\n-0.2675825161 -0.0532258065 -0.0177419355\n-0.2661290323 -0.0656581460 -0.0177419355\n-0.2675704905 -0.0532258065 0.0177419355\n-0.2661290323 -0.0653436481 0.0177419355\n-0.2674851017 -0.0532258065 0.0532258065\n-0.2661290323 -0.0643696959 0.0532258065\n-0.2668755606 -0.0532258065 0.0887096774\n-0.2661290323 -0.0590907636 0.0887096774\n-0.2661290323 -0.0532258065 0.0967099254\n-0.2661290323 -0.0425152834 -0.0887096774\n-0.2661290323 -0.0227933788 -0.0532258065\n-0.2661409607 -0.0177419355 -0.0177419355\n-0.2661290323 -0.0177419355 -0.0195624129\n-0.2661290323 -0.0190353710 0.0177419355\n-0.2661290323 -0.0177419355 -0.0115374126\n-0.2661290323 -0.0235377009 0.0532258065\n-0.2661290323 -0.0371818080 0.0887096774\n-0.2661290323 -0.0170838481 -0.0177419355\n-0.2