In [None]:
import sys
sys.path.append('../..')
import cortex_etl as c_etl

ma = c_etl.analysis_initial_processing("../../configs/5-FullCircuit/5-FullCircuit-2-BetterMinis-Fpr15-StimScan-10x.yaml", loglevel="ERROR")

In [None]:
import matplotlib
import matplotlib.pyplot as plt
import numpy as np


def get_xy_for_single_hex(hex_mean_flatspace_coords, hex_ind):
    
    filtered_df = hex_mean_flatspace_coords[hex_mean_flatspace_coords['hex'] == hex_ind]
    hex_x = filtered_df.iloc[0]['x']
    hex_y = filtered_df.iloc[0]['y']

#     hex_x = hex_mean_flatspace_coords.etl.q(hex=hex_ind).iloc[0]['x']
#     hex_y = hex_mean_flatspace_coords.etl.q(hex=hex_ind).iloc[0]['y']

    return hex_x, hex_y        

import pandas as pd
hex_mean_flatspace_coords_path = '../../scripts/hex_mean_flatspace_coords.parquet'
hex_mean_flatspace_coords = pd.read_parquet(hex_mean_flatspace_coords_path)
hex_mean_flatspace_coords
hex_x, hex_y = get_xy_for_single_hex(hex_mean_flatspace_coords, 0)
hex_x, hex_y


In [None]:
import sys
sys.path.append('../..')
import cortex_etl as c_etl

ma_hex39 = c_etl.analysis_initial_processing("../../configs/5-FullCircuit/5-FullCircuit-2-BetterMinis-Fpr15-StimScan-10x-Hex39.yaml", loglevel="ERROR")
c_etl.post_analysis(ma_hex39.hex39_spikes)

In [None]:
c_etl.plot_rasters(ma_hex39.hex39_spikes)

In [None]:
import numpy
import pandas
import os
import tqdm
import cortex_etl as c_etl
from blueetl.parallel import call_by_simulation
import pandas as pd
from functools import partial
from matplotlib import pyplot as plt



def make_t_bins(t_start, t_end, t_step):
    t_bins = numpy.arange(t_start, t_end + t_step, t_step)
    return t_bins

def flatten_locations(locations, flatmap):
    if isinstance(flatmap, list):
        flat_locations = locations[flatmap].values
    else:
        from voxcell import VoxelData
        fm = VoxelData.load_nrrd(flatmap)
        flat_locations = fm.lookup(locations.values).astype(float)
        flat_locations[flat_locations == -1] = numpy.NaN
    return pandas.DataFrame(flat_locations, index=locations.index)


def make_spatial_bins(flat_locations, nbins=1000):
    mn = numpy.nanmin(flat_locations, axis=0)
    mx = numpy.nanmax(flat_locations, axis=0)
    ratio = (mx[1] - mn[1]) / (mx[0] - mn[0]) # ratio * nx ** 2 = nbins
    nx = int(numpy.sqrt(nbins / ratio))
    ny = int(nbins / nx)
    binsx = numpy.linspace(mn[0], mx[0] + 1E-3, nx + 1)
    binsy = numpy.linspace(mn[1], mx[1] + 1E-3, ny + 1)
    return binsx, binsy

def make_histogram_function(t_bins, loc_bins, location_dframe, spikes):
    t_step = numpy.mean(numpy.diff(t_bins))
    fac = 1000.0 / t_step
    nrns_per_bin = numpy.histogram2d(location_dframe.values[:, 0],
                                     location_dframe.values[:, 1],
                                     bins=loc_bins)[0]
    nrns_per_bin = nrns_per_bin.reshape((1,) + nrns_per_bin.shape)

    spikes = spikes.loc[numpy.in1d(spikes.values, location_dframe.index.values)]
    t = spikes.index.values
    loc = location_dframe.loc[spikes['gid']].values
    raw, _ = numpy.histogramdd((t, loc[:, 0], loc[:, 1]), bins=(t_bins,) + loc_bins)
    raw = fac * raw / (nrns_per_bin + 1E-6)
    return raw

def save(hist, t_bins, loc_bins, out_root):
    if not os.path.isdir(out_root):
        _ = os.makedirs(out_root)
    import h5py
    h5 = h5py.File(os.path.join(out_root, "spiking_activity_3d.h5"), "w")
    grp_bins = h5.create_group("bins")
    grp_bins.create_dataset("t", data=t_bins)
    grp_bins.create_dataset("x", data=loc_bins[0])
    grp_bins.create_dataset("y", data=loc_bins[1])

    grp_data = h5.create_group("histograms")
    for i, val in enumerate(hist.get()):
        grp_data.create_dataset("instance{0}".format(i), data=val)
    mn_data = numpy.mean(numpy.stack(hist.get(), -1), axis=-1)
    grp_data.create_dataset("mean", data=mn_data)
    return mn_data

