In [1]:
import os
import sys

from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
import SimpleITK as sitk
import nrrd
import vtk

import torch
from torch import nn
from torch.nn import functional as F
from torchvision import transforms

import pytorch_lightning as pl
import pickle
import monai 
import glob 
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

sys.path.append('/mnt/raid/C1_ML_Analysis/source/famli-ultra-sim/')
sys.path.append('/mnt/raid/C1_ML_Analysis/source/famli-ultra-sim/dl')
import dl.transforms.ultrasound_transforms as ultrasound_transforms
import dl.loaders.mr_us_dataset as mr_us_dataset
import dl.nets.us_simulation_jit as us_simulation_jit
import dl.nets.us_simu as us_simu

import importlib

from dl.nets.layers import TimeDistributed

sys.path.append('/mnt/raid/C1_ML_Analysis/source/ShapeAXI/src')
from shapeaxi import utils as saxi_utils
from shapeaxi.saxi_transforms import EvalTransform

from pytorch3d.ops import (sample_points_from_meshes,
                           knn_points, 
                           knn_gather,
                           sample_farthest_points)


In [None]:
mount_point = '/mnt/raid/C1_ML_Analysis'

importlib.reload(us_simu)
vs = us_simu.VolumeSamplingBlindSweep(mount_point=mount_point)
vs = vs.cuda()

In [None]:

# diffusor = sitk.ReadImage('/mnt/famli_netapp_shared/C1_ML_Analysis/src/blender/simulated_data_export/studies_merged/FAM-025-0447-5.nrrd')
# diffusor_np = sitk.GetArrayFromImage(diffusor)
# diffusor_t = torch.tensor(diffusor_np.astype(int))

# diffusor_spacing = torch.tensor(diffusor.GetSpacing()).flip(dims=[0])
# diffusor_size = torch.tensor(diffusor.GetSize()).flip(dims=[0])

# diffusor_origin = torch.tensor(diffusor.GetOrigin()).flip(dims=[0])
# diffusor_end = diffusor_origin + diffusor_spacing * diffusor_size
# print(diffusor_size)
# print(diffusor_spacing)
# print(diffusor_t.shape)
# print(diffusor_origin)
# print(diffusor_end)

diffusor_np, diffusor_head = nrrd.read('/mnt/raid//C1_ML_Analysis/simulated_data_export/placenta/FAM-025-0664-4_label11_resampled.nrrd')
diffusor_t = torch.tensor(diffusor_np.astype(int)).permute(2, 1, 0)

print(diffusor_head)
diffusor_size = torch.tensor(diffusor_head['sizes'])
diffusor_spacing = torch.tensor(np.diag(diffusor_head['space directions']))

diffusor_origin = torch.tensor(diffusor_head['space origin']).flip(dims=[0])
diffusor_end = diffusor_origin + diffusor_spacing * diffusor_size
print(diffusor_spacing)
print(diffusor_t.shape)
print(diffusor_origin)
print(diffusor_end)


In [4]:
# fig = px.imshow(diffusor_t.flip(dims=[1]).squeeze().cpu().numpy(), animation_frame=0, binary_string=True)
# fig.show()

In [None]:
# diffusor_batch_t = diffusor_t.permute([2, 1, 0]).cuda().float().unsqueeze(0).unsqueeze(0)
diffusor_batch_t = diffusor_t.cuda().float().unsqueeze(0).unsqueeze(0)

diffusor_origin_batch = diffusor_origin[None, :]
diffusor_end_batch = diffusor_end[None, :]

print(diffusor_batch_t.shape, diffusor_origin_batch.shape, diffusor_origin_batch.shape)
# print(diffusor_origin_batch.shape)

diffusor_in_fov_t = vs.diffusor_in_fov(diffusor_batch_t, diffusor_origin_batch.cuda(), diffusor_end_batch.cuda())

print(diffusor_in_fov_t.shape)


In [None]:
from torch.nn.utils.rnn import pad_sequence

fov_physical = vs.fov_physical()

# repeats = [1,]*len(out_fovs.shape)
# repeats[0] = out_fovs.shape[0]

# fov_physical = fov_physical.repeat(repeats)

V_fov = fov_physical.reshape(-1, 3).cuda()

