In [5]:
import sys, os
sys.path.append(os.path.join(os.getcwd(), '..'))
sys.path.append(os.path.join(os.getcwd(), '..', '..'))
sys.path.append(os.path.join(os.getcwd(), '..', '..'))
sys.path.append(os.path.join(os.getcwd(), '..', '..', 'analysis'))
sys.path.append(os.path.join(os.getcwd(), '..', '..', 'session'))

import numpy as np
from imports import *
from matplotlib.patches import ConnectionPatch
from scipy.stats import pearsonr

In [6]:
def get_spike_counts(spk_times, pulse_times, hw=0.25, bin_count=51):
    collected = []
    for t_pulse in pulse_times:
        selected = spk_times[(spk_times > t_pulse - hw) & (spk_times < t_pulse + hw)]
        collected += [x for x in selected - t_pulse]
    collected = np.array(collected)

    bins = np.linspace(-hw, hw, bin_count)
    counts, _ = np.histogram(collected, bins=bins)
    counts = (counts / len(pulse_times))# * 1/((2. * hw)/float(bin_count - 1))
    
    return bins, counts

In [1]:
def plot_tgt_bgr_psth(example_units):
    unit_count = np.array([len(vals) for vals in example_units.values()]).sum()
    rows = int(np.ceil(unit_count/3))
    fig = plt.figure(figsize=(15, rows*4))
    count = 0

    for session, unit_ids in example_units.items():
        # read AEP events
        animal    = session.split('_')[0]
        aeps_file = os.path.join(source, animal, session, 'AEPs.h5')
        with h5py.File(aeps_file, 'r') as f:
            aeps_events = np.array(f['aeps_events'])

        # read single units
        spike_times = {}
        h5_file = os.path.join(source, animal, session, session + '.h5')
        with h5py.File(h5_file, 'r') as f:
            cfg = json.loads(f['processed'].attrs['parameters'])
            for unit_id in unit_ids:
                spike_times[unit_id] = np.array(f['units'][unit_id][H5NAMES.spike_times['name']])

        for unit_id in unit_ids:
            bins, counts_bgr = get_spike_counts(spike_times[unit_id], aeps_events[aeps_events[:, 1] == 1][:, 0])
            bins, counts_tgt = get_spike_counts(spike_times[unit_id], aeps_events[aeps_events[:, 1] == 2][:, 0])

            ax = fig.add_subplot(rows, 3, count+1)
            tgt_dur, bgr_dur = cfg['sound']['sounds']['target']['duration'], cfg['sound']['sounds']['background']['duration']
            label_tgt = "Tgt: %.2f" % tgt_dur
            label_bgr = "Bgr: %.2f" % bgr_dur
            #ax.hist(bins[:-1], bins=bins, weights=counts_tgt, edgecolor='black', color='tab:orange', alpha=0.9, label=label_tgt)
            ax.hist(bins[:-1], bins=bins, weights=counts_bgr, edgecolor='black', color='black', alpha=0.5, label=label_bgr)
            ax.axvline(0, color='black')
            ax.axvline(tgt_dur, color='tab:orange', ls='--')
            ax.axvline(tgt_dur - 0.25, color='tab:orange', ls='--')
            ax.axvline(bgr_dur, color='black', ls='--', alpha=0.5)
            ax.axvline(bgr_dur - 0.25, color='black', ls='--', alpha=0.5)
            #ax.set_xlabel('Pulse onset, s', fontsize=14)
            ax.axvspan(0, 0.05, alpha=0.3, color='gray')
            ax.set_title("%s : %s" % (session[21:], unit_id), fontsize=14)
            ax.legend(loc='upper right', prop={'size': 10})
            if count % 3 == 0:
                ax.set_ylabel("Firing Rate, Hz", fontsize=14)
            count += 1
        
    return fig

In [10]:
def plot_psth_by_metric(area, m_name, example_units):
    unit_count = np.array([len(vals) for vals in example_units.values()]).sum()
    rows = int(np.ceil(unit_count/3))
    fig = plt.figure(figsize=(15, rows*4))
    count = 0

    for session, unit_ids in example_units.items():
        # read AEP events
        animal    = session.split('_')[0]
        aeps_file = os.path.join(source, animal, session, 'AEPs.h5')
        with h5py.File(aeps_file, 'r') as f:
            aeps_events = np.array(f['aeps_events'])
            aeps = np.array(f[area]['aeps'])

        # TODO find better way. Remove outliers
        aeps[aeps > 5000]  =  5000
        aeps[aeps < -5000] = -5000

        # read single units
        spike_times = {}
        h5_file = os.path.join(source, animal, session, session + '.h5')
        with h5py.File(h5_file, 'r') as f:
            for unit_id in unit_ids:
                spike_times[unit_id] = np.array(f['units'][unit_id][H5NAMES.spike_times['name']])

        # load metrics
        AEP_metrics_lims = {}
        AEP_metrics_raw  = {}
        AEP_metrics_norm = {}
        with h5py.File(aeps_file, 'r') as f:
            grp = f[area]
            for metric_name in grp['raw']:
                AEP_metrics_raw[metric_name]  = np.array(grp['raw'][metric_name])
                AEP_metrics_norm[metric_name] = np.array(grp['norm'][metric_name])
                AEP_metrics_lims[metric_name] = [int(x) for x in grp['raw'][metric_name].attrs['limits'].split(',')]

        # separate high / low AEP metric states
        predictor = AEP_metrics_norm[m_name]
        low_state_idxs  = np.where(predictor < predictor.mean())[0]
        high_state_idxs = np.where(predictor > predictor.mean())[0]
        aeps_low_mean  = aeps[low_state_idxs].mean(axis=0)
        aeps_high_mean = aeps[high_state_idxs].mean(axis=0)

        for unit_id in unit_ids:
            bins, counts_low  = get_spike_counts(spike_times[unit_id], aeps_events[low_state_idxs][:, 0])
            bins, counts_high = get_spike_counts(spike_times[unit_id], aeps_events[high_state_idxs][:, 0])

            vals_max = np.array([counts_high.max(), counts_low.max()]).max()
            aep_low_profile  = (1/10) * vals_max * (aeps_low_mean/500)
            aep_high_profile = (1/10) * vals_max * (aeps_high_mean/500)

            ax = fig.add_subplot(rows, 3, count+1)
            ax.hist(bins[:-1], bins=bins, weights=counts_high, edgecolor='black', color='red', alpha=0.8, label='%s >' % m_name)
            ax.hist(bins[:-1], bins=bins, weights=counts_low, edgecolor='black', color='black', alpha=0.5, label='%s <' % m_name)
            for x_l, x_r in [(-0.25, -0.051), (0.0, 0.199)]:
                ax.plot(np.linspace(x_l, x_r, len(aeps_low_mean)),  aep_high_profile, color='red', lw=2)
                ax.plot(np.linspace(x_l, x_r, len(aeps_high_mean)), aep_low_profile, color='black', lw=2)
            ax.axvline(0, color='black', ls='--')
            #ax.set_xlabel('Pulse onset, s', fontsize=14)
            ax.axvspan(0, 0.05, alpha=0.3, color='gray')
            ax.set_title("%s : %s" % (session[21:], unit_id), fontsize=14)
            ax.legend(loc='upper right', prop={'size': 10})
            if count % 3 == 0:
                ax.set_ylabel("Firing Rate, Hz", fontsize=14)
            count += 1

    return fig