In [None]:
import os
import sys
import matplotlib.pyplot as plt
import numpy as np
import pyvista as pv
from skimage.filters import gaussian
from scipy.ndimage import zoom, center_of_mass, shift, affine_transform
from scipy.spatial import procrustes
from scipy.spatial.transform import Rotation as R
import cv2

import SimpleITK as sitk

import ipywidgets
from pathlib import Path
from tqdm import tqdm
from collections import Counter
import warnings
warnings.filterwarnings("ignore")
#"ipyvtklink", "panel", "ipygany", "static", "pythreejs", "client", "server", "trame", "none"
pv.set_jupyter_backend("panel")

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
PIPELINE_ROOT = Path('./').absolute().parents[1]
PIPELINE_ROOT = PIPELINE_ROOT.as_posix()
sys.path.append(PIPELINE_ROOT)
print(PIPELINE_ROOT)

data_path = '/net/birdstore/Active_Atlas_Data/data_root/atlas_data'
from library.atlas.atlas_utilities import register_volume, adjust_volume, affine_transform_volume, \
    affine_transform_point

In [None]:
def visualize_slices(volume, title="Slice View"):
    """Visualize the middle slices of a 3D volume."""
    fig, axes = plt.subplots(1, 3, figsize=(12, 4))
    mid_slices = [s // 2 for s in volume.shape]
    
    axes[0].imshow(volume[mid_slices[0], :, :], cmap='gray')
    axes[0].set_title("Axial Slice")
    axes[1].imshow(volume[:, mid_slices[1], :], cmap='gray')
    axes[1].set_title("Coronal Slice")
    axes[2].imshow(volume[:, :, mid_slices[2]], cmap='gray')
    axes[2].set_title("Sagittal Slice")
    
    for ax in axes:
        ax.axis("on")
    plt.suptitle(title)
    plt.show()

def get_clockwise_edge_coords(array):
    # Ensure the array is binary (non-zero becomes 1)
    binary = (array > 0).astype(np.uint8)
    
    # Use skimage's find_contours to detect contours at level 0.5
    contours = measure.find_contours(binary, 0.5)

    # If multiple contours, take the largest (most points)
    if not contours:
        return []
    contour = max(contours, key=len)

    # Convert (row, col) to (x, y) and round to int
    coords = np.fliplr(contour).astype(int)

    # Ensure clockwise order using signed area (shoelace formula)
    def is_clockwise(pts):
        return np.sum(
            (pts[1:, 0] - pts[:-1, 0]) * (pts[1:, 1] + pts[:-1, 1])
        ) > 0

    if not is_clockwise(coords):
        coords = coords[::-1]  # Reverse if not clockwise

    return coords


def get_evenly_spaced_vertices(mask, num_points):
    """
    Given a binary mask, extract the outer contour and return evenly spaced vertices along the edge.

    Parameters:
    - mask: 2D numpy array (binary mask)
    - num_points: Number of evenly spaced points to return

    Returns:
    - List of (x, y) coordinates of vertices
    """
    # Ensure mask is uint8
    mask = mask.astype(np.uint8)

    # Find contours (external only)
    contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
    if not contours:
        return []

    # Choose the largest contour (in case there are multiple)
    contour = max(contours, key=cv2.contourArea).squeeze()

    # Calculate arc length (perimeter)
    arc_length = cv2.arcLength(contour, True)
    print(arc_length)

    # Calculate the cumulative arc lengths
    distances = [0]
    for i in range(1, len(contour)):
        d = np.linalg.norm(contour[i] - contour[i - 1])
        distances.append(distances[-1] + d)
    distances = np.array(distances)

    # Sample points at regular intervals
    desired_distances = np.linspace(0, distances[-1], num_points, endpoint=False)
    vertices = []
    j = 0
    for d in desired_distances:
        while j < len(distances) - 1 and distances[j+1] < d:
            j += 1
        # Linear interpolation between points j and j+1
        t = (d - distances[j]) / (distances[j+1] - distances[j])
        pt = (1 - t) * contour[j] + t * contour[j + 1]
        vertices.append(tuple(pt.astype(int)))

    return vertices



In [None]:
def resample_image(image, reference_image):
    """
    Resamples an image to match the reference image in size, spacing, and direction.
    """
    resampler = sitk.ResampleImageFilter()
    resampler.SetReferenceImage(reference_image)
    resampler.SetInterpolator(sitk.sitkLinear)  # Linear interpolation for resampling
    resampler.SetDefaultPixelValue(0)  # Fill with zero if needed
    resultImage = resampler.Execute(image)
    #return sitk.GetArrayFromImage(resultImage)
    return resultImage

def center_images_to_largest_volume(images):
    """
    Centers a list of 3D SimpleITK images using the largest-volume image as the reference.

    Parameters:
        images (List[sitk.Image]): List of 3D SimpleITK Image objects.

    Returns:
        List[sitk.Image]: List of centered images (same order as input).
    """
    if not images:
        raise ValueError("No images provided.")

    # Compute volumes and find reference image
    volumes = [img.GetSize()[0] * img.GetSize()[1] * img.GetSize()[2] * 
               img.GetSpacing()[0] * img.GetSpacing()[1] * img.GetSpacing()[2] for img in images]
    reference_index = volumes.index(max(volumes))
    reference_image = images[reference_index]

    """
    ref_center = sitk.PhysicalPointSource(reference_image.GetPixelID(), reference_image.GetSize(), reference_image.GetOrigin(),
                                          reference_image.GetSpacing(), reference_image.GetDirection())
    ref_center = sitk.CenteredTransformInitializer(reference_image, reference_image,
                                                   sitk.Euler3DTransform(), 
                                                   sitk.CenteredTransformInitializerFilter.MOMENTS).GetTranslation()

    """
    centered_images = []
    for i, img in enumerate(images):
        if i == reference_index:
            centered_images.append(img)
            continue

        # Calculate center transform
        transform = sitk.CenteredTransformInitializer(reference_image, img,
                                                      sitk.Euler3DTransform(), 
                                                      sitk.CenteredTransformInitializerFilter.MOMENTS)

        # Resample image
        resampled = sitk.Resample(img,
                                  reference_image,
                                  transform,
                                  sitk.sitkLinear,
                                  0.0,
                                  img.GetPixelID())

        centered_images.append(resampled)

    return centered_images


def load_volumes(structure):
    images = []

    brains = ['MD585', 'MD589', 'MD594']
    for brain in brains:
        structure_path = os.path.join(data_path, brain, 'structure', f'{structure}.npy')
        arr = np.load(structure_path)
        arr = arr.astype(np.float32)
        print(arr.dtype, arr.shape)
        arr = sitk.GetImageFromArray(arr, isVector=False)
        images.append(arr)

    reference_image = max(images, key=lambda img: np.prod(img.GetSize()))
    resampled_images = [resample_image(img, reference_image) for img in images]
    
    resampled_images = center_images_to_largest_volume(resampled_images)
    
    return resampled_images, reference_image
        #return [sitk.ReadImage(p, sitk.sitkFloat32) for p in volume_paths]


    
def build_unbiased_atlas(images, num_iterations=5):
    """
    Perform groupwise rigid registration and build an unbiased mean image (atlas).

    Parameters:
        images (list of sitk.Image): List of 3D images to register.
        num_iterations (int): Number of iterations for atlas building.

    Returns:
        mean_image (sitk.Image): The final unbiased average image.
        registered_images (list of sitk.Image): List of registered images.
    """
    # Step 1: Initialize the mean image as a simple average (first guess)
    #mean_image = sitk.Mean(images)
    np_images = [sitk.GetArrayFromImage(img) for img in images]
    mean_image = np.mean(np_images, axis=0)
    mean_image = sitk.GetImageFromArray(mean_image)
    for iteration in range(num_iterations):
        print(f"Iteration {iteration + 1}/{num_iterations}")

        transforms = []

        for img in images:
            # Rigid registration
            registration_method = sitk.ImageRegistrationMethod()
            registration_method.SetMetricAsMattesMutualInformation(numberOfHistogramBins=50)
            registration_method.SetOptimizerAsGradientDescent(learningRate=1.0, numberOfIterations=100,
                                                              convergenceMinimumValue=1e-6, convergenceWindowSize=10)
            registration_method.SetInterpolator(sitk.sitkLinear)
            registration_method.SetMetricSamplingPercentage(0.1, sitk.sitkWallClock)
            registration_method.SetMetricSamplingStrategy(registration_method.RANDOM)

            # Rigid transformation only
            tx = sitk.CenteredTransformInitializer(mean_image, img, sitk.Euler3DTransform(),
                                                   sitk.CenteredTransformInitializerFilter.GEOMETRY)
            registration_method.SetInitialTransform(tx, inPlace=False)
            registration_method.SetShrinkFactorsPerLevel([4, 2, 1])
            registration_method.SetSmoothingSigmasPerLevel([2, 1, 0])

            final_tx = registration_method.Execute(mean_image, img)
            transforms.append(final_tx)

        # Step 2: Compute the average transformation
        avg_transform = average_rigid_transforms(transforms)

        # Step 3: Apply inverse average transform to current mean to shift to unbiased space
        inv_avg_transform = avg_transform.GetInverse()
        mean_image = sitk.Resample(mean_image, mean_image, inv_avg_transform,
                                   sitk.sitkLinear, 0.0, mean_image.GetPixelID())

        # Step 4: Resample all images to the new mean space and recompute mean image
        registered_images = []
        for img, tx in zip(images, transforms):
            total_tx = sitk.Transform(avg_transform)
            total_tx.AddTransform(tx)
            resampled = sitk.Resample(img, mean_image, total_tx, sitk.sitkLinear, 0.0, img.GetPixelID())
            registered_images.append(resampled)

        # Recompute mean image
        #mean_image = sitk.Mean(registered_images)

    return registered_images


def average_rigid_transforms(transforms):
    """
    Compute the average rigid transformation from a list of Euler3DTransforms.
    Averages rotation angles and translations independently.

    Parameters:
        transforms (list of sitk.Euler3DTransform): List of transforms.

    Returns:
        sitk.Euler3DTransform: Averaged rigid transform.
    """
    angles = np.zeros(3)
    translations = np.zeros(3)
    center = np.zeros(3)

    for tx in transforms:
        tx = sitk.Euler3DTransform(tx)
        angles += np.array(tx.GetRotation())  # angles in radians
        translations += np.array(tx.GetTranslation())
        center += np.array(tx.GetCenter())

    n = len(transforms)
    angles /= n
    translations /= n
    center /= n

    avg_tx = sitk.Euler3DTransform()
    avg_tx.SetRotation(*angles)
    avg_tx.SetTranslation(translations.tolist())
    avg_tx.SetCenter(center.tolist())

    return avg_tx

In [None]:
#aligned_volumes, final_average = groupwise_register_and_align()
structure = 'PBG_L'
structure = 'SC'
aligned_volumes, final_average = load_volumes(structure)

In [None]:
registered = build_unbiased_atlas(aligned_volumes, num_iterations=5)

In [None]:
avg_volume = np.mean([sitk.GetArrayFromImage(vol) for vol in aligned_volumes], axis=0)
print(avg_volume.dtype, avg_volume.shape)

In [None]:
arr = avg_volume.copy()
z = arr.shape[2] // 2
volume = arr.copy()
#volume = adjust_volume(volume, 255)
slice = volume[:,:,z]
plt.imshow(slice, cmap='gray')

In [None]:
volumes = []

brains = ['MD585', 'MD589', 'MD594']
for brain in brains:
    structure_path = os.path.join(data_path, brain, 'structure', f'{structure}.npy')
    arr = np.load(structure_path)
    com = center_of_mass(arr)
    print(f'{brain} {arr.dtype} {arr.shape} {com}')
    volumes.append(arr)

_, axs = plt.subplots(1, 3, figsize=(12, 12), sharex=True, sharey=True)
axs = axs.flatten()
for img, ax in zip(volumes, axs):
    z = img.shape[2] // 2
    slice = img[:,:,z].astype(np.uint8)
    ax.imshow(slice, cmap="gray")
plt.show()

In [None]:
animal = 'AtlasV8'
structure_path = os.path.join(data_path, animal, 'structure', f'{structure}.npy')
arr = np.load(structure_path)
ids, counts = np.unique(arr, return_counts=True)
if len(ids) < 10:
    print(ids)
    print(counts)
print(arr.shape, arr.dtype, np.mean(arr), np.min(arr), np.max(arr))

In [None]:
z = arr.shape[2] // 2
volume = arr.copy()
#volume = adjust_volume(volume, 255)
#slice = volume[:,:,z].astype(np.uint32)
slice = volume[:,:,z]
plt.imshow(slice, cmap='gray')

In [None]:
z = arr.shape[2] // 2
volume = arr.copy()
volume = adjust_volume(volume, 255)
slice = volume[:,:,z].astype(np.uint32)
plt.imshow(slice, cmap='gray')

In [None]:
adjusted = adjust_volume(volume, 255)
slice = adjusted[:,:,z].astype(np.uint32)
plt.imshow(slice, cmap='gray')

In [None]:
#n, nl, vertices = get_outer_edge_vertices(slice)
vertices = get_evenly_spaced_vertices(slice, 20)
# Visualize result
plt.imshow(slice, cmap='gray')
x, y = zip(*vertices)
plt.plot(x, y, 'r-')
plt.scatter(x, y, c='blue')
plt.title("Contour Vertices")
plt.show()

In [None]:
slice = volume[:,:,z].astype(np.uint32)
print(volume.shape)



In [None]:
data = pv.wrap(volume)
data.plot(volume=True) # Volume render

In [None]:
#transformed_volume = affine_transform_volume(volume, transformation_matrix)
# Visualize original and transformed volumes
visualize_slices(volume, title="Original Volume")
#visualize_slices(transformed_volume, title="Transformed Volume")

In [None]:
# Visualize original and transformed volumes
transformed_volume = affine_transform_volume(volume, transformation_matrix)
visualize_slices(transformed_volume, title="Transformed Volume")