##  Short-term Dynamics of Excitation and Inhibition Balance in Hippocampal CA3-CA1 Circuit
### Paper Figures

### Figure 1:
* Expt design schematic
* Expt design patterns in space and time
* Example set of traces
* Calib of stimuli: desensitization
* Control for CA3 responses to patterns.
* CA3 heatmap
* CA3 analysis with respect to freq
* CA3 expected pattern response

### Figure 2:
* Deconv to extract of STP from trace: Schematic
* Comparison between deconv fit, P2P, and peak fitting
* Fitting of reference trace
* Example deconv and fit thereof.
* STP traces under different conditions: E/I, #sq and frequency
* Diversity in STP traces.

<img src="notes_figures/Figure_outline_Fig2.jpg" width="500"/>

### Figure 3:
* Model schematic for plasticity
* Fitting free params of model for E and I
* Fitting diversity for E and I
* Fitting stochasticity for E and I
* How averaged are E and I respectively? Crucial for escape calculation.

### Figure 4:
* Model integration: schematic of chem and elec components
* Model validation 1: Match to random pattern sequence
* Model validation 2: Match to Curr clamp recordings

### Figure 5
* Example E-I traces over a burst (train)
* Separate panel to compare E and I trace shapes (you can scale E to I and have their own scale bars)
* Plot field responses corresponding to E and I
* E/I ratio of successive pulses across freq and numSq
* Summary figure across cells
Note: This figure should explain why we dont see escape from E-I balance
* Current clamp summary
* Difference between diff freq over pulses
* Difference between numSq over pulses
 
### Figure 6: Why no escape (can be merged with figure 5)
* Comparison of current clamp response with response generated from a model cell with summated E-I currents
Likely reason: inhibition builds up slowly during the train
* Gabazine vs control, which pulse, selectivity
* for the sweeps that cause escape, which pulse index, which numSq, and which frequency are they from?
* EPSP peak time should be in the same range as spike peak time in the sweeps where there is escape

### Figure 7: SDN across pulses for frequencies
* Example plot of SDN
* Grid
* Heatmap of gamma
* Other ideas:
  * SDN across frequencies: The envelope of E and I change with frequency of stimulation
  * relationship of peak of PSCs with the fall tau for both E and I
  * SDN expected vs observed plotted for different frequencies and consecutive pulse # 

<img src="notes_figures/Figure_Outline_Fig7_SDN.jpg" width="500"/>

### Figure 8: Estimations of Sparsity/Overlap from model vs obs
* Example trace vs model
* Estimation of sparsity/overlap from CA3 recordings

(Paper Figures 5 6 and 7)

<img src="notes_figures/Figure_Outline_Fig5_6_7_May2023.png" width="500"/>


### Update 1 Nov 2023
Other ideas:
* **Split poisson train response across patterns and numSq to check if different patterns cause escape at different points in the freq train**
* check the same for oddball expt also

<img src="notes_figures/Figure_updates_Fig5_6_7_1Nov2023.png" width="500"/>

In [1]:
import sys
import os
import importlib
from   pathlib      import Path
import pickle
import psutil

import numpy                as np
import matplotlib           as mpl
import matplotlib.pyplot    as plt
import seaborn              as sns
import pandas               as pd

from scipy.signal   import find_peaks, peak_widths
from scipy.signal   import butter, bessel, decimate, sosfiltfilt
from scipy.stats    import kruskal, wilcoxon, mannwhitneyu, ranksums
from scipy.signal   import filter_design
from scipy.optimize import curve_fit

from PIL            import Image

from eidynamics     import utils, data_quality_checks, ephys_classes, plot_tools, expt_to_dataframe
from eidynamics     import pattern_index
from eidynamics     import abf_to_data
from eidynamics     import fit_PSC
from Findsim        import tab_presyn_patterns_LR_43
import parse_data
import all_cells
import plotFig2
from stat_annotate import *

%matplotlib widget
%tb
# import plotly.express as px
# import plotly.graph_objects as go
# sns.set_context('paper')

# make a colour map viridis
viridis = mpl.colormaps["viridis"]
flare   = mpl.colormaps["flare"]
crest   = mpl.colormaps["crest"]
magma   = mpl.colormaps["magma"]
edge    = mpl.colormaps['edge']

color_E = "flare"
color_I = "crest"
color_freq = {1:magma(0.05), 5:magma(0.1), 10:magma(0.2), 20:magma(.4), 30:magma(.5), 40:magma(.6), 50:magma(.7), 100:magma(.9)}
color_squares = color_squares = {1:viridis(0.2), 5:viridis(.4), 7:viridis(.6), 15:viridis(.8), 20:viridis(1.0)}
color_EI = {-70:flare(0), 0:crest(0)}

Fs = 2e4

freq_sweep_pulses = np.arange(9)

Data parsing program imported
>> Working on:  C:\Users\adity\OneDrive\NCBS


No traceback available to show.


### Path

In [2]:
figure_raw_material_location = Path(r"paper_figure_matter\\")
paper_figure_export_location = Path(r"paper_figures\\")
data_path                    = Path(r"parsed_data\\")
cell_data_path               = Path(r"C:\Users\adity\OneDrive\NCBS\Lab\Projects\EI_Dynamics\Data\Screened_cells\\")

### Data

### Data for figures 1 and 2

In [3]:
# Full data set for FreqSweep protocol (df) (raw dataframe with metadata)
freq_sweep_cc_datapath =  r"C:\Users\adity\OneDrive\NCBS\Lab\Projects\EI_Dynamics\Analysis\parsed_data\all_cells_FreqSweep_CC_long.h5" 
freq_sweep_vc_datapath =  r"C:\Users\adity\OneDrive\NCBS\Lab\Projects\EI_Dynamics\Analysis\parsed_data\all_cells_FreqSweep_VC_long.h5" 
cc_FS_longdf = pd.read_hdf(freq_sweep_cc_datapath, key='data')

# print new size of dataframe
print(f"Full FreqSweep Dataframe has {cc_FS_longdf.shape[0]} sweeps and {cc_FS_longdf.shape[1]} columns")

Full FreqSweep Dataframe has 4971 sweeps and 80073 columns


In [4]:
# # short data path for all protocols (df2) (processed dataframe with metadata and analysed params)
dfshortpath     = r"C:\Users\adity\OneDrive\NCBS\Lab\Projects\EI_Dynamics\Analysis\parsed_data\all_cells_allprotocols_with_fpr_values.h5"
xc_all_shortdf  = pd.read_hdf(dfshortpath, key='data')
print(f"Short Dataframe has {xc_all_shortdf.shape[0]} sweeps and df has {xc_all_shortdf.shape[1]} columns")

Short Dataframe has 16870 sweeps and df has 63 columns


In [5]:
# expanded dataframe (processed dataframe with metadata and analysed params)
expanded_data_path = r"C:\Users\adity\OneDrive\NCBS\Lab\Projects\EI_Dynamics\Analysis\parsed_data\all_cells_FreqSweep_combined_expanded.h5"
xc_FS_analyseddf = pd.read_hdf(expanded_data_path, key='data')
print(xc_FS_analyseddf.shape)

(7770, 100)


In [6]:
# DataFrame metadata columns
num_metadata_columns = 49
num_dataflag_columns = -10
column_name_abbreviations = utils.analysed_properties1_abbreviations
metadata_fields = (cc_FS_longdf.columns[:num_metadata_columns]).to_list()
dataflag_fields = (cc_FS_longdf.columns[-24:-14]).to_list()

### Data for figures 5, 6, 7

In [None]:
# expanded dataframe with channels separated into rows
channel_split_data_path_vc = r"C:\Users\adity\OneDrive\NCBS\Lab\Projects\EI_Dynamics\Analysis\parsed_data\screened_cells_FreqSweep_VC_long_channel_split_into_rows.h5"
channel_split_data_path_cc = r"C:\Users\adity\OneDrive\NCBS\Lab\Projects\EI_Dynamics\Analysis\parsed_data\screened_cells_FreqSweep_CC_long_channel_split_into_rows.h5"
# df4 = pd.read_hdf(channel_split_data_path_vc, key='data')
df5 = pd.read_hdf(channel_split_data_path_cc, key='data')

# # Load CA3 data
# df_CA3_props = pd.read_csv(r"C:\Users\adity\OneDrive\NCBS\Lab\Projects\EI_Dynamics\Data\screened_cells\23-06-06_GrikAA316\3161\CA3_recording_3161_response_properties.csv")

# # cell list
# vc_cells = np.unique( df2 [ (df2['clampMode']=='VC') ]['cellID'])
# cc_cells = np.unique(df['cellID'])
# vc_cells_screened = np.array([7492, 1931, 1621, 1541, 1524, 1523, 1522, 1491]) # 7491, 6201, 6301, 111, 1531


In [None]:
vc_cells_screened = np.array([7492, 1931, 1621, 1541, 1524, 1523, 1522, 1491])

---

#### Figure 1

* Expt design schematic
* Expt design patterns in space and time
* Example set of traces
* Calib of stimuli: desensitization
* Control for CA3 responses to patterns.
* CA3 heatmap
* CA3 analysis with respect to freq
* CA3 expected pattern response

In [7]:
# data screening based on dataflag_fields
df = cc_FS_longdf
dfslice = df[
            (df['location'] == 'CA1') &
            (df['numSq'].isin([1,5,15])) &
            (df['AP'] == 0) &
            (df['IR'] >50) & (df['IR'] < 300) &
            (df['tau'] < 0.05) & (df['tau'] > 0) &
            (df['intensity'] == 100) &
            (df['pulseWidth'] == 2) &
            (df['sweepBaseline'] < -50) &
            (df['condition'] == 'Control') &
            (df['ch0_response']==1) &
            (df['spike_in_stim_period'] == 0) &
            (df['spike_in_baseline_period'] == 0) &
            (df['ac_noise_power_in_ch0'] < 20) 
        ]

screened_trialIDs = dfslice['trialID'].unique()

print(f"Unique cells in screened data: {dfslice['cellID'].nunique()}")
print(f"Unique sweeps in screened data: {dfslice['trialID'].nunique()}")

Unique cells in screened data: 15
Unique sweeps in screened data: 2155


In [8]:
df2 = xc_all_shortdf[(xc_all_shortdf['trialID'].isin(screened_trialIDs)) ]

### Figure 2: CA1 and CA3 responses to grid spots
* Deconv to extract of STP from trace: Schematic
* Comparison between deconv fit, P2P, and peak fitting
* Fitting of reference trace
* Example deconv and fit thereof.
* STP traces under different conditions: E/I, #sq and frequency
* Diversity in STP traces.

<img src="notes_figures/Figure_outline_Fig2.jpg" height="300"/>

updates from meeting on 24 Jan 2024
* Panel A: Raw data normalized pulse response, x = 'pulseIndex', y='pulseResponse', hue='stimFreq' for E
* Panel B: Raw data normalized pulse response, x = 'pulseIndex', y='pulseResponse', hue='stimFreq' for I
* Panel C: Across all cells
* Panel D,E: Upi's panel Fig2 A, B
* Panel F,G: Upi's panel Fig2 C, D

<img src="notes_figures/Figure_updates_Fig2_24Jan2024.jpg" height="500"/>

#### Load VC data

In [None]:
# Load the dataset
freq_sweep_vc_datapath =  r"C:\Users\adity\OneDrive\NCBS\Lab\Projects\EI_Dynamics\Analysis\parsed_data\all_cells_FreqSweep_VC_long.h5" 
df = pd.read_hdf(freq_sweep_vc_datapath, key='data')

#### Screen for cells with good responses

In [None]:
# data screening based on dataflag_fields
dfslice = df[
            (df['location'] == 'CA1') &
            (df['numSq'].isin([1,5,15])) &
            (df['AP'] == 0) &
            (df['IR'] >50) & (df['IR'] < 300) &
            (df['tau'] < 40) & 
            (df['intensity'] == 100) &
            (df['pulseWidth'] == 2) &
            # (df['sweepBaseline'] < -50) &
            (df['condition'] == 'Control') &
            (df['ch0_response']==1) &
            # (df['spike_in_stim_period'] == 0) &
            (df['spike_in_baseline_period'] == 0) &
            (df['ac_noise_power_in_ch0'] < 40) 
        ]

vc_screened_trialIDs = dfslice['trialID'].unique()

