# Viewing 3D volumetric data with matplotlib

Most of you are familiar with image data, taken with ordinary cameras (these are often called "natural images" in the scientific literature), but also with specialized instruments, such as microscopes or telescopes. When you're working with images in Python, you'll most often make use of the Matplotlib library, Python's 2D plotting library which produces publication quality figures in a variety of hard copy formats and interactive environments across platforms.

Today's tutorial will show you how you can easily view 3D volumetric data with Matplotlib. "3D volumetric data" you say? Volumetric data is very common nowadays: it is captured by various technologies, e.g. MRI, CT, PET, USCT or echolocation. As you can already guess, this type of data plays an important role in medicine. 

This tutorial will show you how to import and display image data, and how you can make a fully functional slice viewer to view all the slices in our MRI volume without any interference.

But first, let's go over some basics. You probably know that you can use Matplotlib to display such data is a straightforward call to the `imshow` function.

We'll start by enabling the interactive matplotlib mode in the notebook:

In [2]:
%matplotlib notebook

Now, we can import matplotlib and display some image data:

In [3]:
# import data module 
from skimage import data

# Gather some data 
astronaut = data.astronaut()
ihc = data.immunohistochemistry()
hubble = data.hubble_deep_field()

In [4]:
# Initialize the subplots
fig, ax = plt.subplots(nrows=1, ncols=3)

# Show subplot images and set titles 
ax[0].imshow(astronaut)
ax[0].set_title('Natural image')
ax[1].imshow(ihc)
ax[1].set_title('Microscopy image')
ax[2].imshow(hubble)
ax[2].set_title('Telescope image');

<IPython.core.display.Javascript object>

**Note:** When you run matplotlib in the interactive notebook mode, the open figure remains the *only* active figure until you disable it, using the power symbol on the top-right of the figure. Be sure you do that before moving on from each plot.

These images are called 2-dimensional or 2D images because they are laid out along 2 dimensions: x and y, or, in NumPy parlance, rows and columns or r and c.

Some images are 3D, in that they have an additional *depth* dimension (z, or planes). These include magnetic resonance imaging (MRI) and *serial section transmission electron microscopy* (ssTEM), in which a sample is thinly sliced, like a salami, and each of the slices is imaged separately.

To view such images in matplotlib, we have to choose a slice, and display only that slice. Let's try it out on some freely available MRI data online.

## Interlude: getting the data...

We're going to download a dataset described in Buchel and Friston, *Cortical Interactions Evaluated with Structural Equation 
Modelling and fMRI* (1997). First, we create a temporary directory in which to download the data. We must remember to delete it when we are done with our analysis! If you want to keep this dataset for later use, change `d` to a more permanent directory location of your choice.

In [5]:
# Import
import tempfile

# Create a temporary directory
d = tempfile.mkdtemp()

Now, let's download the data:

In [10]:
# Import
import os

# Return the tail of the path
os.path.basename('http://google.com/attention.zip')

'attention.zip'

In [11]:
# Import requestretrieve
from urllib.request import urlretrieve

# Define URL
url = 'http://www.fil.ion.ucl.ac.uk/spm/download/data/attention/attention.zip'

# Retrieve the data
fn, info = urlretrieve(url, os.path.join(d, 'attention.zip'))

('/var/folders/25/dnnk9t55369c3s4g4_3zdq4h0000gp/T/tmpm1m62372/attention.zip',
 <http.client.HTTPMessage at 0x11f745e80>)

And extract it from the `zip` file to our temporary directory:

In [17]:
# Import zipfile
import zipfile

# Extract the contents
zipfile.ZipFile(fn).extractall(path=d)

If you look at the actually contents of the file, you'll find a bunch of '.hdr' and '.img' files.

In [22]:
# List first 10 files
[f.filename for f in zipfile.ZipFile(fn).filelist[:10]]

['attention/',
 'attention/multi_block_regressors.mat',
 'attention/README_DATA.txt',
 'attention/factors.mat',
 'attention/functional/',
 'attention/functional/snffM00587_0201.hdr',
 'attention/functional/snffM00587_0040.img',
 'attention/functional/snffM00587_0458.hdr',
 'attention/functional/snffM00587_0185.img',
 'attention/functional/snffM00587_0018.hdr']

