In [1]:
import torch,sys,os,time,skimage
import numpy as np
import mcubes,trimesh
from tqdm import tqdm
from skimage import measure
import matplotlib.pyplot as plt
from omegaconf import OmegaConf
import torch.nn.functional as F
from torch.utils.data import DataLoader

sys.path.append('..')

from models.FactorFields import FactorFields 

from utils import SimpleSampler
from dataLoader import dataset_dict

device = 'cuda'
torch.cuda.set_device(0)

%load_ext autoreload
%autoreload 2

In [2]:

@torch.no_grad()
def eval_sdf(reso, bbox, chunk=10240):
    z = torch.linspace(0, bbox,reso[2])
    y = torch.linspace(0, bbox,reso[1])
    x = torch.linspace(0, bbox,reso[0])
    
    coordiantes = torch.empty((reso[2],reso[1],reso[0],3))
    coordiantes[...,0], coordiantes[...,1], coordiantes[...,2] = torch.meshgrid((x, y, z), indexing='ij')
    res = torch.empty(reso[2]*reso[1]*reso[0])
    
    count = 0
    coordiantes = coordiantes.reshape(-1,3)#/(torch.FloatTensor(reso[::-1])-1)*2-1
    coordiantes = torch.split(coordiantes,chunk,dim=0)
    for coordiante in tqdm(coordiantes):

        feats,_ = model.get_coding(coordiante.to(model.device))
        y_recon = model.linear_mat(feats)
        
        res[count:count+y_recon.shape[0]] = y_recon.cpu().view(-1)
        count += y_recon.shape[0]
        # res.append(y_recon.cpu())
    return res.reshape(*reso)

def eval_point(points):
    feats,_ = model.get_basis(points.to(device))
    return model.linear_mat(feats).squeeze().cpu()

def marchcude_to_world(vertices, reso_WHD):
    return vertices/(np.array(reso_WHD)-1)*2-1

@torch.no_grad()
def cal_l1_iou(test_dataset, chunk=10240):
    sdf, coordiantes = test_dataset.sdf, test_dataset.coordiante
    
    sdf_pred = []
    for coordiante in torch.split(coordiantes, chunk, dim=0):
        feats,_ = model.get_coding(coordiante.to(model.device))
        y_recon = model.linear_mat(feats)
        
        sdf_pred.append(y_recon.cpu())
    sdf_pred = torch.cat(sdf_pred)
    
    l1 = (sdf_pred-sdf).abs().mean()
    iou =  torch.sum((sdf>0)&(sdf_pred>0)) / torch.sum(((sdf>0)|(sdf_pred>0)))
    return l1, iou

