In [2]:
import os
import itertools
import numpy as np
import pandas as pd
import nibabel as nib
import matplotlib.gridspec as gs
from matplotlib import colors as mc
from nilearn import plotting as nlp
from matplotlib import pyplot as plt
from matplotlib.colors import LinearSegmentedColormap

In [1]:
def make_boxes(mat, cl_def, pad=1, edge=False):
   """
   mat:    the matrix you want to do stuff to
   cl_def: a list of tuples where the first position is the
           index of the first element in the cluster. the
           second position is the index of the last element
           in the cluster
   pad:    an integer value for the number of zero spaces to add
           around clusters
   edge:   boolean argument. If True, clusters at the corners will
           be drawn full. If False, clusters will be only drawn on
           the inside edge (no white line around the matrix).
           
   returns:
   omat:   the input matrix with the spaces added
   cmat_m: the overlayed cluster boxes in a masked array
   lmat_m: a mask of the added empty spaces
   ind:    the new index positions for the data (for x_ticks...)
   """
   # Sort the cluster definitions based on the start point
   order = np.argsort([i[0] for i in cl_def])
   cl_def = [(i[0], i[1]) for i in np.array(cl_def)[order]]
   # Extract the values
   if edge:
       starts = [i[0] for i in cl_def]
       stops = [i[1]+1 for i in cl_def]
   else:
       starts = [i[0] for i in cl_def if not i[0]==0]
       stops = [i[1]+1 for i in cl_def if not i[1]+1>=mat.shape[0]]
   
   # Find the breakpoints
   bkp = list(np.unique(starts + stops))
   n_bkp = len(bkp)
   # Convert to new indices
   run = 0
   ind = list()
   for i in np.arange(mat.shape[0]):
       if i in bkp:
           run += pad
       ind.append(i+run)

   # Make a grid index
   x = [i[0] for i in itertools.product(ind, ind)]
   y = [i[1] for i in itertools.product(ind, ind)]

   # Create the output matrices
   omat = np.zeros([i+n_bkp*pad for i in mat.shape])
   cmat = np.zeros_like(omat)
   lmat = np.zeros_like(omat, dtype=bool)
   
   # Assign input mat to grid index
   omat[x, y] = mat.flatten()
   # Mask grid index for the line mask
   lmat[x,y] = True
   lmat_m = np.ma.masked_where(lmat, lmat)
   
   # Convert the input based breakpoints to the new index
   starts_c = [ind[i[0]]-pad for i in cl_def]
   stops_c = [ind[i[1]]+1 for i in cl_def]
   # Loop through the breakpoints
   for i in np.arange(len(starts_c)):
       start = starts_c[i]
       stop = stops_c[i]
       # Select the range of rows and columns to paint
       start_ind = np.arange(start, start+pad)
       stop_ind = np.arange(stop, stop+pad)
       
       # If this isn't an edge cluster or we paint them
       if not start<=0 or edge:
           # Draw the top left corner first
           cmat[start_ind, start:stop] = i+1
           cmat[start:stop, start_ind] = i+1
       # if this is an edge cluster and we don't paint them
       # only paint the bottom right corner but from the start
       else:
           # Draw the bottom right corner next
           cmat[stop_ind, :stop+pad] = i+1
           cmat[:stop+pad, stop_ind] = i+1
           continue
       if not stop>=omat.shape[0] or edge:
           # Draw the bottom right corner next
           cmat[stop_ind, start:stop+pad] = i+1
           cmat[start:stop+pad, stop_ind] = i+1
   # Mask the cluster matrix
   cmat_m = np.ma.masked_where(cmat==0, cmat)
   return omat, cmat_m, lmat_m, ind

In [None]:
# Visualize
low = 0
high = -1
f = plt.figure(figsize=(15, 15), frameon=False)
ax = f.add_subplot(111)
ab = ax.matshow(o7[low:high, low:high], vmin=0, vmax=0.8, cmap=plt.cm.viridis, aspect='auto')
ab = ax.matshow(l7[low:high, low:high], cmap=plt.cm.Greys_r, aspect='auto', alpha=1)
ab = ax.matshow(l7[low:high, low:high], cmap=plt.cm.Greys, aspect='auto', alpha=1)
ab = ax.matshow(c7[low:high, low:high], cmap=lin7, vmin=1, vmax=7, aspect='auto')

ab = ax.set_xticks([])
ab = ax.set_yticks([])
ax.set_axis_off()
f.savefig(os.path.join(fig_p, 's7_full.png'), dpi=300, bbox_inches='tight', pad_inches=0)