# Plot correlation maps overview
This notebook plots an overview table of the correlation maps for several conditions and contrasts, on a 2D grid. It can also plot both unthresholded maps with a contour of thresholded map. Also all maps will be plotted in the same range of effect size so as to be comparable.

By Stephen Larroque from Coma Science Group, University of Liège, created on 2018-01-23.

Version v1.4.6

In [None]:
%load_ext autoreload
%autoreload 2
# BEWARE: autoreload works on functions and on general code, but NOT on new class methods:
# if you add or change the name of a method, you have to reload the kernel!
# also it will fail if you use super() calls in the classes you change
# ALSO AUTORELOAD SHOULD BE THE FIRST LINE EVER EXECUTED IN YOUR IPYTHON NOTEBOOK!!!

# Profilers:
# http://pynash.org/2013/03/06/timing-and-profiling/
# http://mortada.net/easily-profile-python-code-in-jupyter.html
# use %lprun -m module func(*args, **kwargs)
try:
    %load_ext line_profiler
    %load_ext memory_profiler
except ImportError as exc:
    pass

In [None]:
# Generate figure inside IPython Notebook (must be called before any import of matplotlib, direct or indirect!)
%matplotlib inline

In [None]:
import matplotlib.pyplot as plt
import matplotlib.colors as pltcol
import matplotlib as mpl
import numpy as np
import nibabel as nib
import os
from nilearn import image
from nilearn import plotting
from IPython import display

In [None]:
def replace_illegal_chars(s):
    """Replace some illegal characters for filenames on most OSes by equivalent text"""
    s = s.replace('>', 'morethan').replace('<', 'lessthan').replace('|', 'or').replace(' ','-').replace('/', '-').replace('\\','-')
    return s

def get_fig_filename(cond, contrast):
    """Get the filename for a figure"""
    return replace_illegal_chars('figs/%s_%s.png' % (cond, contrast))

def sort_it(d, sorter=None):
    """Provide a sorted iterator on a dictionary's items given a 'sorter' list, containing the keys of the dictionary with the wanted order"""
    if sorter is None:
        return sorted(d.iteritems())
    else:
        return ((k, d[k]) for k in sorter)

In [None]:
# Plot the seeds coordinates
coords = {
    "MPFC": (-1, 54, 27), # MPFC
    "PCC/Precuneus": (0, -52, 27), # PCC/precuneus
    }

# Plot seeds on anatomical image
for roi_name, coord in coords.items():
    fig_anat = plotting.plot_anat(r'C:\matlab_tools\conn17f\utils\surf\referenceT1_icbm.nii',
                                 cut_coords=coord, title="Seed placement of %s" % roi_name)
    # Coordinates of seed regions should be specified in first argument and second
    # argument `marker_color` denotes color of the sphere in this case yellow 'y'
    # and third argument `marker_size` denotes size of the sphere
    
    fig_anat.add_markers(list(coords.values()), marker_color=['y', 'g'], marker_size=100)
    fig_anat.savefig('seeds_anat_%s.png' % replace_illegal_chars(roi_name), dpi=dpi_resolution)

# Plot seeds on glass brain
fig_glass = plotting.plot_glass_brain(None, title="Seed placement of %s" % ' and '.join(coords.keys()))
fig_glass.add_markers(coords.values(), marker_color=['y', 'g'], marker_size=100)
fig_glass.savefig('seeds_glassbrain.png', dpi=dpi_resolution)