print(f"Unique cells in screened data: {dfslice['cellID'].nunique()}")
print(f"Unique sweeps in screened data: {dfslice['trialID'].nunique()}")

In [None]:
df3 = xc_FS_analyseddf[xc_FS_analyseddf['trialID'].isin(vc_screened_trialIDs)]

In [None]:
plot_kind = 'strip' # 'line' or 'violin' or 'strip'
column_name_abbreviations = ['pc','pcn','ac','sc','dc','pf','pfn']
df3 = xc_FS_analyseddf[xc_FS_analyseddf['trialID'].isin(vc_screened_trialIDs)]

# ---------------------------------------------------------------------------------------------------------------------------------
# Setup the figure
w, h =  21, 29.7
# make a figure of 7 subplots in 2 columns, in which the first subplot spans two columns and first row, and the rest span 1row, 1 column each
Fig2, ax2 = plt.subplot_mosaic([['A', 'B'],['C', 'C'],['D','D'],['E', 'E'],['F', 'G'],['H','I'],['J','K']], figsize=(w, h), )
# change the spacing between plots
Fig2.subplots_adjust(hspace=1.0, wspace=0.25)
# linearize the ax2 list
ax2 = [ax2[key] for key in ax2.keys()]

# ---------------------------------------------------------------------------------------------------------------------------------
# Fig 2A: STP vs numSq for a sample cell, cp = -70mV
ax2[0].text(-0.1, 1.1, 'A', transform=ax2[0].transAxes, size=20, weight='bold')
cell =  7492
sq   =  15
cp   = -70        # clamping potential subset for the plot
# f = 20          # stimFreq, Hz
to_plot = [f'pcn{i}' for i in freq_sweep_pulses]
df_temp = df3[ (df3['cellID'] == cell) & (df3['numSq'] == sq) & (df3['clampPotential'] == cp)] #
df_melt = pd.melt( df_temp, id_vars=['cellID', 'stimFreq', 'clampPotential', 'numSq', 'patternList'], value_vars=to_plot, var_name='pulseIndex', value_name='peak_response',)

pairwise_draw_and_annotate_line_plot(   ax2[0], df_melt, x='pulseIndex', y='peak_response', hue='stimFreq', draw=True, kind=plot_kind, palette=color_E, 
                                        stat_across='hue', stat=kruskal, skip_first_xvalue=True, annotate_wrt_data=False, offset_btw_star_n_line=0.1, color='grey', coord_system='data', fontsize=12, zorder=10,)
ax2[0].set_xlabel('Pulse Index')
# set xtick location and labels
ax2[0].set_xticks(range(9))
ax2[0].set_xticklabels(freq_sweep_pulses)

# y-axis
ax2[0].set_ylabel('Norm. Response')
ax2[0].set_ylim(0,3)
# no legend
ax2[0].legend([],[], frameon=False)

# # get the lgened labels of ax2d_top and add ' Sq' to each one
# handles, labels = ax2['A'].get_legend_handles_labels()
# labels = [label + ' Sq.' for label in labels]
# ax2d_top.legend(handles, labels, loc='upper right', borderaxespad=0., frameon=False)
# # add a text in the top left corner of ax2d_top showing '20 Hz' in color = color_freq[20]
# ax2d_top.text(0.0, 0.9, f'{f} Hz', transform=ax2d_top.transAxes, size=12, color=color_freq[f], zorder=10)

# ---------------------------------------------------------------------------------------------------------------------------------
# Fig 2B: STP vs numSq for the same cell, cp = 0mV
ax2[1].text(-0.1, 1.1, 'B', transform=ax2[1].transAxes, size=20, weight='bold')

cp = 0        # clamping potential subset for the plot
# f = 20          # stimFreq, Hz
to_plot = [f'pcn{i}' for i in freq_sweep_pulses]
df_temp = df3[ (df3['cellID'] == cell) & (df3['numSq'] == sq)  & (df3['clampPotential'] == cp)] #
df_melt = pd.melt( df_temp, id_vars=['cellID', 'stimFreq', 'clampPotential', 'numSq', 'patternList'], value_vars=to_plot, var_name='pulseIndex', value_name='peak_response',)

pairwise_draw_and_annotate_line_plot(   ax2[1], df_melt, x='pulseIndex', y='peak_response', hue='stimFreq', draw=True, kind=plot_kind, palette=color_I, 
                                        stat_across='hue', stat=kruskal, skip_first_xvalue=True, annotate_wrt_data=False, offset_btw_star_n_line=0.1, color='grey', coord_system='data', fontsize=12, zorder=10,)
# plot_tools.simplify_axes( ax2[1], splines_to_keep=['bottom'], axis_offset=10, remove_ticks=False, xtick_locs=range(9), xtick_labels=freq_sweep_pulses, ytick_locs=range(3), ytick_labels=range(3),)
ax2[1].set_xlabel('Pulse Index')
# set xtick location and labels
ax2[1].set_xticks(range(9))
ax2[1].set_xticklabels(freq_sweep_pulses)
# remove y-ticks and y-axis label
ax2[1].set_ylim(0,3)
ax2[1].set_ylabel('')
ax2[1].legend([],[], frameon=False)


# ---------------------------------------------------------------------------------------------------------------------------------
# Fig 2C: STP vs numSq for all screnned cells cp = -70mV
ax2[2].clear()
ax2[2].text(-0.1, 1.1, 'C', transform=ax2[2].transAxes, size=20, weight='bold')

cp = -70        # clamping potential subset for the plot
# f = 20          # stimFreq, Hz
to_plot = [f'pcn{i}' for i in freq_sweep_pulses]
df_temp = df3[ (df3['numSq'] == sq) & (df3['clampPotential'] == cp) ] #
df_melt = pd.melt( df_temp, id_vars=['cellID', 'stimFreq', 'clampPotential', 'numSq', 'patternList'], value_vars=to_plot, var_name='pulseIndex', value_name='peak_response',)

pairwise_draw_and_annotate_line_plot(   ax2[2], df_melt, x='pulseIndex', y='peak_response', hue='stimFreq', draw=True, kind='strip', palette=color_E, 
                                        stat_across='hue', stat=kruskal, skip_first_xvalue=True, annotate_wrt_data=False, offset_btw_star_n_line=0.1, color='grey', coord_system='data', fontsize=12, zorder=10,)
# plot_tools.simplify_axes( ax2[2], splines_to_keep=['left'], axis_offset=10, remove_ticks=False, xtick_locs=range(9), xtick_labels=freq_sweep_pulses, ytick_locs=range(5), ytick_labels=range(5),)
ax2[2].set_ylabel('Norm. Response')
# ax2[2].set_ylim(0,5)
# remove x-ticks and x-axis label
ax2[2].set_xlabel('')
ax2[2].set_xticklabels([])
ax2[2].legend([],[], frameon=False)

# ---------------------------------------------------------------------------------------------------------------------------------
# Fig 2D: STP vs numSq for all screnned cells for cp = 0mV
ax2[3].clear()
ax2[3].text(-0.1, 1.1, 'D', transform=ax2[3].transAxes, size=20, weight='bold')

cp = 0        # clamping potential subset for the plot
# f = 20          # stimFreq, Hz
to_plot = [f'pcn{i}' for i in freq_sweep_pulses]
df_temp = df3[  (df3['numSq'] == sq) & (df3['stimFreq'] <100 )& (df3['clampPotential'] == cp)  ] #
df_melt = pd.melt( df_temp, id_vars=['cellID', 'stimFreq', 'clampPotential', 'numSq', 'patternList'], value_vars=to_plot, var_name='pulseIndex', value_name='peak_response',)

pairwise_draw_and_annotate_line_plot(   ax2[3], df_melt, x='pulseIndex', y='peak_response', hue='stimFreq', draw=True, kind='strip', palette=color_I, 
                                        stat_across='hue', stat=kruskal, skip_first_xvalue=True, annotate_wrt_data=False, offset_btw_star_n_line=0.1, color='grey', coord_system='data', fontsize=12, zorder=10,)
# plot_tools.simplify_axes( ax2[3], splines_to_keep=['bottom', 'left'], axis_offset=10, remove_ticks=False, xtick_locs=range(9), xtick_labels=freq_sweep_pulses, ytick_locs=range(5), ytick_labels=range(5),)
ax2[3].set_xlabel('Pulse Index')
ax2[3].set_ylabel('Norm. Response')
ax2[3].set_ylim(0,5)

# ---------------------------------------------------------------------------------------------------------------------------------
# Upi's panels
_ = plotFig2.main(df, Fig2, ax2[4:], cellNum=cell, numSq=sq)

# ---------------------------------------------------------------------------------------------------------------------------------
for ax in ax2:
    sns.despine(fig=Fig2, ax=ax, top=True, right=True, left=False, bottom=False, offset=0.1, trim=True)

# ---------------------------------------------------------------------------------------------------------------------------------
# Save figure2
figure_name = 'Figure2'
subfolder2  = 'Figure2'
Fig2.savefig(paper_figure_export_location / subfolder2 / (figure_name + '.png'), dpi=300, bbox_inches='tight')
Fig2.savefig(paper_figure_export_location / subfolder2 / (figure_name + '.svg'), dpi=300, bbox_inches='tight')

In [None]:
importlib.reload(plotFig2)

In [None]:
plot_kind = 'strip' # 'line' or 'violin' or 'strip'
column_name_abbreviations = ['pc','pcn','ac','sc','dc','pf','pfn']
df3 = xc_FS_analyseddf[xc_FS_analyseddf['trialID'].isin(vc_screened_trialIDs)]

# ---------------------------------------------------------------------------------------------------------------------------------
# Setup the figure
w, h =  21, 29.7
# make a figure of 7 subplots in 2 columns, in which the first subplot spans two columns and first row, and the rest span 1row, 1 column each
Fig2, ax2 = plt.subplot_mosaic([['A', 'B'],['C', 'C'],['D','D'],['E', 'E'],['F', 'G'],['H','I'],['J','K']], figsize=(w, h), )
# change the spacing between plots
Fig2.subplots_adjust(hspace=1.0, wspace=0.25)
# linearize the ax2 list
ax2 = [ax2[key] for key in ax2.keys()]

# ---------------------------------------------------------------------------------------------------------------------------------
# Fig 2A: STP vs numSq for a sample cell, cp = -70mV
ax2[0].text(-0.1, 1.1, 'A', transform=ax2[0].transAxes, size=20, weight='bold')
cell =  7492
sq   =  15
cp   = -70        # clamping potential subset for the plot
# f = 20          # stimFreq, Hz
to_plot = [f'pcn{i}' for i in freq_sweep_pulses]
df_temp = df3[ (df3['cellID'] == cell) & (df3['numSq'] == sq) & (df3['clampPotential'] == cp)] #
df_melt = pd.melt( df_temp, id_vars=['cellID', 'stimFreq', 'clampPotential', 'numSq', 'patternList'], value_vars=to_plot, var_name='pulseIndex', value_name='peak_response',)

pairwise_draw_and_annotate_line_plot(   ax2[0], df_melt, x='pulseIndex', y='peak_response', hue='pulseIndex', draw=True, kind=plot_kind, palette="flare", dodge=False, 
                                        stat_across=None, stat=kruskal, skip_first_xvalue=True, annotate_wrt_data=False, offset_btw_star_n_line=0.1, color='grey', coord_system='data', fontsize=12, zorder=10,)
ax2[0].set_xlabel('Pulse Index')
# set xtick location and labels
ax2[0].set_xticks(range(9))
ax2[0].set_xticklabels(freq_sweep_pulses)

# y-axis
ax2[0].set_ylabel('Norm. Response')
ax2[0].set_ylim(0,3)
# no legend
ax2[0].legend([],[], frameon=False)

# # get the lgened labels of ax2d_top and add ' Sq' to each one
# handles, labels = ax2['A'].get_legend_handles_labels()
# labels = [label + ' Sq.' for label in labels]
# ax2d_top.legend(handles, labels, loc='upper right', borderaxespad=0., frameon=False)
# # add a text in the top left corner of ax2d_top showing '20 Hz' in color = color_freq[20]
# ax2d_top.text(0.0, 0.9, f'{f} Hz', transform=ax2d_top.transAxes, size=12, color=color_freq[f], zorder=10)

# ---------------------------------------------------------------------------------------------------------------------------------
# Fig 2B: STP vs numSq for the same cell, cp = 0mV
ax2[1].text(-0.1, 1.1, 'B', transform=ax2[1].transAxes, size=20, weight='bold')

