In [None]:
import torch
from torch import nn
import torch.nn.functional as F
import functools
import numpy as np
import matplotlib.pyplot as plt
import os
from torchvision import transforms
from math import pi
import SimpleITK as sitk
import plotly
import plotly.express as px
import pandas as pd
from torch.utils.data import Dataset, DataLoader
import vtk
from vtk.util.numpy_support import vtk_to_numpy
import pickle
from monai.transforms import (RandSimulateLowResolution)
import monai 


alpha_coeff_boundary_map = 0.1
beta_coeff_scattering = 10  #100 approximates it closer
TGC = 8
CLAMP_VALS = True


def gaussian_kernel(size: int, mean: float, std: float):
    d1 = torch.distributions.Normal(mean, std)
    d2 = torch.distributions.Normal(mean, std*3)
    vals_x = d1.log_prob(torch.arange(-size, size+1, dtype=torch.float32)).exp()
    vals_y = d2.log_prob(torch.arange(-size, size+1, dtype=torch.float32)).exp()

    gauss_kernel = torch.einsum('i,j->ij', vals_x, vals_y)
    
    return gauss_kernel / torch.sum(gauss_kernel).reshape(1, 1)

g_kernel = gaussian_kernel(3, 0., 0.5)
g_kernel = torch.tensor(g_kernel[None, None, :, :])