In [None]:
# PARAMETERS - EDIT ME
voxel_threshold = 0.0001 # minimum threshold to consider as a voxel and not just background noise (because background voxels can be 0.000001 for example), can be float or str ('1%' to give a percentage). TODO: autodetect minimum value (can be -4, 0.02, etc) as the background and use it as the threshold value.
# List of masks
# NOTE: the first image will be used as the template to resample other masks!
list_cols = ["Average Conscious", "Average Unconscious", "Contrast Unconscious > Conscious", "Hyperconnectivity"]
list_imgs = {'Dex': [
                        {'maps': [r'..\unthresholded-contrasts\dex\spmT_0004.nii'],
                         'contours': [r'..\significant-contrasts\dex\dex-avgW1-pos.img', r'..\significant-contrasts\dex\dex-avgW1-neg.img'],
                        },
                        {'maps': [r'..\unthresholded-contrasts\dex\spmT_0003.nii'],
                         'contours': [r'..\significant-contrasts\dex\dex-avgS2-pos.img', r'..\significant-contrasts\dex\dex-avgS2-neg.img'],
                        },
                        {'maps': [r'..\unthresholded-contrasts\dex\spmT_0002.nii'],
                         'contours': [r'..\significant-contrasts\dex\dex-conS2morethanW1-pos.img', r'..\significant-contrasts\dex\dex-conS2morethanW1-neg.img'],
                        },
                        {'maps': [r'..\unthresholded\dex-dmn-hypercon-pos-unthresholded.nii'],
                         'contours': [r'..\significant\testconjunc-dex.nii'],
                        },
                    ],
            'Ketamine': [
                        {'maps': [r'..\unthresholded-contrasts\ket\spmT_0004.nii'],
                         'contours': [r'..\significant-contrasts\ket\ket-avgW1-pos.img', r'..\significant-contrasts\ket\ket-avgW1-neg.img'],
                        },
                        {'maps': [r'..\unthresholded-contrasts\ket\spmT_0003.nii'],
                         'contours': [r'..\significant-contrasts\ket\ket-avgS2-pos.img', r'..\significant-contrasts\ket\ket-avgS2-neg.img'],
                        },
                        {'maps': [r'..\unthresholded-contrasts\ket\spmT_0002.nii'],
                         'contours': [r'..\significant-contrasts\ket\ket-cons2morethanw1-pos.img', r'..\significant-contrasts\ket\ket-cons2morethanw1-neg.img'],
                        },
                        {'maps': [r'..\unthresholded\ket-dmn-hypercon-pos-unthresholded.nii'],
                         'contours': [r'..\significant\testconjunc-ket.nii'],
                        },
                    ],
            'MCS-': [
                        {'maps': [r'..\unthresholded-contrasts\mcs_minus\spmT_0004.nii'],
                         'contours': [r'..\significant-contrasts\emcs\emcs-avgCTR-pos.img', r'..\significant-contrasts\emcs\emcs-avgCTR-neg.img'],
                        },
                        {'maps': [r'..\unthresholded-contrasts\mcs_minus\spmT_0003.nii'],
                         'contours': [r'..\significant-contrasts\mcs_minus\mcsminus-avgmcsminus-pos.img'],
                        },
                        {'maps': [r'..\unthresholded-contrasts\mcs_minus\spmT_0002.nii'],
                         'contours': [r'..\significant-contrasts\mcs_minus\mcsminus-conmcsminusmorethanctr-pos.img', r'..\significant-contrasts\mcs_minus\mcsminus-conmcsminusmorethanctr-neg.img'],
                        },
                        {'maps': [r'..\unthresholded\doc-mcsminus-hyperconnectivity-pos-unthresholded.nii'],
                         'contours': [r'..\significant\testconjunc-mcsminus.nii'],
                        },
                    ],
            }
rows_order = ['Dex', 'Ketamine', 'MCS-']  # list of list_imgs keys to print rows in the specified order. If set to None, will use alphabetical order.
dpi_resolution = 300  # resolution to save images

In [None]:
# Load masks and resample to first
imgs = {}
firstimg = None
countload = 0
allmax = 0
allmin = 0
allmax_sourceimpath = None
allmin_sourceimpath = None
percond_max = {k:0 for k in list_imgs.keys()}
percond_min = {k:0 for k in list_imgs.keys()}
for cond, contrasts in list_imgs.items():
    imgs[cond] = []
    for contrast in contrasts:
        imgs[cond].append({})
        for imtype, imageset in contrast.items():
            imgs[cond][-1][imtype] = []
            for imagepath in imageset:
                # Load image data
                im = image.load_img(imagepath)
                # Resample to first image
                if firstimg:
                    if im.shape != firstimg.shape:
                        im = image.resample_to_img(im, firstimg)
                else:
                    firstimg = im
                # Remove background noise voxels    
                im = image.threshold_img(im, voxel_threshold)
                # Get the minimal and maximum range over all provided images, so as to plot on a comparable range
                curmax = im.get_data().max()
                curmin = im.get_data().min()
                if curmax > allmax:
                    allmax = curmax
                    allmax_sourceimpath = imagepath
                if curmin < allmin:
                    allmin = curmin
                    allmin_sourceimpath = imagepath
                if curmax > percond_max[cond]:
                    percond_max[cond] = curmax
                if curmin < percond_min[cond]:
                    percond_min[cond] = curmin
                # Save image data in our list
                imgs[cond][-1][imtype].append(im)
                countload += 1

