Define various useful functions and classes.

ImageSliceViewer3D is for the interactive visualization of a 3D image. The slices can be selected through a slider.

InteractiveStem adds a movable line in a stem plot. It also constructs an instance of ImageSliceViewer3D from a 3D image. A point in the stem plot correspond to a 2D-image. The position of the movable line in the stem plot determines the images which are displayed in the viewer. 

In [8]:
import numpy as np
import matplotlib.pyplot as plt
import nibabel as nib
import SimpleITK as sitk
import math
import os
import matplotlib.lines as lines
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import matplotlib.colors as colors
from sklearn.preprocessing import StandardScaler
from matplotlib.offsetbox import OffsetImage, AnnotationBbox
import cv2
import csv
import pandas
import ipywidgets as ipyw
import scipy.misc
from IPython.display import clear_output

import esmraldi.segmentation as seg
import esmraldi.imzmlio as imzmlio
import esmraldi.fusion as fusion
import esmraldi.speciesrule as sr
import esmraldi.spectrainterpretation as si

from esmraldi.theoreticalspectrum import TheoreticalSpectrum

%matplotlib notebook

def transpose_eigenvectors(M, eigenvectors, eigenvalues):
    p, n = M.shape
    evT = [np.divide( (np.dot(M, eigenvectors[i])), (np.sqrt(2 * eigenvalues[i] * (n-1)/(p-1)))) for i in range(n)]
    return np.asarray(evT)

class ImageSliceViewer3D:
    """ 
    ImageSliceViewer3D is for viewing volumetric image slices in jupyter or
    ipython notebooks. 
    
    User can interactively change the slice plane selection for the image and 
    the slice plane being viewed. 

    Argumentss:
    Volume = 3D input image
    figsize = default(8,8), to set the size of the figure
    cmap = default('plasma'), string for the matplotlib colormap. You can find 
    more matplotlib colormaps on the following link:
    https://matplotlib.org/users/colormaps.html
    
    """
    
    def __init__(self, volume, mzs, ax=None, figsize=(8,8), cmap='gray'):
        self.volume = volume
        self.mzs = mzs
        self.figsize = figsize
        self.cmap = cmap
        self.ax = ax
        self.v = [np.min(volume), np.max(volume)]
        self.slider = ipyw.IntSlider(min=0, max=self.volume.shape[-1] - 1, step=1, continuous_update=False, 
            description='Image Slice:')
        
    def show(self):
        # Call to select slice plane
        ipyw.interact(self.view_selection, view=ipyw.RadioButtons(
            options=['x-y','y-z', 'z-x'], value='x-y', 
            description='Slice plane selection:', disabled=False,
            style={'description_width': 'initial'}))
    
    def view_selection(self, view):
        # Transpose the volume to orient according to the slice plane selection
        orient = {"y-z":[1,2,0], "z-x":[2,0,1], "x-y": [0,1,2]}
        self.vol = np.transpose(self.volume, orient[view])
        maxZ = self.vol.shape[2] - 1
        self.slider = ipyw.IntSlider(min=0, max=self.vol.shape[-1] - 1, step=1, continuous_update=False, 
            description='Image Slice:')
        # Call to view a slice within the selected slice plane
        ipyw.interact(self.plot_slice, z=self.slider)
        
    def update(self, image, mzs):
        self.volume = image
        self.vol = image
        self.mzs = mzs
        self.slider.max=self.volume.shape[-1] - 1
    
    def plot_slice(self, z):
        # Plot slice for the given plane and slice
        label = "value: "+ str(self.mzs[z])
        #print(label)
        if self.ax is None:
            self.fig = plt.figure(figsize=self.figsize)
            plt.title(label)
            plt.axis('off')
            plt.imshow(self.vol[:,:,z], cmap=plt.get_cmap(self.cmap))
        else:
            self.ax.set_title(label)
            self.ax.axis('off')
            self.ax.imshow(self.vol[:,:, z], cmap=plt.get_cmap(self.cmap))
            
            