class UltrasoundRendering(torch.nn.Module):
    def __init__(self, num_labels=340, grid_width=1024, grid_height=1024, center_x=512, center_y=-80, r1=160.0, r2=896.0, theta=np.pi / 4):
        super(UltrasoundRendering, self).__init__()               
        
        # df = pd.read_csv(acoustic_params_fn)        
        # accoustic_imped,attenuation,mu_0,mu_1,sigma_0
        self.num_labels = num_labels
        self.acoustic_impedance_dict = torch.nn.Parameter(torch.rand(self.num_labels))    # Z in MRayl
        self.attenuation_dict =    torch.nn.Parameter(torch.rand(self.num_labels))   # alpha in dB cm^-1 at 1 MHz
        self.mu_0_dict =           torch.nn.Parameter(torch.rand(self.num_labels)) # mu_0 - scattering_mu   mean brightness
        self.mu_1_dict =           torch.nn.Parameter(torch.rand(self.num_labels)) # mu_1 - scattering density, Nr of scatterers/voxel
        self.sigma_0_dict =        torch.nn.Parameter(torch.rand(self.num_labels)) # sigma_0 - scattering_sigma - brightness std
        
        self.grid = self.compute_grid(grid_width, grid_height, center_x, center_y, r1, r2, theta)
        
        self.inverse_grid = self.compute_grid_inverse(self.grid)
                
        
    def init_params(df):
        df = pd.read_csv(acoustic_params_fn)
        
        # accoustic_imped,attenuation,mu_0,mu_1,sigma_0
        self.acoustic_impedance_dict = torch.nn.Parameter(torch.tensor(df['accoustic_imped']))    # Z in MRayl
        self.attenuation_dict =    torch.nn.Parameter(torch.tensor(df['attenuation']))   # alpha in dB cm^-1 at 1 MHz
        self.mu_0_dict =           torch.nn.Parameter(torch.tensor(df['mu_0'])) # mu_0 - scattering_mu   mean brightness
        self.mu_1_dict =           torch.nn.Parameter(torch.tensor(df['mu_1'])) # mu_1 - scattering density, Nr of scatterers/voxel
        self.sigma_0_dict =        torch.nn.Parameter(torch.tensor(df['sigma_0'])) # sigma_0 - scattering_sigma - brightness std


    def rendering(self, shape, attenuation_medium_map, mu_0_map, mu_1_map, sigma_0_map, z_vals=None, refl_map=None, boundary_map=None):
        
        dists = torch.abs(z_vals[..., :-1, None] - z_vals[..., 1:, None])     # dists.shape=(W, H-1, 1)
        dists = dists.squeeze(-1)                                             # dists.shape=(W, H-1)
        dists = torch.cat([dists, dists[:, -1, None]], dim=-1)                # dists.shape=(W, H)

        attenuation = torch.exp(-attenuation_medium_map * dists)
        attenuation_total = torch.cumprod(attenuation, dim=3, dtype=torch.float32, out=None)

        gain_coeffs = np.linspace(1, TGC, attenuation_total.shape[3])
        gain_coeffs = np.tile(gain_coeffs, (attenuation_total.shape[2], 1))
        gain_coeffs = torch.tensor(gain_coeffs)
        attenuation_total = attenuation_total * gain_coeffs     # apply TGC

        reflection_total = torch.cumprod(1. - refl_map * boundary_map, dim=3, dtype=torch.float32, out=None) 
        reflection_total = reflection_total.squeeze(-1) 
        reflection_total_plot = torch.log(reflection_total + torch.finfo(torch.float32).eps)

        texture_noise = torch.randn(shape, dtype=torch.float32)
        scattering_probability = torch.randn(shape, dtype=torch.float32)

        scattering_zero = torch.zeros(shape, dtype=torch.float32)

        z = mu_1_map - scattering_probability
        sigmoid_map = torch.sigmoid(beta_coeff_scattering * z)

        # approximating  Eq. (4) to be differentiable:
        # where(scattering_probability <= mu_1_map, 
        #                     texture_noise * sigma_0_map + mu_0_map, 
        #                     scattering_zero)
        # scatterers_map =  (sigmoid_map) * (texture_noise * sigma_0_map + mu_0_map) + (1 -sigmoid_map) * scattering_zero   # Eq. (6)
        scatterers_map =  (sigmoid_map) * (texture_noise * sigma_0_map + mu_0_map)

        psf_scatter_conv = torch.nn.functional.conv2d(input=scatterers_map, weight=g_kernel, stride=1, padding="same")
        # psf_scatter_conv = psf_scatter_conv.squeeze()

        b = attenuation_total * psf_scatter_conv    # Eq. (3)

        border_convolution = torch.nn.functional.conv2d(input=boundary_map, weight=g_kernel, stride=1, padding="same")
        # border_convolution = border_convolution.squeeze()

        r = attenuation_total * reflection_total * refl_map * border_convolution # Eq. (2)
        
        intensity_map = b + r   # Eq. (1)
        # intensity_map = intensity_map.squeeze() 
        intensity_map = torch.clamp(intensity_map, 0, 1)

        return intensity_map, attenuation_total, reflection_total_plot, scatterers_map, scattering_probability, border_convolution, texture_noise, b, r
    
    def render_rays(self, W, H):
        N_rays = W 
        t_vals = torch.linspace(0., 1., H)
        z_vals = t_vals.unsqueeze(0).expand(N_rays , -1) * 4 

        return z_vals

    def compute_grid(self, w, h, center_x, center_y, r1, r2, theta):
        # Convert inputs to tensors
        angles = torch.linspace(-theta, theta, w)  # Angles from -theta to theta
        radii = torch.linspace(r1, r2, h)  # Linear space of radii

        # Calculate sin and cos for all angles (broadcasting)
        sin_angles = torch.sin(angles)
        cos_angles = torch.cos(angles)

        # Initialize the grid for intersection points
        # Shape of grid: (h, w, 2) where 2 represents (x, y) coordinates
        grid = torch.zeros(h, w, 2)

        # Calculate intersections for each radius and angle
        for i, radius in enumerate(radii):
            x = (center_x + radius * sin_angles) # x coordinates for all angles at this radius
            y = (center_y + radius * cos_angles) # y coordinates for all angles at this radius

            grid[i] = torch.stack((y, x), dim=1)  # Update grid with coordinates

        return grid

    def compute_grid_inverse(self, grid):
        h, w, _ = grid.shape  # grid dimensions
        inverse_grid = torch.zeros(h, w, 2)  # Initialize inverse grid

        # Iterate through each point in the grid
        for j in range(h):
            for i in range(w):
                # Extract the polar coordinates (represented in the grid)
                xi, yi = torch.round(grid[j, i]).to(torch.long)
                # xi = torch.int(x)


                # Convert back to Cartesian coordinates
                # x = r * torch.cos(theta) + center_x
                # y = r * torch.sin(theta) + center_y

                # Place the Cartesian coordinates in the inverse grid
                if 0 <= xi and xi < w and 0 <= yi and yi < h:
                    inverse_grid[yi, xi] = torch.tensor([i, j])

        return inverse_grid
    
    def grid_transform(self, x, grid, interpolation_mode='nearest'):
        
        w, h, _ = grid.shape
        
        repeats = [1,]*len(x.shape)
        repeats[0] = x.shape[0]
        grid_f = grid / torch.tensor([w, h]) * 2.0 - 1.0
        grid_f = grid_f.repeat(repeats)
        
        return F.grid_sample(x.float(), grid_f, mode=interpolation_mode)


    def forward(self, x):
        #init tissue maps
        #generate maps from the dictionary and the input label map
        x = torch.rot90(x, k=1, dims=[2, 3])
        
        x = self.grid_transform(x, self.grid)
        
        x = x.to(torch.long)
        
        acoustic_imped_map = self.acoustic_impedance_dict[x]
        attenuation_medium_map = self.attenuation_dict[x]
        mu_0_map = self.mu_0_dict[x]
        mu_1_map = self.mu_1_dict[x]
        sigma_0_map = self.sigma_0_dict[x]

        
        #Comput the difference along dimension 2
        diff_arr = torch.diff(acoustic_imped_map, dim=2)                
        # The pad tuple is (padding_left,padding_right, padding_top,padding_bottom)
        # The array is padded at the top
        diff_arr = F.pad(diff_arr, (0,0,1,0))

        #Compute the boundary map using the diff_array
        boundary_map =  -torch.exp(-(diff_arr**2)/alpha_coeff_boundary_map) + 1
        
        #Roll/shift the elements along dimension 2 and set the last element to 0
        shifted_arr = torch.roll(acoustic_imped_map, -1, dims=2)
        shifted_arr[-1:] = 0

        # This computes the sum/accumulation along the direction and set elements that are 0 to 1. Compute the division
        sum_arr = acoustic_imped_map + shifted_arr
        sum_arr[sum_arr == 0] = 1
        div = diff_arr / sum_arr
        # Compute the reflection from the elements
        refl_map = div ** 2
        refl_map = torch.sigmoid(refl_map)      # 1 / (1 + (-refl_map).exp())

        z_vals = self.render_rays(x.shape[2], x.shape[3])

        if CLAMP_VALS:
            attenuation_medium_map = torch.clamp(attenuation_medium_map, 0, 10)
            acoustic_imped_map = torch.clamp(acoustic_imped_map, 0, 10)
            sigma_0_map = torch.clamp(sigma_0_map, 0, 1)
            mu_1_map = torch.clamp(mu_1_map, 0, 1)
            mu_0_map = torch.clamp(mu_0_map, 0, 1)

        ret_list = self.rendering(x.shape, attenuation_medium_map, mu_0_map, mu_1_map, sigma_0_map, z_vals=z_vals, refl_map=refl_map, boundary_map=boundary_map)

        intensity_map  = ret_list[0]
       
        # return intensity_map
        intensity_map_t = self.grid_transform(intensity_map, self.inverse_grid)
        
        intensity_map = torch.rot90(intensity_map, k=3, dims=[2, 3])
        intensity_map_t = torch.rot90(intensity_map_t, k=3, dims=[2, 3])

        return intensity_map_t, attenuation_medium_map, mu_0_map, mu_1_map, sigma_0_map, acoustic_imped_map, boundary_map, shifted_arr, intensity_map

