In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import einops
import pandas as pd

import json
import cv2
import os

import jax
import jax.numpy as jnp
import jax.nn as jnn

import json
from tqdm.notebook import tqdm
from mpl_toolkits.axes_grid1 import make_axes_locatable

In [None]:
# rbf kernel
def kernel(x,x2,a,l):
    diff = jnp.linalg.norm((x[None,...]-x2[:,None,:])/l[:,None,None],axis=-1)**2
    return a[:,None,None]**2*jnp.exp(-0.5*diff)

key = jax.random.PRNGKey(0)
n = 50
x,y = np.meshgrid(np.linspace(0,1,n),np.linspace(0,1,n))
xy = np.stack([x.flatten(),y.flatten()]).T

n2 = 50
x2,y2 = np.meshgrid(np.linspace(350/2048,1,n),np.linspace(0,1800/2048,n2))
xy2 = np.stack([x2.flatten(),y2.flatten()]).T

In [None]:
def gen(mec,fit,combined,xy2):
    alpha_g_locations = jnp.array(fit.filter(regex=f'alpha_{mec}_locations').values)
    rho_g_locations = jnp.array(einops.rearrange(fit.filter(regex=f'rho_{mec}_locations').values,'i (k j) -> i k j',
                            i=4000,j=2,k=alpha_g_locations.shape[1]))
    sigma = jnp.array(fit['sigma'].values)
    eta_g = jnp.array(fit.filter(regex=f'eta_{mec}').values)
    offset_g = jnp.array(fit.filter(regex=f'offset_{mec}_location').values)
    batch_size = 25 # decrease if OOM
    data_xy = jnp.array(combined['x'])
    Gs = []
    for idx,(i,j) in enumerate(zip(combined['loc_ids_1'],combined['loc_ids_2'])):
        eta_pred = jnp.array(np.random.normal(0,1,(batch_size,xy.shape[0])))
        GG = []
        for sub in tqdm(range(0,4000,batch_size)):
            s1 = sub
            s2 = sub+batch_size
            xy_c = jnp.concatenate([data_xy[(i-1):j,:],xy2])
            eta_loc = eta_g[s1:s2,(i-1):j]
            eta = jnp.hstack([eta_loc,eta_pred])
            nn = xy_c.shape[0]
            L_K = jnp.linalg.cholesky(kernel(xy_c,xy_c,alpha_g_locations[s1:s2,idx],rho_g_locations[s1:s2,idx])+\
                                    jnp.eye(nn)[None,...]*(sigma[s1:s2,None,None]**2+1e-2)) # 1e-1
            if mec=='g':
                G = einops.rearrange(jnn.softplus(jnp.einsum('nij,nj->ni',L_K,eta)+offset_g[s1:s2,idx,None])[:,-(n2*n2):],'i (j k) -> i j k',j=n,k=n)
            else:
                G = einops.rearrange(jnp.arcsin(jnn.sigmoid(jnp.einsum('nij,nj->ni',L_K,eta)+offset_g[s1:s2,idx,None]))[:,-(n2*n2):],'i (j k) -> i j k',j=n,k=n)

            GG.append(np.array(G))

        GG = np.concatenate(GG)
        Gs.append(GG)
    return Gs

## Gradient 

In [None]:
with open('data/processed/gradient.json','r') as f:
    combined = json.load(f)
fit = pd.read_csv('data/processed/gradient.csv')

ids = np.array(combined['location_identifier'])-1
loc_ids = np.array(combined['loc_repeat'])-1
xs = np.array(combined['x'])

Gs = gen('g',fit,combined,xy2)

In [None]:
gv = np.load('data/abs_g.npy')*316039.3/282743.3388230814
phiv = np.load('data/phi.npy')

gvm = np.load('data/abs_g_mean.npy')*316039.3/282743.3388230814
phivm = np.load('data/phi_mean.npy')

In [None]:
from matplotlib.colors import LinearSegmentedColormap
colors = [
    (1,0,0),
    (1,0,0),
    (0.6,0.6,0.6),
    (0.6,0.6,0.6),
    (0,0,1),
    (0,0,1),
]
cm = LinearSegmentedColormap.from_list('smooth',colors,N=256)
cm_flip = LinearSegmentedColormap.from_list('smooth',np.array(colors)[::-1],N=256)
cm

