In [1]:
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from statsmodels.formula.api import mixedlm
import os
import sys
from analysisUtils import construct_where_str
sys.path.append('../utils')
from matio import loadmat
from protocols import load_params
from plotting import hide_spines
from db import execute_sql, get_db_info, select_db
sys.path.append('../behavior_analysis')
from traceUtils import setUpLickingTrace

In [2]:
plt.style.use('paper_export')

In [3]:
use_cv = 'cv1.0'
paths = get_db_info()

protocols = ['SameRewDist']  #, 'DistributionalRL_6Odours', 'Bernoulli']
n_prot = len(protocols)
all_results = {}
all_ret_dfs = {}

for prot in protocols:
    
    tmp = loadmat(os.path.join('fano', use_cv, prot, prot + '_results.mat'))
    all_results[prot] = tmp['Results']
    
    # sql string from fano.m
    sql = 'SELECT ephys.figure_path, behavior_path, file_date_id, ephys.file_date, ephys.processed_data_path, ' + \
        'ephys.meta_time, stats, session.name, session.mid, sid, rid, session.exp_date, session.probe1_AP, ' + \
        'session.probe1_ML, session.probe1_DV, session.significance FROM ephys LEFT JOIN session ON ' + \
        'ephys.behavior_path = session.raw_data_path WHERE protocol="SameRewDist" AND exclude=0 AND has_ephys=1 ' + \
        'AND phase>=3 AND n_trial>=150 AND quality>=2 AND curated=1 AND session.significance=1 AND ' + \
        'probe1_region="striatum" ORDER BY session.mid ASC, ephys.file_date ASC'

    all_rets = execute_sql(sql, paths['db'])
    all_ret_dfs[prot] = pd.DataFrame(all_rets, columns=all_rets[0].keys())

In [6]:
all_results['SameRewDist']

