In [None]:
import os
import torch
from torch import nn
from pytorch3d.io import load_obj, save_obj
from pytorch3d.structures import Meshes
from pytorch3d.utils import ico_sphere
from pytorch3d.ops import sample_points_from_meshes, knn_points, knn_gather
from pytorch3d.loss import (
    chamfer_distance, 
    mesh_edge_loss, 
    mesh_laplacian_smoothing, 
    mesh_normal_consistency,
)
import numpy as np
from tqdm.notebook import tqdm
import plotly.graph_objects as go
import plotly.express as px

import sys
sys.path.append('/mnt/raid/C1_ML_Analysis/source/ShapeAXI')
from shapeaxi import utils
from shapeaxi.saxi_layers import TimeDistributed, MHA, Residual, FeedForward, UnpoolMHA, SmoothAttention, SmoothMHA
from shapeaxi.saxi_nets import SaxiMHAEncoder
import pandas as pd

In [None]:
device = torch.device("cuda:0")


In [None]:
target_fn = '/mnt/famli_netapp_shared/C1_ML_Analysis/src/diffusion-models/blender/studies/placenta/FAM-025-0499-5/brain/leftWhiteMatter.stl'
target = utils.ReadSurf(target_fn)
target, target_mean_bb, target_scale_factor = utils.ScaleSurf(target)
target_v, target_f, target_e = utils.PolyDataToTensors(target, device=device)
target_mesh = Meshes(verts=[target_v], faces=[target_f])

In [None]:
def plot_pointcloud(mesh, title=""):
    points = sample_points_from_meshes(mesh, 5000)
    x, y, z = points.clone().detach().cpu().squeeze().unbind(1)    
    fig = go.Figure(data=[go.Scatter3d(x=x, y=y, z=z)])
    fig.show()

In [None]:
sample_levels = []
sample_levels_faces = []
for l in range(6, 1, -1):
    ico_s = utils.IcoSphere(l)
    source_v, source_f = utils.PolyDataToTensors_v_f(ico_s)
    sample_levels.append(len(source_v))
    sample_levels_faces.append(len(source_f))

In [None]:
model = SaxiMHAEncoder()
model = model.to(device)
target_mesh = Meshes(verts=[target_v.cuda()], faces=[target_f.cuda()])
target_mesh_encoded = model(target_mesh.cuda())

In [None]:


class SaxiMHADecoder(nn.Module):
    def __init__(self, input_dim=256, L=4, embed_dim=128, output_dim=3, num_heads=4,  K=4, hidden_dim=64, dropout=0.1, return_sorted=True):
        super(SaxiMHADecoder, self).__init__()

        self.input_dim = input_dim
        self.L = L
        self.K = K
        self.embed_dim = embed_dim
        self.num_heads = num_heads        
        self.hidden_dim = hidden_dim
        self.dropout = dropout
        self.return_sorted = return_sorted

        self.embedding = nn.Linear(input_dim, embed_dim)

        self.unpool = UnpoolMHA()

        for i in range(self.L):
            mha = MHA(embed_dim, num_heads, dropout)
            mha_td = TimeDistributed(mha)
            setattr(self, f'mha_td{i}', mha_td)
            setattr(self, f'residual_mha_td{i}', Residual(mha_td, dimension=embed_dim))

            ff = FeedForward(embed_dim, hidden_dim, dropout)
            setattr(self, f'ff{i}', ff)
            setattr(self, f'residual_ff{i}', Residual(ff, dimension=embed_dim))

            
            setattr(self, f'smooth{i}', SmoothMHA(self.embed_dim, self.hidden_dim, K=self.K))
        
        self.output = nn.Linear(embed_dim, output_dim)
    
    def sample_points(self, x, Ns):
        """
        Samples Ns points from each batch in a tensor of shape (Bs, N, F).

        Args:
            x (torch.Tensor): Input tensor of shape (Bs, N, F).
            Ns (int): Number of points to sample from each batch.

        Returns:
            torch.Tensor: Output tensor of shape (Bs, Ns, F).
        """
        Bs, N, F = x.shape

        # Generate random indices for sampling
        indices = torch.randint(low=0, high=N, size=(Bs, Ns), device=x.device).unsqueeze(-1)

        # Gather the sampled points
        x = knn_gather(x, indices).squeeze(-2).contiguous()

        return x, indices
        
    def forward(self, x):
        
        x = self.embedding(x)

        for i in range(self.L):

            # find closest points to self, i.e., each point in the sample finds the closest K points in the sample
            # dists = knn_points(x_sampled, x_sampled, K=6)
            dists = knn_points(x, x, K=self.K, return_sorted=self.return_sorted)
            x = knn_gather(x, dists.idx)

            x = getattr(self, f'residual_mha_td{i}')(x)
            x = getattr(self, f'residual_ff{i}')(x)
            ## reduce x in the time dim
            x = self.unpool(x)
            x = getattr(self, f'smooth{i}')(x)
            
        x = self.output(x)
        return x

decoder = SaxiMHADecoder()
decoder = decoder.cuda()
X_hat = decoder(target_mesh_encoded)

In [None]:
target_mesh.verts_padded().shape