In [None]:
import os
import torch
from torch import nn
from pytorch3d.io import load_obj, save_obj
from pytorch3d.structures import Meshes
from pytorch3d.utils import ico_sphere
from pytorch3d.ops import sample_points_from_meshes, knn_points, knn_gather
from pytorch3d.loss import (
    chamfer_distance, 
    mesh_edge_loss, 
    mesh_laplacian_smoothing, 
    mesh_normal_consistency,
)
import numpy as np
from tqdm.notebook import tqdm
import plotly.graph_objects as go
import plotly.express as px

import sys
sys.path.append('/mnt/raid/C1_ML_Analysis/source/ShapeAXI/src')
from shapeaxi import utils

from shapeaxi.saxi_nets import SaxiMHAIcoEncoder, SaxiMHAIcoDecoder
import pandas as pd

In [None]:
device = torch.device("cuda:0")


In [None]:
target_fn = '/mnt/famli_netapp_shared/C1_ML_Analysis/src/diffusion-models/blender/studies/placenta/FAM-025-0499-5/brain/leftWhiteMatter.stl'
target = utils.ReadSurf(target_fn)
target, target_mean_bb, target_scale_factor = utils.ScaleSurf(target)
target_v, target_f, target_e = utils.PolyDataToTensors(target, device=device)
target_mesh = Meshes(verts=[target_v], faces=[target_f])

In [None]:
def plot_pointcloud(mesh, title=""):
    points = sample_points_from_meshes(mesh, 5000)
    x, y, z = points.clone().detach().cpu().squeeze().unbind(1)    
    fig = go.Figure(data=[go.Scatter3d(x=x, y=y, z=z)])
    fig.show()

In [None]:
from importlib import reload
reload(utils)


ico_s = utils.IcoSphere(5)
source_v, source_f = utils.PolyDataToTensors_v_f(ico_s)

r_idx = torch.randint(0, len(source_v), (1,)).item()

print(utils.GetNeighbors(ico_s, r_idx))
print(utils.GetNeighborsT(source_f, r_idx))


In [None]:
from pytorch3d.ops import (sample_points_from_meshes,
                           knn_points, 
                           knn_gather)




sample_levels=5
model = SaxiMHAIcoEncoder(input_dim=4, sample_levels=sample_levels)

model.to(device)


N = ico_sphere(sample_levels).verts_packed().shape[0]

x = torch.rand(1, N, 4).to(device)
# print(x.shape)

y = model(x)
print(y)



In [None]:


model = SaxiMHAIcoDecoder(input_dim=4, output_dim=4, sample_levels=sample_levels)

model.to(device)

N = ico_sphere(0).verts_packed().shape[0]

x = torch.rand(1, N, 4).to(device)

y = model(x)
print(y.shape)