array([[<scipy.io.matlab._mio5_params.mat_struct object at 0x14e567aba760>,
        <scipy.io.matlab._mio5_params.mat_struct object at 0x14e56794ffd0>,
        <scipy.io.matlab._mio5_params.mat_struct object at 0x14e56788ffd0>,
        <scipy.io.matlab._mio5_params.mat_struct object at 0x14e591263d60>,
        <scipy.io.matlab._mio5_params.mat_struct object at 0x14e567761fd0>,
        <scipy.io.matlab._mio5_params.mat_struct object at 0x14e567708fd0>],
       [<scipy.io.matlab._mio5_params.mat_struct object at 0x14e5679f1fd0>,
        <scipy.io.matlab._mio5_params.mat_struct object at 0x14e56792efd0>,
        <scipy.io.matlab._mio5_params.mat_struct object at 0x14e56785bfd0>,
        <scipy.io.matlab._mio5_params.mat_struct object at 0x14e567805fd0>,
        <scipy.io.matlab._mio5_params.mat_struct object at 0x14e56772dfd0>,
        <scipy.io.matlab._mio5_params.mat_struct object at 0x14e567657fd0>],
       [array([], dtype=float64), array([], dtype=float64),
        array([], dtype=fl

In [5]:
# ignore lesion sessions
colors, pi, periods, kwargs = load_params(prot)
kwargs['manipulation'] = 'combined'

# create SQL query based on keyword arguments passed to function
_, sql = construct_where_str(prot, kwargs, 'ephys')
rets = execute_sql(sql, paths['db'])
ret_df = pd.DataFrame(rets, columns=rets[0].keys())

In [6]:
# this method ensures the same color mapping across lesion and helper 
mice_rets = select_db(paths['db'], 'session', 'name', 'protocol=? AND has_ephys=1 AND significance=1', (prot,), unique=False)
all_mice = sorted(np.unique([ret['name'] for ret in mice_rets]))
# color_set = mpl.cm.get_cmap('tab20').colors[6::2] + mpl.cm.get_cmap('tab20').colors[:6:2] + mpl.cm.get_cmap('tab20').colors[1::2]
color_set = mpl.cm.get_cmap('Set3').colors
mouse_colors = {k: v for k, v in zip(all_mice, color_set)}

In [7]:
trace_dict = {'cs_in': 0,
              'cs_out': 1,
              'trace_end': 3,
              'xlim': (-1, 5),
              'ylabel': '',
              'xlabel': 'Time from CS (s)'
              }

In [8]:
# from fano.m; would be better to import directly from file
bin_centers = np.arange(50, 5501, 50)
bin_width = 100
bin_half_width = bin_width / 2
align_time = 1000

MS_PER_S = 1000
n_sess = all_results[prot].shape[0]
per_lims = [(-1, 0), (0, 1), (1, 2), (2, 3), (3, 4)]
n_per = len(per_lims)
per_bounds = []

for (x, y) in per_lims:
    start_samps = align_time + x * MS_PER_S
    end_samps = align_time + y * MS_PER_S
    start_idx = np.argmin(np.abs(bin_centers - bin_half_width - start_samps))
    end_idx = np.argmin(np.abs(bin_centers + bin_half_width - end_samps))
    per_bounds.append((start_idx, end_idx))

fano_dfs = {}
for prot in protocols:
    fano_dfs[prot] = pd.DataFrame()
    fano_df = fano_dfs[prot]
    for ff in ['FanoFactor', 'FanoFactorAll']:
        urets = pd.DataFrame(all_ret_dfs[prot]['figure_path'])
        for i_tt in range(all_results[prot].shape[1]):
            fano_cache = np.zeros((n_sess, n_per))
            for i_ret in range(n_sess):
                data = getattr(all_results[prot][i_ret, i_tt], ff)
                for i_per, (x, y) in enumerate(per_bounds):
                    fano_cache[i_ret, i_per] = np.mean(data[x:y+1])
            urets[[f'{ff}_{i_tt}_{i_per}' for i_per in range(n_per)]] = fano_cache
        melt_df = urets.melt(id_vars='figure_path', var_name=f'trial_type_{ff}', value_name=ff)
        fano_df = pd.concat([fano_df, melt_df], axis=1)
        fano_df = fano_df.loc[:,~fano_df.columns.duplicated()].copy()
        
    fano_df['trial_type'] = fano_df['trial_type_FanoFactor'].str[-3].astype(int)
    fano_df['per'] = fano_df['trial_type_FanoFactor'].str[-1].astype(int)
    fano_df['lesion'] = 'lesioned'
    is_combined = np.isin(fano_df['figure_path'], ret_df['figure_path'])
    fano_df.loc[is_combined, 'lesion'] = 'combined'
    fano_df['names'] = fano_df['figure_path'].str[-13:-9]

AttributeError: 'numpy.ndarray' object has no attribute 'FanoFactor'

In [None]:
combined_df = fano_df[fano_df['lesion'] == 'combined']
fano_avg = combined_df.groupby(['names', 'trial_type', 'per'], as_index=False).mean()
fano_avg
for ff in ['FanoFactor', 'FanoFactorAll']:
    
    g = sns.relplot(fano_avg, x='trial_type', y=ff, hue='names', col='per', palette=mouse_colors, legend=False,
                    kind='line', estimator=None, errorbar=None, aspect=1, height=2.5, zorder=1)
    g.map_dataframe(sns.pointplot, x='trial_type', y=ff, hue='trial_type', palette=colors['palette'],
                    errwidth=4, errorbar=('ci', 95))
    
    for i_ax, ax in enumerate(g.axes.flat):
        ax.set_xticks(np.arange(.5, 6, 2))
        ax.set_xticklabels(['Nothing', 'Fixed', 'Variable'])
        ax.set_xlabel('')
        ax.set_title(periods['period_names'][i_ax])
    plt.savefig('figs/fano_{}_summary_{}_across_tt.pdf'.format('_'.join(protocols), ff))
    
    plt.figure(figsize=(2.3, 1.7))
    sns.lineplot(fano_avg, x='per', y=ff, hue='names', palette=mouse_colors, legend=False, estimator='mean',
                 errorbar=None, zorder=1)
    sns.pointplot(fano_avg, x='per', y=ff, hue='per', errwidth=4, errorbar=('ci', 95))
    plt.legend().remove()
    plt.xticks(np.arange(n_per), periods['period_names'], rotation=45, ha='right', rotation_mode='anchor')
    hide_spines()
    plt.savefig('figs/fano_{}_summary_{}_across_per.pdf'.format('_'.join(protocols), ff))