In [1]:
import os
import glob
import cv2
import random
import pandas as pd
from skimage import io
import numpy as np
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import h5py

# Network building stuff
import torch
import torch.nn as nn
import torch.nn.functional as F

import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
import torchmetrics
import torch.distributions as dist


#mesh
from notebooks.utils.libmise.mise import  MISE
from notebooks.utils.libmcubes.mcubes import marching_cubes
import trimesh
 

ModuleNotFoundError: No module named 'notebooks.utils.libmise.mise'

'/scratch/shan/project-noisypixel'

In [2]:
import sys
# sys.path.append("/home2/sdokania/all_projects/project-noisypixel/")

In [3]:
from src.models import *
from src.dataset.dataloader import OccupancyNetDatasetHDF
from src.trainer import ONetLit
from src.utils import Config, count_parameters

In [4]:
config = Config()
config.data_root = "hdf_data/"
config.batch_size = 32
config.output_dir = 'experiment/'
config.exp_path = 'experiment/'

Setting sexperiment path as : /home2/sdokania/all_projects/occ_artifacts/initial


In [5]:

vars(config)

{'c_dim': 128,
 'h_dim': 128,
 'p_dim': 3,
 'data_root': 'hdf_data/',
 'batch_size': 32,
 'output_dir': 'experiment/',
 'exp_name': 'initial',
 'encoder': 'efficientnet-b0',
 'decoder': 'decoder-cbn',
 'exp_path': 'experiment/',
 'lr': 0.0003}

In [6]:
onet = ONetLit(config)

Loaded pretrained weights for efficientnet-b0


In [7]:
net = ONetLit.load_from_checkpoint("efficient_cbn_bs_64_full_data/lightning_logs/version_1/checkpoints/epoch=70-step=34079.ckpt", cfg=config)

Loaded pretrained weights for efficientnet-b0


In [8]:
dataset = OccupancyNetDatasetHDF(config.data_root, num_points=2048)

In [9]:
def make_3d_grid(bb_min, bb_max, shape):
    ''' Makes a 3D grid.
    Args:
        bb_min (tuple): bounding box minimum
        bb_max (tuple): bounding box maximum
        shape (tuple): output shape
    '''
    size = shape[0] * shape[1] * shape[2]

    pxs = torch.linspace(bb_min[0], bb_max[0], shape[0])
    pys = torch.linspace(bb_min[1], bb_max[1], shape[1])
    pzs = torch.linspace(bb_min[2], bb_max[2], shape[2])

    pxs = pxs.view(-1, 1, 1).expand(*shape).contiguous().view(size)
    pys = pys.view(1, -1, 1).expand(*shape).contiguous().view(size)
    pzs = pzs.view(1, 1, -1).expand(*shape).contiguous().view(size)
    p = torch.stack([pxs, pys, pzs], dim=1)

    return p

In [10]:
def eval_points(p,  c ,points_batch_size=100000 ):
    
    p_split = torch.split(p, points_batch_size)
    occ_hats = []

    for pi in p_split:
        pi = pi.unsqueeze(0) 
        with torch.no_grad():
            occ_hat = net.net.decoder(pi, c) 
            

        occ_hats.append(occ_hat.squeeze(0).detach().cpu())

    occ_hat = torch.cat(occ_hats, dim=0)

    return occ_hat

In [24]:
def extract_mesh(occ_hat,padding = 0.1,threshold_g = 0.2):
    n_x, n_y, n_z = occ_hat.shape
    box_size = 1 + padding
    threshold = np.log( threshold_g) - np.log(1. - threshold_g)
    
    occ_hat_padded = np.pad(occ_hat, 1, 'constant', constant_values=-1e6)
    print(threshold,occ_hat_padded.shape, np.min(occ_hat_padded), np.max(occ_hat_padded))
    vertices, triangles = marching_cubes(occ_hat_padded, threshold)
  
    vertices -= 0.5
    # Undo padding
    vertices -= 1
    # Normalize to bounding box
    vertices /= np.array([n_x-1, n_y-1, n_z-1])
    vertices = box_size * (vertices - 0.5)
    
    normals = None

    # Create mesh
    mesh = trimesh.Trimesh(vertices, triangles, vertex_normals=normals,process=False)

    return mesh

