# Figure: bibiplots

In [35]:
%matplotlib notebook

import numpy as np
import pylab as plt
import seaborn as sns; sns.set()
import pickle
import matplotlib

import sparseRRR

def sns_styleset():
    sns.set_context('paper')
    sns.set_style('ticks')
    matplotlib.rcParams['axes.linewidth']    = .5
    matplotlib.rcParams['xtick.major.width'] = .5
    matplotlib.rcParams['ytick.major.width'] = .5
    matplotlib.rcParams['xtick.major.size'] = 2
    matplotlib.rcParams['ytick.major.size'] = 2
    matplotlib.rcParams['xtick.minor.size'] = 1
    matplotlib.rcParams['ytick.minor.size'] = 1
    matplotlib.rcParams['font.size']       = 6
    matplotlib.rcParams['axes.titlesize']  = 7
    matplotlib.rcParams['axes.labelsize']  = 6
    matplotlib.rcParams['legend.fontsize'] = 6
    matplotlib.rcParams['xtick.labelsize'] = 6
    matplotlib.rcParams['ytick.labelsize'] = 6
    matplotlib.rcParams['figure.dpi'] = 120     # only affects the notebook

sns_styleset()

In [5]:
def preprocess(data):
    X = data['counts'][:,data['mostVariableGenes']] / np.sum(data['counts'], axis=1) * 1e+6
    X = np.array(X)
    X = np.log2(X + 1)
    X = X - np.mean(X, axis=0)
    X = X / np.std(X, axis=0)

    Y = data['ephys']
    Y = Y - np.mean(Y, axis=0)
    Y = Y / np.std(Y, axis=0)
    
    return (X,Y)

In [145]:
def adjustlabels(fig, labels, max_iter=1000, eps=0.01, delta=0.1):
    N = len(labels)
    widths = np.zeros(N)
    heights = np.zeros(N)
    centers = np.zeros((N, 2))
    for i,l in enumerate(labels):
        bb = l.get_window_extent(renderer=fig.canvas.get_renderer())
        bb = bb.transformed(plt.gca().transData.inverted())
        widths[i] = bb.width
        heights[i] = bb.height
        centers[i] = (bb.min + bb.max)/2

    for i in range(max_iter):
        stop = True
        for a in range(N):
            for b in range(N):
                if ((a!=b) and
                    (np.abs(centers[a,0]-centers[b,0]) < (widths[a]+widths[b])/2 + delta) and
                    (np.abs(centers[a,1]-centers[b,1]) < (heights[a]+heights[b])/2 +  delta)):
                    
                    d = centers[a] - centers[b]
                    centers[a] += d * eps
                    centers[b] -= d * eps
                    labels[a].set_position(centers[a])
                    labels[b].set_position(centers[b])
                    stop = False
        if stop:
            break

In [64]:
data = pickle.load(open('data/scala2020.pickle', 'rb'))
X,Y = preprocess(data)
genes = data['genes'][data['mostVariableGenes']]
w,v = sparseRRR.relaxed_elastic_rrr(X, Y, rank=2, lambdau=.45, alpha=1)
print('Genes selected: {}'.format(np.sum(w[:,0]!=0)))

data = pickle.load(open('data/scala2019.pickle', 'rb'))
X,Y = preprocess(data)
genes = data['genes'][data['mostVariableGenes']]
w,v = sparseRRR.relaxed_elastic_rrr(X, Y, rank=2, lambdau=.95, alpha=.5)
print('Genes selected: {}'.format(np.sum(w[:,0]!=0)))

data = pickle.load(open('data/cadwell2016.pickle', 'rb'))
X,Y = preprocess(data)
genes = data['genes'][data['mostVariableGenes']]
w,v = sparseRRR.relaxed_elastic_rrr(X, Y, rank=2, lambdau=1.6, alpha=.5)
print('Genes selected: {}'.format(np.sum(w[:,0]!=0)))

Genes selected: 20
Genes selected: 20
Genes selected: 20


In [147]:
fig = plt.figure(figsize=(7.3, 5))

titles = ['Cadwell et al. 2016', 'Scala et al. 2019', 'Scala et al. 2020']
files = ['cadwell2016.pickle', 'scala2019.pickle', 'scala2020.pickle']
lambdas = [1.6, .95, .45]
alphas = [.5, .5, 1]
sizes = [5, 5, 3]
scaleFactor = 3
xylim = 3.4

