In [1]:
import os
import torch
from torch import nn
from pytorch3d.io import load_obj, save_obj
from pytorch3d.structures import Meshes
from pytorch3d.utils import ico_sphere
from pytorch3d.ops import sample_points_from_meshes, knn_points, knn_gather
from pytorch3d.loss import (
    chamfer_distance, 
    mesh_edge_loss, 
    mesh_laplacian_smoothing, 
    mesh_normal_consistency,
)
import numpy as np
from tqdm.notebook import tqdm
import plotly.graph_objects as go
import plotly.express as px

import sys
sys.path.append('/mnt/raid/C1_ML_Analysis/source/ShapeAXI')
from shapeaxi import utils
from shapeaxi.saxi_layers import TimeDistributed, MHA, Residual, FeedForward, MHA_KNN
from shapeaxi.saxi_nets import SaxiMHAEncoder, SaxiMHADecoder, SaxiMHAClassification
import pandas as pd

In [2]:
device = torch.device("cuda:0")


In [3]:
target_fn = '/mnt/famli_netapp_shared/C1_ML_Analysis/src/diffusion-models/blender/studies/placenta/FAM-025-0499-5/brain/leftWhiteMatter.stl'
target = utils.ReadSurf(target_fn)
target, target_mean_bb, target_scale_factor = utils.ScaleSurf(target)
target_v, target_f, target_e = utils.PolyDataToTensors(target, device=device)
target_mesh = Meshes(verts=[target_v], faces=[target_f])

[0m[2m2024-06-17 09:22:37.932 (   0.968s) [         AE48280]    vtkExtractEdges.cxx:435   INFO| [0mExecuting edge extractor: points are renumbered[0m
[0m[2m2024-06-17 09:22:38.087 (   1.123s) [         AE48280]    vtkExtractEdges.cxx:551   INFO| [0mCreated 491520 edges[0m


In [4]:
def plot_pointcloud(mesh, title=""):
    points = sample_points_from_meshes(mesh, 5000)
    x, y, z = points.clone().detach().cpu().squeeze().unbind(1)    
    fig = go.Figure(data=[go.Scatter3d(x=x, y=y, z=z)])
    fig.show()

In [5]:
sample_levels = []
sample_levels_faces = []
for l in range(5, 2, -1):
    ico_s = utils.IcoSphere(l)
    source_v, source_f = utils.PolyDataToTensors_v_f(ico_s)
    sample_levels.append(len(source_v))
    sample_levels_faces.append(len(source_f))
print(sample_levels)
print(sample_levels[::-1])
print(sample_levels)

[10242, 2562, 642]
[642, 2562, 10242]
[10242, 2562, 642]


In [6]:
model = SaxiMHAEncoder()
model = model.to(device)
target_mesh = Meshes(verts=[target_v.cuda()], faces=[target_f.cuda()])

X, X_N = sample_points_from_meshes(target_mesh, sample_levels[0], return_normals=True)
X = torch.cat([X, X_N], dim=1)
target_mesh_encoded = model(X.cuda())

In [7]:
multihead_attn_td = nn.MultiheadAttention(128, 64, bias=False, batch_first=True, dropout=0.1)

query = torch.rand(2, 1, 128)
key = torch.rand(2, 64, 128)

output, attn_output_weights = multihead_attn_td(query, key, key)
print(output.shape)

torch.Size([2, 1, 128])


In [8]:
attn_output_weights.shape

torch.Size([2, 1, 64])

In [9]:
MHA_KNN(128, 64)(torch.rand(2, 12, 128)).shape

torch.Size([2, 12, 128])

In [10]:
# randomly select the K closest points to the query
batch_size, V_n, Embed_dim = 2, 4, 5




In [11]:
mha_c = SaxiMHAClassification(input_dim=3, embed_dim=256, num_heads=256, output_dim=32, K=32, sample_levels=[40962, 10242, 2562, 642, 162], dropout=0.1, num_classes=4)
mha_c.to(device)


In [None]:
mha_c()