In [None]:
acoustic_params_df = pd.read_csv('/mnt/raid/C1_ML_Analysis/source/guibruss/LOTUS/lotus/acoustic_params.csv')
us_render = UltrasoundRendering(num_labels=len(acoustic_params_df.index))
fake_us = us_render(torch.randint(low=0, high=10, size=(2, 1, 256, 256)))

In [None]:
fn = "/mnt/raid/C1_ML_Analysis/source/blender/simulated_data_export/FAM-202-1960-2_mesh_sampling/AC/40.nrrd"
img = sitk.ReadImage(fn)
img1_np = sitk.GetArrayFromImage(img)
px.imshow(img1_np[0])

In [None]:
fn = "/mnt/raid/C1_ML_Analysis/source/blender/simulated_data_export/FAM-202-1960-2_mesh_sampling/AC/98.nrrd"
img = sitk.ReadImage(fn)
img2_np = sitk.GetArrayFromImage(img)
px.imshow(img2_np[0])

In [None]:
img_np = np.stack([img1_np, img2_np]).astype(np.long)
img_np.shape
t = torch.tensor(img_np).to(torch.long)
fake_us = us_render(t)

In [None]:
px.imshow(fake_us[0][0].detach().numpy().squeeze())

In [None]:
px.imshow(fake_us[0][1].detach().numpy().squeeze())

In [None]:

# class ResnetGenerator(nn.Module):
#     """Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations.

#     We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style)
#     """

#     def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect'):
#         """Construct a Resnet-based generator