In [23]:
# List last 10 files
[f.filename for f in zipfile.ZipFile(fn).filelist[-10:]]

['attention/functional/snffM00587_0048.img',
 'attention/functional/snffM00587_0437.hdr',
 'attention/functional/snffM00587_0058.hdr',
 'attention/functional/snffM00587_0394.img',
 'attention/functional/snffM00587_0104.img',
 'attention/multi_condition.mat',
 'attention/block_regressors.mat',
 'attention/structural/',
 'attention/structural/nsM00587_0002.img',
 'attention/structural/nsM00587_0002.hdr']

These are in the NIfTI file format, and we'll need a reader for them. Thankfully, the excellent `nibabel` library provides such a reader. Make sure you install it with either `conda install -c conda-forge nibabel` or `pip install nibabel`, and then:

In [24]:
import nibabel

Now, we can finally read our image, and use the `.get_data()` method to get a NumPy array to view:

In [25]:
# Read the image 
struct = nibabel.load(os.path.join(d, 'attention/structural/nsM00587_0002.hdr'))

# Get a NumPy array
struct_arr = struct.get_data()

## ... Back to plotting

Let's now look at a slice in that array:

In [39]:
plt.imshow(struct_arr[75])

<IPython.core.display.Javascript object>

<matplotlib.image.AxesImage at 0x1247ba8d0>

Whoa! That looks pretty squishy! That's because the resolution along the vertical axis in many MRIs is not the same as along the horizontal axes. We can fix that by passing the `aspect` parameter to the `imshow` function:

In [40]:
plt.imshow(struct_arr[75], aspect=0.5)

<IPython.core.display.Javascript object>

<matplotlib.image.AxesImage at 0x1263df198>

But, to make things easier, we will just *transpose* the data and only look at the horizontal slices, which don't need such fiddling.

In [41]:
struct_arr2 = struct_arr.T

In [42]:
plt.imshow(struct_arr2[34])

<IPython.core.display.Javascript object>

<matplotlib.image.AxesImage at 0x1265eac18>

Pretty! Of course, to then view another slice, or a slice along a different axis, we need another call to `imshow`:

In [43]:
plt.imshow(struct_arr2[5])

<IPython.core.display.Javascript object>

<matplotlib.image.AxesImage at 0x1248b1c18>

All these calls get rather tedious rather quickly. For a long time, I would view 3D volumes using tools outside Python, such as [ITK-SNAP](). But, as it turns out, it's quite easy to add 3D "scrolling" capabilities to the matplotlib viewer! This lets us explore 3D data within Python, minimizing the need to switch contexts between data exploration and data analysis.

The key is to use the matplotlib [event handler API](), which lets us define actions to perform on the plot — including changing the plot's data! — in response to particular key presses or mouse button clicks.

In our case, let's bind the J and K keys on the keyboard to "previous slice" and "next slice":

In [38]:
# Go to the previous slice 
def previous_slice():
    pass

# Go to the next slice
def next_slice():
    pass

# Process keyboard presses 
def process_key(event):
    if event.key == 'j':
        previous_slice()
    elif event.key == 'k':
        next_slice()

Simple enough! Of course, we need to figure out how to actually implement these actions *and* we need to tell the figure that it should use the `process_key` function to process keyboard presses! The latter is simple: we just need to use the figure canvas method `mpl_connect`:

    fig, ax = plt.subplots()
    ax.imshow(struct_arr[..., 43])
    fig.canvas.mpl_connect('key_press_event', process_key)

