In [None]:
import math
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.pylab as pylab
import matplotlib.image as mpimg
from sklearn.cluster import MeanShift, estimate_bandwidth
from sklearn.datasets import make_blobs
from PIL import Image
from itertools import cycle
%matplotlib inline
pylab.rcParams['figure.figsize'] = 16, 12
%matplotlib inline
pylab.rcParams['figure.figsize'] = 16, 12

Let's start by implementing a naive meanshift in a 2D vector space from scratch. This is not a particularly efficient implementation. Credit goes to Eric Choi for the demo (http://www.chioka.in/meanshift-algorithm-for-the-rest-of-us-python/).

First let's start by generating a bunch of datapoints from 4 clusters. This is easy using the in-build ````make_blobs```` function.

In [None]:
original_X, X_shapes = make_blobs(100, 2, centers=4, cluster_std=1.3)
print(original_X.shape)
plt.plot(original_X[:,0], original_X[:,1], 'bo', markersize = 10)

We need a few things to be able to calculate the mean-shift. First, we need a euclidian distance function, a mechanism for determining a neighborhood of points to consider in the gaussian kernel estimator, and the gaussian kernel itself. Finally, our gaussian kernel returns the element-size gaussian kernel density from which we will construct the full population density function.

In [None]:
def euclid_distance(x, xi):
    return np.sqrt(np.sum((x - xi)**2))

def neighbourhood_points(X, x_centroid, distance = 5):
    eligible_X = []
    for x in X:
        distance_between = euclid_distance(x, x_centroid)
        ### print('Evaluating: [%s vs %s] yield dist=%.2f' % (x, x_centroid, distance_between))
        if distance_between <= distance:
            eligible_X.append(x)
    return eligible_X

def gaussian_kernel(distance, bandwidth):
    val = (1/(bandwidth*math.sqrt(2*math.pi))) * np.exp(-0.5*((distance / bandwidth))**2)
    return val

Here we have manually specified the distance ```` look_distance````, which is how far to look in space for neighbors, a parameter to our ````neighborhood_points```` function. 

In [None]:
look_distance = 6  # How far to look for neighbours.
kernel_bandwidth = 4  # Kernel parameter.

The loop.

1. For each datapoint x ∈ X, find the neighbouring points N(x) of x.
2. For each datapoint x ∈ X, calculate the mean shift m(x).
3. For each datapoint x ∈ X, update x ← m(x) by computation of the mean-shift vector directly as the weighted average.
4. Repeat 1. for n_iterations or until the points are almost not moving or not moving.

You can change the n_iterations below to run more iterations of Meanshift.

In [None]:
X = np.copy(original_X)
# print('Initial X: ', X)

past_X = []
n_iterations = 5
for it in range(n_iterations):
    # print('Iteration [%d]' % (it))    

    for i, x in enumerate(X):
        ### Step 1. For each datapoint x ∈ X, find the neighbouring points N(x) of x.
        neighbours = neighbourhood_points(X, x, look_distance)
        # print('[%s] has neighbours [%d]' % (x, len(neighbours)))
        
        ### Step 2. For each datapoint x ∈ X, calculate the mean shift m(x).
        numerator = 0
        denominator = 0
        for neighbour in neighbours:
            distance = euclid_distance(neighbour, x)
            weight = gaussian_kernel(distance, kernel_bandwidth)
            numerator += (weight * neighbour)
            denominator += weight
        
        new_x = numerator / denominator
        
        ### Step 3. For each datapoint x ∈ X, update x ← m(x).
        X[i] = new_x
    
    # print('New X: ', X)
    past_X.append(np.copy(X))

In [None]:
figure = plt.figure(1)
figure.set_size_inches((10, 16))
plt.subplot(n_iterations + 2, 1, 1)
plt.title('Initial state')
plt.plot(original_X[:,0], original_X[:,1], 'bo')
plt.plot(original_X[:,0], original_X[:,1], 'ro')

for i in range(n_iterations):
    figure_index = i + 2
    plt.subplot(n_iterations + 2, 1, figure_index)
    plt.title('Iteration: %d' % (figure_index - 1))
    plt.plot(original_X[:,0], original_X[:,1], 'bo')
    plt.plot(past_X[i][:,0], past_X[i][:,1], 'ro')

Now, an image-based example using sklearn. Let's make it somewhat medically-relevant and basically simple to understand the result.

In [None]:
from PIL import Image
from itertools import cycle

image = Image.open('Spine.jpg')
image = np.array(image)
original_shape = image.shape
print(original_shape)

# Flatted image to RBG channels
X = np.reshape(image, [-1,3])
print(X.shape)

plt.imshow(image)

Scikit-learn contains its own bandwidth estimator, let's use it.

In [None]:
bandwidth = estimate_bandwidth(X, quantile=0.5, n_samples=100)
print(bandwidth)

Now let's run sklearn's mean-shift implementation, which is easily invoked, and let's check how many clusters we get and whether that makes intuitive sense to us.

In [None]:
ms = MeanShift(bandwidth=bandwidth, bin_seeding=True)
ms.fit(X)