In [25]:
def get_mesh(dataset,padding=0.1, resolution0 = 32,upsampling_steps=2,threshold_g = 0.2 ):
    test_img, test_pts, test_gt = dataset[0]
    threshold = np.log( threshold_g) - np.log(1. - threshold_g)
    box_size = 1 +  padding
    nx = 32
    pointsf = 2 * make_3d_grid((-0.5,)*3, (0.5,)*3, (nx,)*3    )
    c = net.net.encoder(test_img.unsqueeze(0))#.detach()
    
    if(upsampling_steps==0):
     
        values = eval_points(pointsf,c ).cpu().numpy()
        value_grid = values.reshape(nx, nx, nx)

    else:
        mesh_extractor = MISE(resolution0, upsampling_steps, threshold)
        points = mesh_extractor.query()
 

        while points.shape[0] != 0:
            # Query points
            pointsf = torch.FloatTensor(points) 
            # Normalize to bounding box
            pointsf = pointsf / mesh_extractor.resolution
            pointsf = box_size * (pointsf - 0.5)
            # Evaluate model and update
            print(pointsf.shape, c.shape)
            values = eval_points( pointsf, c).cpu().numpy()


            values = values.astype(np.float64)

            mesh_extractor.update(points, values)
            points = mesh_extractor.query()


        value_grid = mesh_extractor.to_dense()

     
    mesh = extract_mesh(value_grid )
    
    return mesh

In [26]:
mesh = get_mesh(dataset)


mesh_out_file = os.path.join('./', '%s.off' % 'onet')
mesh.export(mesh_out_file)

torch.Size([35937, 3]) torch.Size([1, 128])
torch.Size([8472, 3]) torch.Size([1, 128])
torch.Size([32700, 3]) torch.Size([1, 128])
torch.Size([23508, 3]) torch.Size([1, 128])
torch.Size([4959, 3]) torch.Size([1, 128])
torch.Size([49, 3]) torch.Size([1, 128])
-1.3862943611198906 (131, 131, 131) -1000000.0 3.153803586959839


'OFF\n7372 10376 0\n-0.4813233934 -0.0687500000 -0.0343750000\n-0.4812500000 -0.0688252955 -0.0343750000\n-0.4812500000 -0.0687500000 -0.0344516971\n-0.4812500000 -0.0687500000 -0.0342975445\n-0.4814420410 -0.0687500000 0.0000000000\n-0.4812500000 -0.0689459392 0.0000000000\n-0.4812500000 -0.0687500000 -0.0002053239\n-0.4812500000 -0.0687500000 0.0002145037\n-0.4814837566 -0.0687500000 0.0343750000\n-0.4812500000 -0.0689983355 0.0343750000\n-0.4812500000 -0.0687500000 0.0341188447\n-0.4812500000 -0.0687500000 0.0346304002\n-0.4812500000 -0.0686733557 -0.0343750000\n-0.4812500000 -0.0685443083 0.0000000000\n-0.4812500000 -0.0685003871 0.0343750000\n-0.4814631739 -0.0343750000 -0.0343750000\n-0.4812500000 -0.0345968054 -0.0343750000\n-0.4812500000 -0.0343750000 -0.0345998610\n-0.4812500000 -0.0343750000 -0.0341478571\n-0.4816089526 -0.0343750000 0.0000000000\n-0.4812500000 -0.0347345791 0.0000000000\n-0.4812500000 -0.0343750000 -0.0003713691\n-0.4812500000 -0.0343750000 0.0003891476\n-0.