V_ = []
VF_ = []

for diff_in_fov in diffusor_in_fov_t:
        
        diff_in_fov = diff_in_fov.reshape(-1, 1)
        
        V_filtered = V_fov[diff_in_fov.squeeze() == 5]
        F_filtered = diff_in_fov[diff_in_fov.squeeze() == 5]
        V_.append(V_filtered)
        VF_.append(F_filtered)

V = pad_sequence(V_, batch_first=True, padding_value=0.0) 
VF = pad_sequence(VF_, batch_first=True, padding_value=0.0)

print(V.shape, VF.shape)

In [None]:
x_v = V
SN = 0

fig = go.Figure(data=[go.Scatter3d(x=x_v[SN,:,0].detach().cpu().numpy(), y=x_v[SN,:,1].detach().cpu().numpy(), z=x_v[SN,:,2].detach().cpu().numpy(), mode='markers', marker=dict(
        size=2,
        color=x_v[SN,:,2].detach().cpu().numpy(),                # set color to an array/list of desired values
        colorscale='jet',   # choose a colorscale
        # opacity=0.5
    ))])
fig.show()

In [8]:
# fig = px.imshow(diffusor_in_fov_t[0].squeeze().flip(dims=[1]).cpu().numpy(), animation_frame=0, binary_string=True)
# fig.show()
# fig = px.imshow(diffusor_in_fov_t[1].squeeze().flip(dims=[1]).cpu().numpy(), animation_frame=0, binary_string=True)
# fig.show()

In [9]:
surf = saxi_utils.ReadSurf('/mnt/raid/C1_ML_Analysis/simulated_data_export/studies_fetus/FAM-025-0664-4_Fetus_Model.vtk')
V_surf, F_surf = saxi_utils.PolyDataToTensors_v_f(surf)

In [10]:
idx = np.load('/mnt/raid/C1_ML_Analysis/simulated_data_export/dists_Fetus_Model.npy')
idx = torch.tensor(idx, dtype=torch.int64, device=V_surf.device)

P = knn_gather(V_surf.unsqueeze(0), idx).squeeze(-2).squeeze(0).contiguous()
P = EvalTransform()(P)

P = P.unsqueeze(0)


In [None]:
fig = go.Figure(data=[go.Scatter3d(x=P[SN,:,2].detach().cpu().numpy(), y=P[SN,:,1].detach().cpu().numpy(), z=P[SN,:,0].detach().cpu().numpy(), mode='markers', marker=dict(
        size=2,
        color=P[SN,:,2].detach().cpu().numpy(),                # set color to an array/list of desired values
        colorscale='jet',   # choose a colorscale
        # opacity=0.5
    ))])
fig.show()

In [12]:
def random_rotate_3d_batch(image_tensor):
    """
    Randomly rotates a batch of 3D image tensors and returns the rotated images along with the rotation matrices.
    
    Args:
        image_tensor (torch.Tensor): A tensor of shape (B, C, D, H, W).
    
    Returns:
        rotated_images (torch.Tensor): Rotated 3D tensors of shape (B, C, D, H, W).
        rotation_matrices (torch.Tensor): Rotation matrices of shape (B, 3, 3).
    """
    assert len(image_tensor.shape) == 5, "Input tensor must be 5D (B, C, D, H, W)."

    B, C, D, H, W = image_tensor.shape

    # Generate random angles for each batch (B, 3)
    angles = torch.rand(B, 3) * 2 * torch.pi  # Random angles between 0 and 2*pi

    # Helper functions to create rotation matrices
    def rotation_matrix_x(angle):
        return torch.tensor([
            [1, 0, 0],
            [0, torch.cos(angle), -torch.sin(angle)],
            [0, torch.sin(angle), torch.cos(angle)]
        ], device=image_tensor.device)

    def rotation_matrix_y(angle):
        return torch.tensor([
            [torch.cos(angle), 0, torch.sin(angle)],
            [0, 1, 0],
            [-torch.sin(angle), 0, torch.cos(angle)]
        ], device=image_tensor.device)

    def rotation_matrix_z(angle):
        return torch.tensor([
            [torch.cos(angle), -torch.sin(angle), 0],
            [torch.sin(angle), torch.cos(angle), 0],
            [0, 0, 1]
        ], device=image_tensor.device)

    # Generate rotation matrices for each sample in the batch
    rotation_matrices = torch.stack([
        rotation_matrix_z(angles[i, 2]) @ 
        rotation_matrix_y(angles[i, 1]) @ 
        rotation_matrix_x(angles[i, 0])
        for i in range(B)
    ])

    # Convert 3x3 rotation matrices to 3x4 affine matrices
    affine_matrices = torch.cat([rotation_matrices, torch.zeros(B, 3, 1, device=image_tensor.device)], dim=2)

    # Generate affine grids for each batch
    grids = F.affine_grid(
        affine_matrices,
        size=image_tensor.size(),
        align_corners=False
    )

    # Apply rotations using grid sampling
    rotated_images = F.grid_sample(
        image_tensor,
        grids,
        mode='bilinear',
        padding_mode='zeros',
        align_corners=False
    )

    return rotated_images, rotation_matrices