# Display some infos on loaded images
print('All %i images successfully loaded!' % countload)
print('Overall max and min ranges found for effect size: (%g, %g)' % (allmin, allmax))
print('The overall max comes from the image file: %s' % allmax_sourceimpath)
print('The overall min comes from the image file: %s' % allmin_sourceimpath)
print('Per condition max:')
print(percond_max)
print('Per condition min:')
print(percond_min)

In [None]:
from numpy import ma
from matplotlib import cbook as mplcbook
from matplotlib import colors as mplcolors
# set the colormap and centre the colorbar
class MidpointNormalize(mplcolors.Normalize):
    """
    Normalise the colorbar to midpoint but do not center, so that diverging bars work there way either side from a prescribed midpoint value)

    e.g. im=ax1.imshow(array, norm=MidpointNormalize(midpoint=0.,vmin=-100, vmax=100))
    Courtesy of Joe Kington: https://matplotlib.org/gallery/userdemo/colormap_normalizations_custom.html and http://chris35wills.github.io/matplotlib_diverging_colorbar/
    """
    def __init__(self, vmin=None, vmax=None, midpoint=None, clip=False):
        self.midpoint = midpoint
        mplcolors.Normalize.__init__(self, vmin, vmax, clip)

    def __call__(self, value, clip=None):
        # I'm ignoring masked values and all kinds of edge cases to make a
        # simple example...
        x, y = [self.vmin, self.midpoint, self.vmax], [0, 0.5, 1]
        return np.ma.masked_array(np.interp(value, x, y), np.isnan(value))

class OffsetNorm(mplcolors.Normalize):
    """Normalize to midpoint but do not center. Accept multiple control points (not just midpoint)."""
    def __init__(self, control_points, clip=False):
        self._control_points = np.unique(sorted(control_points)).astype(float)
        self.vmin = min(self._control_points)  # in case other methods need them
        self.vmax = max(self._control_points)
        self.clip = clip

    def __call__(self, value, clip=None):
        if clip is None:
            clip = self.clip
        result, is_scalar = self.process_value(value)

        # yada yada 
        N = self._control_points.shape[0]
        x, y = self._control_points, np.linspace(0, 1, num=N, endpoint=True)
        return np.ma.masked_array(np.interp(value, x, y))

class MidPointNorm(mplcolors.Normalize):
    """Normalize and center to 0"""
    # Courtesy of Tillsten: https://stackoverflow.com/a/7746125/1121352
    # Future of colorbar normalization: https://github.com/matplotlib/matplotlib/pull/7294
    def __init__(self, midpoint=0, vmin=None, vmax=None, clip=False):
        Normalize.__init__(self,vmin, vmax, clip)
        self.midpoint = midpoint

    def __call__(self, value, clip=None):
        if clip is None:
            clip = self.clip

        result, is_scalar = self.process_value(value)

        self.autoscale_None(result)
        vmin, vmax, midpoint = self.vmin, self.vmax, self.midpoint

        if not (vmin < midpoint < vmax):
            raise ValueError("midpoint must be between maxvalue and minvalue.")       
        elif vmin == vmax:
            result.fill(0) # Or should it be all masked? Or 0.5?
        elif vmin > vmax:
            raise ValueError("maxvalue must be bigger than minvalue")
        else:
            vmin = float(vmin)
            vmax = float(vmax)
            if clip:
                mask = ma.getmask(result)
                result = ma.array(np.clip(result.filled(vmax), vmin, vmax),
                                  mask=mask)

            # ma division is very slow; we can take a shortcut
            resdat = result.data

            #First scale to -1 to 1 range, than to from 0 to 1.
            resdat -= midpoint            
            resdat[resdat>0] /= abs(vmax - midpoint)            
            resdat[resdat<0] /= abs(vmin - midpoint)

            resdat /= 2.
            resdat += 0.5
            result = ma.array(resdat, mask=result.mask, copy=False)                

        if is_scalar:
            result = result[0]            
        return result

    def inverse(self, value):
        if not self.scaled():
            raise ValueError("Not invertible until scaled")
        vmin, vmax, midpoint = self.vmin, self.vmax, self.midpoint

        if mplcbook.iterable(value):
            val = ma.asarray(value)
            val = 2 * (val-0.5)  
            val[val>0]  *= abs(vmax - midpoint)
            val[val<0] *= abs(vmin - midpoint)
            val += midpoint
            return val
        else:
            val = 2 * (val - 0.5)
            if val < 0: 
                return  val*abs(vmin-midpoint) + midpoint
            else:
                return  val*abs(vmax-midpoint) + midpoint