cp = 0        # clamping potential subset for the plot
# f = 20          # stimFreq, Hz
to_plot = [f'pcn{i}' for i in freq_sweep_pulses]
df_temp = df3[ (df3['cellID'] == cell) & (df3['numSq'] == sq)  & (df3['clampPotential'] == cp)] #
df_melt = pd.melt( df_temp, id_vars=['cellID', 'stimFreq', 'clampPotential', 'numSq', 'patternList'], value_vars=to_plot, var_name='pulseIndex', value_name='peak_response',)

pairwise_draw_and_annotate_line_plot(   ax2[1], df_melt, x='pulseIndex', y='peak_response', hue='pulseIndex', draw=True, kind=plot_kind, palette="crest", dodge=False, 
                                        stat_across=None, stat=kruskal, skip_first_xvalue=True, annotate_wrt_data=False, offset_btw_star_n_line=0.1, color='grey', coord_system='data', fontsize=12, zorder=10,)
# plot_tools.simplify_axes( ax2[1], splines_to_keep=['bottom'], axis_offset=10, remove_ticks=False, xtick_locs=range(9), xtick_labels=freq_sweep_pulses, ytick_locs=range(3), ytick_labels=range(3),)
ax2[1].set_xlabel('Pulse Index')
# set xtick location and labels
ax2[1].set_xticks(range(9))
ax2[1].set_xticklabels(freq_sweep_pulses)
# remove y-ticks and y-axis label
ax2[1].set_ylim(0,3)
ax2[1].set_ylabel('')
ax2[1].legend([],[], frameon=False)


# ---------------------------------------------------------------------------------------------------------------------------------
# Fig 2C: STP vs numSq for all screnned cells cp = -70mV
ax2[2].clear()
ax2[2].text(-0.1, 1.1, 'C', transform=ax2[2].transAxes, size=20, weight='bold')

cp = -70        # clamping potential subset for the plot
# f = 20          # stimFreq, Hz
to_plot = [f'pcn{i}' for i in freq_sweep_pulses]
df_temp = df3[ (df3['numSq'] == sq) & (df3['clampPotential'] == cp) ] #
df_melt = pd.melt( df_temp, id_vars=['cellID', 'stimFreq', 'clampPotential', 'numSq', 'patternList'], value_vars=to_plot, var_name='pulseIndex', value_name='peak_response',)

pairwise_draw_and_annotate_line_plot(   ax2[2], df_melt, x='pulseIndex', y='peak_response', hue='pulseIndex', draw=True, kind='strip', palette="flare", dodge=False, 
                                        stat_across=None, stat=kruskal, skip_first_xvalue=True, annotate_wrt_data=False, offset_btw_star_n_line=0.1, color='grey', coord_system='data', fontsize=12, zorder=10,)
# plot_tools.simplify_axes( ax2[2], splines_to_keep=['left'], axis_offset=10, remove_ticks=False, xtick_locs=range(9), xtick_labels=freq_sweep_pulses, ytick_locs=range(5), ytick_labels=range(5),)
ax2[2].set_ylabel('Norm. Response')
ax2[2].set_ylim(0,6)
# remove x-ticks and x-axis label
ax2[2].set_xlabel('')
ax2[2].set_xticklabels([])
ax2[2].legend([],[], frameon=False)

# ---------------------------------------------------------------------------------------------------------------------------------
# Fig 2D: STP vs numSq for all screnned cells for cp = 0mV
ax2[3].clear()
ax2[3].text(-0.1, 1.1, 'D', transform=ax2[3].transAxes, size=20, weight='bold')

cp = 0        # clamping potential subset for the plot
# f = 20          # stimFreq, Hz
to_plot = [f'pcn{i}' for i in freq_sweep_pulses]
df_temp = df3[  (df3['numSq'] == sq) & (df3['stimFreq'] <100 )& (df3['clampPotential'] == cp)  ] #
df_melt = pd.melt( df_temp, id_vars=['cellID', 'stimFreq', 'clampPotential', 'numSq', 'patternList'], value_vars=to_plot, var_name='pulseIndex', value_name='peak_response',)

pairwise_draw_and_annotate_line_plot(   ax2[3], df_melt, x='pulseIndex', y='peak_response', hue='pulseIndex', draw=True, kind='strip', palette="crest", dodge=False, 
                                        stat_across=None, stat=kruskal, skip_first_xvalue=True, annotate_wrt_data=False, offset_btw_star_n_line=0.1, color='grey', coord_system='data', fontsize=12, zorder=10,)
# plot_tools.simplify_axes( ax2[3], splines_to_keep=['bottom', 'left'], axis_offset=10, remove_ticks=False, xtick_locs=range(9), xtick_labels=freq_sweep_pulses, ytick_locs=range(5), ytick_labels=range(5),)
ax2[3].set_xlabel('Pulse Index')
ax2[3].set_ylabel('Norm. Response')
ax2[3].set_ylim(0,6)

# ---------------------------------------------------------------------------------------------------------------------------------
# Upi's panels
_ = plotFig2.main(df, Fig2, ax2[4:], cellNum=cell, numSq=sq)

# ---------------------------------------------------------------------------------------------------------------------------------
for ax in ax2:
    sns.despine(fig=Fig2, ax=ax, top=True, right=True, left=False, bottom=False, offset=0.1, trim=True)

# ---------------------------------------------------------------------------------------------------------------------------------
# # Save figure2
figure_name = 'Figure2x2_combined_stim_freq'
subfolder2  = 'Figure2'
Fig2.savefig(paper_figure_export_location / subfolder2 / (figure_name + '.png'), dpi=300, bbox_inches='tight')
Fig2.savefig(paper_figure_export_location / subfolder2 / (figure_name + '.svg'), dpi=300, bbox_inches='tight')

### fig 2 Old

### Figure 2 old with voltage clamp

In [None]:
# aspect ration of the figure = 1
w, h = [15, 9]
fig2, [[ax2a, ax2b, ax2c],[ax2d_top, ax2e_top, ax2f_top],[ax2d_bottom, ax2e_bottom, ax2f_bottom]] = plt.subplots(3,3, figsize=(w, h), sharey=False)
fig2.subplots_adjust(hspace=0.5, wspace=0.5)
plot_kind = 'violin' # 'line' or 'violin' or 'strip'

column_name_abbreviations = ['pc','pcn','ac','sc','dc','pf','pfn']
cell =  1491
sq   =  15

# -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
# Fig 2A,B,C: Upi's deconvolution fits
ax2a.text(-0.1, 1.1, 'A', transform=ax2a.transAxes, size=20, weight='bold')
ax2b.text(-0.1, 1.1, 'B', transform=ax2b.transAxes, size=20, weight='bold')
ax2c.text(-0.1, 1.1, 'C', transform=ax2c.transAxes, size=20, weight='bold')

sweepnum = 0
df_temp = dfslice[(dfslice['cellID'] == cell) & (dfslice['numSq'] == sq)]
sweep = df_temp.iloc[sweepnum, :]
tracecell, tracestim, stimfreq = sweep[49:20049], sweep[40049:60049], sweep['stimFreq']

ax2a.plot(np.linspace(0, 1, 20000),       tracecell, label='recording') # Raw data
ax2a.plot(np.linspace(0, 1, 20000), 100 * tracestim, label='Stim')      # Stimulus

# find the expected response, add to the axes
fits, _, _ = find_sweep_expected(tracecell, stimfreq, fig2, [ax2a, ax2b, ax2c])

# simplify axes
plot_tools.simplify_axes(ax2a, splines_to_keep=['bottom', 'left'], axis_offset=10, remove_ticks=False, xtick_locs=[0, 0.2, 0.4, 0.6, 0.8, 1.0], xtick_labels=[0, 0.2, 0.4, 0.6, 0.8, 1.0], ytick_locs=[-150, -100, -50, 0, 50], ytick_labels=[-150, -100, -50, 0, 50],)
ax2a.legend(bbox_to_anchor=(0, 1), loc='upper left', borderaxespad=0., frameon=False)
ax2a.set_ylabel('membrane potential (pA)')

plot_tools.simplify_axes(ax2b, splines_to_keep=['bottom', 'left'], axis_offset=10, remove_ticks=False, xtick_locs=[0, 0.2, 0.4, 0.6, 0.8, 1.0], xtick_labels=[0, 0.2, 0.4, 0.6, 0.8, 1.0], ytick_locs=[-50, 0, 50], ytick_labels=[-50, 0, 50],)
ax2b.legend([], frameon=False)
ax2b.set_ylabel('Residual (pA)')

plot_tools.simplify_axes( ax2c, splines_to_keep=['bottom', 'left'], axis_offset=10, remove_ticks=False, xtick_locs=range(9), xtick_labels=freq_sweep_pulses, ytick_locs=[0, 1.0, 2.0], ytick_labels=[0, 1.0, 2.0],)
ax2c.legend([], frameon=False)
ax2c.set_ylabel('Normalized Response')

# -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
# Fig 2D: STP vs numSq
ax2d_top.text(-0.1, 1.1, 'D', transform=ax2d_top.transAxes, size=20, weight='bold')

cp = -70        # clamping potential subset for the plot
f = 20          # stimFreq, Hz
to_plot = [f'pcn{i}' for i in freq_sweep_pulses]
df_temp = df3[ (df3['cellID'] == cell) & (df3['stimFreq'] == f) & (df3['clampPotential'] == cp) ] 
df_melt = pd.melt( df_temp, id_vars=['cellID', 'stimFreq', 'clampPotential', 'numSq', 'patternList'], value_vars=to_plot, var_name='pulseIndex', value_name='peak_response',)

pairwise_draw_and_annotate_line_plot(   ax2d_top, df_melt, x='pulseIndex', y='peak_response', hue='numSq', draw=True, kind=plot_kind, palette=color_squares, 
                                        stat_across='hue', stat=stats.kruskal, skip_first_xvalue=True, annotate_wrt_data=False, offset_btw_star_n_line=0.1, color='grey', coord_system='data', fontsize=12, zorder=10,)
plot_tools.simplify_axes( ax2d_top, splines_to_keep=['bottom', 'left'], axis_offset=10, remove_ticks=False, xtick_locs=range(9), xtick_labels=freq_sweep_pulses, ytick_locs=[0, 1, 2], ytick_labels=[0, 1, 2],)
ax2d_top.set_xlabel('Pulse Index')
ax2d_top.set_ylabel('Norm. Response')

# get the lgened labels of ax2d_top and add ' Sq' to each one
handles, labels = ax2d_top.get_legend_handles_labels()
labels = [label + ' Sq.' for label in labels]
ax2d_top.legend(handles, labels, loc='upper right', borderaxespad=0., frameon=False)
# add a text in the top left corner of ax2d_top showing '20 Hz' in color = color_freq[20]
ax2d_top.text(0.0, 0.9, f'{f} Hz', transform=ax2d_top.transAxes, size=12, color=color_freq[f], zorder=10)
ax2d_top.text(0.2, 0.9, f'{cp} mV', transform=ax2d_top.transAxes, size=12, color=color_EI[cp], zorder=10)

### Fig 2D_Bottom: STP vs numSq
# field plot in the bottom subplot
to_plot = [f'pfn{i}' for i in range(9)]

df_temp = df3[ (df3['cellID'] == cell) & (df3['stimFreq'] == f) & (df3['clampPotential'] == cp) ]
df_melt = pd.melt( df_temp, id_vars=['cellID', 'stimFreq', 'clampPotential', 'numSq', 'patternList'], value_vars=to_plot, var_name='pulseIndex', value_name='peak_response',)

pairwise_draw_and_annotate_line_plot(   ax2d_bottom, df_melt, x='pulseIndex', y='peak_response', hue='numSq', draw=True, kind=plot_kind, palette=color_squares,
                                        stat_across='hue', stat=stats.kruskal, skip_first_xvalue=True, annotate_wrt_data=False, offset_btw_star_n_line=0.1, color='grey', coord_system='data', fontsize=12, zorder=10,)
plot_tools.simplify_axes( ax2d_bottom, splines_to_keep=['bottom', 'left'], axis_offset=10, remove_ticks=False, xtick_locs=range(9), xtick_labels=freq_sweep_pulses, ytick_locs=[0, 1, 2], ytick_labels=[0, 1, 2],)

