In [None]:
import os
import sys

from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
import SimpleITK as sitk
import nrrd
import vtk

import torch
from torch import nn
from torch.nn import functional as F
from torchvision import transforms

import pytorch_lightning as pl
import pickle
import monai 
import glob 
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

sys.path.append('/mnt/famli_netapp_shared/C1_ML_Analysis/src/famli-ultra-sim/')
sys.path.append('/mnt/famli_netapp_shared/C1_ML_Analysis/src/famli-ultra-sim/dl')
import dl.transforms.ultrasound_transforms as ultrasound_transforms
import dl.loaders.mr_us_dataset as mr_us_dataset
import dl.nets.us_simulation_jit as us_simulation_jit
import dl.nets.us_simu as us_simu

import importlib

from dl.nets.layers import TimeDistributed


In [None]:
mount_point = '/mnt/raid/C1_ML_Analysis'

importlib.reload(us_simu)
vs = us_simu.VolumeSamplingBlindSweep(mount_point=mount_point, simulation_fov_grid_size=[64, 128, 128])
vs.cuda()

In [None]:

# diffusor = sitk.ReadImage('/mnt/famli_netapp_shared/C1_ML_Analysis/src/blender/simulated_data_export/studies_merged/FAM-025-0447-5.nrrd')
# diffusor_np = sitk.GetArrayFromImage(diffusor)
# diffusor_t = torch.tensor(diffusor_np.astype(int))

# diffusor_spacing = torch.tensor(diffusor.GetSpacing()).flip(dims=[0])
# diffusor_size = torch.tensor(diffusor.GetSize()).flip(dims=[0])

# diffusor_origin = torch.tensor(diffusor.GetOrigin()).flip(dims=[0])
# diffusor_end = diffusor_origin + diffusor_spacing * diffusor_size
# print(diffusor_size)
# print(diffusor_spacing)
# print(diffusor_t.shape)
# print(diffusor_origin)
# print(diffusor_end)

diffusor_np, diffusor_head = nrrd.read('/mnt/raid//C1_ML_Analysis/simulated_data_export/placenta/FAM-025-0664-4_label11_resampled.nrrd')
diffusor_t = torch.tensor(diffusor_np.astype(int)).permute(2, 1, 0)
print(diffusor_head)
diffusor_size = torch.tensor(diffusor_head['sizes'])
diffusor_spacing = torch.tensor(np.diag(diffusor_head['space directions']))

diffusor_origin = torch.tensor(diffusor_head['space origin']).flip(dims=[0])
diffusor_end = diffusor_origin + diffusor_spacing * diffusor_size
print(diffusor_spacing)
print(diffusor_t.shape)
print(diffusor_origin)
print(diffusor_end)


In [None]:
# fig = px.imshow(diffusor_t.flip(dims=[1]).squeeze().cpu().numpy(), animation_frame=0, binary_string=True)
# fig.show()

In [None]:
# diffusor_batch_t = diffusor_t.permute([2, 1, 0]).cuda().float().unsqueeze(0).unsqueeze(0).repeat(3, 1, 1, 1, 1)
# print(diffusor_batch_t.shape)

# diffusor_origin_batch = diffusor_origin[None, :].repeat(3, 1) + torch.randn(3, 3) * 0.01
# diffusor_end_batch = diffusor_end[None, :].repeat(3, 1) + + torch.randn(3, 3) * 0.01
# print(diffusor_origin, diffusor_origin_batch)
# # print(diffusor_origin_batch.shape)

# diffusor_in_fov_t = vs.diffusor_in_fov(diffusor_batch_t, diffusor_origin_batch.cuda(), diffusor_end_batch.cuda())


In [None]:
# fig = px.imshow(diffusor_in_fov_t[0].squeeze().flip(dims=[1]).cpu().numpy(), animation_frame=0, binary_string=True)
# fig.show()
# fig = px.imshow(diffusor_in_fov_t[1].squeeze().flip(dims=[1]).cpu().numpy(), animation_frame=0, binary_string=True)
# fig.show()

In [None]:
importlib.reload(us_simulation_jit)
us_simulator_cut = us_simulation_jit.MergedLinearLabel11()
grid, inverse_grid, mask_fan = us_simulator_cut.init_grids(256, 256, 128.0, -30.0, 20.0, 215.0, 0.7853981633974483)
us_simulator_cut_td = TimeDistributed(us_simulator_cut, time_dim=2).eval().cuda()


In [None]:


# for tag in vs.tags:

batch_size = 1
diffusor_batch_t = diffusor_t.permute([2, 1, 0]).cuda().float().unsqueeze(0).unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
print(diffusor_batch_t.shape)

diffusor_origin_batch = diffusor_origin[None, :].repeat(batch_size, 1) + torch.randn(batch_size, 3) * 0.01
diffusor_end_batch = diffusor_end[None, :].repeat(batch_size, 1) + + torch.randn(batch_size, 3) * 0.01

