In [None]:
import os
import sys

from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
import SimpleITK as sitk
import nrrd
import vtk

import torch
from torch import nn
from torch.nn import functional as F
from torchvision import transforms

import pytorch_lightning as pl
import pickle
import monai 
import glob 
import plotly.graph_objects as go
import plotly.express as px
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

sys.path.append('/mnt/famli_netapp_shared/C1_ML_Analysis/src/famli-ultra-sim/')
sys.path.append('/mnt/famli_netapp_shared/C1_ML_Analysis/src/famli-ultra-sim/dl')
import dl.transforms.ultrasound_transforms as ultrasound_transforms
import dl.loaders.mr_us_dataset as mr_us_dataset
import dl.nets.us_simulation_jit as us_simulation_jit
import dl.nets.us_simu as us_simu

import importlib

from dl.nets.layers import TimeDistributed


import ocnn
from ocnn.octree import Octree, Points

In [None]:
mount_point = '/mnt/raid/C1_ML_Analysis'

importlib.reload(us_simu)
vs = us_simu.VolumeSamplingBlindSweep(mount_point=mount_point, simulation_fov_grid_size=[128, 256, 256])
vs.cuda()

In [None]:

# diffusor = sitk.ReadImage('/mnt/famli_netapp_shared/C1_ML_Analysis/src/blender/simulated_data_export/studies_merged/FAM-025-0447-5.nrrd')
# diffusor_np = sitk.GetArrayFromImage(diffusor)
# diffusor_t = torch.tensor(diffusor_np.astype(int))

# diffusor_spacing = torch.tensor(diffusor.GetSpacing()).flip(dims=[0])
# diffusor_size = torch.tensor(diffusor.GetSize()).flip(dims=[0])

# diffusor_origin = torch.tensor(diffusor.GetOrigin()).flip(dims=[0])
# diffusor_end = diffusor_origin + diffusor_spacing * diffusor_size
# print(diffusor_size)
# print(diffusor_spacing)
# print(diffusor_t.shape)
# print(diffusor_origin)
# print(diffusor_end)

diffusor_np, diffusor_head = nrrd.read('/mnt/raid//C1_ML_Analysis/simulated_data_export/placenta/FAM-025-0664-4_label11_resampled.nrrd')
diffusor_t = torch.tensor(diffusor_np.astype(int)).permute(2, 1, 0)
print(diffusor_head)
diffusor_size = torch.tensor(diffusor_head['sizes'])
diffusor_spacing = torch.tensor(np.diag(diffusor_head['space directions']))

diffusor_origin = torch.tensor(diffusor_head['space origin']).flip(dims=[0])
diffusor_end = diffusor_origin + diffusor_spacing * diffusor_size
print(diffusor_spacing)
print(diffusor_t.shape)
print(diffusor_origin)
print(diffusor_end)


In [None]:
# fig = px.imshow(diffusor_t.flip(dims=[1]).squeeze().cpu().numpy(), animation_frame=0, binary_string=True)
# fig.show()

In [None]:
# diffusor_batch_t = diffusor_t.permute([2, 1, 0]).cuda().float().unsqueeze(0).unsqueeze(0).repeat(3, 1, 1, 1, 1)
# print(diffusor_batch_t.shape)

# diffusor_origin_batch = diffusor_origin[None, :].repeat(3, 1) + torch.randn(3, 3) * 0.01
# diffusor_end_batch = diffusor_end[None, :].repeat(3, 1) + + torch.randn(3, 3) * 0.01
# print(diffusor_origin, diffusor_origin_batch)
# # print(diffusor_origin_batch.shape)

# diffusor_in_fov_t = vs.diffusor_in_fov(diffusor_batch_t, diffusor_origin_batch.cuda(), diffusor_end_batch.cuda())


In [None]:
# fig = px.imshow(diffusor_in_fov_t[0].squeeze().flip(dims=[1]).cpu().numpy(), animation_frame=0, binary_string=True)
# fig.show()
# fig = px.imshow(diffusor_in_fov_t[1].squeeze().flip(dims=[1]).cpu().numpy(), animation_frame=0, binary_string=True)
# fig.show()

In [None]:
importlib.reload(us_simulation_jit)
us_simulator_cut = us_simulation_jit.MergedCutLabel11()
grid, inverse_grid, mask_fan = us_simulator_cut.init_grids(256, 256, 128.0, -30.0, 20.0, 215.0, 0.7853981633974483)
us_simulator_cut_td = TimeDistributed(us_simulator_cut, time_dim=2).eval().cuda()


