In [3]:
import sys, os
sys.path.append(r'{}'.format(os.path.abspath(os.pardir)))

import pandas as pd 
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator

from scipy.stats import zscore, spearmanr, pearsonr

#import custom modules
from wholebrain_tools import aba, dataIO
import wholebrain_tools.graphics as gt

# Instantiate an Atlas object from the aba module
# The first time you run this it will download the structures.json file from the Allen Institute server
paths = dataIO.pathParser()
structuresFile = paths.structures
A = aba.Atlas(nodes=structuresFile)
DFM = aba.AnatomyDataFrameManager(A)


# Load data

In [None]:
# --------------------------------------------------------------------
searchPath = paths.alldata
# --------------------------------------------------------------------

# WFA
df_wfa = dataIO.allMiceRegions(searchPath=searchPath, channelName='wfa', normCellIntens=True)
df_wfa = DFM.multiIndexDf_from_fineDf(df_wfa, verbose=False)
# Create a coarse Df
wfaCoarse = DFM.regionsDf_to_coarse(df_wfa, verbose=True, normalize=True)
# Select only Controls
wfaCoarse


# PV
df_pv = dataIO.allMiceRegions(searchPath=searchPath, channelName='pv', normCellIntens=True)
df_pv = DFM.multiIndexDf_from_fineDf(df_pv, verbose=False)
# Create a coarse Df
pvCoarse = DFM.regionsDf_to_coarse(df_pv, verbose=True, normalize=True)
# Select only Controls
pvCoarse

# Correlation at the coarse level

In [None]:
totalDf = pd.concat([wfaCoarse, pvCoarse],keys=['wfa','pv'], names=['staining'], axis=1)

## PV energy vs WFA diffuse

### Prepare data

In [None]:
wfaDiff = wfaCoarse.xs('diffuseFluo', axis=1, level='params')
pvEn = pvCoarse.xs('energy', axis=1, level='params')

# Calculate mean and sem for wfa and pv
wfaMean = wfaDiff.mean(axis=1)
wfaSem = wfaDiff.sem(axis=1)
pvMean = pvEn.mean(axis=1)
pvSem = pvEn.sem(axis=1)
# Rename the series
wfaMean.name = 'wfaMean'
wfaSem.name = 'wfaSem'
pvMean.name = 'pvMean'
pvSem.name = 'pvSem'

dataToPlot = pd.concat([wfaMean,pvMean,wfaSem,pvSem], axis=1)
# dataToPlot.index = A.ids_to_acronyms(dataToPlot.index.tolist())
dataToPlot

### Plot

In [None]:
gt.metricsWithErrors(data=dataToPlot, atlas = A,
                    x = 'pvMean', 
                    y = 'wfaMean',
                    err_x='pvSem',
                    err_y='wfaSem',
                    ylabel='WFA Diffuse Fluorescence (A.U.)',
                    xlabel='PV Energy (A.U.)',
                    fontScaling=.8
                    )

## PV energy vs PNN energy

### Prepare data

In [None]:
wfaEnergy = wfaCoarse.xs('energy', axis=1, level='params')
pvEn = pvCoarse.xs('energy', axis=1, level='params')

# Calculate mean and sem for wfa and pv
wfaMean = wfaEnergy.mean(axis=1)
wfaSem = wfaEnergy.sem(axis=1)
pvMean = pvEn.mean(axis=1)
pvSem = pvEn.sem(axis=1)
# Rename the series
wfaMean.name = 'wfaMean'
wfaSem.name = 'wfaSem'
pvMean.name = 'pvMean'
pvSem.name = 'pvSem'

dataToPlot = pd.concat([wfaMean,pvMean,wfaSem,pvSem], axis=1)
# dataToPlot.index = A.ids_to_acronyms(dataToPlot.index.tolist())
dataToPlot

### Plot

In [None]:
gt.metricsWithErrors(data=dataToPlot, atlas = A,
                    x = 'pvMean', 
                    y = 'wfaMean',
                    err_x='pvSem',
                    err_y='wfaSem',
                    ylabel='PNN Energy (A.U.)',
                    xlabel='PV Energy (A.U.)',
                    fontScaling=.8
                    )
                    

