In [1]:
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import scipy
from sklearn.decomposition import PCA

In [2]:
from helper_funcs import *

In [3]:
data_latent_vec = np.load('../Data/latent_vecs2-new.npz')

In [4]:
list(data_latent_vec.keys())

['salient_vec_abide',
 'background_vec_abide',
 'vae_vec_abide',
 'salient_vec_sfari',
 'background_vec_sfari',
 'vae_vec_sfari']

In [5]:
salient_vec_abide = data_latent_vec['salient_vec_abide']
background_vec_abide = data_latent_vec['background_vec_abide']
vae_vec_abide = data_latent_vec['vae_vec_abide']
salient_vec_sfari = data_latent_vec['salient_vec_sfari']
background_vec_sfari = data_latent_vec['background_vec_sfari']
vae_vec_sfari = data_latent_vec['vae_vec_sfari']

In [6]:
def data2cmat(data):
    return np.array([squareform(pdist(data[s,:,:],metric='euclidean')) for s in range(data.shape[0])])

In [7]:
def plot_nice_bar(key,rsa,ax=None,figsize=None,dpi=None,fontsize=None,fontsize_star=None,fontweight=None,line_width=None,marker_size=None,title=None,report_t=False,do_pairwise_stars=True,do_one_sample_stars=True):
    #key = 'ADOS_gotham_PC1'
    import seaborn as sns
    from scipy.stats import ttest_1samp
    
    from scipy.stats import ttest_ind as ttest
    pallete = sns.color_palette()
    pallete_new = sns.color_palette()
    if not figsize:
        figsize = (5,2)
    if not dpi:
        dpi = 300
        
    if not ax:
        fig, ax = plt.subplots(1,1,figsize=figsize,dpi=dpi)

    pallete_new[1]=pallete[0]
    pallete_new[0]=pallete[1]
    pallete_new[0] = tuple(np.array((.5,.5,.5)))

    data=rsa[key]
    n = data.shape[0]
    c = data.shape[1]
    x = np.arange(c)
    
    if not fontsize:
        fontsize = 16
        
    if not fontsize_star:
        fontsize_star = 25
    if not fontweight:        
        fontweight = 'bold'
    if not line_width:    
        line_width = 2.5
    if not marker_size:            
        marker_size = .1
    


    for i in range(c):
        plot_data = np.zeros(data.shape)
        plot_data[:,i] = data[:,i]

        xs = np.repeat(i,n)+(np.random.rand(n)-.5)*.25
        sc = plt.scatter(xs,data[:,i],c='k',s = marker_size)
        b = sns.barplot(data=plot_data,errcolor='r',linewidth=line_width,errwidth=line_width,facecolor=np.hstack((np.array(pallete_new[i]),.3)),edgecolor=np.hstack((np.array(pallete_new[i]),1)))
        #sns.barplot(x=i,y=data.mean(axis=0)[i],errcolor='r',linewidth=2.5,facecolor=np.hstack((np.array(pallete_new[i]),.7)),edgecolor=np.hstack((np.array(pallete_new[i]),.7)))

    #xlbls = ['VAE','BG','SL']
    locs, labels = plt.yticks()  
    new_y = locs
    new_y = np.linspace(locs[0],locs[-1],6)
    plt.yticks(new_y,labels=[f'{yy:.2f}' for yy in new_y],fontsize=fontsize,fontweight=fontweight)
    plt.ylabel('model fit (r)',fontsize=fontsize,fontweight=fontweight)

    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['bottom'].set_visible(False)

    for axis in ['top','bottom','left','right']:
            ax.spines[axis].set_linewidth(line_width)

    #xlbls = ['VAE','BG','SL']
    xlbls = ['VAE','Shared','Specific']
    plt.xticks(np.arange(3),labels=xlbls,fontsize=fontsize,fontweight=fontweight)


    #do_one_sample_stars = True
    if do_one_sample_stars:
        one_sample = np.array([ttest_1samp(data[:,i],0) for i in range(3)])
        one_sample_thresh = np.array((1,.05,.001,.0001))
        one_sample_stars = np.array(('n.s.','*','**','***'))
        #xlbls = ['VAE','BG','SL']
        for i in range(c):
            these_stars = one_sample_stars[max(np.nonzero(one_sample[i,1]<one_sample_thresh)[0])]
            xlbls[i] = f'{xlbls[i]}\n({these_stars})'
        plt.xticks(np.arange(3),labels=xlbls,fontsize=fontsize,fontweight=fontweight,horizontalalignment='center',multialignment='center')

    pairwise_t = np.zeros((3,3))
    pairwise_p = np.zeros((3,3))

    pairwise_sample_thresh = np.array((1,.05,.001,.0001))
    pairwise_sample_stars = np.array(('n.s.','*','**','***'))

    if report_t:
        for i in range(c):
            for j in range(c):
                t,p = ttest(data[:,i],data[:,j])
                #mnames = ['VAE','BG','SL']
                mnames = xlbls

                if p > .001:
                    print(f'{key} {mnames[i]} >  {mnames[j]} | t({data.shape[0]-1}) = {t:.2f} p = {p:.2f}')
                else:
                    print(f'{key} {mnames[i]} >  {mnames[j]} | t({data.shape[0]-1}) = {t:.2f} p $<$ .001')
                pairwise_t[i,j] = t
                pairwise_p[i,j] = p

    comps = [[1,2]]
    #do_pairwise_stars = True
    if do_pairwise_stars:
        for comp_idx in range(len(comps)):
            this_comp = comps[comp_idx]
            sig_idx = max(np.nonzero(pairwise_p[this_comp[0],this_comp[1]]<pairwise_sample_thresh)[0])
            max_y = new_y[-1] + comp_idx*.05
            xs = np.array(this_comp)
            stars = pairwise_sample_stars[sig_idx]
            plt.plot(xs,[max_y,max_y],'k',linewidth=line_width)
            plt.text(xs.mean(),max_y,stars,fontsize=fontsize_star,horizontalalignment='center',fontweight=fontweight)
        
    #plt.plot(plt.xlim(),[0,0],'k',linewidth=line_width)
    ylim = plt.ylim()
    plt.ylim(np.array(ylim)*(1,1.1))
        
    #plt.xlabel(key,fontsize=fontsize,labelpad=25.1,fontweight=fontweight)    
    
    if not title:
        plt.title(key,fontsize=fontsize*1.5,pad=2,fontweight=fontweight)    
    else:
        plt.title(title,fontsize=fontsize*1.5,pad=2,fontweight=fontweight)