for dataset in range(3):
    data = pickle.load(open('data/' + files[dataset], 'rb'))
    X,Y = preprocess(data)
    genes = data['genes'][data['mostVariableGenes']]
    w,v = sparseRRR.relaxed_elastic_rrr(X, Y, rank=2, lambdau=lambdas[dataset], 
                                        alpha=alphas[dataset])
    Zx = X @ w
    Zy = Y @ v
    Zx = Zx / np.std(Zx, axis=0)
    Zy = Zy / np.std(Zy, axis=0)
    
    if dataset==0:
        colors = np.zeros((5,3))
        colors[0,:] = [217,95,2]
        colors[-1,:] = [27,158,119]
        for i in range(1,4):
            colors[i,:] = colors[0,:] * (5-i)/4 + colors[-1,:] * i/4
        colors = colors/256
        colors = [colors[t-1] for t in data['cellTypes']]
    elif dataset==1:
        colors = ['red' if t=='S1' else 'orange' for t in data['regions']]
    elif dataset==2:
        colors = [data['colors'][t] for t in data['ttype']]
    
    plt.subplot(2, 3, dataset+1, aspect='equal')
    plt.scatter(Zy[:,0], Zy[:,1], s=sizes[dataset], color=colors)
    
    plt.ylim([-xylim, xylim])
    plt.xlim([-xylim, xylim])
    plt.xticks([])
    plt.yticks([])
    plt.title(titles[dataset])
#     if dataset==0:
#         plt.ylabel('RNA expression', fontsize=7)

    labels = []
    L = np.corrcoef(np.concatenate((Zx[:,:2], X), axis=1), rowvar=False)[2:,:2]
    for i in np.where(w[:,0]!=0)[0]:
        plt.plot([0, scaleFactor*L[i,0]], [0, scaleFactor*L[i,1]], linewidth=.75, color=[.4, .4, .4],
             zorder=1)
        t = plt.text(scaleFactor*L[i,0], scaleFactor*L[i,1], genes[i], 
             ha='center', va='center', color='k', fontsize=6,
             bbox=dict(facecolor='w', edgecolor='#777777', boxstyle='round', linewidth=.5, pad=.2))
        labels.append(t)
    adjustlabels(fig, labels)
    circ = plt.Circle((0,0), radius=scaleFactor, color=[.4, .4, .4], fill=False, linewidth=.5)
    plt.gca().add_patch(circ)

    
    plt.subplot(2, 3, dataset+4, aspect='equal')
    plt.scatter(Zy[:,0], Zy[:,1], s=sizes[dataset], color=colors)
    
    plt.ylim([-xylim, xylim])
    plt.xlim([-xylim, xylim])
    plt.xticks([])
    plt.yticks([])
#     if dataset==0:
#         plt.ylabel('Electrophysiology', fontsize=7)

    labels = []
    L = np.corrcoef(np.concatenate((Zy[:,:2], Y), axis=1), rowvar=False)[2:,:2]
    for i in range(Y.shape[1]):
        plt.plot([0, scaleFactor*L[i,0]], [0, scaleFactor*L[i,1]], linewidth=.75, color=[.4, .4, .4],
             zorder=1)
        t = plt.text(scaleFactor*L[i,0], scaleFactor*L[i,1], data['ephysNames'][i], 
             ha='center', va='center', color='k', fontsize=6,
             bbox=dict(facecolor='w', edgecolor='#777777', boxstyle='round', linewidth=.5, pad=.2))
        labels.append(t)
    adjustlabels(fig, labels)
    circ = plt.Circle((0,0), radius=scaleFactor, color=[.4, .4, .4], fill=False, linewidth=.5)
    plt.gca().add_patch(circ)
    
sns.despine(left=True, bottom=True)
plt.tight_layout()

fig.text(.02, .97, 'a', fontsize=8, fontweight='bold')
fig.text(.35, .97, 'b', fontsize=8, fontweight='bold')
fig.text(.68, .97, 'c', fontsize=8, fontweight='bold')
fig.text(.02, .45, 'd', fontsize=8, fontweight='bold')
fig.text(.35, .45, 'e', fontsize=8, fontweight='bold')
fig.text(.68, .45, 'f', fontsize=8, fontweight='bold')

plt.savefig('figures/bibiplots.png', dpi=200)
plt.savefig('figures/bibiplots.pdf')

<IPython.core.display.Javascript object>