In [None]:
import os
import torch
from torch import nn
from pytorch3d.io import load_obj, save_obj
from pytorch3d.structures import Meshes
from pytorch3d.utils import ico_sphere
from pytorch3d.ops import sample_points_from_meshes, knn_points, knn_gather
from pytorch3d.loss import (
    chamfer_distance, 
    mesh_edge_loss, 
    mesh_laplacian_smoothing, 
    mesh_normal_consistency,
)
import numpy as np
from tqdm.notebook import tqdm
import plotly.graph_objects as go
import plotly.express as px

import sys
sys.path.append('/mnt/raid/C1_ML_Analysis/source/ShapeAXI')
from shapeaxi import utils
from shapeaxi.saxi_layers import TimeDistributed, MHA, Residual, FeedForward, UnpoolMHA, SmoothAttention, SmoothMHA
from shapeaxi.saxi_nets import SaxiMHAEncoder, SaxiMHADecoder
import pandas as pd

In [None]:
device = torch.device("cuda:0")


In [None]:
target_fn = '/mnt/famli_netapp_shared/C1_ML_Analysis/src/diffusion-models/blender/studies/placenta/FAM-025-0499-5/brain/leftWhiteMatter.stl'
target = utils.ReadSurf(target_fn)
target, target_mean_bb, target_scale_factor = utils.ScaleSurf(target)
target_v, target_f, target_e = utils.PolyDataToTensors(target, device=device)
target_mesh = Meshes(verts=[target_v], faces=[target_f])

In [None]:
def plot_pointcloud(mesh, title=""):
    points = sample_points_from_meshes(mesh, 5000)
    x, y, z = points.clone().detach().cpu().squeeze().unbind(1)    
    fig = go.Figure(data=[go.Scatter3d(x=x, y=y, z=z)])
    fig.show()

In [None]:
sample_levels = []
sample_levels_faces = []
for l in range(5, 2, -1):
    ico_s = utils.IcoSphere(l)
    source_v, source_f = utils.PolyDataToTensors_v_f(ico_s)
    sample_levels.append(len(source_v))
    sample_levels_faces.append(len(source_f))
print(sample_levels)
print(sample_levels[::-1])
print(sample_levels)

In [None]:
model = SaxiMHAEncoder()
model = model.to(device)
target_mesh = Meshes(verts=[target_v.cuda()], faces=[target_f.cuda()])

X, X_N = sample_points_from_meshes(target_mesh, sample_levels[0], return_normals=True)
X = torch.cat([X, X_N], dim=1)
target_mesh_encoded, _ = model(X.cuda())

In [None]:

decoder = SaxiMHADecoder(input_dim=256)
decoder = decoder.cuda()
X_hat = decoder(target_mesh_encoded)

In [None]:
target_mesh.verts_padded().shape
target_mesh.verts_list()[0].shape

In [None]:

def saxi_point_triangle_distance(X, X_hat, K_triangle=1, ignore_first=False, randomize=False):
    """
    Compute the distance between a point and the nearest triangle.
    It uses the knn_points and knn_gather functions from PyTorch3D to find the nearest triangle.
    Args:
        X: (B, N0, 3) tensor of points
        X_hat: (B, N1, 3) tensor of points"""
    
    k_ignore = 0
    if ignore_first:
        k_ignore = 1

    dists = knn_points(X_hat, X, K=(3*K_triangle + k_ignore))
    start_idx = (3*(K_triangle-1)) + k_ignore

    if randomize:
        idx = dists.idx[:, :, torch.randperm(dists.idx.shape[2])]
    else:
        idx = dists.idx
    
    x = knn_gather(X, idx[:, :, start_idx:start_idx + 3])
    # Compute the normal of the triangle
    
    N = torch.cross(x[:, :, 1] - x[:, :, 0], x[:, :, 2] - x[:, :, 0], dim=-1)
    N = N / torch.norm(N, dim=1, keepdim=True)
    # Compute the vector from the point to the first vertex of the triangle
    X_v = (X_hat - x[:, :, 0]) 
    
    return torch.sum(torch.abs(torch.einsum('ijk,ijk->ij', X_v, N)))

In [None]:
saxi_point_triangle_distance(target_v.unsqueeze(0), X_hat, K_triangle=3, randomize=True)

In [None]:
saxi_point_triangle_distance(target_v.unsqueeze(0), X_hat, ignore_first=True, K_triangle=3, randomize=True)

In [None]:
size = np.array([2, 3, 4, 4])

test_v = torch.range(start=0, end=np.prod(size) - 1).reshape(size.tolist())
test_v

In [None]:
test_v.view(test_v.shape[0], -1, test_v.shape[-1])

In [None]:
from shapeaxi.saxi_layers import SelfAttention

# Example shapes
BS = 2  # Batch size
V_n = 1000  # Some dimension
K = 4  # Number of neighbors
embed_dim = 128  # Embedding dimension

class AttentionPooling(nn.Module):
    def __init__(self, pooling_factor=0.5, embed_dim=128, hidden_dim=64, K=4):
        super(AttentionPooling, self).__init__()
        self.pooling_factor = pooling_factor
        self.embed_dim = embed_dim
        self.attn = SelfAttention(embed_dim, hidden_dim, dim=2)
        self.K = K
    
    def forward(self, x):

        # find closest points to self, i.e., each point in the sample finds the closest K points in the sample
        dists = knn_points(x, x, K=self.K)
        # gather the K closest points
        
        x = knn_gather(x, dists.idx)
        # apply self attention, i.e., weighted average of the K closest points
        x, x_s = self.attn(x, x)
        x_s = x_s[:,:,0,:]

        n_samples = int(x.shape[1]*self.pooling_factor)
        idx = torch.argsort(x_s, descending=True, dim=1)[:,:n_samples]
        
        x = knn_gather(x, idx).squeeze(2)
        x_s = knn_gather(x_s, idx).squeeze(2)
        
        return x, x_s

x = torch.rand(BS, V_n, embed_dim)
AttentionPooling(pooling_factor=0.25)(x)[0].shape