In [None]:
# Definition of custom matplotlib colormaps
# https://matplotlib.org/examples/pylab_examples/custom_cmap.html
cdict = {'red':   ((0.0, 0.0, 0.0),
                   (0.25,0.0, 0.0),
                   (0.5, 1.0, 1.0),
                   (0.75,1.0, 1.0),
                   (1.0, 1.0, 1.0)),

         'green': ((0.0, 0.0, 0.0),
                   (0.25,0.0, 0.0),
                   (0.4, 0.9, 0.9),
                   (0.5, 1.0, 1.0),
                   (0.6, 0.9, 0.9),
                   (0.75,0.0, 0.0),
                   (1.0, 0.0, 0.0)),

         'blue':  ((0.0, 1.0, 1.0),
                   (0.25,1.0, 1.0),
                   (0.5, 1.0, 1.0),
                   (0.75,0.0, 0.0),
                   (1.0, 0.0, 0.0)),
         'alpha': ((0.0, 1.0, 1.0),
                   (0.25,1.0, 1.0),
                   (0.45,1.0, 0.0),
                   (0.5, 0.0, 0.0),
                   (0.55,0.0, 1.0),
                   (0.75,1.0, 1.0),
                   (1.0, 1.0, 1.0)),
        }
cdict2 = {'red':   ((0.0, 0.0, 0.0),
                   (0.25,0.0, 0.0),
                   (0.5, 1.0, 1.0),
                   (0.75,1.0, 1.0),
                   (1.0, 1.0, 1.0)),

         'green': ((0.0, 0.0, 0.0),
                   (0.25,0.0, 0.0),
                   (0.5, 1.0, 1.0),
                   (0.75,0.0, 0.0),
                   (1.0, 0.0, 0.0)),

         'blue':  ((0.0, 1.0, 1.0),
                   (0.25,1.0, 1.0),
                   (0.5, 1.0, 1.0),
                   (0.75,0.0, 0.0),
                   (1.0, 0.0, 0.0)),
         'alpha': ((0.0, 1.0, 1.0),
                   (0.25,1.0, 1.0),
                   (0.48,1.0, 0.0),
                   (0.5, 0.0, 0.0),
                   (0.52,0.0, 1.0),
                   (0.75,1.0, 1.0),
                   (1.0, 1.0, 1.0)),
        }
cdict3 = {'red':   ((0.0, 0.0, 0.0),
                   (0.25,0.0, 0.0),
                   (0.5, 1.0, 1.0),
                   (0.75,1.0, 1.0),
                   (1.0, 1.0, 1.0)),

         'green': ((0.0, 0.0, 0.0),
                   (0.25,0.0, 0.0),
                   (0.5, 1.0, 1.0),
                   (0.75,0.0, 0.0),
                   (1.0, 0.0, 0.0)),

         'blue':  ((0.0, 1.0, 1.0),
                   (0.25,1.0, 1.0),
                   (0.5, 1.0, 1.0),
                   (0.75,0.0, 0.0),
                   (1.0, 0.0, 0.0)),
         'alpha': ((0.0, 1.0, 1.0),
                   (0.4999,1.0, 0.0),
                   (0.5, 0.0, 1.0),
                   (1.0, 1.0, 1.0)),
        }