# Correlation at the mid Level

## Prepare data

In [None]:
# WFA
# Create a mid Df
wfaMid = DFM.regionsDf_to_mid(df_wfa, verbose=False, normalize=True)
# Select only Controls
wfaMid

# PV
# Create a mid Df
pvMid = DFM.regionsDf_to_mid(df_pv, verbose=False, normalize=True)
# Select only Controls
pvMid

# Concatenate the 2 stainings
totalDf = pd.concat([wfaMid, pvMid],keys=['wfa','pv'], names=['staining'], axis=1)
# Average across mice
totalDf = totalDf.groupby(by=['staining','params'], axis=1).mean()

dataToPlot = totalDf['wfa'].join(totalDf['pv'],lsuffix='_wfa', rsuffix='_pv')
dataToPlot

## Plot WFA diffuse vs PV Energy

In [None]:
_ = gt.metricsCorrelation(dataToPlot, A,
        # ax = ax,
        x='energy_pv' ,
        y='diffuseFluo_wfa',
        txtLoc = 'tl',
        xlabel = 'PV Energy (A.U.)' ,
        ylabel = 'WFA Diffuse\nFluorescence (A.U.)' ,
        fontScaling = 1
    )

# plt.savefig("allAreaCorr_diffuseVsPv.svg", bbox_inches="tight")

In [None]:
f, axs = plt.subplots(nrows=2,ncols=6,  figsize=(23,8), dpi=100, squeeze=True)

# # All coarse areas
coarseIdList = totalDf.index.get_level_values('coarse').unique().tolist()
for i, ax in enumerate(f.axes):
    
    thisRegion = coarseIdList[i]
    toPlot = dataToPlot.xs(thisRegion, axis=0, level='coarse')

    ax.yaxis.set_major_locator(MaxNLocator(integer=True))

    gt.metricsCorrelation(toPlot, A,
        ax = ax,
        x='energy_pv',
        y='diffuseFluo_wfa',
        txtLoc = 'tl' if i in [0,5,6,9,10,11] else 'br',
        xlabel = 'PV Energy (A.U.)' if i==6 else None,
        ylabel = 'WFA Diffuse\nFluorescence (A.U.)' if i==6 else None,
        title = A.ids_to_names([thisRegion])[0],
    )
plt.subplots_adjust(hspace=0.4, wspace=0.2)

# plt.savefig("allMidAreaCorr_diffuseVsPv.svg", bbox_inches="tight")

## Plot WFA Energy vs PV Energy

In [None]:
dataToPlot = totalDf['wfa'].join(totalDf['pv'],lsuffix='_wfa', rsuffix='_pv')
dataToPlot
_ = gt.metricsCorrelation(dataToPlot, A,
        # ax = ax,
        x='energy_pv' ,
        y='energy_wfa',
        txtLoc = 'br',
        xlabel = 'PV Energy (A.U.)' ,
        ylabel = 'PNN Energy (A.U.)' ,
        fontScaling = 1
    )

# plt.savefig("allAreaCorr_energyVsPv.svg", bbox_inches="tight")

In [None]:
f, axs = plt.subplots(nrows=2,ncols=6,  figsize=(23,8), dpi=100, squeeze=True)

# # All coarse areas
coarseIdList = totalDf.index.get_level_values('coarse').unique().tolist()

for i, ax in enumerate(f.axes):
    
    thisRegion = coarseIdList[i]
    toPlot = dataToPlot.xs(thisRegion, axis=0, level='coarse')

    ax.yaxis.set_major_locator(MaxNLocator(integer=True))

    gt.metricsCorrelation(toPlot, A,
        ax = ax,
        x='energy_pv',
        y='energy_wfa',
        txtLoc = 'tl' if i in [1,2,3,4] else 'br',
        xlabel = 'PV Energy (A.U.)' if i==6 else None,
        ylabel = 'PNN Energy (A.U.)' if i==6 else None,
        title = A.ids_to_names([thisRegion])[0],
    )

plt.subplots_adjust(hspace=0.4, wspace=0.2)
# plt.savefig("allMidAreaCorr_energyVsPv.svg", bbox_inches="tight")