## Mesh Extraction
Run the following code to extract a mesh given a .pth file name. Change the pth_file variable to that of the model you would like to create a mesh for.

In [None]:
import mcubes
import trimesh
import torch
import numpy as np

from ml_helpers import load_checkpoint
from model import Nerf

In [None]:
device = 'cuda'
# pth_file = 'experiments/monkey_3_big_aug/monkey_3_big_aug.pth'
pth_file = 'nerf_models/monkey_biz.pth'
model = torch.load(pth_file).to(device)


In [None]:
# OR
device = 'cuda'
nb_epochs = 5
lr = 1e-3
gamma = 0.5
tn = 1 # these depend on the dataset tf/tn
tf = 10
nb_bins = 100
model_name = 'monkey_3_big_aug'
ckpt_name = '/ckpt100.pth'
checkpoint_path = 'experiments/'+model_name

model = Nerf(hidden_dim=256).to(device)
optimizer = torch.optim.Adam(model.parameters(),lr=lr)
# Every 5 epochs/images in online learning case, we multiply the learning rate by gamma (1/2)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[5,10],gamma=gamma)

ckpt_file = load_checkpoint(checkpoint_path+ckpt_name, model, optimizer, scheduler)

In [None]:

model.load_state_dict(ckpt_file['model_state_dict'])
optimizer.load_state_dict(ckpt_file['optimizer_state_dict'])
model.eval();


In [None]:
N = 100
scale = 2.5

x = torch.linspace(-scale, scale, N)
y = torch.linspace(-scale, scale, N)
z = torch.linspace(-scale, scale, N)

x, y, z = torch.meshgrid((x, y, z))

xyz = torch.cat((x.reshape(-1, 1),
                 y.reshape(-1, 1),
                 z.reshape(-1, 1)), dim=1)

In [None]:
with torch.no_grad():
    _, density = model.forward(xyz.to(device), torch.zeros_like(xyz).to(device))
    
density = density.cpu().numpy().reshape(N, N, N)

In [None]:
vertices, triangles = mcubes.marching_cubes(density, 10 * np.mean(density))
mesh = trimesh.Trimesh(vertices / N, triangles)
mesh.show()

In [None]:
# Convert the mesh to a point cloud
point_cloud = trimesh.points.PointCloud(vertices=vertices)


In [None]:
point_cloud.export('Clouds/coral_new.ply')