In [None]:


# for tag in vs.tags:

batch_size = 1
diffusor_batch_t = diffusor_t.permute([2, 1, 0]).cuda().float().unsqueeze(0).unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
print(diffusor_batch_t.shape)

diffusor_origin_batch = diffusor_origin[None, :].repeat(batch_size, 1) + torch.randn(batch_size, 3) * 0.01
diffusor_end_batch = diffusor_end[None, :].repeat(batch_size, 1) + + torch.randn(batch_size, 3) * 0.01

out_fovs_list = []
for tag in ["M", "L0", "C1"]:
        
        # print(simulation_ultrasound_plane_mesh_grid_transformed_t_idx.shape)
        # print(diffusor_t.shape)

        use_random = False
        probe_origin_rand = None
        probe_direction_rand = None

        if use_random:

                probe_origin_rand = torch.rand(3)*0.001
                probe_origin_rand = probe_origin_rand.cuda()
                rotation_ranges = ((-15, 15), (-15, 15), (-30, 30))  # ranges in degrees for x, y, and z rotations
                probe_direction_rand = vs.random_affine_matrix(rotation_ranges).cuda()

        sampled_sweep = vs.diffusor_sampling_tag(tag, diffusor_batch_t, diffusor_origin_batch.cuda().to(torch.float), diffusor_end_batch.cuda().to(torch.float), probe_origin_rand=probe_origin_rand, probe_direction_rand=probe_direction_rand, use_random=use_random)
        with torch.no_grad():
                sampled_sweep_simu = torch.cat([us_simulator_cut_td(ss.unsqueeze(dim=0), grid.cuda(), inverse_grid.cuda(), mask_fan.cuda()) for ss in sampled_sweep], dim=0)

        # print(sampled_sweep_simu.shape)

        out_fovs = vs.simulated_sweep_in_fov(tag, sampled_sweep_simu)
        
        # print(out_fovs.shape)
        out_fovs_list.append(out_fovs)
        # print(simulation_ultrasound_plane_mesh_grid_transformed_t.shape)
out_fovs = torch.cat(out_fovs_list, dim=0)
print(out_fovs.shape)
# fig = px.imshow(sampled_sweep[0].squeeze().cpu().numpy(), animation_frame=0, binary_string=True)
# fig.show()
# fig = px.imshow(sampled_sweep_simu[0].squeeze().cpu().numpy(), animation_frame=0, binary_string=True)
# fig.show()

In [None]:
fig = px.imshow(out_fovs[0].flip(dims=[0]).squeeze().cpu().numpy(), animation_frame=1, binary_string=True)
# fig = px.imshow(out_fovs[2].flip(dims=[1]).squeeze().detach().cpu().numpy(), animation_frame=0, binary_string=True)
# fig = px.imshow(out_fovs[2].squeeze().detach().cpu().numpy(), animation_frame=0, binary_string=True)
fig.show()

In [None]:
fov_physical = vs.transform_fov_norm(vs.fov_physical())

# repeats = [1,]*len(out_fovs.shape)
# repeats[0] = out_fovs.shape[0]

# fov_physical = fov_physical.repeat(repeats)

V = fov_physical.reshape(-1, 3).cuda()

octrees = []
points = []

for sweep_in_fov in out_fovs:

        sweep_in_fov = sweep_in_fov.reshape(-1, 1)

        octree = Octree(7)
        
        V_filtered = V[sweep_in_fov.squeeze() > 0.1]
        sweep_in_fov_filtered = sweep_in_fov[sweep_in_fov.squeeze() > 0.1]
        p = Points(V_filtered, features=sweep_in_fov_filtered)
        # p = Points(V, features=sweep_in_fov)

        octree.build_octree(p)
        
        octrees.append(octree)
        points.append(p)


points = ocnn.octree.merge_points(points)
octree = ocnn.octree.merge_octrees(octrees)
# NOTE: remember to construct the neighbor indices
octree.construct_all_neigh()

# octree_1 = Octree(16)
# points_1 = Points(V, features=)
# octree_1.build_octree(points_1)
# out_fovs.shape

In [None]:

x, y, z, b = octree.xyzb(7)
aw = torch.argwhere(b == 2).squeeze()

x = x[aw]
y = y[aw]
z = z[aw]
feat = octree.get_input_feature('F')[aw]

N = 100000
random_indices = torch.randperm(x.size(0))[:N]