In [8]:
cmat_salient_vec_abide = data2cmat(salient_vec_abide)
cmat_background_vec_abide = data2cmat(background_vec_abide)
cmat_vae_vec_abide = data2cmat(vae_vec_abide)
cmat_salient_vec_sfari = data2cmat(salient_vec_sfari)
cmat_background_vec_sfari = data2cmat(background_vec_sfari)
cmat_vae_vec_sfari = data2cmat(vae_vec_sfari)

In [12]:
list(np.load(dataFnOut).keys())

['data']

In [9]:
dataFnOut = '../Data/ABIDE-Anat-64iso-S982.npz'
dfFnOut = '../Data/ABIDE_legend_S982.csv'

ABIDE_data = np.load(dataFnOut)['data']
ABIDE_subs = np.load(dataFnOut)['subs']

df = pd.read_csv(dfFnOut)

patients = df['DxGroup'].values==1
controls = df['DxGroup'].values==2
abide_asd = ABIDE_data[patients,:,:,:]

arr = np.load('../Data/SFARI-Anat-64iso-S121.npz')
dfs = pd.read_csv('../Data/sfari_legend_S121.csv')

SFARI_data = arr['data']
SFARI_subs = arr['subs']

sfari_subs_td = dfs['family_type'].values=='non-familial-control'
sfari_subs_dupl = dfs['family_type'].values=='16p-duplication'
sfari_subs_del = dfs['family_type'].values=='16p-deletion'

KeyError: 'subs is not a file in the archive'

In [None]:
assert len(df['BIDS_ID'].values)==len(ABIDE_subs), 'mistmatch lengths'
assert all([df['BIDS_ID'].values[s]==ABIDE_subs[s] for s in range(len(ABIDE_subs))]), 'mismatch order'

## ABIDE

In [None]:
%time
# Make RSA models for ABIDE data
plt.figure(figsize=(15,15))
default_keys = ['ADOS_Total','ADOS_Social','DSMIVTR','AgeAtScan','Sex','ScannerID','ScanSiteID','FIQ']
scales_ = ['ratio','ratio','ordinal','ratio','ordinal','ordinal','ordinal','ratio','ratio','ratio']

model_rdms = dict()
model_idxs = dict()
for i in range(len(default_keys)):

    inVec = df[default_keys[i]].values[patients];
    idx = ~np.isnan(inVec)
    inVec = inVec[idx];
    this_rdm = make_RDM(inVec,data_scale=scales_[i])
    
    model_rdms.update({default_keys[i] : this_rdm})
    model_idxs.update({default_keys[i] : idx})