out_fovs_list = []
for tag in ["M", "L0", "C1"]:
        
        # print(simulation_ultrasound_plane_mesh_grid_transformed_t_idx.shape)
        # print(diffusor_t.shape)

        use_random = False
        probe_origin_rand = None
        probe_direction_rand = None

        if use_random:

                probe_origin_rand = torch.rand(3)*0.001
                probe_origin_rand = probe_origin_rand.cuda()
                rotation_ranges = ((-15, 15), (-15, 15), (-30, 30))  # ranges in degrees for x, y, and z rotations
                probe_direction_rand = vs.random_affine_matrix(rotation_ranges).cuda()

        sampled_sweep = vs.diffusor_sampling_tag(tag, diffusor_batch_t, diffusor_origin_batch.cuda().to(torch.float), diffusor_end_batch.cuda().to(torch.float), probe_origin_rand=probe_origin_rand, probe_direction_rand=probe_direction_rand, use_random=use_random)
        with torch.no_grad():
                sampled_sweep_simu = torch.cat([us_simulator_cut_td(ss.unsqueeze(dim=0), grid.cuda(), inverse_grid.cuda(), mask_fan.cuda()) for ss in sampled_sweep], dim=0)

        # print(sampled_sweep_simu.shape)

        out_fovs = vs.simulated_sweep_in_fov(tag, sampled_sweep_simu)
        
        # print(out_fovs.shape)
        out_fovs_list.append(out_fovs)
        # print(simulation_ultrasound_plane_mesh_grid_transformed_t.shape)
out_fovs = torch.cat(out_fovs_list, dim=0)
print(out_fovs.shape)
# fig = px.imshow(sampled_sweep[0].squeeze().cpu().numpy(), animation_frame=0, binary_string=True)
# fig.show()
# fig = px.imshow(sampled_sweep_simu[0].squeeze().cpu().numpy(), animation_frame=0, binary_string=True)
# fig.show()

In [None]:
fig = px.imshow(out_fovs[0].flip(dims=[0]).squeeze().cpu().numpy(), animation_frame=1, binary_string=True)
# fig = px.imshow(out_fovs[2].flip(dims=[1]).squeeze().detach().cpu().numpy(), animation_frame=0, binary_string=True)
# fig = px.imshow(out_fovs[2].squeeze().detach().cpu().numpy(), animation_frame=0, binary_string=True)
fig.show()

In [None]:
# def fov_physical(self):
#     simulation_fov_mesh_grid_params = [torch.arange(end=s, device=self.simulation_fov_bounds.device) for s in self.simulation_fov_grid_size]
#     simulation_fov_mesh_grid_idx = torch.stack(torch.meshgrid(simulation_fov_mesh_grid_params), dim=-1).squeeze().to(torch.float32)

#     simulation_fov_origin = self.simulation_fov_bounds[[0,2,4]]
#     simulation_fov_end = self.simulation_fov_bounds[[1,3,5]]
#     simulation_fov_size = self.simulation_fov_grid_size

#     simulation_fov_spacing = (simulation_fov_end - simulation_fov_origin)/simulation_fov_size
#     return simulation_fov_origin + simulation_fov_mesh_grid_idx*simulation_fov_spacing


In [None]:
from torch.nn.utils.rnn import pad_sequence

fov_physical = vs.fov_physical()

# repeats = [1,]*len(out_fovs.shape)
# repeats[0] = out_fovs.shape[0]

# fov_physical = fov_physical.repeat(repeats)

V = fov_physical.reshape(-1, 3).cuda()

V_ = []
VF_ = []

for sweep_in_fov in out_fovs:
        
        sweep_in_fov = sweep_in_fov.reshape(-1, 1)
        
        V_filtered = V[sweep_in_fov.squeeze() > 0]
        F_filtered = sweep_in_fov[sweep_in_fov.squeeze() > 0]
        V_.append(V_filtered)
        VF_.append(F_filtered)

V = pad_sequence(V_, batch_first=True, padding_value=0.0) 
VF = pad_sequence(VF_, batch_first=True, padding_value=0.0)

print(V.shape, VF.shape)

In [None]:

SN = 0
N = 50000

random_indices = torch.randperm(V.size(1))[:N]

fig = go.Figure(data=[go.Scatter3d(x=V[SN,random_indices,2].cpu().numpy(), y=V[SN,random_indices,1].cpu().numpy(), z=V[SN,random_indices,0].cpu().numpy(), mode='markers', marker=dict(
        size=2,
        color=VF[SN,random_indices].cpu().numpy().squeeze(),                # set color to an array/list of desired values
        colorscale='jet',   # choose a colorscale
        opacity=0.5
    ))])
fig.show()


In [None]:

from shapeaxi.saxi_layers import AttentionPooling_V, MHA_KNN_V, Residual, FeedForward, KNN_Embeding_V


# mha_k = MHA_KNN_V(embed_dim=4, num_heads=4, return_weights=True, K=6, return_sorted=True, return_v=True, use_direction=True).cuda()
# knn_e = KNN_Embeding_V(input_dim=1, embed_dim=64, K=27, return_sorted=True).cuda()
# attn_p = AttentionPooling_V(embed_dim=4, pooling_factor=0.125, hidden_dim=64, K=27).cuda()


In [None]:
from shapeaxi.saxi_nets import MHAEncoder_V

mhda_encoder = MHAEncoder_V(input_dim=4).cuda()
x_, x_vs = mhda_encoder(torch.cat([V, VF], dim=-1), V)
x_v, x_s = x_vs[-1]


In [None]:
fig = go.Figure(data=[go.Scatter3d(x=x_v[SN,:,0].detach().cpu().numpy(), y=x_v[SN,:,1].detach().cpu().numpy(), z=x_v[SN,:,2].detach().cpu().numpy(), mode='markers', marker=dict(
        size=2,
        color=x_s[SN,:].detach().cpu().numpy().squeeze(),                # set color to an array/list of desired values
        colorscale='jet',   # choose a colorscale
        opacity=0.5
    ))])
fig.show()

In [None]:

diffusor_in_fov = vs.diffusor_in_fov(diffusor_t.cuda().unsqueeze(0).unsqueeze(0).float(), diffusor_origin=diffusor_origin.cuda().unsqueeze(0), diffusor_end=diffusor_end.cuda().unsqueeze(0))
diffusor_in_fov = diffusor_in_fov.reshape(-1, 1).long()

V_diff = fov_physical.reshape(-1, 3).cuda()

V_diff_filtered = V_diff[diffusor_in_fov.squeeze() == 7]
V_diff_filtered = V_diff_filtered.repeat(V.shape[0], 1, 1)

x = V_diff_filtered[SN, :,0].cpu().numpy()
y = V_diff_filtered[SN, :,1].cpu().numpy()
z = V_diff_filtered[SN, :,2].cpu().numpy()

SN = 2
scatter1 = go.Scatter3d(x=x, y=y, z=z, mode='markers', marker=dict(
        size=2,
        color=z,                # set color to an array/list of desired values
        colorscale='jet',   # choose a colorscale
        opacity=1.0
    ))
scatter2 = go.Scatter3d(x=V[SN,random_indices,0].cpu().numpy(), y=V[SN,random_indices,1].cpu().numpy(), z=V[SN,random_indices,2].cpu().numpy(), mode='markers', marker=dict(
        size=2,
        color=VF[SN,random_indices].cpu().numpy().squeeze(),                # set color to an array/list of desired values
        colorscale='jet',   # choose a colorscale
        opacity=0.5
    ))

fig = go.Figure()
fig.add_trace(scatter1)
fig.add_trace(scatter2)

fig.update_layout(
    title='Two Scatter Plots',
    xaxis_title='X Axis',
    yaxis_title='Y Axis'
)
fig.show()

In [None]:
from pytorch3d.ops import knn_points, knn_gather, ball_query

dists = knn_points(V, V_diff_filtered, K=1)
V_diff_filtered_ = knn_gather(V_diff_filtered, dists.idx).squeeze(2)

print(V.shape)
print(torch.min(V, dim=1).values.shape)
ball_query(V, V_diff_filtered, K=1, radius=0.1).idx.shape

In [None]:
fig = go.Figure(data=[go.Scatter3d(x=V_diff_filtered_[SN,random_indices,0].cpu().numpy(), y=V_diff_filtered_[SN,random_indices,1].cpu().numpy(), z=V_diff_filtered_[SN,random_indices,2].cpu().numpy(), mode='markers', marker=dict(
        size=2,
        color=V_diff_filtered_[SN,random_indices,2].cpu().numpy().squeeze(),                # set color to an array/list of desired values
        colorscale='jet',   # choose a colorscale
        opacity=0.5
    ))])
fig.show()

In [None]:
diffusor_in_fov = vs.diffusor_in_fov(diffusor_t.cuda().unsqueeze(0).unsqueeze(0).float(), diffusor_origin=diffusor_origin.cuda().unsqueeze(0), diffusor_end=diffusor_end.cuda().unsqueeze(0))
diffusor_in_fov = diffusor_in_fov.reshape(-1, 1).long()

V_diff = fov_physical.reshape(-1, 3).cuda()

V_diff_filtered = V_diff[diffusor_in_fov.squeeze() == 7]
V_diff_filtered = V_diff_filtered.repeat(V.shape[0], 1, 1)

x = V_diff_filtered[SN, :,0].cpu().numpy()
y = V_diff_filtered[SN, :,1].cpu().numpy()
z = V_diff_filtered[SN, :,2].cpu().numpy()