fig = go.Figure(data=[go.Scatter3d(x=z[random_indices].cpu().numpy(), y=y[random_indices].cpu().numpy(), z=x[random_indices].cpu().numpy(), mode='markers', marker=dict(
        size=2,
        color=feat[random_indices].cpu().numpy().squeeze(),                # set color to an array/list of desired values
        colorscale='jet',   # choose a colorscale
        opacity=0.5
    ))])
fig.show()
print(torch.min(octree.get_input_feature('F')[random_indices]), torch.max(octree.get_input_feature('F')[random_indices]))
print(octree.get_input_feature('F')[random_indices].shape)

In [None]:
# diffusor_plane_t = diffusor_plane_t.squeeze().unsqueeze(1)

# diffusor_plane_t = vs.diffusor_sampling_tag('C1', diffusor_t.unsqueeze(0).unsqueeze(0).cuda().to(torch.float), diffusor_origin.cuda().to(torch.float), diffusor_end.cuda().to(torch.float))

# # print(diffusor_plane_t.shape)
# repeats = [1,]*4
# repeats[0] = diffusor_plane_t.shape[0]

# with torch.no_grad():
#     x = us_simulator_cut_td(diffusor_plane_t, grid.repeat(repeats).cuda(), inverse_grid.repeat(repeats).cuda(), mask_fan.repeat(repeats).cuda())

# fig = px.imshow(x.squeeze().cpu().numpy(), animation_frame=0, binary_string=True)
# fig.show()

# torch.cuda.empty_cache()
# x, y, z, b = octree.xyzb(7)
# print(x.shape)
# print(octree.get_input_feature('F').shape)


In [None]:
unet = ocnn.models.UNet(in_channels=4, out_channels=3).cuda()

In [None]:
data = octree.get_input_feature('PF').to(torch.float)
query_pts = torch.cat([points.points, points.batch_id], dim=1)

logit = unet(data, octree, octree.depth, query_pts)

In [None]:
logit.shape

In [None]:
from ocnn.utils import meshgrid, scatter_add, resize_with_last_val, list2str
from typing import List

class OctreePoolBase(torch.nn.Module):
  r''' The base class for octree-based pooling.
  '''

  def __init__(self, kernel_size: List[int], stride: int, nempty: bool = False):
    super().__init__()
    self.kernel_size = resize_with_last_val(kernel_size)
    self.kernel = list2str(self.kernel_size)
    self.stride = stride
    self.nempty = nempty

  def extra_repr(self) -> str:
    return ('kernel_size={}, stride={}, nempty={}').format(
            self.kernel_size, self.stride, self.nempty)  # noqa


def octree_attn_pool(data: torch.Tensor, octree: Octree, depth: int,
                    kernel: str, stride: int = 2, nempty: bool = False):
  r''' Performs octree attention pooling.

  Args:
    data (torch.Tensor): The input tensor.
    octree (Octree): The corresponding octree.
    depth (int): The depth of current octree.
    kernel (str): The kernel size, like '333', '222'.
    stride (int): The stride of the pooling.
    nempty (bool): If True, :attr:`data` contains only features of non-empty
        octree nodes.
  '''

  neigh = octree.get_neigh(depth, kernel, stride, nempty)

  N1 = data.shape[0]
  N2 = neigh.shape[0]
  K = neigh.shape[1]

  mask = neigh >= 0
  val = 1.0 / (torch.sum(mask, dim=1) + 1e-8)
  
  mask = mask.view(-1)
  val = val.unsqueeze(1).repeat(1, K).reshape(-1)
  val = val[mask]
  

  row = torch.arange(N2, device=neigh.device)
  row = row.unsqueeze(1).repeat(1, K).view(-1)
  col = neigh.view(-1)
  indices = torch.stack([row[mask], col[mask]], dim=0).long()

  mat = torch.sparse_coo_tensor(indices, val, [N2, N1], device=data.device)
  out = torch.sparse.mm(mat, data)
  return out

class OctreeAttnPool(OctreePoolBase):
  r''' Performs octree average pooling.

  Please refer to :func:`octree_avg_pool` for details.
  '''

  def forward(self, data: torch.Tensor, octree: Octree, depth: int):
    r''''''

    return octree_attn_pool(
        data, octree, depth, self.kernel, self.stride, self.nempty)

attn_pool = OctreeAttnPool(kernel_size=[3,3,3], stride=2).cuda()

print(attn_pool(data, octree, octree.depth).shape)
print(data.shape)

In [None]:
sys.path.append('/mnt/famli_netapp_shared/C1_ML_Analysis/src/ShapeAXI')
from shapeaxi.saxi_layers import Residual, FeedForward