labels = ms.labels_
print(labels.shape)
cluster_centers = ms.cluster_centers_
print(cluster_centers.shape)

labels_unique = np.unique(labels)
n_clusters_ = len(labels_unique)

print("number of estimated clusters : %d" % n_clusters_)

Now let's recover the original shape of the image (ignore the RBG channels, they're not relevant anymore, just need the 2D information).

In [None]:
segmented_image = np.reshape(labels, original_shape[:2])
plt.figure(2)
plt.subplot(1, 2, 1)
plt.imshow(image)
plt.axis('off')
plt.subplot(1, 2, 2)
plt.imshow(segmented_image)
plt.axis('off')

In [None]:
from sklearn.cluster import MeanShift, estimate_bandwidth
from sklearn.datasets.samples_generator import make_blobs

Now, to a medical example. Begin by importing SimpleITK and creating an ImageSeriesReader. We will use this SimpleITK function to get a series of DICOM images and select from them which to use during our clustering segmentation pipeline.

In [None]:
import SimpleITK as sitk

reader = sitk.ImageSeriesReader()
seriesIDs = reader.GetGDCMSeriesIDs("DICOM")
numSeries = len(seriesIDs)

dicomSeriesID = 0
dicomNames = reader.GetGDCMSeriesFileNames("DICOM", seriesIDs[dicomSeriesID])
reader.SetFileNames(dicomNames)
try:
    image = reader.Execute()
except RunTimeError:
    print ("--> Something went wrong reading DICOM names!")
    

Write out a more convenient format for research, namely, the compressed NIFTI.

In [None]:
sitk.WriteImage(image, "DICOM.nii.gz")

In [None]:
imageSlice = 45
numpyImage = sitk.GetArrayFromImage(image)
my2DImage = numpyImage[imageSlice]

imgplot = plt.imshow(my2DImage, cmap='gray')

We can use this later if we just want to start with the DICOM data but in our compressed NIFTI format. NIFTI also has the advantage of being naturally de-identified, however, the DICOM data I have given you in this tutorial is already de-identified for obvious reasons.

Now we do some simple conversion. First, we will apply a total variation denoising procedure to the images, which is available in the skimage package, which removes some random noise. In this case we are also grabbing the numpy array representation of the imageSecondly, we scale the data and convert it to 8 bit in preparation for using our pymeanshift library, which works only with 8 bit image data.

In [None]:
from skimage.restoration import denoise_tv_chambolle
import cv2
import numpy as np

numpyImage = denoise_tv_chambolle(numpyImage, weight=0.001)

alpha = 255.0 / np.amax(numpyImage)
beta = np.amin(numpyImage)

numpyImage = cv2.convertScaleAbs(numpyImage, alpha=alpha, beta=beta)

myImageFinal = numpyImage[imageSlice]
imgplot = plt.imshow(myImageFinal, cmap='gray')

For the purposes of this demonstration our goal is to perform the mean shift segmentation on a single axial slide of the MRI which contains a significant amount of tumor. In this demo I have pre-configured the axial slide value corresponding to this 2D slice of the image, however, one could imagine a straightforward 3D version of this code (but beware! this is computationally expensive!) or an interactive image picker writtein in, for example, matplotlib.

Two significant configurable options exist within the construction of the mean shift segmenter. The first is the spatial radius. This is the radius around each seed point which will be considered for the centroid calculation. The second is the range radius, which is defined as the range in the image density that will be considered to be part of the centroid calculation at each iteration and for each seed point. Lastly, there is a minimum density set so that no pixels below a certain threshold are considered in the calculation. This has the effect of removing the background. I have pre-populated some values for you here, however, feel free to play around with this.

In [None]:
spatialRadius = 2
rangeRadius = 5
minDensity = 100

import pymeanshift as pms
from PIL import Image

mySegmenter = pms.Segmenter()
mySegmenter.min_density = minDensity
mySegmenter.spatial_radius = spatialRadius
mySegmenter.range_radius = rangeRadius

segmentedImage, labelImage, numRegions = mySegmenter(myImageFinal)
print ("--> Generated", numRegions, "regions")
print ("--> Label of the intratumoral region appears to be", labelImage[100][150])
tumorLabelIndex = labelImage[100][150]
imgplot = plt.imshow(labelImage, cmap='gray')

Now that we have successful segmented our tumor, let's go ahead and generate a simple mask on the image which is 1 in the tumor and 0 everywhere else. When can perhaps most simply do this by reshaping the image into a one-dimensional array and then simply converting it back. 

In [None]:
imageShape = labelImage.shape
myLabelImage = np.ravel(labelImage)
myLabelImageSegmentation = np.zeros(myLabelImage.shape)

for i in range(0, len(myLabelImage)):
    if myLabelImage[i] == tumorLabelIndex:
        myLabelImageSegmentation[i] = 1
        
myLabelImageSegmentation = np.reshape(myLabelImageSegmentation, (-1, imageShape[1]))
imgplot = plt.imshow(myLabelImageSegmentation, cmap='gray')

Congratulations on your segmented tumor! Now, just do this in 3D, or find a way to concatenate the results between successive mean-shift segmentation in each relevant axial slice along the tumor itself.