avg_pool_3d = torch.nn.AvgPool3d(2, stride=2)
upsample = torch.nn.Upsample(scale_factor=2, mode='nearest')
@torch.no_grad()
def get_surface_sliding(sdf, path=None, resolution=512, grid_boundary=[-1.0, 1.0], return_mesh=False, level=0):
    assert resolution % 512 == 0
    resN = resolution
    cropN = 512
    level = 0
    N = resN // cropN

    grid_min = [grid_boundary[0], grid_boundary[0], grid_boundary[0]]
    grid_max = [grid_boundary[1], grid_boundary[1], grid_boundary[1]]
    xs = np.linspace(grid_min[0], grid_max[0], N+1)
    ys = np.linspace(grid_min[1], grid_max[1], N+1)
    zs = np.linspace(grid_min[2], grid_max[2], N+1)


    meshes = []
    for i in range(N):
        for j in range(N):
            for k in range(N):
                x_min, x_max = xs[i], xs[i+1]
                y_min, y_max = ys[j], ys[j+1]
                z_min, z_max = zs[k], zs[k+1]

                x = np.linspace(x_min, x_max, cropN)
                y = np.linspace(y_min, y_max, cropN)
                z = np.linspace(z_min, z_max, cropN)

                xx, yy, zz = np.meshgrid(x, y, z, indexing='ij')
                points = torch.tensor(np.vstack([xx.ravel(), yy.ravel(), zz.ravel()]).T, dtype=torch.float)
                
                def evaluate(points):
                    z = []
                    for _, pnts in enumerate(torch.split(points, 100000, dim=0)):
                        z.append(-sdf(pnts))
                    z = torch.cat(z, axis=0)
                    return z
            
                # construct point pyramids
                points = points.reshape(cropN, cropN, cropN, 3).permute(3, 0, 1, 2)
                
                points_pyramid = [points]
                for _ in range(3):            
                    points = avg_pool_3d(points[None])[0]
                    points_pyramid.append(points)
                points_pyramid = points_pyramid[::-1]
                
                # evalute pyramid with mask
                mask = None
                threshold = 2 * (x_max - x_min)/cropN * 8
                for pid, pts in enumerate(points_pyramid):
                    coarse_N = pts.shape[-1]
                    pts = pts.reshape(3, -1).permute(1, 0).contiguous()
                    
                    if mask is None:    
                        pts_sdf = evaluate(pts)
                    else:                    
                        mask = mask.reshape(-1)
                        pts_to_eval = pts[mask]
                        #import pdb; pdb.set_trace()
                        if pts_to_eval.shape[0] > 0:
                            pts_sdf_eval = evaluate(pts_to_eval.contiguous())
                            pts_sdf[mask] = pts_sdf_eval

                    if pid < 3:
                        # update mask
                        mask = torch.abs(pts_sdf) < threshold
                        mask = mask.reshape(coarse_N, coarse_N, coarse_N)[None, None]
                        mask = upsample(mask.float()).bool()

                        pts_sdf = pts_sdf.reshape(coarse_N, coarse_N, coarse_N)[None, None]
                        pts_sdf = upsample(pts_sdf)
                        pts_sdf = pts_sdf.reshape(-1)

                    threshold /= 2.

                z = pts_sdf.detach().cpu().numpy()

                if (not (np.min(z) > level or np.max(z) < level)):
                    z = z.astype(np.float32)
                    verts, faces, normals, values = measure.marching_cubes(
                    volume=z.reshape(cropN, cropN, cropN), #.transpose([1, 0, 2]),
                    level=level,
                    spacing=(
                            (x_max - x_min)/(cropN-1),
                            (y_max - y_min)/(cropN-1),
                            (z_max - z_min)/(cropN-1) ))

                    verts = verts + np.array([x_min, y_min, z_min])
                    
                    meshcrop = trimesh.Trimesh(verts, faces)
                    #meshcrop.export(f"{i}_{j}_{k}.ply")
                    meshes.append(meshcrop)

    combined = trimesh.util.concatenate(meshes)

    combined.vertices = combined.vertices/grid_boundary[1]*2-1.0
    
    if return_mesh:
        return combined
    elif path is not None:
        combined.export(f'{path}.ply')  
        



In [10]:

base_conf = OmegaConf.load('../configs/defaults.yaml')
second_conf = OmegaConf.load('../configs/sdf.yaml')
cfg = OmegaConf.merge(base_conf, second_conf)

dataset = dataset_dict[cfg.dataset.dataset_name]

is_save_mesh = False

