In [1]:
import torch
import torch.nn.functional as F
import torch.nn as nn
from generative.networks.nets import AutoencoderKL

import plotly
import plotly.express as px
import plotly.graph_objects as go

import numpy as np
import SimpleITK as sitk
import pickle

In [2]:
model = AutoencoderKL(
    spatial_dims=2,
    in_channels=1,
    out_channels=1,
    num_channels=(128, 256, 384),
    latent_channels=8,
    num_res_blocks=1,
    norm_num_groups=32,
    attention_levels=(False, False, True),
)

In [3]:
model.encode(torch.rand(1, 1, 128, 128))[0].shape

torch.Size([1, 8, 32, 32])

In [4]:
def read_probe_params(probe_params_fn):
    return pickle.load(open(probe_params_fn, 'rb'))

In [5]:
fn_arr = ["/mnt/raid/C1_ML_Analysis/source/blender/simulated_data_export/test_output/ultra-sim/rendering_cut_linear/v0.2/epoch=39-val_loss=3.66.ckpt/FAM-202-1960-2_mesh_sampling_autoencoderkl/AC/100.nrrd",
          "/mnt/raid/C1_ML_Analysis/source/blender/simulated_data_export/test_output/ultra-sim/rendering_cut_linear/v0.2/epoch=39-val_loss=3.66.ckpt/FAM-202-1960-2_mesh_sampling_autoencoderkl/AC/101.nrrd",
          "/mnt/raid/C1_ML_Analysis/source/blender/simulated_data_export/test_output/ultra-sim/rendering_cut_linear/v0.2/epoch=39-val_loss=3.66.ckpt/FAM-202-1960-2_mesh_sampling_autoencoderkl/AC/102.nrrd",
          "/mnt/raid/C1_ML_Analysis/source/blender/simulated_data_export/test_output/ultra-sim/rendering_cut_linear/v0.2/epoch=39-val_loss=3.66.ckpt/FAM-202-1960-2_mesh_sampling_autoencoderkl/AC/103.nrrd",]

fn_probe_arr = ["/mnt/raid/C1_ML_Analysis/source/blender/simulated_data_export/FAM-202-1960-2_mesh_probe_params/AC/100_probe_params.pickle",
                "/mnt/raid/C1_ML_Analysis/source/blender/simulated_data_export/FAM-202-1960-2_mesh_probe_params/AC/101_probe_params.pickle",
                "/mnt/raid/C1_ML_Analysis/source/blender/simulated_data_export/FAM-202-1960-2_mesh_probe_params/AC/102_probe_params.pickle",
                "/mnt/raid/C1_ML_Analysis/source/blender/simulated_data_export/FAM-202-1960-2_mesh_probe_params/AC/103_probe_params.pickle",]


In [6]:
img_np = sitk.GetArrayFromImage(sitk.ReadImage(fn_arr[0]))

fig = go.Figure()
# fig.add_trace(go.Heatmap(z=np.flip(img.detach().numpy().squeeze(), axis=0), colorscale='gray', showscale=False, opacity=0.3))
# fig.add_trace(go.Heatmap(z=np.flip(seg.detach().numpy().squeeze(), axis=0), colorscale='jet', opacity=0.5))
fig.add_trace(go.Heatmap(z=np.rot90(np.flip(img_np.squeeze(), axis=0), k=3), colorscale='gray', opacity=1.0))

fig.update_layout(
            autosize=False,
            width=860,
            height=860
        )

In [7]:
probe_params = read_probe_params(fn_probe_arr[0])
probe_params

{'probe_origin': array([-0.03264857,  0.08562665,  0.2059159 ]),
 'probe_direction': array([[ 7.54979013e-08,  1.00000000e+00, -4.37113883e-08],
        [ 0.00000000e+00, -4.37113883e-08, -1.00000000e+00],
        [-1.00000000e+00,  7.54979013e-08, -3.30011808e-15]]),
 'ref_size': (256, 256, 1),
 'ref_origin': array([-0.09741297,  0.08627142,  0.2059159 ]),
 'ref_spacing': (0.0005059718969278038,
  0.0005059718969278038,
  0.0012895349645987153)}

In [8]:

# Assume the 2D image (3x3) is given
image_2d = torch.tensor(np.rot90(np.flip(img_np.squeeze(), axis=0), k=3), dtype=torch.float32).unsqueeze(0).unsqueeze(0).unsqueeze(-1)  # Add batch and channel dimensions

In [9]:
def create_3d_image_with_coordinates_torch(depth, height, width):
    """
    Create a 3D image with pixel values representing the 3D coordinates using PyTorch.

    Args:
    - width (int): Width of the 2D slice.
    - height (int): Height of the 2D slice.

    Returns:
    - torch.Tensor: A 3D image (height, width, 2) with coordinates as pixel values.
    """
    # Create 2D grids of x and y coordinates
    z_coords, y_coords, x_coords = torch.meshgrid(torch.arange(depth), torch.arange(height), torch.arange(width), indexing='ij')

    # Create a 3D tensor with [x, y, z] coordinates for each pixel
    coords_image = torch.stack([z_coords, y_coords, x_coords], dim=-1)

    return coords_image



In [31]:
coords = create_3d_image_with_coordinates_torch(1, 256, 256)

In [32]:
coords.shape


torch.Size([1, 256, 256, 3])

In [33]:
probe_direction = torch.tensor(probe_params['probe_direction'], dtype=torch.float32).unsqueeze(0).unsqueeze(0)

In [37]:
target_idx = torch.matmul(coords.unsqueeze(0).to(torch.float32), probe_direction)
target_idx = ((target_idx - target_idx.min()) / (target_idx.max() - target_idx.min())*128).long()


In [36]:
target = torch.zeros(1, 1, 128, 128, 128)
target[target_idx].shape

RuntimeError: [enforce fail at alloc_cpu.cpp:117] err == 0. DefaultCPUAllocator: can't allocate memory: you tried to allocate 1649267441664 bytes. Error code 12 (Cannot allocate memory)

In [None]:
# volume = resampled_3d.squeeze().detach().numpy()
# sitk.WriteImage(sitk.GetImageFromArray(volume), "/mnt/famli_netapp_shared/C1_ML_Analysis/resampled_3d.nrrd")