#         Parameters:
#             input_nc (int)      -- the number of channels in input images
#             output_nc (int)     -- the number of channels in output images
#             ngf (int)           -- the number of filters in the last conv layer
#             norm_layer          -- normalization layer
#             use_dropout (bool)  -- if use dropout layers
#             n_blocks (int)      -- the number of ResNet blocks
#             padding_type (str)  -- the name of padding layer in conv layers: reflect | replicate | zero
#         """
#         assert(n_blocks >= 0)
#         super(ResnetGenerator, self).__init__()
#         if type(norm_layer) == functools.partial:
#             use_bias = norm_layer.func == nn.InstanceNorm2d
#         else:
#             use_bias = norm_layer == nn.InstanceNorm2d

#         model = [nn.ReflectionPad2d(3),
#                  nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
#                  norm_layer(ngf),
#                  nn.ReLU(True)]

#         n_downsampling = 2
#         for i in range(n_downsampling):  # add downsampling layers
#             mult = 2 ** i
#             model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
#                       norm_layer(ngf * mult * 2),
#                       nn.ReLU(True)]

#         mult = 2 ** n_downsampling
#         for i in range(n_blocks):       # add ResNet blocks

#             model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]

#         for i in range(n_downsampling):  # add upsampling layers
#             mult = 2 ** (n_downsampling - i)
#             model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
#                                          kernel_size=3, stride=2,
#                                          padding=1, output_padding=1,
#                                          bias=use_bias),
#                       norm_layer(int(ngf * mult / 2)),
#                       nn.ReLU(True)]
#         model += [nn.ReflectionPad2d(3)]
#         model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
#         model += [nn.Tanh()]

#         self.model = nn.Sequential(*model)

#     def forward(self, input):
#         """Standard forward"""
#         return self.model(input)
    
# class ResnetBlock(nn.Module):
#     """Define a Resnet block"""

#     def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
#         """Initialize the Resnet block

#         A resnet block is a conv block with skip connections
#         We construct a conv block with build_conv_block function,
#         and implement skip connections in <forward> function.
#         Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf
#         """
#         super(ResnetBlock, self).__init__()
#         self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)

#     def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
#         """Construct a convolutional block.

#         Parameters:
#             dim (int)           -- the number of channels in the conv layer.
#             padding_type (str)  -- the name of padding layer: reflect | replicate | zero
#             norm_layer          -- normalization layer
#             use_dropout (bool)  -- if use dropout layers.
#             use_bias (bool)     -- if the conv layer uses bias or not

#         Returns a conv block (with a conv layer, a normalization layer, and a non-linearity layer (ReLU))
#         """
#         conv_block = []
#         p = 0
#         if padding_type == 'reflect':
#             conv_block += [nn.ReflectionPad2d(1)]
#         elif padding_type == 'replicate':
#             conv_block += [nn.ReplicationPad2d(1)]
#         elif padding_type == 'zero':
#             p = 1
#         else:
#             raise NotImplementedError('padding [%s] is not implemented' % padding_type)

#         conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True)]
#         if use_dropout:
#             conv_block += [nn.Dropout(0.5)]

#         p = 0
#         if padding_type == 'reflect':
#             conv_block += [nn.ReflectionPad2d(1)]
#         elif padding_type == 'replicate':
#             conv_block += [nn.ReplicationPad2d(1)]
#         elif padding_type == 'zero':
#             p = 1
#         else:
#             raise NotImplementedError('padding [%s] is not implemented' % padding_type)
#         conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)]

#         return nn.Sequential(*conv_block)

#     def forward(self, x):
#         """Forward function (with skip connections)"""
#         out = x + self.conv_block(x)  # add skip connections
#         return out

In [None]:
reader = vtk.vtkSTLReader()
reader.SetFileName('/mnt/raid/C1_ML_Analysis/source/blender/simulated_data_export/FAM-202-1960-2_mesh/gestational/gestational_sac.stl')
reader.Update()
surf = reader.GetOutput()

In [None]:
volume = sitk.ReadImage('/mnt/raid/C1_ML_Analysis/source/blender/simulated_data_export/FAM-202-1960-2_mesh.nrrd')

In [None]:
# probe_params_df = pd.read_csv("/mnt/raid/C1_ML_Analysis/source/blender/simulated_data_export/CSV_files/FAM-202-1960-2_mesh_probe_params.csv")
probe_params_df = pd.read_csv("/mnt/raid/C1_ML_Analysis/source/blender/simulated_data_export/CSV_files/FAM-202-1960-2_mesh_probe_params_BPD.csv")

