In [None]:
from scipy.io import loadmat
import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as widgets

from proto.visualization import slices_interactive, compare_volumes_interactive, compare_volumes, slider

%matplotlib inline

In [None]:
data = loadmat('../tmp/mri-001-axial.mat')['voxels']

In [None]:
data_inverted = 2*np.mean(data) - data

In [None]:
slices_interactive(data_inverted)

In [None]:
from scipy.ndimage import filters

def unsharp_mask(image, sigma, weight):
    '''
    There are many ways to define an "unsharp" mask, 
    however this is a pretty standard one.
    
    NOTE: there is no "threshold" here.  It would be
    easy to add, however.  See:
    
        http://www.damiensymonds.net/tut_usm.html
        
    for details on how it is typically implemented.
    '''
    blurred = filters.gaussian_filter(image, sigma)
    return image - weight*blurred

In [None]:
data_inverted_unsharp = unsharp_mask(data_inverted, 5, 1)
compare_volumes_interactive(data_inverted, data_inverted_unsharp)

In [None]:
def adaptive_threshold(data):
    thresholded = np.empty_like(data, dtype='b')
    for iz in range(data.shape[2]):
        thresholded[:, :, iz] = data[:, :, iz] > np.median(data[:, :, iz])
    return thresholded

thresholded = adaptive_threshold(data_inverted_unsharp)

compare_volumes_interactive(data_inverted, thresholded)

In [None]:
from scipy import signal

kernel_length = 31
kernel_unnormalized = signal.gaussian(kernel_length, std=5)
kernel = kernel_unnormalized / np.sum(kernel_unnormalized)

plt.plot(kernel, '-o')
plt.title("Gaussian Kernel")

In [None]:
thresholded_as_float = thresholded.astype('f')
kernel_x = np.reshape(kernel, (kernel_length, 1, 1))
blurred_x = signal.convolve(thresholded_as_float, kernel_x, 'same')

kernel_y = np.reshape(kernel, (1, kernel_length, 1))
blurred_y = signal.convolve(thresholded_as_float, kernel_y, 'same')

In [None]:
compare_volumes_interactive(blurred_x, blurred_y)

In [None]:
blurred_xy = blurred_x + blurred_y

In [None]:
# Assume that `blurred_x` and `blurred_y` will have the same "maximum" value, M.
# M likely occurs
# If this is the case, then the overlapping points will have a value of 2*M, and the intermediate
# lines in the grid will have a value of about M, hence we want to threshold at 3/4*(2*M)
blurred_xy_thresholded = blurred_xy > np.max(blurred_xy)*3/4

In [None]:
compare_volumes_interactive(blurred_xy, blurred_xy_thresholded)

In [None]:
from scipy.ndimage import distance_transform_cdt, label, watershed_ift

# NOTE: this watershed algorithm is not finished
# I stopped working on it, because after some consideration, it does not seem
# like we even need to bother doing watersheding, as all of the points appear 
# properly separated already.

def watershed_segmentation(binary_array):
    distance = distance_transform_cdt(binary_array)
    #local_maximums = peak_local_max(distance, indices=False, footprint=np.ones((3, 3, 3)), labels=binary_array)
    markers = label(local_maximums)
    # TODO: finish this

In [None]:
from scipy import ndimage

labeled, num_features = ndimage.label(blurred_xy_thresholded)

slices_interactive(labeled, cmap='jet')

In [None]:
def histogram_of_feature_sizes(labeled_array):
    labeled_array_flattened = labeled_array.flatten()
    labeled_array_flattened_excluding_background = labeled_array_flattened[labeled_array_flattened > 0]
    feature_sizes = np.bincount(labeled_array_flattened_excluding_background)
    plt.plot(feature_sizes[4:1000])
    plt.xlabel('Size (# of Voxels)')
    plt.ylabel('Label Count')

def reject_outliers(data, m = 2.0):
    distance_from_median = np.abs(data - np.median(data))
    median_distance_from_median = np.median(distance_from_median)
    if median_distance_from_median == 0:
        return data
    else:
        return data[distance_from_median <= m*median_distance_from_median]
    

histogram_of_feature_sizes(labeled)

In [None]:
def prune_features(labeled_array, min_size, max_size):
    '''
    Remove features that have too many or too few pixels.
    '''
    feature_sizes = np.bincount(labeled_array.flatten())
    remaining_features_array = np.logical_and(feature_sizes >= min_size, feature_sizes <= max_size)
    remaining_features = np.nonzero(remaining_features_array)

    labeled_array_flat = labeled_array.flatten()
    features_pruned = np.in1d(labeled_array_flat, remaining_features)
    features_pruned.resize(*labeled_array.shape)
    return features_pruned

def min_max_thresholdb(min_val, max_val):
    '''
    Tool to help identify the min and max threshold values.
    
    Of course, this will have to be done automatically in the final algorithm!
    '''
    blurred_xy_thresholded_pruned = prune_features(labeled, min_val, max_val)
    compare_volumes(blurred_xy_thresholded_pruned, blurred_xy_thresholded, 63)
    compare_volumes(blurred_xy_thresholded_pruned, blurred_xy_thresholded, 78)
    compare_volumes(blurred_xy_thresholded_pruned, blurred_xy_thresholded, 96)
    
keywords = {
    'min_val': slider(100),
    'max_val': slider(1000),
}
widgets.interact(min_max_thresholdb, **keywords)

In [None]:
# the max/min thresholds are manually determined using the previous
# interactive block; eventually this step will need to be automated
blurred_xy_thresholded_pruned = prune_features(labeled, 50, 500)
labeled_pruned, number_of_labels = ndimage.label(blurred_xy_thresholded_pruned)
# NOTE: it is not clear whether `data` should be the first argument 
# here;  need to think about this more
grid_intersections = ndimage.center_of_mass(data, labeled_pruned, range(1, number_of_labels + 1))
x, y, z = zip(*grid_intersections)

In [None]:
%matplotlib notebook
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(111, projection='3d')
ax.scatter(x, y, z)