def setup_cmap(hist, plotting_options, hist_mean=[]):
    
    hist_for_mask = hist
    if (len(hist_mean)):
        hist_for_mask = hist_mean
    
    masked_hist = hist
    mx_clim = numpy.percentile(hist, plotting_options['max_lim_pct'])
    mn_clim = numpy.percentile(hist, plotting_options['min_lim_pct'])
    indices_to_mask = numpy.asarray(numpy.argwhere(hist_for_mask <= plotting_options['mask_fr']))
    for pair in indices_to_mask:
        if (len(hist_mean)):
            masked_hist[:, pair[0], pair[1]] = numpy.nan
        else:
            masked_hist[pair[0], pair[1]] = numpy.nan
    
    
    cmap = plotting_options['cmap']
    cmap.set_bad('white',1.)
    
    clim = [mn_clim, mx_clim]
    
    return cmap, clim, masked_hist
    

def plot_and_save_single_image(hist, plotting_options, path):

    cmap, clim, masked_hist = setup_cmap(hist, plotting_options)
    
    fig = plt.figure()
    ax = fig.add_axes([0.05, 0.05, 0.9, 0.9])
    img = ax.imshow(masked_hist, cmap=cmap, clim=clim)
    plt.colorbar(img, cmap=cmap, label='FR (spikes / s')
    
    plt.scatter(67.87893675169182, 26.50221238938053)
    
    plt.box(False)
    plt.tick_params(left = False, right = False, labelleft = False, labelbottom = False, bottom = False)
    plt.savefig(path)
    plt.close()


import os
import numpy
import tqdm
import matplotlib
from matplotlib import pyplot as plt

#     plot(hist, t_bins, loc_bins, images_dir, flatspace_video_opt['delete_images'], flatspace_path_pre, plotting_options, hist_mean=hist_mean)
def plot(hist, t_bins, loc_bins, images_dir, delete_images, video_output_root, plotting_options, min_color_lim_pct=-1, hist_mean=[]):
    
    if not os.path.isdir(images_dir):
        _ = os.makedirs(images_dir)
    
    cmap, clim, masked_hist = setup_cmap(hist, plotting_options, hist_mean=hist_mean)
    
    fps = []
    for t_start, t_end, bin_index in tqdm.tqdm(zip(t_bins[:-1], t_bins[1:], list(range(len(t_bins))))):
        fig = plt.figure(figsize=(10, 10))
        ax = fig.add_axes([0.05, 0.05, 0.9, 0.9])
        img = ax.imshow(masked_hist[bin_index, :, :], cmap=cmap, clim=clim)
        plt.colorbar(img, cmap=cmap, label='FR (spikes / s')
        
        ax.set_title("{0} - {1} ms".format(t_start, t_end))
        plt.box(False)
        plt.tick_params(left = False, right = False, labelleft = False, labelbottom = False, bottom = False)
        fn = "frame{:04d}.png".format(bin_index)
        fp = os.path.join(images_dir, fn)
        fig.savefig(fp)
        fps.append(fp)
        if (bin_index == 0):
            fn = "frame{:04d}.pdf".format(bin_index)
            fp = os.path.join(images_dir, fn)
            fig.savefig(fp)
        
        plt.close(fig)

    c_etl.video_from_image_files(fps, video_output_root + ".mp4")
    if delete_images:
        for f in fps:
            os.remove(f)

import numpy
from scipy.ndimage import gaussian_filter
def single_flatspace_video(simulation_row, 
                           filtered_dataframes, 
                           flat_locations, 
                           flatspace_video_opt, 
                           analysis_config,
                           plotting_options,
                           flatspace_path_pre=None, 
                           images_dir=None):

    window_row = filtered_dataframes['windows'].iloc[0]

    if (flatspace_path_pre==None):
        flatspace_path_pre = flatspace_video_opt['video_output_root'] + str(simulation_row['simulation_id']) + "_" + simulation_row['simulation_string']
    if (images_dir==None):
        images_dir = str(window_row['flatspace_video_images_dir']) + "/" + flatspace_video_opt['vid_str'] + "_" + str(simulation_row['simulation_id']) + "/"
    

    t_bins = make_t_bins(window_row['t_start'], window_row['t_stop'], flatspace_video_opt['t_step'])
    spikes = filtered_dataframes['spikes'].loc[:, ['time', 'gid']].set_index('time')
    
    loc_bins = make_spatial_bins(flat_locations, flatspace_video_opt['n_spatial_bins'])
    spatial_temporal_hist = make_histogram_function(t_bins, loc_bins, flat_locations, spikes)
    smoothed_spatial_temporal_hist = gaussian_filter(spatial_temporal_hist, [flatspace_video_opt['temporal_smoothing_sigma'], 1.0, 1.0])
    hist = smoothed_spatial_temporal_hist
    hist_mean = numpy.mean(hist, axis=0)
    
    plot(hist, t_bins, loc_bins, images_dir, flatspace_video_opt['delete_images'], flatspace_path_pre, plotting_options, hist_mean=hist_mean)
    plot_and_save_single_image(hist_mean, plotting_options, flatspace_path_pre + '_hist_mean.pdf')

    print(flatspace_video_opt)
    
    if (flatspace_video_opt['stim_anal'] != None):
        
        print("Hey")

