In [None]:
import os
import sys
import matplotlib.pyplot as plt
import numpy as np
import pyvista as pv
from skimage.filters import gaussian
from scipy.ndimage import zoom, center_of_mass, shift, affine_transform
import cv2

import ipywidgets
from pathlib import Path
from tqdm import tqdm
from collections import Counter
import warnings
warnings.filterwarnings("ignore")
#"ipyvtklink", "panel", "ipygany", "static", "pythreejs", "client", "server", "trame", "none"
pv.set_jupyter_backend("panel")

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
PIPELINE_ROOT = Path('./').absolute().parents[1]
PIPELINE_ROOT = PIPELINE_ROOT.as_posix()
sys.path.append(PIPELINE_ROOT)
print(PIPELINE_ROOT)

data_path = '/net/birdstore/Active_Atlas_Data/data_root/atlas_data'
from library.atlas.atlas_utilities import register_volume, adjust_volume, affine_transform_volume, apply_affine_transform

In [None]:
def filter_top_n_values(volume: np.ndarray, n: int, set_value: int = 1) -> np.ndarray:
    """
    Get the `n` most common unique values from a numpy volume.
    Sets those values to `set_value` and the rest to 0.

    Parameters:
        volume (np.ndarray): Input 3D volume.
        n (int): Number of most common unique values to retain.
        set_value (int, optional): The value to assign to the most common values. Defaults to 1.

    Returns:
        np.ndarray: Transformed volume.
    """
    
    # Flatten the volume and count occurrences of unique values
    values, counts = np.unique(volume[volume != 0], return_counts=True)
    
    # Get the top `n` most common values
    top_n_values = [val for val, _ in Counter(dict(zip(values, counts))).most_common(n)]
    print(f'top {n} {top_n_values=}')
    
    # Create a mask where only top N values are retained
    mask = np.isin(volume, top_n_values)
    
    # Set the selected values to `set_value` and the rest to 0
    result = np.where(mask, set_value, 0)
    
    return result


def center_3d_volume(volume: np.ndarray) -> np.ndarray:
    """
    Centers a 3D volume by shifting its center of mass to the geometric center.

    Parameters:
    volume (np.ndarray): A 3D numpy array representing the volume.

    Returns:
    np.ndarray: The centered 3D volume.
    """
    if volume.ndim != 3:
        raise ValueError("Input volume must be a 3D numpy array")
    
    # Compute the center of mass
    com = np.array(center_of_mass(volume))
    
    # Compute the geometric center
    shape = np.array(volume.shape)
    geometric_center = (shape - 1) / 2
    
    # Compute the shift required
    shift_values = geometric_center - com
    
    # Apply shift
    centered_volume = shift(volume, shift_values, mode='constant', cval=0)
    
    return centered_volume

def crop_nonzero_3d(volume):
    """
    Crops a 3D volume to remove all-zero regions.
    
    Parameters:
        volume (numpy.ndarray): A 3D NumPy array.
        
    Returns:
        numpy.ndarray: The cropped 3D volume.
    """
    if volume.ndim != 3:
        raise ValueError("Input volume must be a 3D NumPy array")
    
    # Find nonzero elements
    nonzero_coords = np.argwhere(volume)
    
    # Get bounding box of nonzero elements
    min_coords = nonzero_coords.min(axis=0)
    max_coords = nonzero_coords.max(axis=0) + 1  # Add 1 to include the max index
    
    # Crop the volume
    cropped_volume = volume[min_coords[0]:max_coords[0],
                            min_coords[1]:max_coords[1],
                            min_coords[2]:max_coords[2]]
    
    return cropped_volume

def normalize16(img):
    if img.dtype == np.uint32:
        print('image dtype is 32bit')
        return img.astype(np.uint16)
    else:
        mn = img.min()
        mx = img.max()
        mx -= mn
        img = ((img - mn)/mx) * 2**16 - 1
        return np.round(img).astype(np.uint16) 


def apply_affine_transformation(volume, matrix, translation):
    """Apply an affine transformation to a 3D volume."""
    transformed_volume = affine_transform(volume, matrix, offset=translation, order=1)
    return transformed_volume

def visualize_slices(volume, title="Slice View"):
    """Visualize the middle slices of a 3D volume."""
    fig, axes = plt.subplots(1, 3, figsize=(12, 4))
    mid_slices = [s // 2 for s in volume.shape]
    
    axes[0].imshow(volume[mid_slices[0], :, :], cmap='gray')
    axes[0].set_title("Axial Slice")
    axes[1].imshow(volume[:, mid_slices[1], :], cmap='gray')
    axes[1].set_title("Coronal Slice")
    axes[2].imshow(volume[:, :, mid_slices[2]], cmap='gray')
    axes[2].set_title("Sagittal Slice")
    
    for ax in axes:
        ax.axis("on")
    plt.suptitle(title)
    plt.show()



In [None]:
structure = 'SC'
structure_path = os.path.join(data_path, 'AtlasV8', 'structure', f'{structure}.npy')
arr = np.load(structure_path)
print(arr.shape)

In [None]:
z = arr.shape[2] // 2
plt.imshow(arr[:, :, z], cmap='gray')

In [None]:
volume = arr.copy()
volume = adjust_volume(volume, 255)
plt.imshow(volume[:,:,z], cmap="gray")

In [None]:
data = pv.wrap(v2)
data.plot(volume=True) # Volume render

In [None]:
transformation_matrix = np.array([[ 1.00020121e+00, -7.85455904e-02, -1.64689004e-02,
        -1.46273691e+03],
       [ 1.81477262e-01,  1.14218291e+00,  7.63646430e-02,
        -4.31843852e+03],
       [-3.29134061e-02, -7.44862557e-02,  1.07530126e+00,
         1.31840353e+03],
       [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         1.00000000e+00]])
print(transformation_matrix)

In [None]:
transformed_volume = affine_transform_volume(volume, transformation_matrix)
# Visualize original and transformed volumes
visualize_slices(volume, title="Original Volume")
visualize_slices(transformed_volume, title="Transformed Volume")


In [None]:
print(volume.shape)
print(transformed_volume.shape)