In [13]:
diffusor_batch_rotated_t, rotation_matrices = random_rotate_3d_batch(diffusor_batch_t)
diffusor_in_fov_rotated_t = vs.diffusor_in_fov(diffusor_batch_rotated_t, diffusor_origin_batch.cuda(), diffusor_end_batch.cuda())

In [None]:
V.shape

In [15]:
V_ = []
VF_ = []

for diff_in_fov in diffusor_in_fov_rotated_t:
        
        diff_in_fov = diff_in_fov.reshape(-1, 1)
        
        V_filtered = V_fov[diff_in_fov.squeeze() == 5]
        F_filtered = diff_in_fov[diff_in_fov.squeeze() == 5]
        V_.append(V_filtered)
        VF_.append(F_filtered)

V_rotated = pad_sequence(V_, batch_first=True, padding_value=0.0) 
VF_rotated = pad_sequence(VF_, batch_first=True, padding_value=0.0)

In [None]:
x_v = V_rotated
SN = 0

fig = go.Figure(data=[go.Scatter3d(x=x_v[SN,:,0].detach().cpu().numpy(), y=x_v[SN,:,1].detach().cpu().numpy(), z=x_v[SN,:,2].detach().cpu().numpy(), mode='markers', marker=dict(
        size=2,
        color=x_v[SN,:,2].detach().cpu().numpy(),                # set color to an array/list of desired values
        colorscale='jet',   # choose a colorscale
        # opacity=0.5
    ))])
fig.show()

In [17]:
def apply_batch_rotation(V, rotation_matrices):
    """
    Applies a batch of rotation matrices to a batch of 3D points.

    Args:
        V (torch.Tensor): A tensor of shape (BS, N, 3) containing 3D points.
        rotation_matrices (torch.Tensor): A tensor of shape (BS, 3, 3) containing rotation matrices.

    Returns:
        rotated_points (torch.Tensor): A tensor of shape (BS, N, 3) with rotated points.
    """
    assert V.shape[-1] == 3, "Points tensor must have the last dimension of size 3 (3D coordinates)."
    assert rotation_matrices.shape[-2:] == (3, 3), "Rotation matrices must have shape (BS, 3, 3)."
    assert V.shape[0] == rotation_matrices.shape[0], "Batch size of points and rotation matrices must match."

    # Apply the rotation matrix to each batch element
    rotated_points = torch.matmul(V, rotation_matrices)  # Transpose for proper multiplication but here we don't transpose because the V tensor is ordered XYZ while the rotation matrix is ordered ZYX
    # rotated_points = torch.matmul(V, rotation_matrices.transpose(1, 2))  # Transpose for proper multiplication
    return rotated_points


In [18]:
P_rotated = apply_batch_rotation(P.float().cuda(), rotation_matrices)

In [None]:
fig = go.Figure(data=[go.Scatter3d(x=P_rotated[SN, :,2].detach().cpu().numpy(), y=P_rotated[SN, :,1].detach().cpu().numpy(), z=P_rotated[SN, :,0].detach().cpu().numpy(), mode='markers', marker=dict(
        size=2,
        color=P_rotated[SN, :,2].detach().cpu().numpy(),                # set color to an array/list of desired values
        colorscale='jet',   # choose a colorscale
        # opacity=0.5
    ))])
fig.show()