In [None]:
l1 = combined['loc_ids_1'][2]
l2 = combined['loc_ids_2'][2]
for i in range(l1,l2+1):
    mask = np.where(np.array(combined['idents'])==i)[0]
    xc = np.tile(np.array(combined['x'])[i-1][None,...]*2048*0.325,(len(mask),1))
    err = np.abs(np.mean(gv[mask])-np.array([np.percentile(gv[mask],5),np.percentile(gv[mask],95)])[...,None])
    print(gv[mask])

In [None]:
fig = plt.figure(figsize=(12,5.1), constrained_layout=True)
subfigs = fig.subfigures(1, 2,wspace=0.0,width_ratios=[1, 4])

ax0 = subfigs[0].subplots(2, 1, subplot_kw={"projection": "3d"})
ax = subfigs[1].subplots(2, 5)
m_point = 2048*0.325/2
#spacing = np.linspace(0,49,5).astype(int)[::-1]
spacing = np.array([0,12,24,36,48]).astype(int)[::-1]
for idx2,(aa2,aa,loc) in enumerate(zip(ax0,ax,[2,3])):
    im1 = aa2.plot_surface(x2*2048*0.325,y2*2048*0.325,np.nanmedian(Gs[loc],axis=0),facecolors=cm(y2)[...,:-1],shade=True,alpha=0.8,linewidth=0)
    #cbaxes = fig.add_axes([0.15, 0.9, 0.2, 0.03])
    #cb = fig.colorbar(im1,cax=cbaxes,location='top',ticks=[-1, 0, 1])
    #cb.ax.set_xticklabels(['0', '30', '60'])
    #subfigs[0].colorbar(im1, ax=aa2)
    aa2.set_box_aspect(aspect=(2, 2, 2), zoom=1.)
    aa2.view_init(azim=110, elev=30)

    l1 = combined['loc_ids_1'][loc]
    l2 = combined['loc_ids_2'][loc]
    for i in range(l1,l2+1):
        mask = np.where(np.array(combined['idents'])==i)[0]
        xc = np.tile(np.array(combined['x'])[i-1][None,...]*2048*0.325,(len(mask),1))
        err = np.abs(np.mean(gv[mask])-np.array([np.percentile(gv[mask],5),np.percentile(gv[mask],95)])[...,None])
        aa2.scatter3D(xc[0,0],xc[0,1],gv[mask].mean(),color='black')
    aa2.set_zlim(0,60)
    aa2.set_xlabel(r'y [$\mu m$]')
    aa2.set_ylabel(r'x [$\mu m$]')
    for idx,(a,b) in enumerate(zip(aa,spacing)):
        med = np.mean(Gs[loc][...,b:(b+1)],axis=(0,2))
        low = np.percentile(Gs[loc][...,b:(b+1)],5,axis=(0,2))
        up = np.percentile(Gs[loc][...,b:(b+1)],95,axis=(0,2))
        x_range = np.linspace(0,2048*0.325,Gs[loc].shape[1])
        half = x_range.max()/2
        ll = np.where(x_range>half-70)[0][0]
        uu = np.where(x_range<half+70)[0][-1]
        a.plot(x_range,med,color='black')
        a.fill_between(x_range[:ll],low[:ll],up[:ll],alpha=0.2,color='crimson')
        a.fill_between(x_range[ll:uu],low[ll:uu],up[ll:uu],alpha=0.2,color='gray')
        a.fill_between(x_range[uu:],low[uu:],up[uu:],alpha=0.2,color='blue')

        x_coords = np.array(combined['x'])[np.where(loc_ids==loc)[0]]*2048*0.325
        closest = (np.abs(x_range[...,None] - x_coords[:,1][None,...])).argmin(axis=0)
        a.scatter(x_coords[:,1],med[closest],color='black',s=20,edgecolors=None,alpha=1.0)

        l1 = combined['loc_ids_1'][loc]
        l2 = combined['loc_ids_2'][loc]
        # for i in range(l1,l2+1):
        #     mask = np.where(np.array(combined['idents'])==i)[0]
        #     xc = np.tile(np.array(combined['x'])[i-1][None,...]*2048*0.325,(len(mask),1))
        #     err = np.abs(np.mean(gv[mask])-np.array([np.percentile(gv[mask],5),np.percentile(gv[mask],95)])[...,None])
        #     a.errorbar(xc[0,1],gv[mask].mean(),yerr=err,color='black')
        mask = np.isin(np.array(combined['idents']),range(l1,l2+1))
        mask2 = x_coords[:,1]>500


        # l1 = combined['loc_ids_1'][loc]
        # l2 = combined['loc_ids_2'][loc]
        # for i in range(l1,l2+1):
        #     mask = np.where(np.array(combined['idents'])==i)[0]
        #     xc = np.tile(np.array(combined['x'])[i-1][None,...]*2048*0.325,(len(mask),1))
        #     err = np.abs(np.median(gv[mask])-np.array([np.percentile(gv[mask],5),np.percentile(gv[mask],95)])[...,None])
        #     a.errorbar(xc[0,1],gv[mask].mean(),color='black',marker='o')
        # mask = np.isin(np.array(combined['idents']),range(l1,l2+1))
        # mask2 = x_coords[:,1]>500
        # a.axhline(np.median(gvm[np.where(loc_ids==loc)[0]][mask2]),0.6,0.95,color='blue',linestyle='--')
        # a.axhline(np.median(gvm[np.where(loc_ids==loc)[0]][~mask2]),0.05,0.4,color='red',linestyle='--')
        m1 = np.array(combined['idents'])-1
        m2 = np.isin(np.array(combined['idents']),range(l1,l2+1))
        xc = np.array(combined['x'])[m1][m2]*2048*0.325
        mask2 = xc[:,1]>500
        yerr=np.abs(np.median(gv[m2][mask2])-np.array([np.percentile(gv[m2][mask2],5),np.percentile(gv[m2][mask2],95)])[...,None])
        yerr2=np.abs(np.median(gv[m2][~mask2])-np.array([np.percentile(gv[m2][~mask2],5),np.percentile(gv[m2][~mask2],95)])[...,None])
        eb1 = a.errorbar(x=550,y=np.median(gv[m2][mask2]),xerr=130,yerr=yerr,color='blue',linestyle='--')
        eb2 = a.errorbar(x=150,y=np.median(gv[m2][~mask2]),xerr=130,yerr=yerr2,color='red',linestyle='--')
        l1 = combined['loc_ids_1'][loc]
        l2 = combined['loc_ids_2'][loc]
        # for i in range(l1,l2+1):
        #     mask = np.where(np.array(combined['idents'])==i)[0]
        #     xc = np.tile(np.array(combined['x'])[i-1][None,...]*2048*0.325,(len(mask),1))
        #     err = np.abs(np.median(gv[mask])-np.array([np.percentile(gv[mask],5),np.percentile(gv[mask],95)])[...,None])
        #     a.errorbar(xc[0,1],gv[mask].mean(),color='black',marker='o')
        eb1[-1][0].set_linestyle('--')
        eb1[-1][1].set_linestyle('--')
        eb2[-1][0].set_linestyle('--')
        eb2[-1][1].set_linestyle('--')
        a.set_ylim(0,62)
        a.spines[['right', 'top']].set_visible(False)
        if idx==0:
            a.set_ylabel(r'$|G^*|$ [Pa]')
        if idx2==1:
            a.set_xlabel('x [$\mu m$]')
        else:
            si = (spacing[2]-b)/50*2048*0.325
            sign = ''
            if si>0:
                sign = '+'
            elif si<0:
                sign = '-'
            a.set_title(r'{:} {:.0f} [$\mu m$]'.format(sign,np.abs(si)))
        #a.text(0.02,0.85,alp,transform=a.transAxes,fontsize=20,clip_on=True)