cdict4 = {'red':   ((0.0, 0.0, 0.0),
                   (0.25,0.0, 0.0),
                   (0.5, 1.0, 1.0),
                   (1.0, 1.0, 1.0)),

         'green': ((0.0, 0.3, 0.3),
                   (0.5, 1.0, 1.0),
                   (0.75,0.0, 0.0),
                   (1.0, 0.0, 0.0)),

         'blue':  ((0.0, 1.0, 1.0),
                   (0.5, 1.0, 1.0),
                   (0.75,0.0, 0.0),
                   (1.0, 0.0, 0.0)),
         'alpha': ((0.0, 1.0, 1.0),
                   (0.4999,1.0, 0.0),
                   (0.5, 0.0, 1.0),
                   (1.0, 1.0, 1.0)),
        }
cdict_posonly = {'red':  
                  ((0.0, 0.0, 0.0),
                   (0.25,0.0, 0.0),
                   (0.5, 1.0, 1.0),
                   (0.75,1.0, 1.0),
                   (1.0, 0.5, 0.5)),

         'green': ((0.0, 0.0, 0.0),
                   (0.2,0.0, 0.0),
                   (0.4, 0.9, 0.9),
                   (0.5, 1.0, 1.0),
                   (0.6, 0.9, 0.9),
                   (0.80,0.0, 0.0),
                   (1.0, 0.0, 0.0)),

         'blue':  ((0.0, 0.0, 0.0),
                   (1.0, 0.0, 0.0)),
         'alpha': ((0.0, 0.0, 0.0),
                   (0.25,0.0, 0.0),
                   (0.45,0.0, 0.0),
                   (0.5, 0.0, 0.0),
                   (0.55,0.0, 1.0),
                   (0.75,1.0, 1.0),
                   (1.0, 1.0, 1.0)),
        }
plt.register_cmap(name='BrainBlueRed', data=cdict)
plt.register_cmap(name='BrainBlueRedNarrowWhite', data=cdict2)
plt.register_cmap(name='BrainBlueRedNoWhite', data=cdict3)  # almost no white except at 0 value to keep background white
plt.register_cmap(name='BrainCyanRedNoWhite', data=cdict4)
plt.register_cmap(name='BrainRed', data=cdict_posonly)

In [None]:
# PLot!
from nilearn import plotting
contourfill = False  # fill the contours?
show_maps = True  # hide maps to show only contours?
resize_pct = 1.0  # force resize brain images?
colorbar_percond = False  # if False, synchronize the same colorbar range for all plots, else if True, synchronize only per condition (per row), if None, do not synchronize colorbar at all (each plot will have its own colorbar range)
colorbar_labelsize = 30
colorbar_move = 0.1  # percentage of current x position (for one figure) to shift the colorbar on the left
colorbar_always_include_zero = True  # ensure the colorbar always include the 0 value
colorbar_map = plt.get_cmap('BrainCyanRedNoWhite')  # use plt.get_cmap('BrainRed') or plotting.cm.cold_hot_r for positive correlation only and plt.get_cmap('BrainBlueRed') for divergent blue-red gradient for both positive and negative correlation maps
colorbar_normalizer = MidPointNorm  # what function to use to normalize the colorbar? Use MidpointNormalize to normalize but not center to 0, or MidPointNorm to normalize and center to 0.
contour_allcolors = ['#081e7c', '#7c0808']  # define the contours colors as a list with 1st element being negative and 2nd for positive values. Can use anything matplotlib supports, such as standard colors (eg, 'r', 'b') or  html codes (eg, '#081e7c' for dark blue and '#7c0808' for dark red).
contour_markersize = 4.

if not os.path.isdir('figs'):
    os.mkdir('figs')

# Prepare for subplotting
total_figures = sum(len(x) for x in imgs.values())  # precompute the total number of figures we will plot (so we can make a subplot)
total_rows = len(imgs.keys())
total_cols = len(list_cols)
# Use subplot
fig_root, axes = plt.subplots(total_rows, total_cols, figsize=(15*total_cols,8*total_rows), frameon=False)
plt.subplots_adjust(hspace=0, wspace=0)  # remove spacing between subplots (= bigger figures!)

