## Load Modules

In [None]:
%load_ext autoreload
%autoreload 2

import os
import sys
sys.path.append('/home/zachpen87/clearmap/ClearMap2/')

import h5py
import holoviews as hv
import dask.array as da
import numpy as np
import scipy.ndimage as ndimage
import scipy.stats as stats
import skimage
from skimage import io
#import ClearMap.ImageProcessing.H5 as H5img
import ClearMap.Visualization.JupNtbk_vis as jvis
hv.notebook_extension('bokeh')

## Define directory and group information

In [None]:
directory = '/home/zachpen87/clearmap/data/SEFL17/' 
res_directory = '/home/zachpen87/clearmap/data/SEFL17/results' 

group_dict = {
    'NoTrauma_lowf' : ['ms_12','ms_18','ms_23','ms_05','ms_15','ms_07'],
    'Trauma_highf' : ['ms_06','ms_17','ms_22','ms_10','ms_09','ms_16'],
    'NoStim' : ['ms_01','ms_02','ms_03','ms_04','ms_25','ms_26','ms_27','ms_28'],
    'NoTrauma' : ['ms_05','ms_07','ms_12','ms_13','ms_14','ms_15','ms_18','ms_20','ms_21','ms_23'],
    'Trauma' : ['ms_06','ms_08','ms_09','ms_10','ms_11','ms_16','ms_17','ms_19','ms_22','ms_24'],
    'NoStimNoTrauma' : ['ms_02','ms_04','ms_25','ms_27'],
    'NoStimTrauma' : ['ms_01','ms_03','ms_26','ms_28']
}

if not os.path.isdir(res_directory):
    os.mkdir(res_directory)

## Load ABA reference image and size

In [None]:
ABA_directory = '/home/zachpen87/clearmap/AtlasDocs/Horizontal'
ABA_ref = 'ABA_25um_reference__1_2_3__slice_None_None_None__slice_None_None_None__slice_None_None_None__.tif'
ref = skimage.io.imread(os.path.join(ABA_directory, ABA_ref))
htmp_shape = ref.shape

## For each group, create single hdf5 with 4d heatmat dataset (animal, Z, Y, X)

In [None]:
for group in group_dict:
    
    fpath = os.path.join(res_directory, 
                         '.'.join([group,'hdf5']))
    
    with h5py.File(fpath, 'w') as f: 
        
        n = [len(group_dict[group])]
        f.create_dataset('n', data=n, dtype='uint8')
        
        f.create_dataset(
            'allmaps',
            tuple(np.concatenate([n, list(htmp_shape)])),
            dtype = 'float64')
        
        for idx, ms in enumerate(group_dict[group]):
            print('writing {ms} to group: {group}'.format(ms=ms,group=group))
            subpath = os.path.join(directory,ms,'data.hdf5')
            with h5py.File(subpath, 'r') as subf:
                f['allmaps'][idx,:,:,:] = subf['heatmap'][:]
                    

## Define mean and standard deviation heatmaps

In [None]:
for group in group_dict:
    
    fpath = os.path.join(res_directory, '.'.join([group,'hdf5']))
    with h5py.File(fpath, 'a') as f: 
        
        print('creating dask array for all data in group {group}'.format(group=group))
        n=f['n'][0]
        allmaps = da.from_array(
            f['allmaps'],
            chunks=(n,500,500,500))
        
        print('calculating mean heatmap for group: {group}'.format(group=group))
        allmaps.mean(axis=0).to_hdf5(fpath,('/'+'mean'))
        
        print('calculating std dev heatmap for group: {group}'.format(group=group))
        allmaps.std(axis=0).to_hdf5(fpath,('/'+'std'))
        
        print('calculating z heatmap for group: {group}'.format(group=group))
        z = f['mean'][:]
        std = f['std'][:]
        z[std==0] = 0
        z[std>0] = z[std>0]/std[std>0]
        f.create_dataset('z',data=z)
        
        

## Visualize group heatmaps

In [None]:
%output size = 120

groups = ['NoTrauma_lowf','Trauma_highf']
dset = 'mean'
interval = 10
plane = 'c'
include_ref = True

hmap_dict = {}
if include_ref:
    hmap_dict['ref'] = jvis.gen_hmap(ref,plane=plane,title='Allen Brain Atlas',inter=interval,tools=['hover']) 

for group in groups:
    hdf5_file = os.path.join(res_directory, '{group}.hdf5'.format(group=group))
    with h5py.File(hdf5_file,'r') as f:
        hmap_dict[group] = jvis.gen_hmap((f[dset][:]*100)**2,
                               plane=plane,
                               title= '{group}: {var}'.format(group=group,var=dset),
                               inter=interval,lims=(0,50),
                               cmap='inferno',
                               alpha=1,
                               tools=['hover'])
    
hv.NdLayout(hmap_dict).cols(1)