class InteractiveStem():
    def __init__(self, ax, image, labels=None):
        self.ax = ax
        self.ax_stem = ax[1]
        self.ax_viewer = ax[2]
        self.c = self.ax_stem.get_figure().canvas
        self.image = image

        line = self.ax_stem.lines[0]
        if labels is None:
            self.mzs = line.get_xdata()
        else:
            self.mzs = labels
        self.eigenvector = line.get_ydata()
                
        self.position = 0.1
        left, right = self.ax_stem.get_xlim()
        self.line = lines.Line2D([left, right], [self.position, self.position], picker=5, color="g")
        self.ax_stem.add_line(self.line)
        
        self.text = self.ax_stem.text(0,0, "", va="bottom", ha="left")
        
        self.viewer_restricted = ImageSliceViewer3D(image, [i for i in range(image.shape[-1])], ax=self.ax_viewer)
        self.viewer_restricted.show()
        
        self.c.draw_idle()
        self.sid = self.c.mpl_connect('pick_event', self.clickonline)

    def clickonline(self, event):
        self.text.set_text(event.artist)
        if event.artist == self.line:
            self.follower = self.c.mpl_connect("motion_notify_event", self.followmouse)
            self.releaser = self.c.mpl_connect('button_press_event', self.releaseonclick)

    def followmouse(self, event):
        self.line.set_ydata([event.ydata, event.ydata])
        self.c.draw_idle()
    
    def releaseonclick(self, event):
        self.position = event.ydata
        self.text.set_text("")
        indices = np.arange(0, len(self.mzs))
        cond = self.eigenvector > self.position if self.position > 0 else self.eigenvector < self.position
        indices_above_line = indices[cond]
        images_selected = self.image[..., indices_above_line]
        mzs_selected = self.mzs[indices_above_line]
        #self.text.set_text(str(self.image.shape) + " " + str(images_selected.shape) + " " + str(mzs_selected.shape))
        self.viewer_restricted.update(images_selected, mzs_selected)
        self.c.mpl_disconnect(self.releaser)
        self.c.mpl_disconnect(self.follower)

Read the MALDI image (.nii) and the m/z information (.csv)

In [10]:
inputname = "data/peaksel_prominence75.nii"
mzsname = "data/peaksel_prominence75.csv"
theoreticalname="data/species_rule.json"
is_ratio = False
threshold = 0
normname = None

## Theoretical spectrum
species = sr.json_to_species(theoreticalname)
ions = [mol for mol in species if mol.category=="Ion"]
adducts = [mol for mol in species if mol.category=="Adduct"]
theoretical_spectrum = TheoreticalSpectrum(ions, adducts)

## Observed spectrum
if inputname.lower().endswith(".imzml"):
    imzml = imzmlio.open_imzml(inputname)
    image = imzmlio.to_image_array(imzml)
    mzs, intensities = imzml.getspectrum(0)
else:
    image = sitk.GetArrayFromImage(sitk.ReadImage(inputname)).T
    if mzsname:
        with open(mzsname) as csv_file:
            csv_reader = csv.reader(csv_file, delimiter=";")
            mzs = [float(row[0]) for row in csv_reader]
    else:
        mzs = [i for i in range(image.shape[2])]
    mzs = np.asarray(mzs)

image = image[..., mzs >= threshold]
mzs = mzs[mzs >= threshold]
mzs = np.around(mzs, decimals=2)



if normname is not None:
    print("Norm image detected")
    norm_img = sitk.ReadImage(normname)
    norm_img = sitk.GetArrayFromImage(norm_img).T
    norm_img_3D = norm_img[..., None]
    before = image.max()
    image = np.divide(image, norm_img_3D, out=np.zeros_like(image, dtype=np.float), where=norm_img_3D!=0)
    after = image.max()
    print("Before max=", before, ", after=", after)


if is_ratio:
    ratio_images, ratio_mzs = fusion.extract_ratio_images(image, mzs)
    image = np.concatenate((image, ratio_images), axis=2)
    mzs = np.concatenate((mzs, ratio_mzs))
    
annotation = si.annotation(mzs, theoretical_spectrum.spectrum, 1)
labels = np.array([str(val) + "(" + str(key) +")" for key, val in annotation.items()])

Normalize the images such that their intensity is scaled between 0 and 255, according to the minimum and maximum value of each image.
Visualize the results with an image slider.

In [11]:
image = imzmlio.normalize(image)
image_shape = (image.shape[0], image.shape[1])
image_norm = fusion.flatten(image)
M = image_norm.T
spectra_image = pandas.DataFrame(data=M, columns=mzs)


viewerSlice = ImageSliceViewer3D(image, mzs, figsize=(4,4))
viewerSlice.show()

<IPython.core.display.Javascript object>

Compute the PCA on both the original and transposed matrix.

In [12]:
print("Computing PCA")
p, n = M.shape
fit_pca = fusion.pca(M, n)
eigenvectors = fit_pca.components_
eigenvalues = fit_pca.singular_values_
print(eigenvectors.shape, eigenvalues.shape)
eigenvectors_transposed = transpose_eigenvectors(M, eigenvectors, eigenvalues)
print("Done")

Computing PCA


(249, 249) (249,)


Done


Else use other dimension reduction techniques, such as NMF :