class OctreeMHA(nn.Module):
    def __init__(self, embed_dim, num_heads, dropout=0.1, return_weights=False, use_direction=True, kernel_size=[3,3,3], nempty=False):
        super(OctreeMHA, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.return_weights = return_weights
        self.kernel = list2str(resize_with_last_val(kernel_size))
        self.nempty = nempty
        
        self.attention = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, bias=False, batch_first=True)
        self.use_direction = use_direction
    
    def forward(self, x, octree, depth):

        # batch_size, V_n, Embed_dim = x.shape

        neigh = octree.get_neigh(depth, kernel=self.kernel, stride=1, nempty=self.nempty)

        # the query is the input point itself, the shape of q is [BS, V_n, 1, Embed_dim]
        q = x.unsqueeze(-2)
        
        k = x[neigh]
        
        #the value tensor contains the directions towards the closest points. 
        # the intuition here is that based on the query and key embeddings, the model will learn to predict
        # the best direction to move the new embedding, i.e., create a new point in the point cloud
        # the shape of v is [BS, V_n, K, Embed_dim]
        if self.use_direction:
            v = k - q
        else:
            v = k

        v, x_w = self.attention(q, k, v)
        v = v.squeeze(-2)
        # v = v.contiguous().view(batch_size, V_n, Embed_dim)
        # x_w = x_w.contiguous().view(batch_size, V_n, self.K)

        # x_w = torch.zeros(batch_size, V_n, device=x.device).scatter_add_(1, dists.idx.view(batch_size, -1), x_w.view(batch_size, -1))
        
        # The new predicted point is the sum of the input point and the weighted sum of the directions
        if self.use_direction:
            x = x + v
        else:
            x = v
        
        if self.return_weights:
            return x, x_w
        return x

nempty=True
# print(OctreeMHA(4, 4, use_direction=False, nempty=nempty).cuda()(octree.get_input_feature('PF', nempty=nempty).to(torch.float), octree, octree.depth).shape)

In [None]:
class OctreeMHAEncoder(nn.Module):
    def __init__(self, input_dim=3, embed_dim=8, hidden_dim=64, kernel_size=[3,3,3], num_heads=8, output_dim=256, stages=[8, 16, 32, 64, 128], dropout=0.1):
        super(OctreeMHAEncoder, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.dropout = dropout

        self.embedding = nn.Linear(input_dim, embed_dim)
        embed_dim = self.embed_dim
        for i, st in enumerate(stages):
            setattr(self, f"octree_mha{i}", OctreeMHA(embed_dim=embed_dim, num_heads=num_heads, return_weights=True, dropout=dropout, use_direction=True, kernel_size=[3,3,3], nempty=False))
            setattr(self, f"ff_{i}", Residual(FeedForward(embed_dim, hidden_dim=hidden_dim, dropout=dropout)))
            setattr(self, f"pool_{i}", OctreeAttnPool(kernel_size=[3,3,3], stride=2))

        
        self.output = nn.Linear(embed_dim, output_dim)
        
    def forward(self, x):
        
        x = self.embedding(x)

        weights = torch.zeros(x.shape[0], x.shape[1], device=x.device)
        idx = torch.arange(x.shape[1], device=x.device).unsqueeze(0).expand(x.shape[0], -1)
        
        indices = []
        
        for i, sl in enumerate(self.sample_levels):
            
            if i > 0:
                # select the first sl points a.k.a. downsample/pooling                
                x, x_i = self.sample_points(x, sl)

                # initialize idx with the index of the current level
                idx = x_i
                
                for idx_prev in reversed(indices): # go through the list of the previous ones in reverse
                    idx = knn_gather(idx_prev, idx).squeeze(-2).contiguous() # using the indices of the previous level update idx, at the end idx should have the indices of the first level
                
                idx = idx.squeeze(-1)
                indices.append(x_i)
            
            # the mha will select optimal points from the input
            x, x_w = getattr(self, f"mha_{i}")(x)
            x = getattr(self, f"ff_{i}")(x)
            
            weights.scatter_add_(1, idx, x_w)

        #output layer
        x = self.output(x)
        return x, weights
    
# OctreeMHAEncoder(input_dim=4, embed_dim=128, hidden_dim=64, kernel_size=[3,3,3], num_heads=8, output_dim=256, stages=[8, 16, 32, 64, 128], dropout=0.1).cuda()(octree.get_input_feature('PF').to(torch.float)).shape