In [None]:
def slice_cmat(data,idx):
    mat = data[patients,:][:,patients]
    mat = mat[idx,:][:,idx]
    return mat

In [None]:
def fit_rsa(data,key):
    corr = scipy.stats.stats.kendalltau
    r = np.array([corr(get_triu(slice_cmat(data[i,:,:],model_idxs[key])),get_triu(model_rdms[key]))[0] for i in range(10)])
    r = np.arctan(r) # Fisher Z transform
    return r

# PCA SCORES

In [None]:
# ABIDE FIT MODELS 
patients = df['DxGroup'].values==1

data = [cmat_vae_vec_abide,cmat_background_vec_abide,cmat_salient_vec_abide]

rsa_results = dict()
for key in default_keys:
    res = np.array([fit_rsa(datum,key) for datum in data]).transpose()
    rsa_results.update({key : res})
    
    
# PCA RSA
keys_pca = {}
keys_pca.update({'ADOS_PCA' :  ['ADOS_Total','ADOS_Comm', 'ADOS_Social', 'ADOS_StBeh']})
keys_pca.update({'ADI_PCA' :   ['ADI_R_SocialTotal', 'ADI_R_VerbalTotal', 'ADI_R_RRB','ADI_R_Onset Total']})
keys_pca.update({'Vineland_PCA' :   ['VINELAND_Receptive_Vscore',
 'VINELAND_Expressive_Vscore',
 'VINELAND_Written_Vscore',
 'VINELAND_CommunicationStandard',
 'VINELAND_Personal_Vscore',
 'VINELAND_Domestic_Vscore',
 'VINELAND_Community_Vscore',
 'VINELAND_DaylyLiving_Standard',
 'VINELAND_Interpersonal_Vscore',
 'VINELAND_Play_Vscore',
 'VINELAND_Coping_Vscore',
 'VINELAND_Socical_Standard',
 'VINELAND_Domestic_Standard',
 'VINELAND_ABC_Standard',
 'VINELAND_Informant']})
keys_pca.update({'WISC_PCA' :  ['WISC4 VCI Verbal Comprehension Index',
       'WISC4 PRI Perceptual Reasoning Index',
       'WISC4 WMI Working Memory Index', 'WISC4 PSI Processing Speed Index',
       'WISC4 Sim Scaled', 'WISC4 Vocab Scaled', 'WISC4 Info Scaled',
       'WISC4 Blk Dsn Scaled', 'WISC4 Pic Con Scaled', 'WISC4 Matrix Scaled',
       'WISC4 Dig Span Scaled', 'WISC4 Let Num Scaled', 'WISC4 Coding Scaled',
       'WISC4 Sym Scaled']})
# keys_pca.update({'' :   []})


# Calculate PCA RSA
pca_keys = list(keys_pca.keys())
model_pcas = dict()
for key in pca_keys:
    arr = np.array(df[keys_pca[key]])
    arr = arr[patients,:]

    idx = ~np.isnan(arr.mean(axis=1))
    mat = arr[idx,:]

    pca = PCA(n_components=1)
    pca_vec = pca.fit_transform(mat)
    rdm = make_RDM(pca_vec)
    model_rdms.update({key : rdm})
    model_idxs.update({key : idx})
    model_pcas.update({key : pca_vec})
    #model_idxs.update({key : pca.fit_transform(mat)})
    

    res = np.array([fit_rsa(datum,key) for datum in data]).transpose()
    rsa_results.update({key : res})
    
    df[key] = 0
    for i_rel,i_abs in enumerate(np.nonzero(model_idxs[key])[0]):
        #df[key].values[i_abs] = model_pcas[key][i_rel]
        df[key].values[np.nonzero(patients)[0][i_abs]] = model_pcas[key][i_rel]

In [None]:
#df.to_csv('../Data/ABIDE_legend_S982_pca.csv')

## SFARI

In [None]:
# data = SFARI_data
patients = sfari_subs_dupl+sfari_subs_del

In [None]:
%time
# SFARI
plt.figure(figsize=(15,15))
# default_keys = ['ADOS_Total','ADOS_Social','DSMIVTR','AgeAtScan','Sex','ScannerID','ScanSiteID','FIQ']
# scales_ = ['ratio','ratio','ordinal','ratio','ordinal','ordinal','ordinal','ratio','ratio','ratio']

