In [None]:
import os
import sys
import matplotlib.pyplot as plt
import numpy as np
from skimage.filters import gaussian
from scipy.ndimage import zoom, center_of_mass, shift, affine_transform
from scipy import ndimage as ndi
import cv2
import SimpleITK as sitk
from pathlib import Path
from tqdm import tqdm
from shapely.geometry import MultiPoint, LineString
from shapely.ops import unary_union, polygonize
from scipy.spatial import Delaunay

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
PIPELINE_ROOT = Path('./').absolute().parents[1]
PIPELINE_ROOT = PIPELINE_ROOT.as_posix()
sys.path.append(PIPELINE_ROOT)
print(PIPELINE_ROOT)

data_path = '/net/birdstore/Active_Atlas_Data/data_root/atlas_data'
from library.atlas.atlas_utilities import register_volume, adjust_volume, \
    affine_transform_point, resample_image, load_transformation, get_min_max_mean, \
    create_subvolume_from_boundary_vertices, interpolate_points, order_points_concave_hull

from library.controller.sql_controller import SqlController


In [None]:
def visualize_slices(volume, title="Slice View"):
    """Visualize the middle slices of a 3D volume."""
    fig, axes = plt.subplots(1, 3, figsize=(12, 4))
    mid_slices = [s // 2 for s in volume.shape]
    
    axes[0].imshow(volume[mid_slices[0], :, :], cmap='gray')
    axes[0].set_title("Axial Slice")
    axes[1].imshow(volume[:, mid_slices[1], :], cmap='gray')
    axes[1].set_title("Coronal Slice")
    axes[2].imshow(volume[:, :, mid_slices[2]], cmap='gray')
    axes[2].set_title("Sagittal Slice")
    
    for ax in axes:
        ax.axis("on")
    plt.suptitle(title)
    plt.show()

In [None]:
def _scipy_affine_from_homogeneous(M4x4):
    """
    Convert a 4x4 homogeneous transform (maps input->output in world coords)
    into (matrix, offset) for scipy.ndimage.affine_transform, which expects
    a mapping from output coords to input coords.

    For input x_in and output x_out:
      x_out = A @ x_in + t
    SciPy needs x_in = A_inv @ x_out + (-A_inv @ t)
    """
    M = np.asarray(M4x4, dtype=float)
    if M.shape != (4, 4):
        raise ValueError("Affine must be 4x4 homogeneous matrix")

    A = M[:3, :3]
    t = M[:3, 3]
    #A_inv = np.linalg.inv(A)
    #offset = -A_inv @ t
    #return A_inv, offset
    offset = -A @ t
    return A, offset


def apply_affine_3d(
    vol_zyx,
    M4x4,
    output_shape=None,
    order=0,        # 0 keeps labels (0/255) crisp
    mode="constant",
    cval=0.0,
    prefilter=False
):
    """
    Apply a 4x4 affine transform to a 3D volume using scipy.ndimage.affine_transform.

    Parameters
    ----------
    vol_zyx : ndarray (Z, Y, X)
        Input volume (e.g., 0/255).
    M4x4 : ndarray (4,4)
        Homogeneous transform matrix. Convention: columns [R|t].
    output_shape : tuple or None
        Shape of output volume (Z, Y, X). If None, uses input shape.
    order : int
        Interpolation order (0..5). Use 0 for label images (binary/255).
    mode : str
        Boundary mode for samples outside input.
    cval : float
        Constant fill value for 'constant' mode.
    prefilter : bool
        Keep False when order=0 for speed and to avoid unwanted smoothing.

    Returns
    -------
    out : ndarray
        Transformed volume.
    """
    if output_shape is None:
        output_shape = vol_zyx.shape

    matrix, offset = _scipy_affine_from_homogeneous(M4x4)
    out = ndi.affine_transform(
        vol_zyx,
        matrix=matrix,
        offset=offset,
        output_shape=output_shape,
        order=order,
        mode=mode,
        cval=cval,
        prefilter=prefilter
    )
    # If order=0, values will already be 0 or 255. For safety, cast back to original dtype.
    return out.astype(vol_zyx.dtype)

def get_3d_scale_from_affine(matrix):
    # Extract the relevant 3x3 sub-matrix for scaling and rotation
    scale_rotation_matrix = matrix[:3, :3]

    # Calculate the magnitude of the column vectors
    sx = np.linalg.norm(scale_rotation_matrix[:, 0])  # Scale along the transformed x-axis
    sy = np.linalg.norm(scale_rotation_matrix[:, 1])  # Scale along the transformed y-axis
    sz = np.linalg.norm(scale_rotation_matrix[:, 2])  # Scale along the transformed z-axis

    return sx, sy, sz

In [None]:
animal = 'ALLEN771602'
sqlController = SqlController(animal)
transform = load_transformation(animal, 10.0, 10.0, inverse=False)

In [None]:
polygons = sqlController.get_annotation_volume(8357, 10)
coords = list(polygons.values())
min_vals, max_vals, mean_vals = get_min_max_mean(coords)

In [None]:
min_z = min(polygons.keys())
max_z = max(polygons.keys())

In [None]:
min_x = min_vals[0]
min_y = min_vals[1]
max_x = max_vals[0]
max_y = max_vals[1]
xlength = max_x - min_x
ylength = max_y - min_y
slice_size = (int(round(ylength)), int(round(xlength)))
print(f'slice size={slice_size}')
print(f'{min_x=} {max_x=}')
print(f'{min_y=} {max_y=}')
print(f'{min_z=} {max_z=}')

In [None]:
slices = []
points_dict = {}
for i, idx in enumerate(range(min_z, max_z)):
    volume_slice = np.zeros(slice_size, dtype=np.uint8)
    if idx in polygons:
        points = polygons[idx]
        points = np.array(points) - np.array((min_x, min_y))
        points = order_points_concave_hull(points, alpha=0.5)
        points = interpolate_points(points, 250)
        points = np.array(points).astype(np.int32)
        points_dict[i] = points
    else:
        try:
            points = points_dict[i]
        except KeyError:
            pass
        

    cv2.fillPoly(volume_slice, pts=[points], color=255)
    slices.append(volume_slice)
    
print(f'len slices={len(slices)}')
volume = np.stack(slices, axis=0).astype(np.uint8)  # Keep this at uint8!

upper = 0
allen_id = 255
volume = gaussian(volume, 1.0)            
volume[(volume > upper) ] = allen_id
volume[(volume != allen_id)] = 0
volume = volume.astype(np.uint32)

print('shape', volume.shape)
print(f'length keys {len(polygons.keys())} z length {volume.shape[0]}')
visualize_slices(volume, title="Original Volume")
# @10um section min=297, max=1187

In [None]:
animal = 'ALLEN771602'
um = 10.0
transform = load_transformation(animal, um, um, inverse=False)
R = np.array(transform.GetParameters()[0:9]).reshape(3,3)
t = transform.GetParameters()[9:]
print(R)
print()
print('t', t)
print()
t = np.array(t).reshape(3,1)
M = np.hstack( [R, t ])
M = np.vstack([M, np.array([0, 0, 0, 1])])
print(M)
scale_x, scale_y, scale_z = get_3d_scale_from_affine(M)
print(scale_x, scale_y, scale_z)
offset = -R @ t
R_inv = np.linalg.inv(R)
offset = -R_inv @ t
offset = offset.reshape(3,)

print(offset)
print(volume.shape)

In [None]:
# Create an identity affine transform
affine_transform = sitk.AffineTransform(3) # 3 for 3D

# Define the transformation matrix and translation
# The matrix is 3x3 and represents rotation, scaling, and shear
# The translation is a 3-element vector
# Example: a slight rotation and translation
matrix = transform.GetParameters()[0:9]
translation = transform.GetParameters()[9:]
#translation = (0.0, 0.0, 0.0)

affine_transform.SetMatrix(matrix)
affine_transform.SetTranslation(translation)

In [None]:
# Create a Resample filter
resampler = sitk.ResampleImageFilter()
# Set the transform
resampler.SetTransform(affine_transform)
# Set the reference image (determines output size, spacing, origin, and direction)
# Often, the input image itself is used as the reference for the output grid.
swapped = np.swapaxes(volume.copy(), 0,2)
image = sitk.GetImageFromArray(swapped.astype(np.float32))
resampler.SetReferenceImage(image)
# Set the interpolator (e.g., linear, nearest neighbor, B-spline)
resampler.SetInterpolator(sitk.sitkLinear)
# Set the default pixel value for areas outside the original image bounds
resampler.SetDefaultPixelValue(0.0)
# Execute the resampling
resampled = resampler.Execute(image)
vol_affine = sitk.GetArrayFromImage(resampled)
visualize_slices(vol_affine, title="Affine Volume")

In [None]:
sc_origin = np.array((9.31496e+02,2.40000902e+02,2.7000e+02))
sc_origin *= 10
sc_origin[0]