cbaxes = subfigs[1].add_axes([-0.3, 0.3, 0.01, 0.3])
import matplotlib
norm = matplotlib.colors.Normalize(vmin=1, vmax=2)
#cb = fig.colorbar(im1,cax=cbaxes,location='left',ticks=[-1, 0, 1])
cb = subfigs[1].colorbar(matplotlib.cm.ScalarMappable(norm=norm,cmap=cm_flip),cbaxes,pad=0.1)
cb.ax.set_zorder(1000)
cb.ax.set_title('Collagen\nConcentration', rotation=0)
cb.set_label('[mg/ml]')
#cb.ax.set_xticklabels(['0', '30', '60'])
#subfigs[0].colorbar(im1, ax=aa2)

plt.show()
# fig.savefig('results/gradients_supp_line.png',dpi=300,bbox_inches='tight',transparent=False,facecolor='white')


In [None]:
fig = plt.figure(figsize=(6,6), constrained_layout=True)
subfigs = fig.subfigures(1, 2,wspace=0.1,width_ratios=[1, 1])

ax = subfigs[0].subplots(2, 1, subplot_kw={"projection": "3d"})

ax[0].plot_surface(x2*2048*0.325,y2*2048*0.325,np.nanmedian(Gs[2],axis=0),facecolors=cm(y2)[...,:-1],alpha=0.8,linewidth=0.)
im1 = ax[1].plot_surface(x2*2048*0.325,y2*2048*0.325,np.nanmedian(Gs[3],axis=0),facecolors=cm(y2)[...,:-1],alpha=0.8,linewidth=0.)
# cbaxes = fig.add_axes([0.15, 0.9, 0.2, 0.03])
# cb = fig.colorbar(im1,cax=cbaxes,location='bottom',ticks=[-1, 0, 1])
#cb = subfigs[0].colorbar(im1,location='bottom',ticks=[-1, 0, 1])
#cb.ax.set_xticklabels(['0', '30', '60'])
#fig.colorbar(im1,location='bottom')