#         where_stim = numpy.argwhere(numpy.logical_and(((t_bins) >= flatspace_video_opt['stim_anal']['stim_period'][0]), ((t_bins)) < flatspace_video_opt['stim_anal']['stim_period'][1])).flatten()
#         where_not_stim = numpy.argwhere(numpy.logical_and(((t_bins) >= flatspace_video_opt['stim_anal']['spont_period'][0]), ((t_bins)) < flatspace_video_opt['stim_anal']['spont_period'][1])).flatten()

#         hist_stim = hist[where_stim[:-1]]
#         hist_not_stim = hist[where_not_stim[:-1]]
#         hist_stim_mean = numpy.mean(hist_stim, axis=0)
#         hist_not_stim_mean = numpy.mean(hist_not_stim, axis=0)

#         hist_stim_mean_diff = hist_stim_mean - hist_not_stim_mean
#         log_hist_stim_mean_diff = numpy.log(hist_stim_mean_diff)

#         stim_minus_spont = hist_stim - hist_not_stim_mean

#         plot(stim_minus_spont, t_bins[where_stim], loc_bins, images_dir, flatspace_video_opt['delete_images'], flatspace_path_pre + '_stim_minus_spont', plotting_options)
#         plot(stim_minus_spont, t_bins[where_stim], loc_bins, images_dir, flatspace_video_opt['delete_images'], flatspace_path_pre + '_stim_minus_spont_min_lim_60', plotting_options, min_color_lim_pct=60)
#         # plot(stim_minus_spont, t_bins[where_stim], loc_bins, images_dir, flatspace_video_opt['delete_images'], flatspace_path_pre + '_stim_minus_spont_min_lim_60_log', min_color_lim_pct=60)
#         # plot(numpy.log(hist_stim - hist_not_stim_mean), t_bins[where_stim], loc_bins, images_dir, flatspace_path_pre + 'log_subtrac_mean')

#         plot_and_save_single_image(hist_not_stim_mean, plotting_options, flatspace_path_pre + '_hist_not_stim_mean.pdf')
#         plot_and_save_single_image(hist_stim_mean, plotting_options, flatspace_path_pre + '_hist_stim_mean.pdf')
#         plot_and_save_single_image(hist_stim_mean_diff, plotting_options, flatspace_path_pre + '_hist_stim_mean_diff.pdf')
#         plot_and_save_single_image(log_hist_stim_mean_diff, plotting_options, flatspace_path_pre + '_log_hist_stim_mean_diff.pdf')
#         plot_and_save_single_image(log_hist_stim_mean_diff, plotting_options, flatspace_path_pre + '_log_hist_stim_mean_diff_-4_-2.pdf')


    r_dict = {"smoothed_spatial_temporal_hist": smoothed_spatial_temporal_hist,
            "t_bins": t_bins}
    return r_dict




a = ma.AllCompartments_spikes

for flatspace_video_key in a.analysis_config.custom['flatspace_videos']:
    flatspace_video_opt = a.analysis_config.custom['flatspace_videos'][flatspace_video_key]
    flatspace_video_opt['vid_str'] = flatspace_video_opt['window'] + "_" + str(flatspace_video_opt['t_step']) + "_" + str(flatspace_video_opt['n_spatial_bins']) + "_" + str(flatspace_video_opt['temporal_smoothing_sigma'])
    flatspace_video_opt['video_output_root'] = str(a.figpaths.flatspace_videos) + "/" + flatspace_video_opt['vid_str'] + "/"
    os.makedirs(flatspace_video_opt['video_output_root'], exist_ok=True)

    dataframes={
        "circuits": a.repo.simulations.df.loc[:, ['circuit', 'circuit_id', 'simulation_id']],
        "spikes": a.repo.report.df.etl.q(neuron_class="ALL", window=flatspace_video_opt['window']),
        "windows": a.repo.windows.df.etl.q(window=flatspace_video_opt['window']), 
        "neurons": a.repo.neurons.df.etl.q(neuron_class="ALL")}

    gids = a.repo.neurons.df.etl.q(circuit_id=0)['gid']
    locations = a.repo.simulations.df.loc[:, ['circuit', 'circuit_id', 'simulation_id']].iloc[0]['circuit'].cells.get(gids, ["x", "y", "z"])
    flat_locations = c_etl.flatten_locations(locations, a.analysis_config.custom["flatmap"])

    results = call_by_simulation(a.repo.simulations.df.etl.q(ca=1.05, depol_stdev_mean_ratio=0.4, desired_connected_proportion_of_invivo_frs=0.15, vpm_pct=20.0), 
                                    dataframes, 
                                    func=partial(single_flatspace_video, 
                                                flat_locations=flat_locations, 
                                                flatspace_video_opt=flatspace_video_opt, 
                                                analysis_config=a.analysis_config.custom,
                                                plotting_options={"cmap": matplotlib.cm.cividis,
                                                                  "mask_fr": 0.0001,
                                                                 "max_lim_pct": 99,
                                                                 "min_lim_pct": 0},
                                                flatspace_path_pre='figures/Fig-nbS1Evoked', 
                                                images_dir='figures/Fig-nbS1Evoked/'),
                                    how='series')