ax2d_bottom.set_xlabel('Pulse Index')
ax2d_bottom.set_ylabel('Norm. Response')
ax2d_bottom.legend([], frameon=False)

# -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
# Fig 2E: STP vs Frequency
ax2e_top.text(-0.1, 1.1, 'E', transform=ax2e_top.transAxes, size=20, weight='bold')

s = 5  # squares
cp = -70
to_plot = [f'pcn{i}' for i in freq_sweep_pulses]
df_temp = df3[ (df3['cellID'] == cell) & (df3['clampPotential'] == cp) & (df3['numSq'] == s) ]
df_melt = pd.melt( df_temp[df_temp['stimFreq'] < 100], id_vars=['cellID', 'stimFreq', 'clampPotential', 'numSq', 'patternList'], value_vars=to_plot, var_name='pulseIndex', value_name='peak_response',)

pairwise_draw_and_annotate_line_plot(   ax2e_top, df_melt, x='pulseIndex', y='peak_response', hue='stimFreq', draw=True, kind='violin', palette=color_freq, stat_across='hue',
                                        stat=stats.kruskal, skip_first_xvalue=True, annotate_wrt_data=False, offset_btw_star_n_line=0.1, color='grey', coord_system='data', fontsize=12, zorder=10,)
plot_tools.simplify_axes( ax2e_top, splines_to_keep=['bottom', 'left'], axis_offset=10, remove_ticks=False, xtick_locs=range(9), xtick_labels=freq_sweep_pulses, ytick_locs=[0, 1, 2], ytick_labels=[0, 1, 2],)
ax2e_top.set_xlabel('Pulse Index')
ax2e_top.set_ylabel('Norm. Response')

# get the lgened labels of ax2e_top and add ' Hz' to each one
handles, labels = ax2e_top.get_legend_handles_labels()
labels = [label + ' Hz' for label in labels]
ax2e_top.legend(handles, labels, loc='upper right', borderaxespad=0., frameon=False)

# add a text in the top left corner of ax2d_top showing '5 Sq' in color = color_squares[5]
ax2e_top.text(0.0, 0.9, f'{s} Sq', transform=ax2e_top.transAxes, size=12, color=color_squares[s], zorder=10)
ax2e_top.text(0.2, 0.9, f'{cp} mV', transform=ax2e_top.transAxes, size=12, color=color_EI[cp], zorder=10)


### 2E Bottom: Field plot in the bottom
to_plot = [f'pfn{i}' for i in range(9)]

df_temp = df3[ (df3['cellID'] == cell) & (df3['clampPotential'] == cp) & (df3['numSq'] == s) ]
df_melt = pd.melt( df_temp[df_temp['stimFreq'] < 100], id_vars=['cellID', 'stimFreq', 'clampPotential', 'numSq', 'patternList'], value_vars=to_plot, var_name='pulseIndex', value_name='peak_response',)

pairwise_draw_and_annotate_line_plot(   ax2e_bottom, df_melt, x='pulseIndex', y='peak_response', hue='stimFreq', draw=True, kind='violin', palette=color_freq,
                                        stat_across='hue', stat=stats.kruskal, skip_first_xvalue=True, annotate_wrt_data=False, offset_btw_star_n_line=0.1, color='grey', coord_system='data', fontsize=12, zorder=10,)
plot_tools.simplify_axes( ax2e_bottom, splines_to_keep=['bottom', 'left'], axis_offset=10, remove_ticks=False, xtick_locs=range(9), xtick_labels=freq_sweep_pulses, ytick_locs=[0, 1, 2], ytick_labels=[0, 1, 2],)
ax2e_bottom.set_xlabel('Pulse Index')
ax2e_bottom.set_ylabel('Norm. Response')
ax2e_bottom.legend([], frameon=False)

# -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
# Fig 2F: STP vs E/I
ax2f_top.text(-0.1, 1.1, 'F', transform=ax2f_top.transAxes, size=20, weight='bold')

f = 20  # Hz
s = 5  # squares
to_plot = [f'pcn{i}' for i in freq_sweep_pulses]
df_temp = df3[ (df3['cellID'] == cell) & (df3['stimFreq'] == f) & (df3['numSq'] == s)]
df_melt = pd.melt( df_temp, id_vars=['cellID', 'stimFreq', 'clampPotential', 'numSq', 'patternList'], value_vars=to_plot, var_name='pulseIndex', value_name='peak_response',)

pairwise_draw_and_annotate_line_plot(   ax2f_top, df_melt, x='pulseIndex', y='peak_response', hue='clampPotential', draw=True, kind='violin', palette=color_EI,
                                        stat_across='hue', stat=stats.kruskal, skip_first_xvalue=True, annotate_wrt_data=False, offset_btw_star_n_line=0.1, color='grey', coord_system='data', fontsize=12, zorder=10,)

plot_tools.simplify_axes( ax2f_top, splines_to_keep=['bottom', 'left'], axis_offset=10, remove_ticks=False, xtick_locs=range(9), xtick_labels=freq_sweep_pulses, ytick_locs=[0, 1, 2], ytick_labels=[0, 1, 2],)
ax2f_top.set_xlabel('Pulse Index')
ax2f_top.set_ylabel('Norm. Response')

# get the lgened labels of ax2f_top and add ' mV' to each one
handles, labels = ax2f_top.get_legend_handles_labels()
labels = [label + ' mV' for label in labels]
ax2f_top.legend(handles, labels, loc='upper right', borderaxespad=0., frameon=False)

# add a text in the top left corner of ax2d_top showing '20 Hz' in color = color_freq[20]
ax2f_top.text(0.0, 0.9, f'{f} Hz', transform=ax2f_top.transAxes, size=12, color=color_freq[f], zorder=10)
ax2f_top.text(0.2, 0.9, f'{s} Sq', transform=ax2f_top.transAxes, size=12, color=color_squares[s], zorder=10)


### Fig 2F Bottom: Field plot in the bottom
to_plot = [f'pfn{i}' for i in range(9)]
df_temp = df3[ (df3['cellID'] == cell) & (df3['stimFreq'] == f) & (df3['numSq'] == s)]
df_melt = pd.melt( df_temp, id_vars=['cellID', 'stimFreq', 'clampPotential', 'numSq', 'patternList'], value_vars=to_plot, var_name='pulseIndex', value_name='peak_response',)

pairwise_draw_and_annotate_line_plot(   ax2f_bottom, df_melt, x='pulseIndex', y='peak_response', hue='clampPotential', draw=True, kind='violin', palette=color_EI,
                                        stat_across='hue', stat=stats.kruskal, skip_first_xvalue=True, annotate_wrt_data=False, offset_btw_star_n_line=0.1, color='grey', coord_system='data', fontsize=12, zorder=10)
plot_tools.simplify_axes(ax2f_bottom, splines_to_keep=['bottom','left'], axis_offset=10, remove_ticks=False, xtick_locs=range(9), xtick_labels=freq_sweep_pulses, ytick_locs=[0,1,2], ytick_labels=[0,1,2])
ax2f_bottom.set_xlabel('Pulse Index')
ax2f_bottom.set_ylabel('Norm. Response')
ax2f_bottom.legend([], frameon=False)

# -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
# # save figure
figure_name = 'Figure2_VC_old'
fig2.savefig(paper_figure_export_location / misc / (figure_name + '.png'), dpi=300, bbox_inches='tight')
fig2.savefig(paper_figure_export_location / misc / (figure_name + '.svg'), dpi=300, bbox_inches='tight')

### Figure 3:
* Model schematic for plasticity
* Fitting free params of model for E and I
* Fitting diversity for E and I
* Fitting stochasticity for E and I
* How averaged are E and I respectively? Crucial for escape calculation.

In [None]:
w, h = [15, 9]

fig3, [[ax3a, ax3b, ax3c],[ax3d, ax3e, ax3f]] = plt.subplots(2,3, figsize=(w, h), sharey=False)
# have more space between suplots
fig3.subplots_adjust(hspace=0.5, wspace=0.5)

# -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
# Fig 3A: Model schematic for plasticity
ax3a.text(-0.1, 1.1, 'A', transform=ax3a.transAxes, size=20, weight='bold')
image3_path = all_cells.project_path_root / r"Lab\Projects\EI_Dynamics\Analysis\paper_figure_matter\model_schematic.png"
im3 = Image.open(image3_path)
# get the size of the image
# im3_width, im3_height = im3.size
# get the aspect ratio of the ax1a axes object
# ax3a_ratio = ax3a.get_window_extent().width / ax3a.get_window_extent().height
# change the axis limits so that verticle side of the image fits the axis and horizontal side is aligned left on the axis
# ax3a.set_ylim(0, im3_height)
# ax3a.set_xlim(0, im3_height*ax3a_ratio)
# plot the image
ax3a.imshow(im3)#, extent=[0, im3_width, 0, im3_height], aspect=1)
ax3a.axis('off')

# -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
# Fig 3B: Model schematic for plasticity
# make the ax3b aspect ratio same as ax3a
ax3b.set_aspect(ax3a.get_aspect())

ax3b.text(-0.1, 1.1, 'B', transform=ax3b.transAxes, size=20, weight='bold')
image3_path = all_cells.project_path_root / r"Lab\Projects\EI_Dynamics\Analysis\paper_figure_matter\model73.png"
im3 = Image.open(image3_path)
# get the size of the image
# im3_width, im3_height = im3.size
# get the aspect ratio of the ax1a axes object
# ax3b_ratio = ax3b.get_window_extent().width / ax3b.get_window_extent().height
# change the axis limits so that verticle side of the image fits the axis and horizontal side is aligned left on the axis
# ax3b.set_ylim(0, im3_height)
# ax3b.set_xlim(0, im3_height*ax3a_ratio)
# plot the image
ax3b.imshow(im3)#, extent=[0, im3_width, 0, im3_height], aspect=1)
ax3b.axis('off')


# -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
# # save figure
figure_name = 'Figure3'
fig3.savefig(paper_figure_export_location / (figure_name + '.png'), dpi=300, bbox_inches='tight')
fig3.savefig(paper_figure_export_location / (figure_name + '.svg'), dpi=300, bbox_inches='tight')


### Figure 4: Upi is making

### Figure 5
* Example E-I traces over a burst (train)
* E/I ratio of successive pulses across freq and numSq
* Summary figure across cells

__Note: This figure should explain why we dont see escape from E-I balance__

<img src="notes_figures/Figure_updates_Fig5_24Jan2024.jpg" height="500"/>


In [None]:
df5.iloc[0,:49]

In [None]:
def calculate_expected_response(celldf, freq, patternID):
    """
    Calculate the expected response of a pattern based on the response to individual spots in the pattern
    """
    from eidynamics import pattern_index
    Fs = 2e4
    # check if the given cell has 1sq data
    if not 1 in celldf['numSq'].unique():
        print('No 1Sq data for this cell', celldf['numSq'].unique())
        return
    # get the pattern
    spotIDs_for_patternID   = pattern_index.get_patternIDlist_for_nSq_pattern(patternID)
    ipi                     = int(Fs/freq)
    probePulseStart         = int(Fs * celldf['probePulseStart'].iloc[0]) # 2e4*0.2 = 4000''
    pulseTrainStart         = int(Fs * celldf['pulseTrainStart'].iloc[0]) # 2e4 * 0.5 = 10000
    constituent_spots_of_pattern    = pattern_index.get_patternIDlist_for_nSq_pattern(patternID)
    
    # get the response to the given pattern
    celldf                          = celldf.loc[:, ~celldf.columns.isin(celldf.columns[28:49])]
    celldf.loc[:, 'patternList']    = celldf['patternList'].astype('int32')
    pattern_response_df             = celldf[(celldf['patternList'] == patternID) & (celldf['stimFreq'] == freq)  ]
       
    # to get the expected response we have to add the average responses to each individual spot in the pattern
    # individual patterns are in the list 'constituent_spots_of_pattern'.
    # step1: slice the dataframe to get only those rows where 'patternList' is in the list 'constituent_spots_of_pattern'
    # step2: get the peaks for each row between columns probePulseStart and probePulseStart+ipi
    # step3: average the peaks across 3 trials for each pattern and then sum them up to get the expected response
    # Do this using pandas slicing methods rather than a for loop wherever possible

    # step1: slice the dataframe to get only those rows where 'patternList' is in the list 'constituent_spots_of_pattern'
    df_temp = celldf.loc[celldf['patternList'].isin(constituent_spots_of_pattern), :]#.copy()
    
    # step2: get the peaks for each row between columns probePulseStart and probePulseStart+ipi
    df_temp.loc[:, 'peaks'] = df_temp.iloc[:, probePulseStart:probePulseStart+ipi].max(axis=1)
    
    # step3: average the 'peaks' across 3 trials for each pattern and then sum them up to get the expected response
    df_temp['peaks'] = pd.to_numeric(df_temp['peaks'])  # Convert 'peaks' column to numeric data type
    expected_response = df_temp.loc[:,('patternList','peaks')].groupby(by='patternList').mean().sum()['peaks']

    return pattern_response_df, df_temp, constituent_spots_of_pattern, expected_response