SN = 2

scatter1 = go.Scatter3d(x=V[SN,random_indices,0].cpu().numpy(), y=V[SN,random_indices,1].cpu().numpy(), z=V[SN,random_indices,2].cpu().numpy(), mode='markers', marker=dict(
        size=2,
        color=VF[SN,random_indices].cpu().numpy().squeeze(),                # set color to an array/list of desired values
        colorscale='gray',   # choose a colorscale
        opacity=0.5
    ))
scatter2 = go.Scatter3d(x=V_diff_filtered_[SN,random_indices,0].cpu().numpy(), y=V_diff_filtered_[SN,random_indices,1].cpu().numpy(), z=V_diff_filtered_[SN,random_indices,2].cpu().numpy(), mode='markers', marker=dict(
        size=2,
        color=V_diff_filtered_[SN,random_indices,2].cpu().numpy().squeeze(),                # set color to an array/list of desired values
        colorscale='jet',   # choose a colorscale
        opacity=1.0
    ))

fig = go.Figure()
fig.add_trace(scatter1)
fig.add_trace(scatter2)

fig.update_layout(
    title='Two Scatter Plots',
    xaxis_title='X Axis',
    yaxis_title='Y Axis'
)
fig.show()

In [None]:
def get_target(X, X_origin, X_end):
    # put the diffusor in the fov
    diffusor_in_fov = vs.diffusor_in_fov(X.float(), diffusor_origin=X_origin, diffusor_end=X_end)

    V_fov = vs.fov_physical().reshape(-1, 3)
    V_diff = []
    F_diff = []
    
    # Get only non-background points and their corresponding labels
    for d_fov in diffusor_in_fov:

        d_fov = d_fov.reshape(-1)

        V_diff.append(V_fov[d_fov > 0])
        F_diff.append(d_fov[d_fov > 0])
    
    # Pad them to create tensors
    V_diff = pad_sequence(V_diff, batch_first=True, padding_value=0.0)
    F_diff = pad_sequence(F_diff, batch_first=True, padding_value=0.0)
    F_diff = F_diff.unsqueeze(-1)
    
    # Get all the vertices and the labels for each point
    return V_diff, F_diff


V_diff, F_diff = get_target(diffusor_t.cuda().unsqueeze(0).unsqueeze(0).float(), X_origin=diffusor_origin.cuda().unsqueeze(0), X_end=diffusor_end.cuda().unsqueeze(0))



In [None]:
from plotly.subplots import make_subplots

N = 50000

random_indices = torch.randperm(V_diff.size(1))[:N]

scatter1 = go.Scatter3d(x=V_diff[0,random_indices,0].cpu().numpy(), y=V_diff[0,random_indices,1].cpu().numpy(), z=V_diff[0,random_indices,2].cpu().numpy(), mode='markers', marker=dict(
        size=2,
        color=F_diff[0,random_indices,0].cpu().numpy(),                # set color to an array/list of desired values
        colorscale='jet',   # choose a colorscale
        opacity=0.5
    ))


v_sk = V_diff[0,F_diff[0,:,0] == 7,:]
f_sk = F_diff[0,:,0][F_diff[0,:,0] == 7]

scatter2 = go.Scatter3d(x=v_sk[:,0].cpu().numpy(), y=v_sk[:,1].cpu().numpy(), z=v_sk[:,2].cpu().numpy(), mode='markers', marker=dict(
        size=2,
        color=f_sk.cpu().numpy(),                # set color to an array/list of desired values
        colorscale='jet',   # choose a colorscale
        opacity=0.5
    ))

fig = go.Figure()
fig = make_subplots(
            rows=1, cols=2,
            specs=[[{'type': 'scatter3d'}, {'type': 'scatter3d'}]]
        )
fig.add_trace(scatter1, row=1, col=1)
fig.add_trace(scatter2, row=1, col=2)

fig.update_layout(
    title='Two Scatter Plots',
    xaxis_title='X Axis',
    yaxis_title='Y Axis'
)
fig.show()

In [None]:
dists = knn_points(x_v[0:1], V_diff, K=1)
# x_s = x_s[:, :, 0, :] # Get the first score, i.e., the score value of the point

y_f = knn_gather(F_diff, dists.idx).squeeze(2)

In [None]:
fig = go.Figure(data=[go.Scatter3d(x=x_v[0,:,0].cpu().numpy(), y=x_v[0,:,1].cpu().numpy(), z=x_v[0,:,2].cpu().numpy(), mode='markers', marker=dict(
        size=2,
        color=(y_f[0,:].cpu().numpy().squeeze() == 7).astype(np.double),                # set color to an array/list of desired values
        colorscale='jet',   # choose a colorscale
        opacity=0.5
    ))])
fig.show()