default_keys = ['best_full_scale_iq','rrb_css','sa_css','age_years', 'ord_diagnosis',
       'ord_gene', 'ord_sex','ord_scanner']
scales_ = ['ratio','ratio','ratio','ratio','ordinal','ordinal','ordinal','ordinal']

#model_rdms = dict()
#model_idxs = dict()
for i in range(8):
    #plt.subplot(4,4,i+1);
    inVec = dfs[default_keys[i]].values[patients];
    idx = ~np.isnan(inVec)
    inVec = inVec[idx];
    this_rdm = make_RDM(inVec,data_scale=scales_[i])
    #sns.heatmap(this_rdm,cbar=[],xticklabels=[],yticklabels=[]);
    #plt.title(default_keys[i]);
    
    model_rdms.update({default_keys[i] : this_rdm})
    model_idxs.update({default_keys[i] : idx})

In [None]:
# SFARI FIT MODELS 
data = [cmat_vae_vec_sfari,cmat_background_vec_sfari,cmat_salient_vec_sfari]
#rsa_results = dict()
for key in default_keys:
    res = np.array([fit_rsa(datum,key) for datum in data]).transpose()
    rsa_results.update({key : res})

In [None]:
keys = list(rsa_results.keys())
for key in keys:
    print(f'{key} | {model_rdms[key].shape}')

In [None]:
ncols = 3
nrows = int(np.ceil(len(keys)/3))
plt.figure(figsize=np.array((ncols,nrows))*4)

for i,key in enumerate(keys):
    ax = plt.subplot(nrows,ncols,i+1)
    plot_nice_bar(key,rsa_results,
                  ax=ax,figsize=None,
                  dpi=300,fontsize=12,
                  fontsize_star=12,
                  fontweight='bold',
                  line_width=2.5,
                  marker_size=12)
    
plt.subplots_adjust(
    left=None,
    bottom=None,
    right=None,
    top=None,
    wspace=.5,
    hspace=.5)

plt.savefig('../../bars/RSA_barplots_newData2.png')

In [None]:
keys

In [None]:
from matplotlib.ticker import FormatStrFormatter

ncols = 3
nrows = int(np.ceil(len(keys)/3))
plt.figure(figsize=np.array((ncols,nrows))*4)

keys = ['ord_scanner',
 'age_years',
 'ord_diagnosis',
 'ord_gene',
 'ord_sex',
'best_full_scale_iq']

ttls = ['Scanner Type',
 'Age',
 'DSM IV',
 'Genotype',
 'Sex',
'Full Scale IQ']

for i,key in enumerate(keys):
    ax = plt.subplot(nrows,ncols,i+1)
    plot_nice_bar(key,rsa_results,
                  ax=ax,figsize=None,
                  dpi=300,fontsize=12,
                  fontsize_star=12,
                  fontweight='bold',
                  line_width=2.5,
                  marker_size=12,title=ttls[i])
    ax.yaxis.set_major_formatter(FormatStrFormatter('%.3f'))
    
plt.subplots_adjust(
    left=None,
    bottom=None,
    right=None,
    top=None,
    wspace=.5,
    hspace=.5)

plt.savefig('../../bars/RSA_barplots_newData2-SFARI.pdf')

In [None]:
from matplotlib.ticker import FormatStrFormatter
keys = ['ScannerID',
 'ScanSiteID',
 'AgeAtScan',
 'Sex',
 'DSMIVTR',
 'FIQ',
 'ADOS_PCA',
'ADOS_Total',
 'ADOS_Social',
 'ADI_PCA',
 'Vineland_PCA',
 'WISC_PCA']

ttls = ['Scanner Type',
 'Scanning Site',
 'Age',
 'Sex',
 'DSM IV',
 'Full Scale IQ',
 'ADOS (PCA)',
'ADOS Total',
 'ADOS Social',
 'ADI (PCA)',
 'Vineland (PCA)',
 'WISC (PCA)']


ncols = 3
nrows = int(np.ceil(len(keys)/3))
plt.figure(figsize=np.array((ncols,nrows))*4)

for i,key in enumerate(keys):
    ax = plt.subplot(nrows,ncols,i+1)
    plot_nice_bar(key,rsa_results,
                  ax=ax,figsize=None,
                  dpi=300,fontsize=12,
                  fontsize_star=12,
                  fontweight='bold',
                  line_width=2.5,
                  marker_size=12,title=ttls[i])
    ax.yaxis.set_major_formatter(FormatStrFormatter('%.3f'))
    
plt.subplots_adjust(
    left=None,
    bottom=None,
    right=None,
    top=None,
    wspace=.5,
    hspace=.5)