l1 = combined['loc_ids_1'][loc]
l2 = combined['loc_ids_2'][loc]
for a in ax:
    for i in range(l1,l2+1):
        mask = np.where(np.array(combined['idents'])==i)[0]
        xc = np.tile(np.array(combined['x'])[i-1][None,...]*2048*0.325,(len(mask),1))
        err = np.abs(np.mean(gv[mask])-np.array([np.percentile(gv[mask],5),np.percentile(gv[mask],95)])[...,None])
        a.scatter3D(xc[0,0],xc[0,1],gv[mask].mean(),color='black')

for (a,alp) in zip(ax,['A','B']):
    a.view_init(azim=120, elev=30)
    a.set_xlabel(r'y [$\mu m$]')
    a.set_ylabel(r'x [$\mu m$]')

    a.text2D(0.02,0.9,alp,transform=a.transAxes,fontsize=20,clip_on=True)

m_point = 2048*0.325/2
ax2 = subfigs[1].subplots(2, 1)
for idx,(loc,a,alp) in enumerate(zip([2,3],ax2,['C','D'])):
    med = np.mean(Gs[loc][...,24:25],axis=(0,2))
    low = np.percentile(Gs[loc][...,24:25],5,axis=(0,2))
    up = np.percentile(Gs[loc][...,24:25],95,axis=(0,2))
    x_range = np.linspace(0,2048*0.325,Gs[loc].shape[1])
    half = x_range.max()/2
    ll = np.where(x_range>half-70)[0][0]
    uu = np.where(x_range<half+70)[0][-1]
    a.plot(x_range,med,color='black')
    a.fill_between(x_range[:ll],low[:ll],up[:ll],alpha=0.2,color='crimson')
    a.fill_between(x_range[ll:uu],low[ll:uu],up[ll:uu],alpha=0.2,color='gray')
    a.fill_between(x_range[uu:],low[uu:],up[uu:],alpha=0.2,color='blue')

    x_coords = np.array(combined['x'])[np.where(loc_ids==loc)[0]]*2048*0.325
    closest = (np.abs(x_range[...,None] - x_coords[:,1][None,...])).argmin(axis=0)
    a.scatter(x_coords[:,1],med[closest],color='black',s=20,edgecolors=None,alpha=1.0)

    l1 = combined['loc_ids_1'][loc]
    l2 = combined['loc_ids_2'][loc]
    # for i in range(l1,l2+1):
    #     mask = np.where(np.array(combined['idents'])==i)[0]
    #     xc = np.tile(np.array(combined['x'])[i-1][None,...]*2048*0.325,(len(mask),1))
    #     err = np.abs(np.mean(gv[mask])-np.array([np.percentile(gv[mask],5),np.percentile(gv[mask],95)])[...,None])
    #     a.errorbar(xc[0,1],gv[mask].mean(),yerr=err,color='black')
    mask = np.isin(np.array(combined['idents']),range(l1,l2+1))
    mask2 = x_coords[:,1]>500
    # a.axhline(np.median(gvm[np.where(loc_ids==loc)[0]][mask2]),0.6,0.95,color='blue')
    # a.axhline(np.median(gvm[np.where(loc_ids==loc)[0]][~mask2]),0.05,0.4,color='red')
    m1 = np.array(combined['idents'])-1
    m2 = np.isin(np.array(combined['idents']),range(l1,l2+1))
    xc = np.array(combined['x'])[m1][m2]*2048*0.325
    mask2 = xc[:,1]>500
    yerr=np.abs(np.median(gv[m2][mask2])-np.array([np.percentile(gv[m2][mask2],5),np.percentile(gv[m2][mask2],95)])[...,None])
    yerr2=np.abs(np.median(gv[m2][~mask2])-np.array([np.percentile(gv[m2][~mask2],5),np.percentile(gv[m2][~mask2],95)])[...,None])
    eb1 = a.errorbar(x=550,y=np.median(gv[m2][mask2]),xerr=130,yerr=yerr,color='blue',linestyle='--')
    eb2 = a.errorbar(x=150,y=np.median(gv[m2][~mask2]),xerr=130,yerr=yerr2,color='red',linestyle='--')
    l1 = combined['loc_ids_1'][loc]
    l2 = combined['loc_ids_2'][loc]
    # for i in range(l1,l2+1):
    #     mask = np.where(np.array(combined['idents'])==i)[0]
    #     xc = np.tile(np.array(combined['x'])[i-1][None,...]*2048*0.325,(len(mask),1))
    #     err = np.abs(np.median(gv[mask])-np.array([np.percentile(gv[mask],5),np.percentile(gv[mask],95)])[...,None])
    #     a.errorbar(xc[0,1],gv[mask].mean(),color='black',marker='o')
    eb1[-1][0].set_linestyle('--')
    eb1[-1][1].set_linestyle('--')
    eb2[-1][0].set_linestyle('--')
    eb2[-1][1].set_linestyle('--')
    if idx==0:
        a.set_ylim(0,62)
    else:
        a.set_ylim(10,50)
    a.set_ylabel(r'$|G^*|$ [Pa]')
    a.spines[['right', 'top']].set_visible(False)
    a.text(0.02,0.85,alp,transform=a.transAxes,fontsize=20,clip_on=True,zorder=1)