In [None]:

class VolumeSlicingDataset(Dataset):
    def __init__(self, volume, surf, num_samples, ref_size=(256, 256, 1), ref_spacing=(0.0004119873046875, 0.0004119873046875, 0.0010500028729438782), transform=None, interpolator=sitk.sitkNearestNeighbor):
                
        self.volume = volume
        self.surf = self.ComputeNormals(surf)
        self.surf_points =  vtk_to_numpy(self.surf.GetPoints().GetData())
        self.surf_normals = vtk_to_numpy(self.surf.GetPointData().GetArray("Normals"))
        self.num_samples = num_samples
        self.transform = transform
        self.ref_size = ref_size
        self.ref_spacing = ref_spacing
        self.interpolator = interpolator
        

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        
        img = self.sample_image()
        img_np = sitk.GetArrayFromImage(img)
        
        if self.transform:
            return self.transform(img_np)
        
        return img_np
    
    def ComputeNormals(self, surf):
        normals = vtk.vtkPolyDataNormals()
        normals.SetInputData(surf);
        normals.ComputeCellNormalsOn();
        normals.ComputePointNormalsOn();
        normals.SplittingOff();
        normals.Update()

        return normals.GetOutput()    

    def align_matrix(self, N, A=np.array([-1.0, 0.0, 0.0])):
        # Calculate the axis of rotation (cross product of A and N)
        axis = np.cross(A, N)
        axis_norm = np.linalg.norm(axis)
        if axis_norm == 0:
            # No rotation needed if vectors are parallel
            return np.identity(3)
        axis = axis / axis_norm

        # Calculate the angle of rotation using the dot product
        angle = np.arccos(np.clip(np.dot(A, N) / (np.linalg.norm(A) * np.linalg.norm(N)), -1.0, 1.0))
       
        return self.rotation_matrix(axis, angle)
    
    def rotation_matrix(self, axis, angle):
        """
        Rotate the matrix 'M' around the 'axis' by 'angle' degrees.

        Parameters:
        v (np.array): The vector to be rotated.
        axis (np.array): The axis of rotation (should be a normalized vector).
        angle (float): The angle of rotation in degrees.

        Returns:
        np.array: The rotated vector.
        """
        # Convert the angle from degrees to radians
        angle_rad = np.radians(angle)

        # Rodrigues' rotation formula components
        axis = axis / np.linalg.norm(axis)  # Ensure the axis is a unit vector
        K = np.array([[0, -axis[2], axis[1]], [axis[2], 0, -axis[0]], [-axis[1], axis[0], 0]])
        R = np.eye(3) + np.sin(angle_rad) * K + (1 - np.cos(angle_rad)) * np.dot(K, K)

        # Apply the rotation
        return R
    
    def create_4x4_matrix(self, R, P):
        """
        Create a transformation matrix in homogeneous coordinates from a rotation matrix and a translation vector.

        Parameters:
        R (np.array): The 3x3 rotation matrix.
        P (np.array): The translation vector.

        Returns:
        np.array: The 4x4 transformation matrix in homogeneous coordinates.
        """
        T = np.zeros((4, 4))
        T[:3, :3] = R
        T[:3, 3] = P
        T[3, 3] = 1
        return T

    def sample_image(self):
        
        random_i = np.random.randint(low=0, high=self.surf_points.shape[0])
        
        surf_point = self.surf_points[random_i]
        # surf_point = np.array([-0.006, 0.033, 0.21])
        surf_normal = self.surf_normals[random_i]
        # surf_normal = np.array([0, 0, 1.0])       
        
        #This matrix aligns the vector (0, 0, 1) to the surface normal
        rotation_matrix = self.align_matrix(surf_normal)
        matrix_world = self.create_4x4_matrix(rotation_matrix, surf_point)
        direction = matrix_world[0:3, 0:3]
        
        delta_origin = np.array([0, -self.ref_size[1]*self.ref_spacing[1]/2.0, -self.ref_size[2]*self.ref_spacing[2]/2.0, 1.0])
        delta_origin = np.dot(matrix_world, delta_origin)
        

        ref = sitk.Image(int(self.ref_size[0]), int(self.ref_size[1]), int(self.ref_size[2]), sitk.sitkFloat32)
        ref.SetOrigin(delta_origin[0:3].astype(np.double))
        ref.SetSpacing(self.ref_spacing)
        ref.SetDirection(direction.flatten().astype(np.double))

        resampler = sitk.ResampleImageFilter()
        if self.interpolator:
            resampler.SetInterpolator(self.interpolator)
        resampler.SetReferenceImage(ref)

        return resampler.Execute(self.volume)