In [None]:
o,e,spotlist,expt = calculate_expected_response(cc_FS_longdf[cc_FS_longdf['cellID']==5211], 20, 46)

#### Fig 5 for 15sq patterns

In [None]:
# aspect ration of the figure = 1
w, h = [15, 9]

fig5, [[ax5a, ax5b],[ax5c,ax5d],[ax5e, ax5f]] = plt.subplots(3,2, figsize=(w, h), sharey=False)
# have more space between suplots
fig5.subplots_adjust(hspace=0.5, wspace=0.5)

# -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
# Fig 5A: Example E-I traces over a train
ax5a.text(-0.1, 1.1, 'A', transform=ax5a.transAxes, size=20, weight='bold')
any_good_vc_cell = vc_cells_screened[0]
df_temp = df4[(df4['cellID'] == any_good_vc_cell) & (df4['stimFreq'] == 20) & (df4['numSq'] == 15) & (df4['channel'] == 'cell')]
num_traces, num_metadata_columns = int(df_temp.shape[0]/2), df_temp.shape[1]-20000
E = df_temp[df_temp['clampPotential']==-70].iloc[:,num_metadata_columns:].to_numpy()
I = df_temp[df_temp['clampPotential']==  0].iloc[:,num_metadata_columns:].to_numpy()
for i in range(num_traces):
    ax5a.plot(np.linspace(0,1,20000), I[i,:], alpha=0.2, linewidth=1, color=color_EI[0], )
    ax5a.plot(np.linspace(0,1,20000), E[i,:], alpha=0.2, linewidth=1, color=color_EI[-70], )
ax5a.plot(np.linspace(0,1,20000), np.mean(I, axis=0), linewidth=1, color=color_EI[0], )
ax5a.plot(np.linspace(0,1,20000), np.mean(E, axis=0), linewidth=1, color=color_EI[-70], )

# simplify the axis
plot_tools.simplify_axes(ax5a, splines_to_keep=['bottom','left'], axis_offset=10, remove_ticks=False, xtick_locs=np.round(np.linspace(0.1,0.5,5),1), xtick_labels=np.round(np.linspace(0.1,0.5,5),1), ytick_locs=np.linspace(-500,1000,4), ytick_labels=np.linspace(-500,1000,4))
# add floating scalebar to the plots if splines are not retained
# plot_tools.add_floating_scalebar(ax5a, scalebar_origin=[0.3,600], xlength=0.1, ylength=100, labelx='100', labely='100', unitx='ms', unity='pA', fontsize=8, color=color_EI[0], linewidth=2, pad=0.1, show_labels=False)
# plot_tools.add_floating_scalebar(ax5a, scalebar_origin=[0.3,-200], xlength=0.1, ylength=100, labelx='100', labely='100', unitx='ms', unity='pA', fontsize=8, color=color_EI[-70], linewidth=2, pad=0.1, show_labels=False)

# add an insert subplot in ax5a using matplotlib
ax5a_inset = ax5a.inset_axes([0.9,0.5,0.1,0.5]) # left, bottom, width, height
t1,t2 = 0.205, 0.225
Fs = 2e4
t = np.linspace(t1,t2, int((t2-t1)*Fs))
# for i in range(num_traces):
#     ax5a_inset.plot(t, I[i,int(t1*Fs):int(t1*Fs)+len(t)], alpha=0.2, linewidth=1, color=color_EI[0], )
#     ax5a_inset.plot(t, E[i,int(t1*Fs):int(t1*Fs)+len(t)], alpha=0.2, linewidth=1, color=color_EI[-70], )
ax5a_inset.plot(t, np.mean(I[:,int(t1*Fs):int(t1*Fs)+len(t)], axis=0), linewidth=2, color=color_EI[0], )
ax5a_inset.plot(t, np.mean(E[:,int(t1*Fs):int(t1*Fs)+len(t)], axis=0), linewidth=2, color=color_EI[-70], )
plot_tools.simplify_axes(ax5a_inset, splines_to_keep=['left'], axis_offset=10, remove_ticks=False, xtick_locs=[t1,t2], xtick_labels=[t1,t2], ytick_locs=np.linspace(-250,750,5), ytick_labels=np.linspace(-250,750,5))

# -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
# Fig 5B: Whole train
ax5b.text(-0.1, 1.1, 'B', transform=ax5b.transAxes, size=20, weight='bold')

for i in range(num_traces):
    ax5b.plot(np.linspace(0,1,20000), I[i,:], alpha=0.2, linewidth=1, color=color_EI[0], )
    ax5b.plot(np.linspace(0,1,20000), E[i,:], alpha=0.2, linewidth=1, color=color_EI[-70], )
ax5b.plot(np.linspace(0,1,20000), np.mean(I, axis=0), linewidth=1, color=color_EI[0], )
ax5b.plot(np.linspace(0,1,20000), np.mean(E, axis=0), linewidth=1, color=color_EI[-70], )

# plot_tools.add_floating_scalebar(ax5b, scalebar_origin=[0.3,600], xlength=0.1, ylength=100, labelx='100', labely='100', 
                                # unitx='ms', unity='pA', fontsize=8, color=color_EI[0], linewidth=2, pad=0.1, show_labels=False)
# plot_tools.add_floating_scalebar(ax5b, scalebar_origin=[0.3,-200], xlength=0.1, ylength=100, labelx='100', labely='100', 
                                # unitx='ms', unity='pA', fontsize=8, color=color_EI[-70], linewidth=2, pad=0.1, show_labels=False)
# simplify the axis
plot_tools.simplify_axes(ax5b, splines_to_keep=['bottom','left'], axis_offset=10, remove_ticks=False, xtick_locs=np.round(np.linspace(0,1.0,5),1), xtick_labels=np.round(np.linspace(0,1.0,5),1), ytick_locs=np.linspace(-500,1000,4), ytick_labels=np.linspace(-500,1000,4))

# -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
# Fig 5C: STP vs Frequency
ax5c.text(-0.1, 1.1, 'C', transform=ax5c.transAxes, size=20, weight='bold')
to_plot = [ f'pcn{i}' for i in range(9) ]

# # melt the dataframe
df_temp = df3[(df3['location']=='CA1') & (df3['clampMode']=='VC') & (df3['numSq']>1) & (np.abs(df3['pcn2'])<500)& (df3['cellID'].isin(vc_cells_screened))]
df_melt = pd.melt(df_temp[(df_temp['stimFreq']<100)& (df_temp['clampPotential']==0) & (df_temp['numSq']==15)], id_vars=['cellID', 'stimFreq', 'clampPotential', 'numSq', 'patternList'], value_vars=to_plot, var_name='pulseIndex', value_name='peak_response')
sns.lineplot(data=df_melt, x='pulseIndex', y='peak_response', hue='stimFreq', palette=color_I, ax=ax5c, alpha=1, linewidth=2, err_style="bars", errorbar=("se", 1),)

plot_tools.simplify_axes(ax5c, splines_to_keep=['bottom','left'], axis_offset=10, remove_ticks=False, xtick_locs=range(9), xtick_labels=freq_sweep_pulses, ytick_locs=np.linspace(0,2,5), ytick_labels=np.linspace(0,2,5))
ax5c.set_xlabel('Pulse Index')
ax5c.set_ylabel('Peak Response (mV)')

# -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
# Fig 5D: STP vs Frequency
ax5d.text(-0.1, 1.1, 'D', transform=ax5d.transAxes, size=20, weight='bold')
to_plot = [ f'pcn{i}' for i in range(9) ]

# # melt the dataframe
df_temp = df3[(df3['location']=='CA1') & (df3['clampMode']=='VC') & (df3['numSq']>1) & (np.abs(df3['pcn2'])<500) & (df3['cellID'].isin(vc_cells_screened))]

df_melt = pd.melt(df_temp[(df_temp['stimFreq']<100)& (df_temp['clampPotential']==-70) & (df_temp['numSq']==15)], id_vars=['cellID', 'stimFreq', 'clampPotential', 'numSq', 'patternList'], value_vars=to_plot, var_name='pulseIndex', value_name='peak_response')
sns.lineplot(data=df_melt, x='pulseIndex', y='peak_response', hue='stimFreq', palette=color_E, ax=ax5d, alpha=1, linewidth=2, err_style="bars", errorbar=("se", 1),)

plot_tools.simplify_axes(ax5d, splines_to_keep=['bottom','left'], axis_offset=10, remove_ticks=False, xtick_locs=range(9), xtick_labels=freq_sweep_pulses, ytick_locs=np.linspace(0,2,5), ytick_labels=np.linspace(0,2,5))
ax5d.set_xlabel('Pulse Index')
ax5d.set_ylabel('Peak Response (mV)')


# -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
# Fig 5E: STP vs Frequency for field data
ax5e.text(-0.1, 1.1, 'E', transform=ax5e.transAxes, size=20, weight='bold')
to_plot = [ f'pfn{i}' for i in range(9) ]

# # melt the dataframe
df_temp = df3[(df3['location']=='CA1') & (df3['clampMode']=='VC') & (df3['numSq']>1) & (np.abs(df3['pfn2'])<500)& (df3['cellID'].isin(vc_cells_screened))]

df_melt = pd.melt(df_temp[(df_temp['stimFreq']<100)& (df_temp['clampPotential']==0) & (df_temp['numSq']==15)], id_vars=['cellID', 'stimFreq', 'clampPotential', 'numSq', 'patternList'], value_vars=to_plot, var_name='pulseIndex', value_name='peak_response')
sns.lineplot(data=df_melt, x='pulseIndex', y='peak_response', hue='stimFreq', palette=color_I, ax=ax5e, alpha=1, linewidth=2, err_style="bars", errorbar=("se", 1),)

# ax5e.set_xticks(range(9), labels=freq_sweep_pulses)
ax5e.set_xlabel('Pulse Index')
ax5e.set_ylabel('Peak Response (mV)')
plot_tools.simplify_axes(ax5e, splines_to_keep=['bottom','left'], axis_offset=10, remove_ticks=False, xtick_locs=range(9), xtick_labels=freq_sweep_pulses, ytick_locs=np.linspace(0,2,5), ytick_labels=np.linspace(0,2,5))

# -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
# Fig 5F: STP vs Frequency for field data
ax5f.text(-0.1, 1.1, 'F', transform=ax5f.transAxes, size=20, weight='bold')
to_plot = [ f'pfn{i}' for i in range(9) ]

# # melt the dataframe
df_temp = df3[(df3['location']=='CA1') & (df3['clampMode']=='VC') & (df3['numSq']>1) & (np.abs(df3['pfn2'])<500) & (df3['cellID'].isin(vc_cells_screened))]

df_melt = pd.melt(df_temp[(df_temp['stimFreq']<100)& (df_temp['clampPotential']==-70) & (df_temp['numSq']==15)], id_vars=['cellID', 'stimFreq', 'clampPotential', 'numSq', 'patternList'], value_vars=to_plot, var_name='pulseIndex', value_name='peak_response')
sns.lineplot(data=df_melt, x='pulseIndex', y='peak_response', hue='stimFreq', palette=color_E, ax=ax5f, alpha=1, linewidth=2, err_style="bars", errorbar=("se", 1),)


ax5f.set_xlabel('Pulse Index')
ax5f.set_ylabel('Peak Response (mV)')
plot_tools.simplify_axes(ax5f, splines_to_keep=['bottom','left'], axis_offset=10, remove_ticks=False, xtick_locs=range(9), xtick_labels=freq_sweep_pulses, ytick_locs=np.linspace(0,2,5), ytick_labels=np.linspace(0,2,5))