ax2[1].set_xlabel(r'x [$\mu m$]')


import matplotlib
norm = matplotlib.colors.Normalize(vmin=1, vmax=2)
cbaxes = subfigs[1].add_axes([-0.4, 0.96, 0.4, 0.02])
cb = subfigs[1].colorbar(matplotlib.cm.ScalarMappable(norm=norm,cmap=cm_flip),cbaxes,pad=0.1,location='bottom')
#cb.ax.set_zorder(1000)
cb.ax.set_title('Collagen concentration', rotation=0)
cb.set_label('[mg/ml]')

plt.show()
# fig.savefig('results/gradients_line.png',dpi=300,bbox_inches='tight',transparent=False,facecolor='white')
# 

In [None]:
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable

## Example spatial maps

In [None]:
with open('data/processed/sample.json','r') as f:
    combined = json.load(f)
fit = pd.read_csv('data/processed/sample.csv')

ids = np.array(combined['location_identifier'])-1
loc_ids = np.array(combined['loc_repeat'])-1
xs = np.array(combined['x'])

Gs = gen('g',fit,combined,xy)
Gs_phi = gen('phi',fit,combined,xy)

In [None]:
# (n locations, mcmc samples, x, y)
print(len(Gs),Gs[0].shape)

In [None]:
def bar(im,ax,fig):
    divider = make_axes_locatable(ax)
    cax = divider.append_axes('right', size='5%', pad=0.05)
    cbar = fig.colorbar(im, cax=cax, orientation='vertical')


# demonstrate with two locations
for loc in [0,2]:
    x_coords = xs[loc_ids==loc]
    fig,axes = plt.subplots(2,2)
    for ax,field in zip(axes,[Gs,Gs_phi]):
        im1 = ax[0].imshow(field[loc].mean(axis=0),cmap='jet',extent=(0,1,0,1),origin='lower',alpha=0.7)
        im2 = ax[1].imshow(field[loc].std(axis=0),cmap='jet',extent=(0,1,0,1),origin='lower',alpha=0.7)
        ax[1].scatter(*x_coords.T,color='black',marker='x')
        bar(im1,ax[0],fig)
        bar(im2,ax[1],fig)

    axes[0,0].set_ylabel(r'$|G^*|$')
    axes[1,0].set_ylabel(r'$\phi$')

    axes[0,0].set_title('mean')
    axes[0,1].set_title('std')

    for ax in axes.ravel():
        ax.xaxis.set_visible(False)
        ax.yaxis.set_visible(False)