In [None]:
class VolumeSlicingProbeParamsDataset(Dataset):
    def __init__(self, df, volume, probe_params_column_name="probe_params_fn", mount_point="./", transform=None, interpolator=sitk.sitkNearestNeighbor):
                
        self.df = df
        self.volume = volume        
        self.probe_params_column_name = probe_params_column_name
        self.mount_point = mount_point
        self.transform = transform        
        self.interpolator = interpolator
        

    def __len__(self):
        return len(self.df.index)

    def __getitem__(self, idx):
        
        probe_params_fn = self.df.iloc[idx][self.probe_params_column_name]        
        probe_params = self.read_probe_params(probe_params_fn)

        img = self.sample_image(probe_params)
        img_np = sitk.GetArrayFromImage(img)
        
        if self.transform:
            return self.transform(img_np)
        
        return img_np
    
    def read_probe_params(self, probe_params_fn):
        return pickle.load(open(os.path.join(self.mount_point, probe_params_fn), 'rb'))
    
    def sample_image(self, probe_params, interpolator=sitk.sitkNearestNeighbor, identity_direction=True):
        
        probe_direction = probe_params['probe_direction']
        ref_size = probe_params['ref_size']
        ref_origin = probe_params['ref_origin']
        ref_spacing = probe_params['ref_spacing']

        ref = sitk.Image(int(ref_size[0]), int(ref_size[1]), int(ref_size[2]), sitk.sitkFloat32)
        ref.SetOrigin(ref_origin)
        ref.SetSpacing(ref_spacing)
        ref.SetDirection(probe_direction.flatten().tolist())

        resampler = sitk.ResampleImageFilter()
        if interpolator:
            resampler.SetInterpolator(interpolator)
        resampler.SetReferenceImage(ref)
        
        sample = resampler.Execute(self.volume)
        if identity_direction:
            sample_np = sitk.GetArrayFromImage(sample).squeeze()
            sample_np = np.flip(np.rot90(sample_np, k=1, axes=(0, 1)), axis=0)
            sample = sitk.GetImageFromArray(sample_np)
            sample.SetSpacing(ref_spacing)
        return sample

In [None]:
ds = VolumeSlicingProbeParamsDataset(probe_params_df, volume, mount_point="/mnt/raid/C1_ML_Analysis/source/blender/simulated_data_export")

In [None]:
sample_np = ds[np.random.randint(low=0, high=len(probe_params_df.index))]
px.imshow(sample_np.squeeze())

In [None]:
w = torch.empty(3, 5)
nn.init.kaiming_uniform_(w, mode='fan_in', nonlinearity='relu')
w

In [None]:
img3_np = sitk.GetArrayFromImage(sitk.ReadImage("/mnt/raid/C1_ML_Analysis/source/blender/simulated_data_export/FAM-202-1960-2_mesh_sampling/BPD/20.nrrd"))
# img3_np.shape
px.imshow(img3_np.squeeze())

In [None]:
t = torch.tensor(img3_np.astype(int)).to(torch.long)
fake_us = us_render(t[None])
px.imshow(fake_us[8][0].detach().numpy().squeeze())

In [None]:
px.imshow(fake_us[0][0].detach().numpy().squeeze())

In [None]:

img4_np = sitk.GetArrayFromImage(sitk.ReadImage("/mnt/raid/C1_ML_Analysis/source/blender/simulated_data_export/FAM-202-1960-2_20211019_033119_split_frames/BPD/96.nrrd"))
# img3_np.shape
px.imshow(img4_np.squeeze())

In [None]:
lowres = RandSimulateLowResolution(prob=1.0, zoom_range=(0.15, 0.3))
img4_t = torch.tensor(img4_np).permute([2, 0, 1]).unsqueeze(dim=0)
img4_t_lowres = lowres(img4_t)
px.imshow(img4_t_lowres[0].permute(1, 2, 0).numpy().squeeze())

In [None]:
from monai.networks.blocks import (
    ResidualUnit,
    MLPBlock
)

In [None]:
model_f = MLPBlock(256, 128)

In [None]:
model_f(torch.rand(10, 1, 256, 256)).shape