# Main plotting routine
fig_row = -1
for cond, contrasts in sort_it(imgs, rows_order):
    # For each condition/group
    fig_row += 1
    fig_col = -1
    # Get colorbar range to use
    if colorbar_percond is False:
        # Synchronize all colorbar ranges of all plots
        cmax = allmax
        cmin = allmin
    elif colorbar_percond is True:
        # Synchronize the colorbars only per condition (per row)
        cmax = percond_max[cond]
        cmin = percond_min[cond]
    elif colorbar_percond is None:
        # Do not do any colorbar synchronization (range decided per image by nilearn)
        cmax = cmin = None
    for con_id, contrast in enumerate(contrasts):
        # For each contrast
        fig_col += 1
        # Get current subplot axes
        if total_rows > 1 and total_cols > 1:
            # multiple rows and columns
            cur_ax = axes[fig_row][fig_col]
        elif total_rows > 1 and total_cols == 1:
            # One contrast but multiple groups
            cur_ax = axes[fig_row]
        elif total_rows == 1 and total_cols > 1:
            # One group but multiple contrasts
            cur_ax = axes[fig_col]
        else:
            # Only one figure to plot!
            cur_ax = axes
        # Create a background glass brain plot (with no map, we will add it later)
        fig = plotting.plot_glass_brain(None, plot_abs=False, colorbar=False, symmetric_cbar=False, annotate=False, alpha=0.2, threshold=None, axes=cur_ax, figure=fig_root)
        for imtype, imageset in contrast.items():
            # For each map of this contrast, we add to the current figure!
            for im in imageset:
                # For each map of this type of imageset (each contrast can have two types: maps or contours, and then each can have several maps, usually one for negative and one for positive correlations)
                if show_maps and imtype == 'maps':
                    # Plot the map as an overlay (ie, all activated voxels will be shown along with a color to reflect the intensity)
                    if colorbar_always_include_zero:
                        # Sanity check on colorbar range: min cannot be positive and max cannot be negative (else it will still work but we will the background might be negative or positively colored instead of blank)
                        if colorbar_percond is None:
                            cmax = im.get_data().max()
                            cmin = im.get_data().min()
                        if cmin > 0:
                            cmin = 0
                        if cmax < 0:
                            cmax = 0
                    fig.add_overlay(im, colorbar=False if colorbar_percond is not None else True, vmax=cmax, vmin=cmin, cmap=colorbar_map, norm=colorbar_normalizer(midpoint=0.))
                    if colorbar_percond is None:
                        # If individal colorbar range per plot, then at least increase size of colorbar labels for readability
                        fig._colorbar_ax.tick_params(labelsize=colorbar_labelsize)
                elif imtype == 'contours':
                    # Plot only the contours of the provided map (great to just highlight significant regions)
                    # First we select the appropriate color
                    if im.get_data().min() < 0 and im.get_data().max() == 0:
                        # Only negative map, we use blue and we need to convert from negative to positive values (else nilearn won't plot)
                        contour_color = [contour_allcolors[0]]
                        # Compute absolute values, else nilearn cannot plot contours + convert back to nifti else cannot use as contours
                        im_abs = nib.Nifti1Image(np.absolute(im.get_data()), affine=im.affine)
                    elif im.get_data().max() > 0 and im.get_data().min() == 0:
                        # Positive only map
                        contour_color = [contour_allcolors[1]]
                        im_abs = im
                    else:
                        # Both positive and negative map (or null map), we plot both
                        contour_color = contour_allcolors
                        im_abs = im
                    fig.add_contours(im_abs, levels=[-0.1, 0.1], colors=contour_color, linewidths=contour_markersize, auto_resize=False, filled=contourfill)
            # Resize the figure
            if resize_pct != 1.0:
                rect = fig.rect
                fig.rect = ([x*resize_pct for x in rect])
            # Refresh display after each contrast (IPython.display)
            display.clear_output(wait=True)
            display.display(plt.gcf())

    # Plot the colorbar after the last plot of the row (or the last plot of all rows if colorbar_percond is false)
    if colorbar_percond is True or (colorbar_percond is False and (fig_col+1)*(fig_row+1) == total_figures):
        # Plot the colorbar at the last plot of the row
        fig._show_colorbar(colorbar_map, colorbar_normalizer(midpoint=0., vmin=cmin, vmax=cmax))
        box = fig._colorbar_ax.get_position()
        fig._colorbar_ax.set_position([box.x0+(box.x0*colorbar_move/total_cols), box.y0, box.width, box.height])  # move a bit the bar to the right, need to divide by number of columns (to move relative to last figure only, not to overall row, else will get too far away)
        fig._colorbar_ax.tick_params(labelsize=colorbar_labelsize)