You can find the full documentation for `mpl_connect` [here](http://matplotlib.org/users/event_handling.html), including what other kinds of events you can bind (such as mouse button clicks).

It took me just a bit of exploring to find out that `imshow` returns an `AxesImage` object, which lives "inside" the matplotlib `Axes` object where all the drawing takes place, in its `.images` attribute. And this object provides a convenient `set_array` method that swaps out the image data being displayed! So, all we need to do is:

- plot an arbitrary index, and store that index, maybe as an additional runtime attribute on the `Axes` object.
- provide functions `next_slice` and `previous_slice` that change the index and uses `set_array` to set the corresponding slice of the 3D volume.
- use the figure canvas `draw` method to redraw the figure with the new data.

In [44]:
# Bind keyboard presses to Figure canvas
def multi_slice_viewer(volume):
    fig, ax = plt.subplots()
    ax.volume = volume
    ax.index = volume.shape[0] // 2
    ax.imshow(volume[ax.index])
    fig.canvas.mpl_connect('key_press_event', process_key)

    
# Process keyboard presses 
def process_key(event):
    fig = event.canvas.figure
    ax = fig.axes[0]
    if event.key == 'j':
        previous_slice(ax)
    elif event.key == 'k':
        next_slice(ax)
    fig.canvas.draw()

# Go to the previous slice
def previous_slice(ax):
    volume = ax.volume
    ax.index = (ax.index - 1) % volume.shape[0]  # wrap around using %
    ax.images[0].set_array(volume[ax.index])

# Go to the next slice
def next_slice(ax):
    volume = ax.volume
    ax.index = (ax.index + 1) % volume.shape[0]
    ax.images[0].set_array(volume[ax.index])

Let's try it out!

In [45]:
multi_slice_viewer(struct_arr2)

<IPython.core.display.Javascript object>

This works! Nice! But, if you try this out at home, you'll notice that scrolling up with K also squishes the horizontal scale of the plot. Huh?

In [46]:
multi_slice_viewer(struct_arr2)

<IPython.core.display.Javascript object>

What's happening is that adding event handlers to Matplotlib simply piles them on on top of each other. In this case, K is a built-in keyboard shortcut to change the x-axis to use a logarithmic scale. If we want to use K exclusively, we have to remove it from matplotlib's default key maps. These live as lists in the `plt.rcParams` dictionary, which is matplotlib's repository for default system-wide settings:

    plt.rcParams['keymap.<command>'] = ['<key1>', '<key2>']

where pressing any of the keys in the list (i.e. `<key1>` or `<key2>`) will cause `<command>` to be executed.

Thus, we'll need to write a helper function to remove keys that we want to use wherever they may appear in this dictionary. (This function doesn't yet exist in matplotlib, but would probably be a welcome contribution!)

In [47]:
def remove_keymap_conflicts(new_keys_set):
    for prop in plt.rcParams:
        if prop.startswith('keymap.'):
            keys = plt.rcParams[prop]
            remove_list = set(keys) & new_keys_set
            for key in remove_list:
                keys.remove(key)

## A fully functional slice viewer

Ok, let's rewrite our function to make use of this new tool:

In [48]:
# Tie keyboard presses to Figure 
def multi_slice_viewer(volume):
    remove_keymap_conflicts({'j', 'k'})
    fig, ax = plt.subplots()
    ax.volume = volume
    ax.index = volume.shape[0] // 2
    ax.imshow(volume[ax.index])
    fig.canvas.mpl_connect('key_press_event', process_key)

# Process keyboard presses
def process_key(event):
    fig = event.canvas.figure
    ax = fig.axes[0]
    if event.key == 'j':
        previous_slice(ax)
    elif event.key == 'k':
        next_slice(ax)
    fig.canvas.draw()

# Go to the previous slice
def previous_slice(ax):
    volume = ax.volume
    ax.index = (ax.index - 1) % volume.shape[0]  # wrap around using %
    ax.images[0].set_array(volume[ax.index])

# Go to the next slice
def next_slice(ax):
    volume = ax.volume
    ax.index = (ax.index + 1) % volume.shape[0]
    ax.images[0].set_array(volume[ax.index])

Now, we should be able to view all the slices in our MRI volume without pesky interference from the default keymap!

In [49]:
multi_slice_viewer(struct_arr2)

<IPython.core.display.Javascript object>

One nice feature about this method is that it works on *any* matplotlib backend! So, if you try this out in the IPython terminal console, you will still get the same interaction as you did in the browser! And the same is true for a Qt or Tkinter app embedding a matplotlib plot. This simple tool therefore lets you build ever more complex applications around matplotlib's visualisation capabilities.

## Before you go...

Let's not forget to clean up after ourselves, and delete the temporary directory (if you made one):

In [6]:
import shutil
shutil.rmtree(d)