In [13]:
from sklearn.decomposition import NMF
print("Computing NMF")
p, n = M.shape
nb_comp = 6
nmf = NMF(n_components=nb_comp, init='nndsvda', solver='cd', random_state=0)
fit_nmf = nmf.fit(M)
eigenvectors = fit_nmf.components_ #H
eigenvalues = nmf.fit_transform(M); #W
eigenvectors_transposed = eigenvalues.T
print("Done")

Computing NMF


Done


In [14]:
print(eigenvectors.shape)
print(eigenvectors)
print(eigenvalues.shape)

(6, 249)
[[5.55312207e+01 1.78923791e+02 3.80409272e+02 ... 2.05152291e+02
  1.80129795e+02 5.95556406e+01]
 [0.00000000e+00 1.63135789e+01 4.18146334e+00 ... 2.78530951e+01
  2.89462199e+01 4.06861496e+01]
 [2.12674958e+01 6.51035883e+01 0.00000000e+00 ... 2.30647464e+01
  1.48933872e+01 0.00000000e+00]
 [0.00000000e+00 2.09164754e-01 1.60666299e+00 ... 7.97426765e+01
  3.04579722e+01 1.05413495e+00]
 [2.51818165e+00 1.43115577e+01 0.00000000e+00 ... 1.95488406e+01
  5.85366978e+01 8.59743330e+01]
 [1.33249085e+01 2.80685609e+01 0.00000000e+00 ... 4.26928817e+01
  3.45805743e+01 1.03343152e+01]]
(20976, 6)


After the PCA is computed, explore the loadings of the transposed matrix (image for each axis, on the left) and the loadings of the original matrix on the same axis (stem plot, in the middle). Move the first slider to select an axis in the PCA. 

Move the green horizontal line ($y = a$) in the stem plot to visualize the images (on the right) :
  1. whose score is above $a$ for the selected axis, if $a > 0$.
  2. whose score is below $a$ for the selected axis, if $a \leq 0$
  
Move the second slider to visualize all of these images.

In [15]:
def plot_slice(viewer, z, ax):
    # Plot slice for the given plane and slice
    ImageSliceViewer3D.plot_slice(viewer, z)
    ax[1].clear()
    ax[2].clear()
    ax[1].stem(mzs, eigenvectors[z], use_line_collection=True)
    viewer.hline = InteractiveStem(ax, image, labels)

plt.close()


image_eigenvectors = eigenvectors_transposed.T
new_shape = image_shape + (image_eigenvectors.shape[-1],)
print(new_shape)
image_eigenvectors = image_eigenvectors.reshape(new_shape)
fig, ax = plt.subplots(1, 3, figsize=(10,4))
viewer = ImageSliceViewer3D(image_eigenvectors, mzs=[i for i in range(new_shape[-1])], ax=ax[0])
viewer.plot_slice = lambda z: plot_slice(viewer,z, ax)
viewer.show()
print(theoretical_spectrum.spectrum)

{'AX4_Na+': 569.1696000000001, 'AX5_Na+': 701.2119000000001, 'AX6_Na+': 833.2542000000001, 'AX7_Na+': 965.2965, 'AX8_Na+': 1097.3388, 'AX9_Na+': 1229.3811, 'AX10_Na+': 1361.4234000000001, 'AX11_Na+': 1493.4657, 'AX12_Na+': 1625.508, 'AX13_Na+': 1757.5503, 'AX14_Na+': 1889.5926, 'AX15_Na+': 2021.6349, 'AX16_Na+': 2153.6772, 'AX17_Na+': 2285.7195, 'AX18_Na+': 2417.7618, 'AX19_Na+': 2549.8041000000003, 'AX20_Na+': 2681.8464000000004, 'AX21_Na+': 2813.8887, 'AX22_Na+': 2945.931, 'AX4_K+': 585.1435, 'AX5_K+': 717.1858000000001, 'AX6_K+': 849.2281, 'AX7_K+': 981.2704, 'AX8_K+': 1113.3127000000002, 'AX9_K+': 1245.3550000000002, 'AX10_K+': 1377.3973000000003, 'AX11_K+': 1509.4396000000002, 'AX12_K+': 1641.4819000000002, 'AX13_K+': 1773.5242000000003, 'AX14_K+': 1905.5665000000001, 'AX15_K+': 2037.6088000000002, 'AX16_K+': 2169.6511, 'AX17_K+': 2301.6934, 'AX18_K+': 2433.7357, 'AX19_K+': 2565.7780000000002, 'AX20_K+': 2697.8203000000003, 'AX21_K+': 2829.8626, 'AX22_K+': 2961.9049}