# Add row and columns headers
for ax, col in zip(axes[0] if total_cols > 1 else (axes if total_rows > 1 else [axes]), list_cols):
    # Columns headers (contrasts names)
    ax.set_title(col, size='50')

for ax, row in zip(axes[:,0] if total_rows > 1 and total_cols > 1 else (axes if total_rows > 1 else ([axes[0]] if total_cols > 1 else [axes])), (k for k, v in sort_it(imgs, rows_order))):
    # Rows headers (condition/group names)
    ax.text(-0.1, 0.5, row, rotation=90,
        verticalalignment='center', horizontalalignment='left',
        transform=ax.transAxes,
        color='blue', fontsize=50)

# Save figure
#fig.savefig(get_fig_filename(cond, list_cols[con_id]), bbox_inches='tight', pad_inches=0, dpi=dpi_resolution)  # save each figure separately (need to disable subplot)
fig_root.savefig('figs/whole_figure.png', bbox_inches='tight', pad_inches=0, dpi=dpi_resolution)  # save all figures at once in one picture (needs subplot enabled)
# Save the colorbar separately (did not find another way of doing it nicely on the same picture, no way to reorient to horizontal...)
if colorbar_percond is False:
    a = np.array([[0,1]])
    fig_cbar = plt.figure(figsize=(4, 0.25))
    img = plt.imshow(a, cmap=colorbar_map, norm=colorbar_normalizer(midpoint=0., vmin=allmin, vmax=allmax))
    plt.gca().set_visible(False)
    cax = plt.axes([0.1, 0.2, 0.8, 0.6])
    plt.colorbar(orientation="horizontal", cax=cax)
    fig_cbar.savefig('figs/whole_figure_colorbar.png', bbox_inches='tight', pad_inches=0, dpi=dpi_resolution*2)
# Print confirmation message
print('Figure(s) saved in figs/ folder.')

# Show the final figure with the headers
plt.show()

In [None]:
# Only save an alternative of the global colorbar (useful to quickly tune the plotting parameters)
a = np.array([[0,1]])
fig_cbar = plt.figure(figsize=(4, 0.25))
img = plt.imshow(a, cmap=colorbar_map, norm=colorbar_normalizer(midpoint=0., vmin=allmin, vmax=allmax))
plt.gca().set_visible(False)
cax = plt.axes([0.1, 0.2, 1, 1])
figc = plt.colorbar(orientation="horizontal", cax=cax)
#figc.set_ticks(list(figc.get_ticks()) + [allmin, allmax])  # add the minimum and maximum values
figc.set_ticks([allmin, 0, allmax])  # show only the minimum, 0 (midpoint) and maximum values
plt.xticks(fontsize=15)
plt.show()
fig_cbar.savefig('figs/whole_figure_colorbar2.png', bbox_inches='tight', pad_inches=0, dpi=dpi_resolution*2)

In [None]:
# TEST
fig = plotting.plot_glass_brain(None, plot_abs=False, colorbar=False, symmetric_cbar=False, annotate=False, alpha=0.2, threshold=None)
fig._show_colorbar(colorbar_map, MidpointNormalize(midpoint=0., vmin=allmin, vmax=allmax))
box = fig._colorbar_ax.get_position()
fig._colorbar_ax.set_position([box.x0*1.1, box.y0, box.width, box.height])
#cb1 = mpl.colorbar.ColorbarBase(plt.gcf().get_axes()[0], cmap=colorbar_map,
#                                norm=MidpointNormalize(midpoint=0., vmin=allmin, vmax=allmax),
#                                orientation='horizontal')
#cb1.set_label('Some Units')
#cb1.draw_all()
plt.show()