# -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
# save figure
figure_name = 'Figure5_15sq'
fig5.savefig(paper_figure_export_location / (figure_name + '.png'), dpi=300, bbox_inches='tight')
fig5.savefig(paper_figure_export_location / (figure_name + '.svg'), dpi=300, bbox_inches='tight')

### Fig 5 supplementary: All screened VC cells

In [None]:
fig5B, axs5B = plt.subplots(3,3, figsize=(w, h), sharex=True, sharey=True)
# have more space between suplots
fig5B.subplots_adjust(hspace=0.5, wspace=0.5)
#linearize the axs
axs5B = axs5B.flatten()
to_plot = [ f'pcn{i}' for i in range(9) ]
for i,selected_cell in enumerate(vc_cells_screened):
    ax = axs5B[i]
    # # melt the dataframe
    df_temp = df3[(df3['clampMode']=='VC') & (df3['cellID']==selected_cell)]

    df_melt = pd.melt(df_temp[(df_temp['stimFreq']<100)& (df_temp['clampPotential']==-70) & (df_temp['numSq']==5)], id_vars=['cellID', 'stimFreq', 'clampPotential', 'numSq', 'patternList'], value_vars=to_plot, var_name='pulseIndex', value_name='peak_response')
    sns.lineplot(data=df_melt, x='pulseIndex', y='peak_response', hue='stimFreq', palette=color_E, ax=ax, alpha=1, linewidth=2, err_style="bars", errorbar=("se", 1),)

    plot_tools.simplify_axes(ax, splines_to_keep=['bottom','left'], axis_offset=10, remove_ticks=False, xtick_locs=range(9), xtick_labels=freq_sweep_pulses, ytick_locs=np.linspace(0,2,5), ytick_labels=np.linspace(0,2,5))
    ax.set_xlabel('Pulse Index')
    ax.set_ylabel('Peak Response (mV)')
    ax.set_title(f'VC - {selected_cell},5sq')

# save figure
figure_name = 'Figure5_supp_5sq_E_all_screenedVCcells.png'
fig5B.savefig(paper_figure_export_location / (figure_name), dpi=300, bbox_inches='tight')

In [None]:
fig5C, axs5C = plt.subplots(3,3, figsize=(w, h), sharex=True, sharey=True)
# have more space between suplots
fig5C.subplots_adjust(hspace=0.5, wspace=0.5)
#linearize the axs
axs5C = axs5C.flatten()
to_plot = [ f'pcn{i}' for i in range(9) ]
for i,selected_cell in enumerate(vc_cells_screened):
    ax = axs5C[i]
    # # melt the dataframe
    df_temp = df3[(df3['clampMode']=='VC') & (df3['cellID']==selected_cell)]

    df_melt = pd.melt(df_temp[(df_temp['stimFreq']<100)& (df_temp['clampPotential']==-70) & (df_temp['numSq']==15)], id_vars=['cellID', 'stimFreq', 'clampPotential', 'numSq', 'patternList'], value_vars=to_plot, var_name='pulseIndex', value_name='peak_response')
    sns.lineplot(data=df_melt, x='pulseIndex', y='peak_response', hue='stimFreq', palette=color_E, ax=ax, alpha=1, linewidth=2, err_style="bars", errorbar=("se", 1),)

    plot_tools.simplify_axes(ax, splines_to_keep=['bottom','left'], axis_offset=10, remove_ticks=False, xtick_locs=range(9), xtick_labels=freq_sweep_pulses, ytick_locs=np.linspace(0,2,5), ytick_labels=np.linspace(0,2,5))
    ax.set_xlabel('Pulse Index')
    ax.set_ylabel('Peak Response (mV)')
    ax.set_title(f'VC - {selected_cell},15sq')

# save figure
figure_name = 'Figure5_supp_15sq_E_all_screenedVCcells.png'
fig5C.savefig(paper_figure_export_location / (figure_name), dpi=300, bbox_inches='tight')

In [None]:
fig5Binh, axs5Binh = plt.subplots(3,3, figsize=(w, h), sharex=True, sharey=True)
# have more space between suplots
fig5Binh.subplots_adjust(hspace=0.5, wspace=0.5)
#linearize the axs
axs5Binh = axs5Binh.flatten()
to_plot = [ f'pcn{i}' for i in range(9) ]
for i,selected_cell in enumerate(vc_cells_screened):
    ax = axs5Binh[i]
    # # melt the dataframe
    df_temp = df3[(df3['clampMode']=='VC') & (df3['cellID']==selected_cell)]

    df_melt = pd.melt(df_temp[(df_temp['stimFreq']<100)& (df_temp['clampPotential']==0) & (df_temp['numSq']==5)], id_vars=['cellID', 'stimFreq', 'clampPotential', 'numSq', 'patternList'], value_vars=to_plot, var_name='pulseIndex', value_name='peak_response')
    if df_melt.shape[0]>0:
        sns.lineplot(data=df_melt, x='pulseIndex', y='peak_response', hue='stimFreq', palette=color_I, ax=ax, alpha=1, linewidth=2, err_style="bars", errorbar=("se", 1),)
    # sns.lineplot(data=df_melt, x='pulseIndex', y='peak_response', hue='stimFreq', palette=color_I, ax=ax, alpha=1, linewidth=2, err_style="bars", errorbar=("se", 1),)

    plot_tools.simplify_axes(ax, splines_to_keep=['bottom','left'], axis_offset=10, remove_ticks=False, xtick_locs=range(9), xtick_labels=freq_sweep_pulses, ytick_locs=np.linspace(0,2,5), ytick_labels=np.linspace(0,2,5))
    ax.set_xlabel('Pulse Index')
    ax.set_ylabel('Peak Response (mV)')
    ax.set_title(f'VC - {selected_cell},5sq')

# save figure
figure_name = 'Figure5_supp_5sq_I_all_screenedVCcells.png'
fig5Binh.savefig(paper_figure_export_location / (figure_name), dpi=300, bbox_inches='tight')

In [None]:
fig5Binh, axs5Binh = plt.subplots(3,3, figsize=(w, h), sharex=True, sharey=True)
# have more space between suplots
fig5Binh.subplots_adjust(hspace=0.5, wspace=0.5)
#linearize the axs
axs5Binh = axs5Binh.flatten()
to_plot = [ f'pcn{i}' for i in range(9) ]
for i,selected_cell in enumerate(vc_cells_screened):
    ax = axs5Binh[i]
    # # melt the dataframe
    df_temp = df3[(df3['clampMode']=='VC') & (df3['cellID']==selected_cell)]

    df_melt = pd.melt(df_temp[(df_temp['stimFreq']<100)& (df_temp['clampPotential']==0) & (df_temp['numSq']==15)], id_vars=['cellID', 'stimFreq', 'clampPotential', 'numSq', 'patternList'], value_vars=to_plot, var_name='pulseIndex', value_name='peak_response')
    if df_melt.shape[0]>0:
        sns.lineplot(data=df_melt, x='pulseIndex', y='peak_response', hue='stimFreq', palette=color_I, ax=ax, alpha=1, linewidth=2, err_style="bars", errorbar=("se", 1),)
    # sns.lineplot(data=df_melt, x='pulseIndex', y='peak_response', hue='stimFreq', palette=color_I, ax=ax, alpha=1, linewidth=2, err_style="bars", errorbar=("se", 1),)

    plot_tools.simplify_axes(ax, splines_to_keep=['bottom','left'], axis_offset=10, remove_ticks=False, xtick_locs=range(9), xtick_labels=freq_sweep_pulses, ytick_locs=np.linspace(0,2,5), ytick_labels=np.linspace(0,2,5))
    ax.set_xlabel('Pulse Index')
    ax.set_ylabel('Peak Response (mV)')
    ax.set_title(f'VC - {selected_cell},15sq')

# save figure
figure_name = 'Figure5_supp_15sq_I_all_screenedVCcells.png'
fig5Binh.savefig(paper_figure_export_location / (figure_name), dpi=300, bbox_inches='tight')

### Fig 6: Why no escape?

In [None]:
# Fig 6: Show that in current clamp the cells do not fire, there is no escape from EI balance, and when inhibition is blocked by Gabazine, the cells fire
# in the pandas dataframe df3, create a new column 'spike' which stores value True if in the same row from columns 40:20049 has value greater than 30
# df['spike'] = df.iloc[:,40:20049].gt(30).any(axis=1)


cells_that_fire = df[(df['clampMode']=='CC') & (df['AP']==True)]['cellID'].unique()
all_cc_cells = df[(df['clampMode']=='CC')]['cellID'].unique()
print(len(cells_that_fire), len(all_cc_cells))

In [None]:
all_cc_cells

In [None]:
for i,selected_cell in enumerate(all_cc_cells):
    for f in [20,30,40,50]:
        for s in [1,5,15]:
            figure_name = f'raw_data_CC_{selected_cell}_numsq_{s}_freq_{f}'
            print(figure_name)
            fig_cc, ax_cc = plt.subplots(figsize=(20, 10), )
            df_temp = df[(df['clampMode']=='CC') & (df['cellID']==selected_cell) & (df['condition']=='Control') & (df['numSq']==s) & (df['stimFreq']==f)]
            if df_temp.shape[0]>0:
                _, ax_cc = plot_tools.plot_data_from_df(df_temp, data_start_column = 49, combine=False, fig=fig_cc, ax=ax_cc)
                ax_cc[0].set_title(str(selected_cell))
                # save fig
                fig_cc.savefig(paper_figure_export_location / (figure_name + '.png'), dpi=300, bbox_inches='tight')

### Fig 6 Supplementary: All gabazine cells

all gabazine cells = [3131, 5291, 3882, 3872, 3791]

screened gabazine cells = [3131, 3882, 3872]

In [None]:
# aspect ration of the figure = 1
w, h = [15, 9]

fig6_supp1, [[ax6supp1_a, ax6supp1_b],[ax6supp1_c,ax6supp1_d]] = plt.subplots(2,2, figsize=(w, h), sharey=False)
# have more space between suplots
fig6_supp1.subplots_adjust(hspace=0.5, wspace=0.5)

gabazine_cells = df[df['condition']=='Gabazine']['cellID'].unique()
print(f'Gabazine Cells: {gabazine_cells}')
selected_cell = gabazine_cells[0]

# -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
# Fig 6A: Example E-I traces over a train
ax6supp1_a.text(-0.1, 1.1, 'A', transform=ax6supp1_a.transAxes, size=20, weight='bold')

data = df[(df['condition'] == 'Gabazine') & (df['stimFreq'] == 50) & (df['numSq'] == 5) & (df['cellID'] == selected_cell)]
fig6_supp1, ax6supp1_a, signal_mapping = plot_tools.plot_data_from_df(data, data_start_column = 49 , combine=True, fig=fig6_supp1, ax=ax6supp1_a)

data = df[(df['condition'] == 'Control') & (df['stimFreq'] == 50) & (df['numSq'] == 5) & (df['cellID'] == selected_cell)]
fig6_supp1, ax6supp1_a, _ = plot_tools.plot_data_from_df(data, data_start_column = 49 , signals_to_plot=['Cell',], signal_colors=['green',], combine=True, fig=fig6_supp1, ax=ax6supp1_a, signal_mapping=signal_mapping)

# simplify
plot_tools.simplify_axes(ax6supp1_a, splines_to_keep=[], )
ax6supp1_a.set_title(str(selected_cell)+' 50Hz 5sq')

#remove legend
ax6supp1_a.legend([],[], frameon=False)

# -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
# Fig 6B: Example E-I traces over a train
ax6supp1_b.text(-0.1, 1.1, 'B', transform=ax6supp1_b.transAxes, size=20, weight='bold')

data = df[(df['condition'] == 'Gabazine') & (df['stimFreq'] == 50) & (df['numSq'] == 15) & (df['cellID'] == selected_cell)]
fig6_supp1, ax6supp1_b, signal_mapping = plot_tools.plot_data_from_df(data, data_start_column = 49 , combine=True, fig=fig6_supp1, ax=ax6supp1_b)

data = df[(df['condition'] == 'Control') & (df['stimFreq'] == 50) & (df['numSq'] == 15) & (df['cellID'] == selected_cell)]
fig6_supp1, ax6supp1_b, _ = plot_tools.plot_data_from_df(data, data_start_column = 49 , signals_to_plot=['Cell',], signal_colors=['green',], combine=True, fig=fig6_supp1, ax=ax6supp1_b, signal_mapping=signal_mapping)

