Load the data

In [9]:
from preprocessing import load_folder_paths, load_dcm_datasets, get_volume
from scipy import ndimage
import matplotlib.pyplot as plt
import os
import SimpleITK as sitk
import numpy as np

paths = load_folder_paths()
dataset_path = os.path.expanduser('~/Desktop/UniToBrain')
path = os.path.join(dataset_path, 'MOL-106')
v = get_volume(path, extract_brain=False, filter=True, window_params=(80, 160), correct_motion=False, standardize=False, spatial_downsampling_factor=2, temporal_downsampling_factor=1)

Loading /Users/simonma/Desktop/UniToBrain/MOL-106...
Dicom files loaded, count: 288
Processing /Users/simonma/Desktop/UniToBrain/MOL-106...
5 [0.488281, 0.488281]


ic| volume_seq.max(): 160.0
    volume_seq.min(): 0.0
    volume_seq.dtype: dtype('float32')


Total time taken: 2.87 seconds
Average time taken per volume: 0.16 seconds
Done!


Interactive Visualization Functions

In [10]:
from ipywidgets import interact, IntSlider
from math import ceil

def overlay_volume_sequence_interactive(volume_seq):
    num_overlays = len(volume_seq) - 1
    nrows = ceil(num_overlays ** 0.5)
    ncols = ceil(num_overlays / nrows)
    print(f"{nrows=} {ncols=}")
    def plot_slice(slice_idx):

        fig, axes = plt.subplots(nrows, 
                                ncols, 
                                figsize=(5*ncols, 5*nrows),
                                squeeze=True)
        if nrows == 1:
            if ncols == 1:
                axes = [[axes]]
            else:
                axes = [axes]
        print(f"{axes=}")
        for i in range(num_overlays):
            ax = axes[i // ncols][i % ncols]
            print(f"{ax=}")
            ax.imshow(volume_seq[0][slice_idx], cmap="gray")
            ax.imshow(volume_seq[i+1][slice_idx], cmap="hot", alpha=0.5)
        plt.show(block=True)
    
    interact(
        plot_slice,
        slice_idx=IntSlider(
            min=0,
            max=len(volume_seq[0])-1, 
            step=1,
            value=0,
            description='Slice:'
        )
    )

def multi_vol_seq_interactive(volume_seqs, titles=None):
    """
    Interactive plot of multiple volume sequences using ipywidgets
    
    Parameters:
    - volume_seqs: List of 4D volume sequences to display
    - titles: Optional list of titles for each sequence
    """
    if titles is None:
        titles = [f"Volume {i+1}" for i in range(len(volume_seqs))]
        
    num_volumes = len(volume_seqs)
    nrows = int(num_volumes ** 0.5)
    ncols = (num_volumes + nrows - 1) // nrows
    
    def plot_volumes(time_idx, slice_idx):
        fig, axes = plt.subplots(nrows, ncols, 
                                figsize=(5*ncols, 5*nrows),
                                squeeze=True)
        if nrows == 1:
            if ncols == 1:
                axes = [[axes]]
            else:
                axes = [axes]
                
        for i, (volume_seq, title) in enumerate(zip(volume_seqs, titles)):
            row, col = i // ncols, i % ncols
            ax = axes[row][col]
            
            t = min(time_idx, len(volume_seq) - 1)
            s = min(slice_idx, len(volume_seq[t]) - 1)
            
            im = ax.imshow(volume_seq[t][s], cmap='magma')
            ax.set_title(title)
            plt.colorbar(im, ax=ax)
            
        plt.tight_layout()
        plt.show(block=True)
        
    max_time = max(len(vol) for vol in volume_seqs) - 1
    max_slice = max(len(vol[0]) for vol in volume_seqs) - 1
    
    interact(
        plot_volumes,
        time_idx=IntSlider(min=0, max=max_time, step=1, value=0, description='Time:'),
        slice_idx=IntSlider(min=0, max=max_slice, step=1, value=0, description='Slice:')
    )

def multi_vol_seq_interactive(volume_seqs, titles=None):
    """
    Interactive plot of multiple volume sequences using ipywidgets
    
    Parameters:
    - volume_seqs: List of 4D volume sequences to display
    - titles: Optional list of titles for each sequence
    """
    if titles is None:
        titles = [f"Volume {i+1}" for i in range(len(volume_seqs))]
        
    num_volumes = len(volume_seqs)
    nrows = int(num_volumes ** 0.5)
    ncols = (num_volumes + nrows - 1) // nrows
    
    def plot_volumes(time_idx, slice_idx):
        fig, axes = plt.subplots(nrows, ncols, 
                                figsize=(5*ncols, 5*nrows),
                                squeeze=True)
        if nrows == 1:
            if ncols == 1:
                axes = [[axes]]
            else:
                axes = [axes]
                
        for i, (volume_seq, title) in enumerate(zip(volume_seqs, titles)):
            row, col = i // ncols, i % ncols
            ax = axes[row][col]
            
            t = min(time_idx, len(volume_seq) - 1)
            s = min(slice_idx, len(volume_seq[t]) - 1)
            
            im = ax.imshow(volume_seq[t][s], cmap='magma')
            ax.set_title(title)
            plt.colorbar(im, ax=ax)
            
            # Add format_coord function to display pixel values on hover
            def make_format_coord(img):
                def format_coord(x, y):
                    if x is None or y is None:
                        return ""
                    x, y = int(x + 0.5), int(y + 0.5)
                    if 0 <= y < img.shape[0] and 0 <= x < img.shape[1]:
                        val = img[y, x]
                        return f'x={x}, y={y}, value={val:.2f}'
                    return 'x={:.0f}, y={:.0f}'.format(x, y)
                return format_coord
            
            ax.format_coord = make_format_coord(volume_seq[t][s])
            
        plt.tight_layout()
        plt.show(block=True)
        
    max_time = max(len(vol) for vol in volume_seqs) - 1
    max_slice = max(len(vol[0]) for vol in volume_seqs) - 1
    
    interact(
        plot_volumes,
        time_idx=IntSlider(min=0, max=max_time, step=1, value=0, description='Time:'),
        slice_idx=IntSlider(min=0, max=max_slice, step=1, value=0, description='Slice:')
    )


In [11]:
multi_vol_seq_interactive([v])

interactive(children=(IntSlider(value=0, description='Time:', max=17), IntSlider(value=0, description='Slice:'…

In [12]:
from IPython.display import clear_output
# Callback invoked when the StartEvent happens, sets up our new data.
def start_plot():
    global metric_values, multires_iterations

    metric_values = []
    multires_iterations = []


# Callback invoked when the EndEvent happens, do cleanup of data and figure.
def end_plot():
    global metric_values, multires_iterations

    del metric_values
    del multires_iterations
    # Close figure, we don't want to get a duplicate of the plot latter on.
    plt.close()


# Callback invoked when the IterationEvent happens, update our data and display new figure.
def plot_values(registration_method, metric_name='Metric Value'):
    if registration_method.GetOptimizerIteration() % 10 == 0:
        global metric_values, multires_iterations

        metric_values.append(registration_method.GetMetricValue())
        # Clear the output area (wait=True, to reduce flickering), and plot current data
        clear_output(wait=True)
        # Plot the similarity metric values
        plt.plot(metric_values, "r")
        plt.plot(
            multires_iterations,
            [metric_values[index] for index in multires_iterations],
            "b*",
        )
        plt.xlabel("Iteration Number", fontsize=12)
        plt.ylabel(metric_name, fontsize=12)
        plt.show()


# Callback invoked when the sitkMultiResolutionIterationEvent happens, update the index into the
# metric_values list.
def update_multires_iterations():
    global metric_values, multires_iterations
    multires_iterations.append(len(metric_values))

In [32]:
def plot_volume_slices(volume, title=None, cmap='magma'):
    """
    Plots all slices of a 3D volume in a rectangular grid arrangement.
    
    Parameters:
    - volume: 3D numpy array of shape (D, H, W)
    - title: Optional string for overall figure title
    - cmap: Colormap to use for plotting
    """
    D = volume.shape[0]
    vmin, vmax = np.min(volume), np.max(volume)

    # Calculate grid dimensions to arrange plots in a roughly square layout
    n_rows = int(np.ceil(np.sqrt(D)))
    n_cols = int(np.ceil(D / n_rows))
    
    # Create figure and subplots
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 15))
    if title:
        fig.suptitle(title, fontsize=16)
    
    # Flatten axes array for easier iteration
    axes = axes.flatten()
    
    # Plot each slice
    for i in range(D):
        axes[i].imshow(volume[i], cmap=cmap, vmin=vmin, vmax=vmax)
        axes[i].axis('off')
        axes[i].set_title(f'Slice {i}')
    
    # Turn off any empty subplots
    for i in range(D, len(axes)):
        axes[i].axis('off')
    
    plt.tight_layout()
    plt.show()

In [56]:
def register_volume_3D(moving_volume: np.ndarray, reference_volume: np.ndarray,
               lr: float = 0.5, n_iters: int = 200, relaxation_factor: float = 0.9, 
               gradient_magnitude_tolerance: float = 1e-4, max_step: float = 4.0, min_step: float = 1e-4, 
               spacing: tuple = (0.488281, 0.488281, 5.0),
               multi_res: bool = False):
    
    def command_iteration(method):
        if (method.GetOptimizerIteration() + 1) % 50 == 0:
            print(f"Iteration: {method.GetOptimizerIteration()}")
            print(f"Metric value: {method.GetMetricValue():.4f}")

    # Set min pixel value to 0 so that background aligns with pixels moved in by transformation during registration
    # min_pixel_value = np.min(moving_volume)
    # moving_volume, reference_volume = moving_volume - min_pixel_value, reference_volume - min_pixel_value

    # Convert to SimpleITK images
    moving = sitk.GetImageFromArray(moving_volume)
    fixed = sitk.GetImageFromArray(reference_volume)
    # if not multi_res:
    #     moving_image = sitk.DiscreteGaussian(moving_image, smoothing_sigma)
    #     reference_image = sitk.DiscreteGaussian(reference_image, smoothing_sigma)

    moving.SetSpacing(spacing)
    fixed.SetSpacing(spacing)
    
    print("moving size (x, y, z):", moving.GetSize())
    # Initialize 2D transform
    # initial_transform = sitk.CenteredTransformInitializer(
    #     fixed,
    #     moving,
    #     sitk.Euler3DTransform(),
    #     sitk.CenteredTransformInitializerFilter.GEOMETRY
    # )
    # print(initial_transform)
    # moving_array = sitk.GetArrayFromImage(moving)
    # plot_volume_slices(moving_array, title="Moving Volume")

    # # Setup registration method
    # reg_method = sitk.ImageRegistrationMethod()
    # # registration_method.SetMetricAsMeanSquares()
    # reg_method.SetMetricAsMattesMutualInformation(numberOfHistogramBins=100)
    # reg_method.SetMetricSamplingPercentage(0.1)
    # reg_method.SetMetricSamplingStrategy(sitk.ImageRegistrationMethod.RANDOM)
    
    # reg_method.SetInterpolator(sitk.sitkBSpline)
    # reg_method.SetInitialTransform(initial_transform)

    # # Optimizer settings
    # reg_method.SetOptimizerAsRegularStepGradientDescent(
    #     learningRate=lr,
    #     maximumStepSizeInPhysicalUnits=max_step,
    #     minStep=min_step,
    #     numberOfIterations=n_iters,
    #     gradientMagnitudeTolerance=gradient_magnitude_tolerance,
    #     relaxationFactor=relaxation_factor,
    # )

    # reg_method.SetOptimizerScalesFromPhysicalShift()

    # if multi_res:
    #     reg_method.SetShrinkFactorsPerLevel(shrinkFactors=[2, 1])
    #     reg_method.SetSmoothingSigmasPerLevel(smoothingSigmas=[1, 2])


    # reg_method.AddCommand(sitk.sitkIterationEvent, lambda: command_iteration(reg_method))
    # reg_method.AddCommand(sitk.sitkStartEvent, start_plot)
    # reg_method.AddCommand(sitk.sitkEndEvent, end_plot)
    # reg_method.AddCommand(sitk.sitkMultiResolutionIterationEvent, update_multires_iterations)
    # reg_method.AddCommand(sitk.sitkIterationEvent, lambda: plot_values(reg_method, 'MSE'))


    # final_transform = reg_method.Execute(fixed, moving)
    # print(f"Stopping condition: {reg_method.GetOptimizerStopConditionDescription()}")

    # resampled = sitk.Resample(
    #     moving,
    #     fixed,
    #     final_transform,
    #     sitk.sitkBSpline,  # B-spline interpolation for final resampling
    #     0.0,               # Default value for out-of-range pixels
    #     moving.GetPixelID()
    # )
    # return sitk.GetArrayFromImage(resampled)

In [57]:
v1 = register_volume_3D(v[1], v[0], lr=0.1, n_iters=200)

moving size (x, y, z): (256, 256, 16)


In [53]:
multi_vol_seq_interactive([[v[0], v1], v[:2]])

interactive(children=(IntSlider(value=0, description='Time:', max=1), IntSlider(value=0, description='Slice:',…

In [16]:
unwindowed = get_volume(paths[13], extract_brain=False, windowing=False, correct_motion=False, spatial_downsampling_factor=2, temporal_downsampling_factor=7)

TypeError: get_volume() got an unexpected keyword argument 'windowing'

In [None]:
i = unwindowed[1][0]
# i = np.clip(i, -1024, 1000)
i = (i > -40) & (i < 120)
i = ndimage.binary_erosion(i, iterations=2)
i = ndimage.binary_fill_holes(i)
i = 1 - i
print(np.mean(i))
plt.imshow(i, cmap="gray")

More of these ideas a few cells below

In [None]:
volume_seq_2 = get_volume(paths[0], extract_brain=False, windowing=True, correct_motion=False, spatial_downsampling_factor=2, temporal_downsampling_factor=7)

In [None]:
overlay_volume_sequence_interactive(volume_seq_2)

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].imshow(v1[0], cmap="gray")
ax[1].imshow(reg[0], cmap="gray")
plt.show()

In [None]:
volume_seq_2 = get_volume(paths[130], extract_brain=False, windowing=True, correct_motion=False, spatial_downsampling_factor=2, temporal_downsampling_factor=7)

In [None]:
overlay_volume_sequence_interactive(volume_seq_2)

In [None]:
reg = register_volume_inplane_weighted(volume_seq_2[0], volume_seq_2[1], n_samples=1, lr=1e-4, n_iters=500)
overlay_volume_sequence_interactive([volume_seq_2[0], reg])

## An registration framwork based on MSE
- optimized registration parameters
- added customizability

In [75]:
def register_2(moving_volume: np.ndarray, reference_volume: np.ndarray, n_samples: int = 5, 
               lr: float = 1.0, n_iters: int = 1000, relaxation_factor: float = 0.99, 
               gradient_magnitude_tolerance: float = 1e-5, max_step: float = 4.0, min_step: float = 5e-4, 
               weighting_scheme: str = 'inverse', spacing: tuple = (1, 1), plot_metrics: bool = True,
               multi_res: bool = False, smoothing_sigma: float = 2.0,
               verbose: bool = False):
    Y, Z, X = moving_volume.shape
    
    middle = Y // 2
    half_range = n_samples // 2
    # Sample slice indices evenly   
    slice_indices = np.linspace(middle - half_range, middle + half_range, n_samples, dtype=int)
    
    # Store transforms and their weights
    transforms = []
    metric_values = []

    def command_iteration(method):
        if (method.GetOptimizerIteration() + 1) % 50 == 0:
            print(f"Iteration: {method.GetOptimizerIteration()}")
            print(f"Metric value: {method.GetMetricValue():.4f}")

    # Register each sampled slice
    for slice_idx in slice_indices:
        print(f"Registering slice {slice_idx} of {Y}")
        # Get corresponding slices
        moving_slice = moving_volume[slice_idx]
        reference_slice = reference_volume[slice_idx]
        # Set min pixel value to 0 so that background aligns with pixels moved in by transformation during registration
        min_pixel_value = np.min(moving_slice)
        moving_slice, reference_slice = moving_slice - min_pixel_value, reference_slice - min_pixel_value
        if verbose:
            print(f"{min_pixel_value=}")
        # Convert to SimpleITK images
        moving_image = sitk.GetImageFromArray(moving_slice)
        reference_image = sitk.GetImageFromArray(reference_slice)
        if not multi_res:
            moving_image = sitk.DiscreteGaussian(moving_image, smoothing_sigma)
            reference_image = sitk.DiscreteGaussian(reference_image, smoothing_sigma)
        
        # Set 2D spacing
        moving_image.SetSpacing(spacing)
        reference_image.SetSpacing(spacing)
        
        # Initialize 2D transform
        initial_transform = sitk.CenteredTransformInitializer(
            moving_image,
            reference_image,
            sitk.Euler2DTransform(),
            sitk.CenteredTransformInitializerFilter.GEOMETRY
        )
        
        # Setup registration method
        registration_method = sitk.ImageRegistrationMethod()
        registration_method.SetMetricAsMeanSquares()
        

        # Optimizer settings
        registration_method.SetOptimizerAsRegularStepGradientDescent(
            learningRate=lr,
            maximumStepSizeInPhysicalUnits=max_step,
            minStep=min_step,
            numberOfIterations=n_iters,
            gradientMagnitudeTolerance=gradient_magnitude_tolerance,
            relaxationFactor=relaxation_factor,
            
        )
        if multi_res:
            registration_method.SetShrinkFactorsPerLevel(shrinkFactors=[2, 1])
            registration_method.SetSmoothingSigmasPerLevel(smoothingSigmas=[1, 2])
            registration_method.SmoothingSigmasAreSpecifiedInPhysicalUnitsOn()

        registration_method.SetInitialTransform(initial_transform)
        registration_method.SetInterpolator(sitk.sitkLinear)
        if plot_metrics:
            registration_method.AddCommand(sitk.sitkIterationEvent, lambda: command_iteration(registration_method))
            registration_method.AddCommand(sitk.sitkStartEvent, start_plot)
            registration_method.AddCommand(sitk.sitkEndEvent, end_plot)
            registration_method.AddCommand(sitk.sitkMultiResolutionIterationEvent, update_multires_iterations)
            registration_method.AddCommand(sitk.sitkIterationEvent, lambda: plot_values(registration_method, 'MSE'))
        try:
            final_transform = registration_method.Execute(reference_image, moving_image)
            metric_value = registration_method.GetMetricValue()
            print(slice_idx)
            print(f"Stopping condition: {registration_method.GetOptimizerStopConditionDescription()}")
            print(f"Metric value: {metric_value:.4f}")
            # Store transform parameters and weight
            params = final_transform.GetParameters()
            center = final_transform.GetCenter()
            transforms.append((params, center))
            metric_values.append(metric_value)
            
        except RuntimeError as e:
            print(f"Registration failed for slice {slice_idx}: {e}")
            continue
    
    if not transforms:
        print("No successful registrations - returning original volume")
        return moving_volume
    
    # Convert metric values to weights based on chosen scheme
    metric_values = np.array(metric_values)
    print(metric_values)

    if weighting_scheme == 'inverse':
        # Original inverse weighting
        weights = 1.0 / (metric_values + 1e-10)

    elif weighting_scheme == 'inverse_root':
        weights = 1.0 / (np.sqrt(metric_values) + 1e-10)
    
    elif weighting_scheme == 'exponential':
        # Exponential decay: w = exp(-metric_value)
        # More robust to outliers than inverse
        weights = np.exp(-metric_values)
    
    elif weighting_scheme == 'softmax':
        # Softmax-based weighting: emphasizes better matches while maintaining non-zero weights
        # Negative because lower metric values are better
        weights = np.exp(-metric_values) / np.sum(np.exp(-metric_values))
        
    elif weighting_scheme == 'rank':
        # Rank-based weighting: less sensitive to absolute metric values
        ranks = np.argsort(np.argsort(-metric_values))  # Higher rank for lower metric value
        weights = 1.0 / (ranks + 1)
    
    elif weighting_scheme == 'threshold':
        # Threshold-based: only keep transforms with metric values below mean
        mean_metric = np.mean(metric_values)
        weights = np.where(metric_values < mean_metric, 1.0, 0.0)
        if np.sum(weights) == 0:  # If all transforms are above mean
            weights = np.ones_like(metric_values)
    
    # Normalize weights
    weights = weights / np.sum(weights)
    
    # Compute weighted average transformation
    avg_angle = 0
    avg_tx = 0
    avg_ty = 0
    avg_cx = 0
    avg_cy = 0
    
    for (params, center), weight in zip(transforms, weights):
        avg_angle += params[0] * weight
        avg_tx += params[1] * weight
        avg_ty += params[2] * weight
        avg_cx += center[0] * weight
        avg_cy += center[1] * weight
    
    print(f"{avg_angle=:.4f} {avg_tx=:.4f} {avg_ty=:.4f} {avg_cx=:.4f} {avg_cy=:.4f}")
    # print(f"Time taken: {time() - t1}")
        # Create final average transform
    final_transform = sitk.Euler2DTransform()
    final_transform.SetAngle(avg_angle)
    final_transform.SetTranslation((avg_tx, avg_ty))
    final_transform.SetCenter((avg_cx, avg_cy))
    
    # Apply transform to each slice of moving volume
    registered_volume = np.zeros_like(moving_volume)
    for i in range(moving_volume.shape[0]):
        moving_slice = sitk.GetImageFromArray(moving_volume[i])
        registered_slice = sitk.Resample(
            moving_slice,
            reference_image,
            final_transform,
            sitk.sitkLinear,
            0.0, # min_pixel_value,
            moving_slice.GetPixelID()
        )
        registered_volume[i] = sitk.GetArrayFromImage(registered_slice) + min_pixel_value

    return registered_volume

In [None]:
reg_2 = register_2(v2, v1, n_samples=1, lr=1, n_iters=500, weighting_scheme='inverse', plot_metrics=True)

In [None]:
overlay_volume_sequence_interactive([v1, reg_2])

In [None]:
multi_vol_seq_interactive([v, [v1, reg_2]])

In [45]:
import os
dataset_path = os.path.expanduser('~/Desktop/UniToBrain')
path = os.path.join(dataset_path, 'MOL-112')


## Let's try registration with different data:
- less aggressive windowing
- Skull-stripping

In [None]:
v_raw = get_volume(paths[13], extract_brain=False, windowing=False, correct_motion=False, spatial_downsampling_factor=2, temporal_downsampling_factor=7)

In [None]:
from preprocessing import get_3d_mask, apply_mask
# v_raw[0] = apply_mask(v_raw[0], get_3d_mask(v_raw[0]))
# v_raw[1] = apply_mask(v_raw[1], get_3d_mask(v_raw[1]))
# v_raw = np.clip(v_raw, -200, 80)
multi_vol_seq_interactive([v_raw, [v_raw[0], v_raw[1]]])

In [None]:
reg_raw = register_2(v_raw[1], v_raw[0], n_samples=1, lr=1, n_iters=200, relaxation_factor=0.8)

In [None]:
multi_vol_seq_interactive([v_raw, [v_raw[0], reg_raw]])

In [None]:
import os
dataset_path = os.path.expanduser('~/Desktop/UniToBrain')
path = os.path.join(dataset_path, 'MOL-133')
v_133 = get_volume(path, extract_brain=False, windowing=True, correct_motion=False, spatial_downsampling_factor=2, temporal_downsampling_factor=7)

In [None]:
reg_133 = register_2(v_133[1], v_133[0], n_samples=1, lr=1, n_iters=200, weighting_scheme='inverse')

In [None]:
multi_vol_seq_interactive([v_133, [v_133[0], reg_133]])

This works! Notice the rotation correction from t=0 to t=1

In [None]:
# Let's try this on the previous scan
v_100_reg2 = register_2(v[1], v[0], n_samples=1, lr=1, n_iters=150, weighting_scheme='inverse')
multi_vol_seq_interactive([v, [v[0], v_100_reg2]])

My goodness this works too, I've spent wayyy to much time on this simple registration task

In [None]:
v_100_reg_multi_res = register_2(v[1], v[0], n_samples=1, lr=1, n_iters=1000, weighting_scheme='inverse', multi_res=True)
multi_vol_seq_interactive([v, [v[0], v_100_reg_multi_res]])

In [None]:
v_100_reg1 = register_volume_inplane_weighted(v[1], v[0], n_samples=1, lr=1e-3, n_iters=2000)

In [None]:
v_106 = get_volume(os.path.join(dataset_path, 'MOL-106'), extract_brain=False, windowing=True, correct_motion=False, spatial_downsampling_factor=2, temporal_downsampling_factor=7)

In [None]:
overlay_volume_sequence_interactive(v_106)

In [None]:
reg_106 = register_2(v_106[1], v_106[0], n_samples=1, lr=1.2, n_iters=1000, weighting_scheme='inverse', smoothing_sigma=5, )
overlay_volume_sequence_interactive([v_106[0], reg_106])

In [None]:
reg_106_2 = reg_106 + np.min(v_106[0]) - np.min(reg_106)
np.min(reg_106_2), np.min(v_106[0])
# overlay_volume_sequence_interactive([v_106[0], reg_106_2])

In [None]:
reg_106 = register_2(v_106[0], v_106[1], n_samples=1, lr=0.11, n_iters=1000, weighting_scheme='inverse')
multi_vol_seq_interactive([v_106, [reg_106, v_106[1]]])

Learning rate too low

# Without Windowing

In [None]:
v_106_no_window = get_volume(os.path.join(dataset_path, 'MOL-106'), extract_brain=False, windowing=False, correct_motion=False, spatial_downsampling_factor=2, temporal_downsampling_factor=7)
v_106_no_window = np.clip(v_106_no_window, -1024, 3000)

In [None]:
overlay_volume_sequence_interactive(volum)

In [None]:
reg_106_no_window = register_2(v_106_no_window[1], v_106_no_window[0], n_samples=1, lr=1, n_iters=1000)


In [None]:
multi_vol_seq_interactive([v_106_no_window, [v_106_no_window[0],reg_106_no_window]])

# Noise Estimation
- This part should help select a volume to act as the reference for registering other volumes in the sequence. It should be the least corrupted and least noisy out of the first few volumes of the sequence.
- As we are doing in-plane registration, is suffices to only look at the images which are being sampled in registration and do noise estimation in 2D.

In [312]:
import numpy as np
from scipy import ndimage
from sklearn.linear_model import LinearRegression

def get_local_statistics(img, patch_size):
    # Calculate local means and variances using sliding windows
    kernel = np.ones((patch_size, patch_size)) / (patch_size * patch_size)
    local_mean = ndimage.convolve(img, kernel, mode='reflect')
    local_var = ndimage.convolve(img**2, kernel, mode='reflect') - local_mean**2
    return local_mean, local_var

def estimate_noise_levels(image, patch_size=16):
    """
    Estimate quantum and electronic noise levels in a CT image.
    
    Parameters:
    -----------
    image : ndarray
        2D array representing the CT image
    patch_size : int
        Size of patches for local variance estimation
    
    Returns:
    --------
    quantum_noise : float
        Estimated quantum noise level
    electronic_noise : float
        Estimated electronic noise level
    """
    
    # Get local means and variances
    means, variances = get_local_statistics(image, patch_size)
    
    # Flatten arrays for regression
    means = means.flatten()
    variances = variances.flatten()
    
    # Remove outliers and negative values
    valid_indices = (variances > 0) & (means > 0)
    means = means[valid_indices]
    variances = variances[valid_indices]
    
    # Prepare data for linear regression
    X = means.reshape(-1, 1)
    y = variances
    
    # Fit linear model: variance = quantum_noise * mean + electronic_noise
    model = LinearRegression()
    model.fit(X, y)
    
    quantum_noise = model.coef_[0]  # Slope represents quantum noise
    electronic_noise = model.intercept_  # Y-intercept represents electronic noise
    
    return quantum_noise, electronic_noise

def validate_noise_estimation(image, patch_size=16, num_splits=5):
    """
    Validate noise estimation by comparing results across image subsets.
    
    Parameters:
    -----------
    image : ndarray
        2D array representing the CT image
    patch_size : int
        Size of patches for local variance estimation
    num_splits : int
        Number of random subsets to use for validation
    
    Returns:
    --------
    dict
        Dictionary containing mean and std of noise estimates
    """
    quantum_estimates = []
    electronic_estimates = []
    
    height, width = image.shape
    for _ in range(num_splits):
        # Create random mask for subset selection
        mask = np.random.rand(height, width) > 0.5
        subset = image.copy()
        subset[~mask] = 0
        
        # Estimate noise for subset
        q_noise, e_noise = estimate_noise_levels(subset, patch_size)
        quantum_estimates.append(q_noise)
        electronic_estimates.append(e_noise)
    
    return {
        'quantum_noise_mean': np.mean(quantum_estimates),
        'quantum_noise_std': np.std(quantum_estimates),
        'electronic_noise_mean': np.mean(electronic_estimates),
        'electronic_noise_std': np.std(electronic_estimates)
    }

def plot_noise_analysis(image, quantum_noise, electronic_noise):
    """
    Create visualization of noise estimation results.
    
    Parameters:
    -----------
    image : ndarray
        2D array representing the CT image
    quantum_noise : float
        Estimated quantum noise level
    electronic_noise : float
        Estimated electronic noise level
    """
    import matplotlib.pyplot as plt
    
    # Calculate local means and variances
    means, variances = get_local_statistics(image, 16)
    
    plt.figure(figsize=(10, 6))
    plt.scatter(means.flatten(), variances.flatten(), alpha=0.1, label='Local measurements')
    
    # Plot fitted line
    x_range = np.linspace(means.min(), means.max(), 100)
    y_fit = quantum_noise * x_range + electronic_noise
    plt.plot(x_range, y_fit, 'r-', label='Fitted noise model')
    
    plt.xlabel('Local mean intensity')
    plt.ylabel('Local variance')
    plt.title('Noise Analysis Plot')
    plt.legend()
    plt.grid(True)
    plt.show()

In [None]:
print(estimate_noise_levels(v_106[0][0], patch_size=3))
plt.imshow(v_106[0][0], cmap='gray')

In [None]:
from preprocessing import bilateral_filter
filtered = bilateral_filter(v_106[0][0], 10, 10)
estimate_noise_levels(filtered, patch_size=3)

In [None]:
plt.imshow(filtered, cmap='gray')

In [None]:
plot_noise_analysis(filtered, *estimate_noise_levels(filtered, patch_size=3))