In [None]:
def plot_rsa_bar(data,key):
    import scipy
    import seaborn as sns
    ttest_1samp = scipy.stats.ttest_1samp
    ttest_rel = scipy.stats.ttest_rel    
    
    # Calculate Plotting data
    c = data.shape[1]
    n = data.shape[0]
    xs = np.arange(c)
    xlbls = ['VAE','BG','SL']

    m = data.mean(axis=0)
    sd = data.std(axis=0)
    
    # T statistics
    t_1samp = np.array([ttest_1samp(data[:,i],0)[0] for i in range(c)])
    p_1samp = np.array([ttest_1samp(data[:,i],0)[1] for i in range(c)])

    p_paired_t = np.zeros((c,c))
    t_paired_t = np.zeros((c,c))

    # Bar, Errorbar, Scatter 
    for i in range(c):
        for j in range(c):
            t_paired_t[i,j] = ttest_rel(data[:,i],data[:,j])[0]
            p_paired_t[i,j] = ttest_rel(data[:,i],data[:,j])[1]
            
    if p_paired_t[2,1] < .001:
        print(f'{key} {xlbls[2]} > {xlbls[1]}: t({data.shape[0]-1}) = {t_paired_t[2,1].round(2)}, p $<$ .001')
    else:
        print(f'{key} {xlbls[2]} > {xlbls[1]}: t({data.shape[0]-1}) = {t_paired_t[2,1].round(2)}, p = {p_paired_t[2,1].round(3)}')
        
    if p_paired_t[1,2] < .001:
        print(f'{key} {xlbls[1]} > {xlbls[2]}: t({data.shape[0]-1}) = {t_paired_t[1,2].round(2)}, p $<$ .001')
    else:
        print(f'{key} {xlbls[1]} > {xlbls[2]}: t({data.shape[0]-1}) = {t_paired_t[1,2].round(2)}, p = {p_paired_t[1,2].round(3)}')

    # Plotting
    fig, ax = plt.subplots(1,1)
    pallete = sns.color_palette()
    pallete_new = sns.color_palette()
    pallete_new[1]=pallete[0]
    pallete_new[0]=pallete[1]
    pallete_new[0] = tuple(np.array((.5,.5,.5)))

    linewidth = 3
    fontsize = 14
    fontweight='bold'
    for i in range(3):
        plt.bar(xs[i],m[i],edgecolor=np.hstack((pallete_new[i],1)),
                facecolor=np.hstack((pallete_new[i],.5)),
                linewidth=linewidth)

        plt.errorbar(xs[i],m[i],sd[i],fmt='r ',linewidth=linewidth)
        plt.scatter(x=(np.repeat(i,n)+(np.random.rand(n)-.5)*.25),y=data[:,i],
                   s=fontsize*3,
                   c='k')



    t_thresh = np.array((1,.05,.001,.0001));
    t_stars = np.array(('n.s.','*','**','***'));
    star_idx = (np.nonzero((p_paired_t[2,1] < t_thresh))[0]).max();
    plt.plot([1,2],np.repeat(data.max(),2)*1.25,'k',linewidth=linewidth);
    plt.text(1.4,data.max()*1.25+.002,t_stars[star_idx],fontsize=fontsize*1.25,fontweight='bold');

    ax.spines['top'].set_visible(False);
    ax.spines['right'].set_visible(False);
    for axis in ['top','bottom','left','right']:
            ax.spines[axis].set_linewidth(linewidth);

    xlabels = ['VAE','BG','SL'];
    one_sample_star_idx = [(np.nonzero((p_1samp[ii] < t_thresh))[0]).max() for ii in range(c)];
    xlabels = [f'{xlabels[ii]}\n({t_stars[one_sample_star_idx[ii]]})' for ii in range(c)];

    #xticks = np.linspace(data.min().round(2),data.max().round(2),5)
    #plt.yticks(xticks,labels=[str(ii) for ii in xticks],fontsize=fontsize)
    plt.yticks(fontsize=fontsize,fontweight=fontweight);
    plt.xticks(xs,xlabels,fontsize=fontsize,fontweight=fontweight);
    plt.title(key,fontweight=fontweight,fontsize=fontsize*1.25,pad=fontsize)

In [None]:
keys = list(rsa_results.keys())
for key in keys:
    data = rsa_results[key]
    plot_rsa_bar(data,key)
    #plt.savefig('/Users/aidasaglinskas/Desktop/bars/'+ 'bar_' + key + '.pdf')