# simplify
plot_tools.simplify_axes(ax6supp1_b, splines_to_keep=[], )
ax6supp1_b.set_title(str(selected_cell)+' 50Hz 15sq')
ax6supp1_b.legend([],[], frameon=False)
# -----------------------------------------------------------------------------------------------------------------------------------------------------------------------

# Fig 6C: All gabazine cells
ax6supp1_c.text(-0.1, 1.1, 'C', transform=ax6supp1_c.transAxes, size=20, weight='bold')
# cellID is in gabazine cells
data = df[(df['condition'] == 'Gabazine') & (df['stimFreq'] == 50) & (df['numSq'] == 5) & (df['cellID'].isin(gabazine_cells))]
fig6_supp1, ax6supp1_c, signal_mapping = plot_tools.plot_data_from_df(data, data_start_column = 49 , combine=True, fig=fig6_supp1, ax=ax6supp1_c)

data = df[(df['condition'] == 'Control') & (df['stimFreq'] == 50) & (df['numSq'] == 5) & (df['cellID'].isin(gabazine_cells))]
fig6_supp1, ax6supp1_c, _ = plot_tools.plot_data_from_df(data, data_start_column = 49 , signals_to_plot=['Cell',], signal_colors=['green',], combine=True, fig=fig6_supp1, ax=ax6supp1_c, signal_mapping=signal_mapping)



# simplify
plot_tools.simplify_axes(ax6supp1_c, splines_to_keep=[], )
ax6supp1_c.set_title('All cells 50Hz 5sq')
#remove legend
ax6supp1_c.legend([],[], frameon=False)

# -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
# Fig 6D: All gabazine cells
ax6supp1_d.text(-0.1, 1.1, 'D', transform=ax6supp1_d.transAxes, size=20, weight='bold')
# cellID is in gabazine cells
data = df[(df['condition'] == 'Gabazine') & (df['stimFreq'] == 50) & (df['numSq'] == 15) & (df['cellID'].isin(gabazine_cells))]
fig6_supp1, ax6supp1_d, signal_mapping = plot_tools.plot_data_from_df(data, data_start_column = 49 , combine=True, fig=fig6_supp1, ax=ax6supp1_d)

data = df[(df['condition'] == 'Control') & (df['stimFreq'] == 50) & (df['numSq'] == 15) & (df['cellID'].isin(gabazine_cells))]
fig6_supp1, ax6supp1_d, _ = plot_tools.plot_data_from_df(data, data_start_column = 49 , signals_to_plot=['Cell',], signal_colors=['green',], combine=True, fig=fig6_supp1, ax=ax6supp1_d, signal_mapping=signal_mapping)



# simplify
plot_tools.simplify_axes(ax6supp1_d, splines_to_keep=[], )
ax6supp1_d.set_title('All cells 50Hz 15sq')
#remove legend
ax6supp1_d.legend([],[], frameon=False)
# -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
# # save figure
figure_name = 'Figure6_supp1_cell3131.png'
fig6_supp1.savefig(paper_figure_export_location / (figure_name), dpi=300, bbox_inches='tight')b


In [None]:
# aspect ration of the figure = 1
w, h = [15, 9]

fig6_supp2, ax6_supp2 = plt.subplots(5,4, figsize=(w, h), sharey=False)
#flatten ax6_supp2
# ax6_supp2 = ax6_supp2.flatten()
# have more space between suplots
fig6_supp2.subplots_adjust(hspace=0.5, wspace=0.5)

gabazine_cells = df[df['condition']=='Gabazine']['cellID'].unique()
print(f'Gabazine Cells: {gabazine_cells}')


# -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
for i,selected_cell in enumerate(gabazine_cells):
    for j,f in enumerate([20,30,40,50]):

        

        data = df[(df['condition'] == 'Gabazine') & (df['stimFreq'] == f) & (df['numSq'] == 5) & (df['cellID'] == selected_cell)]
        fig6_supp2, ax6_supp2[i,j], signal_mapping = plot_tools.plot_data_from_df(data, data_start_column = 49 , combine=True, fig=fig6_supp2, ax=ax6_supp2[i,j])

        data = df[(df['condition'] == 'Control') & (df['stimFreq'] == f) & (df['numSq'] == 5) & (df['cellID'] == selected_cell)]
        fig6_supp2, ax6_supp2[i,j], _ = plot_tools.plot_data_from_df(data, data_start_column = 49 , signals_to_plot=['Cell',], signal_colors=['green',], combine=True, fig=fig6_supp2, ax=ax6_supp2[i,j], signal_mapping=signal_mapping)

        # simplify
        plot_tools.simplify_axes(ax6_supp2[i,j], splines_to_keep=[], )
        ax6_supp2[i,j].set_title(f'{selected_cell} {f}Hz 5sq')
        #remove legend
        ax6_supp2[i,j].legend([],[], frameon=False)

# -----------------------------------------------------------------------------------------------------------------------------------------------------------------------

# # save figure
figure_name = 'Figure6_supp2_all_gabazine_cells_5sq.png'
fig6_supp2.savefig(paper_figure_export_location / (figure_name), dpi=300, bbox_inches='tight')


In [None]:
# aspect ration of the figure = 1
w, h = [15, 9]

fig6_supp2, ax6_supp2 = plt.subplots(5,4, figsize=(w, h), sharey=False)
#flatten ax6_supp2
# ax6_supp2 = ax6_supp2.flatten()
# have more space between suplots
fig6_supp2.subplots_adjust(hspace=0.5, wspace=0.5)

gabazine_cells = df[df['condition']=='Gabazine']['cellID'].unique()
print(f'Gabazine Cells: {gabazine_cells}')


# -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
for i,selected_cell in enumerate(gabazine_cells):
    for j,f in enumerate([20,30,40,50]):

        

        data = df[(df['condition'] == 'Gabazine') & (df['stimFreq'] == f) & (df['numSq'] == 15) & (df['cellID'] == selected_cell)]
        fig6_supp2, ax6_supp2[i,j], signal_mapping = plot_tools.plot_data_from_df(data, data_start_column = 49 , combine=True, fig=fig6_supp2, ax=ax6_supp2[i,j])

        data = df[(df['condition'] == 'Control') & (df['stimFreq'] == f) & (df['numSq'] == 15) & (df['cellID'] == selected_cell)]
        fig6_supp2, ax6_supp2[i,j], _ = plot_tools.plot_data_from_df(data, data_start_column = 49 , signals_to_plot=['Cell',], signal_colors=['green',], combine=True, fig=fig6_supp2, ax=ax6_supp2[i,j], signal_mapping=signal_mapping)

        # simplify
        plot_tools.simplify_axes(ax6_supp2[i,j], splines_to_keep=[], )
        ax6_supp2[i,j].set_title(f'{selected_cell} {f}Hz 15sq')
        #remove legend
        ax6_supp2[i,j].legend([],[], frameon=False)

# -----------------------------------------------------------------------------------------------------------------------------------------------------------------------

# # save figure
figure_name = 'Figure6_supp2_all_gabazine_cells_15sq.png'
fig6_supp2.savefig(paper_figure_export_location / (figure_name), dpi=300, bbox_inches='tight')


### SDN

In [None]:
# cells with 1sq experiments
freq_sweep_datapath =  r"C:\Users\adity\OneDrive\NCBS\Lab\Projects\EI_Dynamics\Analysis\parsed_data\all_cells_FreqSweep_CC_long.h5" 
df = pd.read_hdf(freq_sweep_datapath, key='data')
# df = df[ pd.notnull(df['peaks_cell'])] # remove all sweeps that had NaNs in analysed params (mostly due to bad pulse detection)
# reset index
# df = df.reset_index(drop=True)
print(f'All cells freqSweep data has {df.shape[0]} sweeps or {df.shape[0]/3} presentations')

cell_with_1sq_data = df[df['numSq']==1]['cellID'].unique()
print(f'There are {len(cell_with_1sq_data)} cells with 1sq data')

# only keep cells that have 1sq data and remove led and field data
df = df[(df['cellID'].isin(cell_with_1sq_data))]
print(f'Cells that have 1sq freqSweep data have total {df.shape[0]} sweeps or {df.shape[0]/3} presentations')

# remove unwanted metadata and signals in columns 40049 to 80049
columns_to_keep = ['cellID','exptID','condition','numSq','pulseWidth','clampPotential','stimFreq','patternList','probePulseStart','pulseTrainStart','sweep']
data1 = df[columns_to_keep]
data2 = df.iloc[:,49:40049]
data = pd.concat([data1,data2], axis=1)
del data1, data2, df
# print data type for the first 10 columns
for i in range(10):
    if data.columns[i] == 'condition':
        continue
    # else make the datatype float
    data[data.columns[i]] = data[data.columns[i]].astype('float32')

# data2 = data.groupby(columns_to_keep[:-1]).mean().reset_index()
print(f'final data shape = {data.shape}')



In [None]:
# filter those rows where row['probePulseStart'] == row['pulseTrainStart']
# data['probe_pulse'] should be equal to 0 if row['probePulseStart'] == row['pulseTrainStart'], and 1 otherwise
data['probe_pulse'] = data['probePulseStart'] != data['pulseTrainStart']

# count how many rows does data['probe_pulse'] == 1
print(f'Number of rows where probe_pulse == 1 is {data[data["probe_pulse"]==1].shape[0]}')

In [None]:
value, freq = np.unique(data['pulseTrainStart'], return_counts=True)
# print value and freq
print(f'value = {value}')
print(f'freq = {freq}')

In [None]:
fig7_supp1, ax7_supp1 = plt.subplots(4,3, figsize=(10,8), sharey=True, layout='constrained')
fig7_supp2, ax7_supp2 = plt.subplots(4,3, figsize=(10,8), sharey=True, layout='constrained')
ax7_supp1 = ax7_supp1.flatten()
ax7_supp2 = ax7_supp2.flatten()

for i, cell in enumerate(cell_with_1sq_data):
    # cell=cell_with_1sq_data[4]
    df_temp_1sq = data[(data['cellID']==cell) & (data['numSq']==1)]
    # print(df_temp_1sq[df_temp_1sq["patternList"]==3.0].iloc[:,:14])
    # average across sweeps for the same patternList value
    df_temp_1sqx = df_temp_1sq.groupby(['cellID','condition','numSq','pulseWidth','clampPotential','stimFreq','probePulseStart','pulseTrainStart','patternList']).mean().reset_index()
    # print(df_temp_1sqx[df_temp_1sqx["patternList"]==3.0].iloc[:,:14])
    # plot the data
    t0 = df_temp_1sqx['probePulseStart'][0]
    t1 = t0+0.05
    time = np.linspace(t0,t1,1000)
    frame = np.zeros((24,24))
    for sweep in range(df_temp_1sqx.shape[0]):
        pattern = df_temp_1sqx.iloc[sweep]['patternList']
        spotloc = pattern_index.patternID[pattern][0]
        # convert frame location to x,y coordinate on a 24,24 grid
        x, y = spotloc//24, spotloc%24
        trace = df_temp_1sqx.iloc[sweep,11+int(t0*2e4):11+int(t0*2e4)+1000]
        tracemax = np.max(trace)
        frame[x,y] = tracemax
        ax7_supp1[i].plot(time, trace, color='k', alpha=0.4)

    ax7_supp1[i].set_title(f'{cell}')
    ax7_supp2[i].set_title(f'{cell}')
    # plot heatmap of the frame on ax7_supp2 axis
    sns.heatmap(frame, ax=ax7_supp2[i], cmap='viridis', vmin=0, vmax=2)

# save figures
figure1_name = 'Figure7_supp1_SDN_CC_1sq_traces.png'
figure2_name = 'Figure7_supp2_SDN_CC_1sq_max_heatmap.png'

fig7_supp1.savefig(paper_figure_export_location / (figure1_name), dpi=300, bbox_inches='tight')
fig7_supp2.savefig(paper_figure_export_location / (figure2_name), dpi=300, bbox_inches='tight')

    
    

Divisive inhibition is given by following equation:
$$\theta = {\epsilon} - {\beta}*{\epsilon} $$


SDN is defined by gamma, given by following equation:
$$\theta = \frac{{\gamma}{\epsilon}}{{\gamma}+{\epsilon}}$$
where $\epsilon$ is expected and $\theta$ is observed response,