In [None]:
fig,a = plt.subplots(1,1,figsize=(10,5), subplot_kw={"projection": "3d"})
a.set_box_aspect(aspect=(2, 5, 2), zoom=1.3)


m_point = 2048*0.325/2
max_point = m_point*2
spacing = np.linspace(0,47,4).astype(int)
loc = 2
for idx,b in enumerate(spacing):
    med = np.mean(Gs[loc][...,b:(b+1)],axis=(0,2))
    low = np.percentile(Gs[loc][...,b:(b+1)],5,axis=(0,2))
    up = np.percentile(Gs[loc][...,b:(b+1)],95,axis=(0,2))
    x_range = np.linspace(0,2048*0.325,Gs[loc].shape[1])
    half = x_range.max()/2
    ll = np.where(x_range>half-70)[0][0]
    uu = np.where(x_range<half+70)[0][-1]
    #a.add_collection3d(plt.plot(x_range,med,color='black'))
    a.add_collection3d(a.fill_between(x_range[:ll],low[:ll],up[:ll],alpha=0.2,color='crimson'),zdir='y',zs=b/max(spacing)*max_point)
    a.add_collection3d(a.fill_between(x_range[ll:uu],low[ll:uu],up[ll:uu],alpha=0.2,color='gray'),zdir='y',zs=b/max(spacing)*max_point)
    a.add_collection3d(a.fill_between(x_range[uu:],low[uu:],up[uu:],alpha=0.2,color='blue'),zdir='y',zs=b/max(spacing)*max_point)

    x_coords = np.array(combined['x'])[np.where(loc_ids==loc)[0]]*2048*0.325
    closest = (np.abs(x_range[...,None] - x_coords[:,1][None,...])).argmin(axis=0)
    #a.scatter(x_coords[:,1],med[closest],color='black',s=20,edgecolors=None,alpha=1.0)

    l1 = combined['loc_ids_1'][loc]
    l2 = combined['loc_ids_2'][loc]
    for i in range(l1,l2+1):
        mask = np.where(np.array(combined['idents'])==i)[0]
        xc = np.tile(np.array(combined['x'])[i-1][None,...]*2048*0.325,(len(mask),1))
        err = np.abs(np.mean(gv[mask])-np.array([np.percentile(gv[mask],5),np.percentile(gv[mask],95)])[...,None])
        a.scatter3D(xc[0,1],xc[0,0],gv[mask].mean(),color='black')
    mask = np.isin(np.array(combined['idents']),range(l1,l2+1))
    mask2 = x_coords[:,1]>500


    l1 = combined['loc_ids_1'][loc]
    l2 = combined['loc_ids_2'][loc]
    for i in range(l1,l2+1):
        mask = np.where(np.array(combined['idents'])==i)[0]
        xc = np.tile(np.array(combined['x'])[i-1][None,...]*2048*0.325,(len(mask),1))
        err = np.abs(np.median(gv[mask])-np.array([np.percentile(gv[mask],5),np.percentile(gv[mask],95)])[...,None])

        #a.scatter(xc[0,1],gv[mask].mean(),color='black',marker='o')
    # eb1[-1][0].set_linestyle('--')
    # eb1[-1][1].set_linestyle('--')
    # eb2[-1][0].set_linestyle('--')
    # eb2[-1][1].set_linestyle('--')
    a.set_xlim(0,700)
    a.set_zlim(0,700)
    a.set_zlim(0,70)
    a.spines[['right', 'top']].set_visible(False)
    if idx==0:
        a.set_zlabel(r'$|G^*|$ [Pa]')
    #a.text(0.02,0.85,alp,transform=a.transAxes,fontsize=20,clip_on=True)
a.view_init(azim=-40, elev=20,vertical_axis='z')
plt.show()
#fig.savefig('results/gradients.png',dpi=300,bbox_inches='tight',transparent=False)