scores = []
for mode in ['8M']:
    for scene in ['armadillo','statuette','dragon','lucy']:

        cfg.dataset.datadir = f'../data/SDF/{scene}_{mode}.npy'
        train_dataset = dataset(cfg.dataset, split='train')
        test_dataset = dataset(cfg.dataset, split='test')

        batch_size = cfg.training.batch_size
        n_iter = cfg.training.n_iters

        model = FactorFields(cfg, device)

        print(cfg.dataset.datadir)
        sdf, coordiantes = train_dataset.sdf.to(device), train_dataset.coordiante.to(device)
        trainingSampler = SimpleSampler(len(train_dataset), cfg.training.batch_size)

        grad_vars = model.get_optparam_groups(lr_small=cfg.training.lr_small,lr_large=cfg.training.lr_large)
        optimizer = torch.optim.Adam(grad_vars, betas=(0.9, 0.99))#


        loss_scale = 1.0
        lr_factor = 0.1 ** (1 / n_iter)
        pbar = tqdm(range(n_iter))
        start = time.time()
        for iteration in pbar:
            loss_scale *= lr_factor


            pixel_idx = torch.randint(0,len(train_dataset),(batch_size,))


            feats, coeffs = model.get_coding(coordiantes[pixel_idx])
            sdf_recon = model.linear_mat(feats)

            loss_dist = torch.mean((sdf_recon-sdf[pixel_idx])**2) 

            loss = loss_dist * loss_scale

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if iteration%100==0:
                pbar.set_description(
                            f'Iteration {iteration:05d}:'
                            + f' loss_dist = {loss_dist.item():.8f}'
                        )

            
        time_takes = time.time()-start
        reso = train_dataset.DHW[::-1]
        # sdf_res = eval_sdf([384]*3)
        mae,  gIoU = cal_l1_iou(test_dataset)
        torch.set_printoptions(precision=6)
        scores.append(gIoU)
        print(gIoU,time_takes)
        np.savetxt(f'../logs/SDF/{scene}.txt',np.array([gIoU.item(),time_takes]))

        if is_save_mesh:
            _reso = 1024
            sdf_res = eval_sdf([_reso]*3,train_dataset.DHW[0])
            vertices, triangles, normals, values = skimage.measure.marching_cubes(
                    sdf_res.numpy(), level=0.0
                )
            vertices = marchcude_to_world(vertices, [_reso]*3)
            triangles = triangles[...,::-1]
            mesh = trimesh.Trimesh(vertices, triangles)
            mesh.export(f'../logs/SDF//{scene}.ply');
        
print(np.mean(scores))

=====> total parameters:  5342274
/vlg-nfs/anpei/Code/NeuBasis/data/mesh//armadillo_8M.npy


Iteration 09900: loss_dist = 0.00000025: 100%|███████████████████████████████████████████████████| 10000/10000 [00:30<00:00, 325.66it/s]


tensor(0.986945) 30.70767855644226


100%|█████████████████████████████████████████████████████████████████████████████████████████| 104858/104858 [01:39<00:00, 1055.10it/s]


=====> total parameters:  5342274
/vlg-nfs/anpei/Code/NeuBasis/data/mesh//statuette_8M.npy


Iteration 09900: loss_dist = 0.00001105: 100%|███████████████████████████████████████████████████| 10000/10000 [00:31<00:00, 316.70it/s]


tensor(0.966919) 31.576430797576904


100%|█████████████████████████████████████████████████████████████████████████████████████████| 104858/104858 [01:37<00:00, 1075.21it/s]


=====> total parameters:  5342274
/vlg-nfs/anpei/Code/NeuBasis/data/mesh//dragon_8M.npy


Iteration 09900: loss_dist = 0.00000032: 100%|███████████████████████████████████████████████████| 10000/10000 [00:30<00:00, 329.56it/s]


tensor(0.980855) 30.343759059906006


100%|█████████████████████████████████████████████████████████████████████████████████████████| 104858/104858 [01:37<00:00, 1070.92it/s]


=====> total parameters:  5342274
/vlg-nfs/anpei/Code/NeuBasis/data/mesh//lucy_8M.npy


Iteration 09900: loss_dist = 0.00000029: 100%|███████████████████████████████████████████████████| 10000/10000 [00:30<00:00, 333.20it/s]


tensor(0.983279) 30.012317180633545


100%|█████████████████████████████████████████████████████████████████████████████████████████| 104858/104858 [01:37<00:00, 1073.82it/s]


0.9794994