$\beta$ is strength of Divisive Inhibition (DI) and $\gamma$ is strength of Divisive Normalization (DN)

In [None]:
# add SDN gamma fucntion
def sdn_observed(expected, gamma):
    observed = (gamma*expected)/(gamma+expected)
    return observed

def di_observed(expected, beta): #divisive inhibition
    observed = expected - beta*expected
    return observed

In [None]:
def plot_first_pulse_SDN(df, cell, ax, remove_spikes_from_fitting=True):
    dfcell = df[(df['cellID']==cell)]
    # make a spotmax dict that stores tracemax value for each spot
    spotmax = {}
    expected = np.zeros(dfcell.shape[0])
    observed = np.zeros(dfcell.shape[0])
    for i,row in dfcell.iterrows():
        pattern = row['patternList']
        if row['numSq'] >1:
            continue
        spotloc = pattern_index.patternID[pattern][0]
        start = 11+int(row['probePulseStart']*2e4)+200 # 0.1s from probePulseStart
        trace = row[start:start+600] # 30ms from the start
        tracemax = np.max(trace)
        spotmax[pattern] = tracemax

    for i in range(dfcell.shape[0]):
        row = dfcell.iloc[i,:]
        numSq = row['numSq']
        patternList = int(row['patternList']) #48	:[197,255,347,401,439]
        spotloclist = pattern_index.patternID[patternList] #[197,255,347,401,439]
        # exp = 0
        for spotloc in spotloclist:
            pattern = pattern_index.get_patternID([spotloc])
            expected[i] += spotmax[pattern]
        # expected[i] = exp/len(spotloclist)

        # calculate observed (for now only for the probe pulse)
        start = 11+int(row['probePulseStart']*2e4)+200 # 0.1s from probePulseStart
        trace = row[start:start+600] # 30ms from the start
        observed[i] = np.max(trace)

    #add expected and observed to the dfcell_short
    dfcell_short = dfcell.iloc[:,:11]
    dfcell_short['expected'] = expected
    dfcell_short['observed'] = observed

    # scatterplot between expected and observed
    sns.scatterplot(data=dfcell_short[dfcell_short['numSq']>1], x='expected', y='observed', hue='patternList', size='numSq', sizes=[200,200], palette='deep', ax=ax, edgecolor='black', alpha=0.5, legend=False)
    # delete legend
    ax.legend([],[], frameon=False)

    # fit various equations
    if remove_spikes_from_fitting:
        x = dfcell_short[(dfcell_short['numSq']>1) & (dfcell_short['observed']<25) ]['expected']
        y = dfcell_short[(dfcell_short['numSq']>1) & (dfcell_short['observed']<25) ]['observed']
    else:
        x = dfcell_short[(dfcell_short['numSq']>1) ]['expected']
        y = dfcell_short[(dfcell_short['numSq']>1) ]['observed']

    # linear fit
    slope, intercept, r_value, p_value, std_err = stats.linregress(x,y)
    # plot the line
    ax.plot(x, intercept + slope*x, 'cornflowerblue', label='Linear fit',)

    # fit di_observed to the data
    popt, _ = curve_fit(di_observed, x, y)
    ydi = di_observed(x, *popt)
    # plot the line
    ax.plot(x, ydi, 'mediumblue', label=r'Divisive Inhibition fit: $\beta$ ='+f'{popt[0]:.2f}')

    # fit sdn_observed to the data
    popt, _ = curve_fit(sdn_observed, x, y)
    ysdn = sdn_observed(x, *popt)
    # plot the line
    ax.plot(x, ysdn, 'darkviolet', label=r'SDN fit: $\gamma$ ='+f'{popt[0]:.2f}')

    # show labels
    ax.legend()

    return dfcell_short, ax


In [None]:
fig7, ax7 = plt.subplots(4,3, figsize=(20,15), layout='constrained')
ax7 = ax7.flatten()
remove_spikes_from_fitting = True
for i,cell in enumerate(cell_with_1sq_data):
    try:
        dfcell_short, ax7[i] = plot_first_pulse_SDN(data, cell, ax7[i],remove_spikes_from_fitting=remove_spikes_from_fitting)
        ax7[i].set_title(f'Cell {cell}')
        ax7[i].set_xlabel('Expected (mV)')
        ax7[i].set_ylabel('Observed (mV)')
        # ax7[i].set_xlim([0,25])
        # ax7[i].set_ylim([0,25])
    except:
        print(f'Cell {cell} failed')
# save fig7
if remove_spikes_from_fitting:
    figure_name = 'Figure7_CC_SDN_first_pulse_withoutfittingspikes.png'
else:
    figure_name = 'Figure7_CC_SDN_first_pulse.png'
fig7.savefig(paper_figure_export_location / (figure_name), dpi=300, bbox_inches='tight')

In [None]:
# For each cell, for each pattern, combine the mean trace for each constituent 1sq spot
# get peak and other responses for each trace
# compare expected vs observed

def plot_raw_expected_observed(data, cell, ax, mode='raw'):
    dfcell = data[(data['cellID']==cell) & (data['stimFreq']==20)]
    print(dfcell['exptID'])
    dfcell2 = dfcell[dfcell['numSq']!=1]
    # color
    clr = {1:viridis(0),2:viridis(0.1),3:viridis(0.2),4:viridis(0.3),5:viridis(0.4),6:viridis(0.5),7:viridis(0.6),8:viridis(0.7),}
    # clr = 
    for j in range(dfcell2.shape[0]):
        row = dfcell2.iloc[j,:]
        numSq = row['numSq']
        observed = row[11:20011].to_numpy()
        time = np.linspace(0,1,20000)
        patternList = int(row['patternList']) #48	:[197,255,347,401,439]
        spotloclist = pattern_index.patternID[patternList] #[197,255,347,401,439]
        
        
        p0 =row['probePulseStart']
        pt = row['pulseTrainStart']
        pulse_times = (2e4* utils.get_pulse_times(8,pt,20)).astype(int)
        
        expected = np.linspace(0,0,20000) 
        for spotloc in spotloclist:
            pattern = pattern_index.get_patternID([spotloc])
            trace = dfcell[(dfcell['patternList']==pattern)].iloc[:,11:20011].to_numpy()
            expected += np.mean(trace, axis=0)
        # subtract baseline from expected
        expected = expected - np.mean(expected[:4000])

        if mode=='raw':
            # save expected to the dfcell2
            # dfcell2.iloc[j, 20011:40011] = expected   
            # plot expected and observed on the ax
            ax.plot(time, 10*j+observed, color='g', alpha=0.5)
            ax.plot(time, 10*j+expected, color='r', alpha=0.5)
        
        elif mode=='p1-pn':
            # expected = np.linspace(0,0,20000)
            # get the first pulse response for expected
            first_pulse_response_trace = expected[pulse_times[0]:pulse_times[1]]
            e1 = np.max(first_pulse_response_trace )

            constructed = np.linspace(0,0,20000)
            for i,p in enumerate(pulse_times):
                c = clr[i+1]
                t0 = int(p)
                t1 = t0 + 1000
                constructed[t0:t1] += first_pulse_response_trace
                pmax_exp = np.max(constructed[t0:t1])
                pmax_obs = np.max(observed[t0:t1])
                ax.scatter(pmax_exp, pmax_obs, 30, c )

        elif mode=='pn-pn':
            # expected = np.linspace(0,0,20000)
            # get the first pulse response for expected
            # first_pulse_response_trace = expected[pulse_times[0]:pulse_times[1]]
            # e1 = np.max(first_pulse_response_trace )

            # constructed = np.linspace(0,0,20000)
            for i,p in enumerate(pulse_times):
                c = clr[i+1]
                t0 = int(p)
                t1 = t0 + 1000
                # print(t0,t1, expected.shape)
                # constructed[t0:t1] += first_pulse_response_trace
                pmax_exp = np.max(expected[t0:t1])
                pmax_obs = np.max(observed[t0:t1])
                ax.scatter(pmax_exp, pmax_obs, 30, c )

In [None]:
fig7_4, ax7_4 = plt.subplots(figsize=(10,8), layout='constrained')
plot_raw_expected_observed(data, 3402, ax7_4, mode='raw')

In [None]:
fig7_5, ax7_5 = plt.subplots(figsize=(10,8), layout='constrained')
plot_raw_expected_observed(data, 3402, ax7_5, mode='pn-pn')

In [None]:
fig7_6, ax7_6 = plt.subplots(figsize=(10,8), layout='constrained')
for cell in cell_with_1sq_data:
    plot_raw_expected_observed(data, cell, ax7_6, mode='p1-pn')

In [None]:
# 1st pulse vs nth pulse responses
freq_sweep_datapath =  r"C:\Users\adity\OneDrive\NCBS\Lab\Projects\EI_Dynamics\Analysis\parsed_data\all_cells_FreqSweep_combined_short.h5" 
df2 = pd.read_hdf(freq_sweep_datapath, key='data')
df2 = df2[ pd.notnull(df2['peaks_cell'])] # remove all sweeps that had NaNs in analysed params (mostly due to bad pulse detection)
df2 = df2.reset_index(drop=True)

In [None]:
# selected cell with 1sq data
df2 = df2[(df2['cellID'].isin(cell_with_1sq_data))]
# pick every cell and separate into 1sq and >1sq data, plot for each pattern, expected constructed from 1sq data, observed from >1sq data

def plot_expected_observed(data, cell, ax):
    dfcell = df2[(df2['cellID']==cell) & (df2['stimFreq']==20)]
    dfcell1sq = dfcell[dfcell['numSq']==1]

    for i in range(dfcell.shape[0]):
        row = dfcell.iloc[i,:]
        numSq = row['numSq']
        peaks = row['peaks_cell'].tolist()
        patternList = int(row['patternList']) #48	:[197,255,347,401,439]
        spotloclist = pattern_index.patternID[patternList] #[197,255,347,401,439]
        expected = np.zeros((9,))
        for spotloc in spotloclist:
            pattern = pattern_index.get_patternID([spotloc])
            pattern_peaks = dfcell1sq[(dfcell1sq['patternList']==pattern)]['peaks_cell'].to_numpy()
            pattern_peaks = np.mean( np.concatenate(pattern_peaks, axis=0).reshape(3,-1), axis=0)
            expected += pattern_peaks
        


        


In [None]:
csv_path = r"C:\Users\adity\OneDrive\NCBS\Lab\Projects\EI_Dynamics\Data\23-11-24_GrikAA375\data001.txt"
df = pd.read_csv(csv_path, sep='\t')
# make the coord column as index
df = df.set_index('coord')

In [None]:
# make a fig with two subpanels
fig8, ax8 = plt.subplots(1,2, figsize=(8,8), sharey=True, layout='constrained', gridspec_kw={'width_ratios': [1, 1]})

'''
vals = df_CA3_heatmap[pulse].values
# get index values
idx = (df_CA3_heatmap[pulse].index.get_level_values(0).values) - 1
'''
vals1 = df['field1']
idx1 = (df['field1'].index.get_level_values(0).values) 

vals2 = df['field2']
idx2 = (df['field2'].index.get_level_values(0).values)
# idx2 = (df['coord'].values) #index.get_level_values(0).values)

plot_tools.plot_grid(spot_locs=idx1, spot_values=vals1, ax=ax8[0], cmap='grey', vmin=-2.0, vmax=0)
plot_tools.plot_grid(spot_locs=idx2, spot_values=vals2, ax=ax8[1], cmap='grey', vmin=-2.0, vmax=0)
# # plot field1
# sns.heatmap(field2grid, ax=ax8[0], cmap='grey')
# # plot field2
# sns.heatmap(field1grid, ax=ax8[1], cmap='grey')
# #remove ticks
# ax8[0].set_xticks([])
# ax8[0].set_yticks([])
# ax8[1].set_xticks([])
# ax8[1].set_yticks([])

# # only 1 colorbar
# fig8.colorbar(ax8[0].get_children()[0], ax=ax8, location="top", use_gridspec=False, pad=0.1, shrink=0.5)
# # remove the colorbar from the subplots
# ax8[0].collections[0].colorbar.remove()
# ax8[1].collections[0].colorbar.remove()


In [None]:
vals1

Fig 5 comment from Upi:

show that variance in the field is smaller than the VC currents: multivariate anova # this as a substitute for CA3 recordings