# Data Visualization Program

### Figure 1:
* Expt design schematic
* Expt design patterns in space and time
* Example set of traces
* Calib of stimuli: desensitization
* Control for CA3 responses to patterns.
* CA3 heatmap
* CA3 analysis with respect to freq
* CA3 expected pattern response

### Figure 2:
* Deconv to extract of STP from trace: Schematic
* Comparison between deconv fit, P2P, and peak fitting
* Fitting of reference trace
* Example deconv and fit thereof.
* STP traces under different conditions: E/I, #sq and frequency
* Diversity in STP traces.

[<img src="notes_figures/Figure_outline_Fig2.jpg" width="500"/>]

### Figure 3:
* Model schematic for plasticity
* Fitting free params of model for E and I
* Fitting diversity for E and I
* Fitting stochasticity for E and I
* How averaged are E and I respectively? Crucial for escape calculation.

### Figure 4:
* Model integration: schematic of chem and elec components
* Model validation 1: Match to random pattern sequence
* Model validation 2: Match to Curr clamp recordings

### Figure 5
* Example E-I traces over a burst (train)
* Separate panel to compare E and I trace shapes (you can scale E to I and have their own scale bars)
* Plot field responses corresponding to E and I
* E/I ratio of successive pulses across freq and numSq
* Summary figure across cells
Note: This figure should explain why we dont see escape from E-I balance
* Current clamp summary
* Difference between diff freq over pulses
* Difference between numSq over pulses
 
### Figure 6: Why no escape (can be merged with figure 5)
* Comparison of current clamp response with response generated from a model cell with summated E-I currents
Likely reason: inhibition builds up slowly during the train
* Gabazine vs control, which pulse, selectivity
* for the sweeps that cause escape, which pulse index, which numSq, and which frequency are they from?
* EPSP peak time should be in the same range as spike peak time in the sweeps where there is escape

### Figure 7: SDN across pulses for frequencies
* Example plot of SDN
* Grid
* Heatmap of gamma
* Other ideas:
  * SDN across frequencies: The envelope of E and I change with frequency of stimulation
  * relationship of peak of PSCs with the fall tau for both E and I
  * SDN expected vs observed plotted for different frequencies and consecutive pulse # 

<img src="notes_figures/Figure_Outline_Fig7_SDN.jpg" width="500"/>

### Figure 8: Estimations of Sparsity/Overlap from model vs obs
* Example trace vs model
* Estimation of sparsity/overlap from CA3 recordings

(Paper Figures 5 6 and 7)

<img src="notes_figures/Figure_Outline_Fig5_6_7_May2023.png" width="500"/>


### Update 1 Nov 2023
Other ideas:
* **Split poisson train response across patterns and numSq to check if different patterns cause escape at different points in the freq train**
* check the same for oddball expt also

<img src="notes_figures/Figure_updates_Fig5_6_7_1Nov2023.png" width="500"/>

In [1]:
import sys
import os
import importlib
from   pathlib      import Path
import pickle
import psutil

import numpy                as np
import matplotlib           as mpl
import matplotlib.pyplot    as plt
import seaborn              as sns
import pandas               as pd

from scipy.signal   import find_peaks, peak_widths
from scipy.signal   import butter, bessel, decimate, sosfiltfilt
from scipy.stats    import kruskal, wilcoxon, mannwhitneyu, ranksums
from scipy.signal   import filter_design
from scipy.optimize import curve_fit

from PIL            import Image

from eidynamics     import utils, data_quality_checks, ephys_classes, plot_tools, expt_to_dataframe
from eidynamics     import pattern_index
from eidynamics     import abf_to_data
from eidynamics     import fit_PSC
from Findsim        import tab_presyn_patterns_LR_43
import parse_data
import all_cells

%matplotlib widget
%tb
# import plotly.express as px
# import plotly.graph_objects as go
# sns.set_context('paper')

# make a colour map viridis
viridis = mpl.colormaps["viridis"]
flare   = mpl.colormaps["flare"]
crest   = mpl.colormaps["crest"]
magma   = mpl.colormaps["magma"]
edge    = mpl.colormaps['edge']

color_E = flare
color_I = crest
color_freq = {1:magma(0.05), 5:magma(0.1), 10:magma(0.2), 20:magma(.4), 30:magma(.5), 40:magma(.6), 50:magma(.7), 100:magma(.9)}
color_squares = color_squares = {1:viridis(0.2), 5:viridis(.4), 7:viridis(.6), 15:viridis(.8), 20:viridis(1.0)}
color_EI = {-70:flare(0), 0:crest(0)}

Fs = 2e4

freq_sweep_pulses = np.arange(9)

Data parsing program imported
>> Working on:  C:\Users\adity\OneDrive\NCBS


No traceback available to show.


In [44]:
def get_stat_stars(p, alpha=0.05):
    significance_asterisks = {0.05:'*', 0.01:'**', 0.001:'***', 0.0001:'****'}
    if p < alpha:
        return significance_asterisks[alpha]
    else:
        return 'n.s.'

In [45]:
def find_sweep_expected(trace, freq, fig, ax):
    
    time = np.linspace(0, len(trace)/20000, len(trace))

    fits, fig, ax = fit_PSC.main(time, trace, freq, show_plots=True, fig=fig, ax=ax)
    
    return fits, fig, ax

### Path

In [46]:
figure_raw_material_location = Path(r"paper_figure_matter\\")
paper_figure_export_location = Path(r"paper_figures\\")
data_path                    = Path(r"parsed_data\\")
cell_data_path               = Path(r"C:\Users\adity\OneDrive\NCBS\Lab\Projects\EI_Dynamics\Data\Screened_cells\\")

### Data

### Data for figures 1 and 2

In [47]:
# Full data set for FreqSweep protocol (df) (raw dataframe with metadata)
freq_sweep_cc_datapath =  r"C:\Users\adity\OneDrive\NCBS\Lab\Projects\EI_Dynamics\Analysis\parsed_data\all_cells_FreqSweep_CC_long.h5" 
freq_sweep_vc_datapath =  r"C:\Users\adity\OneDrive\NCBS\Lab\Projects\EI_Dynamics\Analysis\parsed_data\all_cells_FreqSweep_VC_long.h5" 
cc_FS_longdf = pd.read_hdf(freq_sweep_cc_datapath, key='data')

# print new size of dataframe
print(f"Full FreqSweep Dataframe has {cc_FS_longdf.shape[0]} sweeps")

Full FreqSweep Dataframe has 4971 sweeps


In [48]:
# expanded dataframe (processed dataframe with metadata and analysed params)
expanded_data_path = r"C:\Users\adity\OneDrive\NCBS\Lab\Projects\EI_Dynamics\Analysis\parsed_data\all_cells_FreqSweep_combined_expanded.h5"
xc_FS_analyseddf = pd.read_hdf(expanded_data_path, key='data')
print(xc_FS_analyseddf.shape)

KeyboardInterrupt: 

In [None]:
# # short data path for all protocols (df2) (processed dataframe with metadata and analysed params)
dfshortpath     = r"C:\Users\adity\OneDrive\NCBS\Lab\Projects\EI_Dynamics\Analysis\parsed_data\all_cells_allprotocols_with_fpr_values.h5"
xc_all_shortdf  = pd.read_hdf(dfshortpath, key='data')
print(f"Short Dataframe has {xc_all_shortdf.shape[0]} sweeps and df has {xc_all_shortdf.shape[1]} columns")

Short Dataframe has 16870 sweeps and df has 62 columns


In [14]:
# DataFrame metadata columns
column_name_abbreviations = utils.analysed_properties1_abbreviations
metadata_columns = 49
metadata_fields = (cc_FS_longdf.iloc[:2,:metadata_columns].columns).to_list()

### Data for figures 5, 6, 7

In [None]:
# expanded dataframe with channels separated into rows
channel_split_data_path_vc = r"C:\Users\adity\OneDrive\NCBS\Lab\Projects\EI_Dynamics\Analysis\parsed_data\screened_cells_FreqSweep_VC_long_channel_split_into_rows.h5"
channel_split_data_path_cc = r"C:\Users\adity\OneDrive\NCBS\Lab\Projects\EI_Dynamics\Analysis\parsed_data\screened_cells_FreqSweep_CC_long_channel_split_into_rows.h5"
# df4 = pd.read_hdf(channel_split_data_path_vc, key='data')
# df5 = pd.read_hdf(channel_split_data_path_cc, key='data')

# Load CA3 data
df_CA3_props = pd.read_csv(r"C:\Users\adity\OneDrive\NCBS\Lab\Projects\EI_Dynamics\Data\screened_cells\23-06-06_GrikAA316\3161\CA3_recording_3161_response_properties.csv")

# cell list
vc_cells = np.unique( df2 [ (df2['clampMode']=='VC') ]['cellID'])
cc_cells = np.unique(df['cellID'])
vc_cells_screened = np.array([7492, 1931, 1621, 1541, 1524, 1523, 1522, 1491]) # 7491, 6201, 6301, 111, 1531


In [None]:
vc_cells_screened = np.array([7492, 1931, 1621, 1541, 1524, 1523, 1522, 1491])

---

#### Figure 1

* Expt design schematic
* Expt design patterns in space and time
* Example set of traces
* Calib of stimuli: desensitization
* Control for CA3 responses to patterns.
* CA3 heatmap
* CA3 analysis with respect to freq
* CA3 expected pattern response

In [30]:
from scipy import stats

significance_asterisks = {1.0:'n.s.', 0.05:'*', 0.01:'**', 0.001:'***', 0.0001:'****'}

def annotate_stat_stars(ax, pval, alpha=0.05, star_loc=[0.5,0.5], add_line=True, line_locs=[0,1,2,2], offset_btw_star_n_line=0.1, color='k', coord_system='axes', fontsize=12, **kwargs):
    significance_asterisks = {1.0:'n.s.', 0.05:'*', 0.01:'**', 0.001:'***', 0.0001:'****'}
    annot_text = f'p={pval:.2f}'
    for k in significance_asterisks.keys():
        if pval > k:
            break
        else:
            annot_text = significance_asterisks[k]
     

    if coord_system == 'axes':    
        # add text annotation on the axis
        ax.text(star_loc[0], star_loc[1], annot_text, color=color, fontsize=fontsize, ha='center', transform=ax.transAxes, **kwargs)
    else:
        ax.text(star_loc[0], star_loc[1], annot_text, color=color, fontsize=fontsize, ha='center', **kwargs)

    ## add a line to connect the two groups for which annotation is added
    # also add two tiny lines at the end of the main line
    off = offset_btw_star_n_line
    if add_line:
        x0, x1 = line_locs[0], line_locs[1]
        y0, y1 = line_locs[2], line_locs[3]
        yaxis_extent = np.diff(ax.get_ylim())[0]
        y1a, y1b = y1-0.01*yaxis_extent, y1+0.01*yaxis_extent
        # if coord_system == 'axes', draw a line in axes coordinates else data coordinates
        if coord_system == 'axes':
            ax.plot([x0, x1], [y0, y1], transform=ax.transAxes, color=color, linewidth=1)
            ax.plot([x0, x0], [y1a, y1b], transform=ax.transAxes, color=color, linewidth=1)
            ax.plot([x1, x1], [y1a, y1b], transform=ax.transAxes, color=color, linewidth=1)
        else:
            ax.plot([x0, x1], [y0, y1], color=color, linewidth=1)
            ax.plot([x0, x0], [y1a, y1b], color=color, linewidth=1)
            ax.plot([x1, x1], [y1a, y1b], color=color, linewidth=1)

def pairwise_annotate_violin_plot(ax, df, x='', y='', stat=stats.wilcoxon, add_line=False, offset=0.1, color='grey', coord_system='axes', fontsize=12, **kwargs):
    '''
    This function annotates a violin plot with pairwise statistical significance values.
    The function assumes that the x-axis is a categorical variable and the y-axis is a continuous variable.
    '''
    
    unique_values = np.unique(df[x])
    labels = ax.get_xticklabels()
    num_violins = len(labels)
    violin_locs = {int(label.get_text()):label.get_position() for label in labels}

    counter = 0
    for i in unique_values:
        for j in unique_values:
            if i > j:
                xipos, xjpos = violin_locs[i][0], violin_locs[j][0]
                xpos = (xipos + xjpos) / 2
                ypos = (1.0 + counter * 0.05) * ax.get_ylim()[1]
                counter += 1

                _, pval = stat(df[df[x] == i][y], df[df[x] == j][y])
                annotate_stat_stars(ax, pval, star_loc=[xpos, ypos], add_line=True, line_locs=[xipos, xjpos, ypos, ypos], offset_btw_star_n_line=offset, color=color, coord_system=coord_system, fontsize=12, zorder=10)
            

In [None]:
plt.close('all')

# aspect ratio of the figure = 1
w, h = [15, 9]

fig1, [[ax1a, ax1b, ax1c],[ax1d, ax1e, ax1f],[ax1g, ax1h, ax1i]] = plt.subplots(3,3, figsize=(w, h),)
# have more space between suplots
fig1.subplots_adjust(hspace=0.5, wspace=0.5)

# -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
# Fig 1A: Slice, polygon projection, and recording electrodes
ax1a.text(-0.1, 1.1, 'A', transform=ax1a.transAxes, size=20, weight='bold')
image1_path = all_cells.project_path_root / r"Lab\Projects\EI_Dynamics\Analysis\paper_figure_matter\slice_electrode_expression_cropped_with_scalebar_blue_polygon.png"
im1 = Image.open(image1_path)
# get the size of the image
im1_width, im1_height = im1.size
# get the aspect ratio of the ax1a axes object
ax1a_ratio = ax1a.get_window_extent().width / ax1a.get_window_extent().height

# change the axis limits so that verticle side of the image fits the axis and horizontal side is aligned left on the axis
ax1a.set_ylim(0, im1_height)
ax1a.set_xlim(0, im1_height*ax1a_ratio)
# plot the image
ax1a.imshow(im1, extent=[0, im1_width, 0, im1_height], aspect=1)
ax1a.axis('off')

# -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
# Fig 1B: Grid Pattern in Space overlaid on CA3 slice at 40x
ax1b.text(-0.1, 1.1, 'B', transform=ax1b.transAxes, size=20, weight='bold')
image2_path = all_cells.project_path_root / r'Lab\Projects\EI_Dynamics\Analysis\paper_figure_matter\CA3-polygonFrame_figure_with_cellboundaries.png'
im2 = Image.open(image2_path)
im2_width, im2_height = im2.size
# ax1b.imshow(im2)

ax1b.set_ylim(0, im1_height)
ax1b.set_xlim(0, im1_height*ax1a_ratio)
ax1b.imshow(im2, extent=[0, im1_width, 0, im1_height], aspect=1)
ax1b.axis('off')
# ax1b.set_anchor('W')

# -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
# Fig 1C: Protocol Freq Sweep
sample_cell = 1941
ax1c.text(-0.1, 1.1, 'C', transform=ax1c.transAxes, size=20, weight='bold')
data = df[(df['cellID'] == sample_cell) & (df['numSq']==15) &  (df['stimFreq']==20)]
fig1, ax1c, _ = plot_tools.plot_data_from_df(data, data_start_column=49,combine=True, fig=fig1, ax=ax1c, )
# set legend inside the plot
# ax1c.legend(loc='upper right', bbox_to_anchor=(0.9, 0.9), frameon=True) # legend removed to save space
# simplify
plot_tools.simplify_axes(ax1c, splines_to_keep=[], )
#remove legend
ax1c.legend([],[], frameon=False)

# -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
# # Fig 1D: First peak response | cell_fpr_max vs numSq

ax1d.text(-0.1, 1.1, 'D', transform=ax1d.transAxes, size=20, weight='bold')
# ax1d.set_title('First Peak Response')
df_temp = df2[ (df2['clampMode']=='CC') & (df2['location']=='CA1') & (df2['cell_fpr_max']<25) & (df2['cell_fpr_max']>0) & (df2['numSq']!=7) & (df2['numSq']<=15)] #& (df2['cell_fpr']<10)  & (df2['AP']==0) 

# sns.stripplot(data=df_temp,  x="numSq", y="cell_fpr_max", hue="numSq", palette=color_squares, ax=ax1d, alpha=0.8, s=2, jitter=0.25, orient="v", linewidth=0.25)
sns.violinplot(data=df_temp, x="numSq", y="cell_fpr_max", hue="numSq", palette=color_squares, ax=ax1d, alpha=0.5, inner='quart', split=False, dodge=False,zorder=3)
[part.set_edgecolor((part.get_edgecolor()[:],  0)) for part in ax1d.get_children() if isinstance(part, mpl.collections.PolyCollection)]

ax1d.set_ylabel('First Peak Response (mV)')
ax1d.set_xlabel('Number of Squares per Pattern')
ax1d.legend([],[], frameon=False)
# remove top and right spines
ax1d.spines['top'].set_visible(False)
ax1d.spines['right'].set_visible(False)

## Statistics for the violin plots: mannwhitneyu Rank Sum Test across numSq values
pairwise_annotate_violin_plot(ax1d, df_temp, x='numSq', y='cell_fpr_max', stat=stats.mannwhitneyu, add_line=True, offset=0.1, color='grey', coord_system='', fontsize=12, zorder=10)
            

# -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
# # Fig 1E: Field FPR histogram, should be same style as 1D
ax1e.text(-0.1, 1.1, 'E', transform=ax1e.transAxes, size=20, weight='bold')
# plot a relationship between field_fpr and field_stpr using a scatter plot
df_temp = df2[ (df2['fieldData']==True)  & (df2['numSq']!=7) & (df2['numSq']<=15) & (df2['fieldunit']=='mV')   ] #& (df2['field_fpr']<5) (df2['field_fpr_max']!= 0.0) & & (df2['field_fpr_max']<25)

# sns.stripplot(data=df_temp, x="numSq", y="field_fpr_p2p", hue="numSq", palette=color_squares, ax=ax1e, alpha=0.8, s=2, jitter=0.25, orient="v", linewidth=0.25)
sns.violinplot(data=df_temp, x="numSq", y="field_fpr_p2p", hue="numSq", palette=color_squares, ax=ax1e, alpha=0.5, inner='quart', split=False, dodge=False,zorder=3)
[part.set_edgecolor((part.get_edgecolor()[:],  0)) for part in ax1e.get_children() if isinstance(part, mpl.collections.PolyCollection)]

ax1e.set_ylabel('First Peak Field Response (mV)')
ax1e.set_xlabel('Number of Squares per Pattern')
# no legend
ax1e.legend([],[], frameon=False)

# remove top and right spines
ax1e.spines['top'].set_visible(False)
ax1e.spines['right'].set_visible(False)

# statistics
pairwise_annotate_violin_plot(ax1e, df_temp, x='numSq', y='field_fpr_p2p', stat=stats.mannwhitneyu, add_line=True, offset=0.1, color='grey', coord_system='', fontsize=12, zorder=10)

# -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
# # Fig 1F: relationship between cell_fpr and field_fpr
ax1f.text(-0.1, 1.1, 'F', transform=ax1f.transAxes, size=20, weight='bold')
df_temp = df2[   (df2['clampMode']=='CC')  & (df2['location']=='CA1') & (df2['fieldunit']=='mV') & (df2['cell_fpr_max']<20) & (df2['numSq']!=7) & (df2['numSq']<20)] #(df2['cell_fpr_max']>0) & (df2['cellID']==3131) & (df2['cellID']==3131)& (df2['cell_fpr']<10) & & (df2['field_fpr']<10)
#@& (df2['cell_fpr_max']<20) & (df2['field_fpr_max']!=0)
numCells = len(np.unique(df_temp['cellID']))
# pick as many colors as there are cells
colors = sns.color_palette('Paired', numCells)

# plot a relationship between cell_fpr and field_fpr using a scatter plot, but normalize for each cell
# sns.scatterplot(df_temp, x="field_fpr", y="cell_fpr", hue='cellID', style='numSq', palette='Dark2', ax=ax1f, alpha=0.8, s=49, legend=True)
# for each cell normalize the field_fpr and cell_fpr before plotting
markerstyle = {1: 'P', 5:'*', 15: 'X'}
for i,cell in enumerate(np.unique(df_temp['cellID'])):
    for s in np.unique(df_temp['numSq']):
        df_temp_cell = df_temp[ (df_temp['cellID']==cell) & (df_temp['numSq']==s)]
        if len(df_temp_cell) > 0:  # Check if the dataframe is not empty
            field_fpr = df_temp_cell['field_fpr_p2p']
            cell_fpr = df_temp_cell['cell_fpr_max']
            field_fpr_norm = (field_fpr - np.min(field_fpr))/(np.max(field_fpr) - np.min(field_fpr))
            cell_fpr_norm = (cell_fpr - np.min(cell_fpr))/(np.max(cell_fpr) - np.min(cell_fpr))
            # save the norm values back in the dataframe
            df_temp.loc[ (df_temp['cellID']==cell) & (df_temp['numSq']==s), 'field_fpr_norm'] = field_fpr_norm
            df_temp.loc[ (df_temp['cellID']==cell) & (df_temp['numSq']==s), 'cell_fpr_norm'] = cell_fpr_norm
            # add scatteplot for each cell with normalized values of field_fpr and cell_fpr while also adding the cellID as a label, and adding marker style for different numSq values
sns.scatterplot(data=df_temp, x='field_fpr_p2p', y='cell_fpr_max', s=25, hue='numSq', palette=color_squares, ax=ax1f, alpha=0.5)
# add cell labels as legend
ax1f.legend(bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0.)   

sns.jointplot(data=df_temp, x="field_fpr_norm", y="cell_fpr_norm", hue="numSq", palette=color_squares, alpha=0.8, s=20, legend=True)

ax1f.set_xlabel('Field First Peak Response (mV)')
ax1f.set_ylabel('Cell First Peak Response (mV)')
# legend off
ax1f.legend([],[], frameon=False)
# remove top and right spines
ax1f.spines['top'].set_visible(False)
ax1f.spines['right'].set_visible(False)

# -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
# # Fig 1G: CA3 heatmap
# import heatmap data
df_CA3_heatmap = pd.read_hdf(r"parsed_data\CA3_recording_3161_grid_response_pivot.h5")

grid_aspect_ratio = pattern_index.polygon_frame_properties['aspect_ratio']

h = 8
w = grid_aspect_ratio*h
# ax1g.suptitle('Peak Response to 9 pulses in 24 hexagonal grid squares', fontsize=16)

pulse = 0
# get data from column=i and all rows
vals = df_CA3_heatmap[pulse].values
# get index values
idx = (df_CA3_heatmap[pulse].index.get_level_values(0).values) - 1

# make heatmaps
plot_tools.plot_grid(spot_locs=idx, spot_values=vals, grid=[24,24], cmap='viridis', ax=ax1g, vmax=15)

# add text to the plot on top left corner of the heatmap in white color
ax1g.text(0.0, 0.9, 'Basal', transform=ax1g.transAxes, size=8,  color='white')
# add text to the plot on bottom left corner of the heatmap
ax1g.text(0.0, 0.05, 'Apical', transform=ax1g.transAxes, size=8, color='white')
ax1g.set_title('CA3 Response to grid spots')
ax1g.text(-0.1, 1.1, 'G', transform=ax1g.transAxes, size=20, weight='bold')

# -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
# # Fig 1H: CA3 pulse response
ax1h.text(-0.1, 1.1, 'H', transform=ax1h.transAxes, size=20, weight='bold')
df_temp = df2[ (df2['cellID']==3161) ]

# plot a relationship between cell_fpr and field_fpr using a scatter plot (options: box, strip, swarm, violin, boxen)
sns.stripplot(data=df_temp, x="cell_fpr_max",  y='numSq', s=5, alpha=0.8, jitter=True, orient="h", ax=ax1h, hue='numSq', palette=color_squares,  linewidth=0.5)
# set labels
ax1h.set_xlabel('First Peak Depolarization (mV)')
ax1h.set_ylabel('Num Squares')
# legend off
ax1h.legend([],[], frameon=False)
# remove top and right spines
ax1h.spines['top'].set_visible(False)
ax1h.spines['right'].set_visible(False)

# -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
# # Fig 1I: CA3 cell_fpr vs field_fpr
ax1i.text(-0.1, 1.1, 'I', transform=ax1i.transAxes, size=20, weight='bold')

# add scatterplot on ax1i
sns.scatterplot(data=df_temp, x="cell_fpr_max", y="field_fpr_p2p", hue='numSq', palette=color_squares, ax=ax1i, alpha=0.8, s=20, legend=True)
ax1i.set_xlabel('First Peak Optic Current (pA)')
ax1i.set_ylabel('Field response (mV)')
# legend off
ax1i.legend([],[], frameon=False)
# remove top and right spines
ax1i.spines['top'].set_visible(False)
ax1i.spines['right'].set_visible(False)

# -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
# # save figure
figure_name = 'Figure1'
fig1.savefig(paper_figure_export_location / (figure_name + '.png'), dpi=300, bbox_inches='tight')
fig1.savefig(paper_figure_export_location / (figure_name + '.svg'), dpi=300, bbox_inches='tight')

In [None]:
# Recreate fig panel ax1e using plotly
# # Fig 1E: Field FPR histogram, should be same style as 1D
# figtemp, axtemp = plt.subplots(1,1, figsize=(w, h))
# plot a relationship between field_fpr and field_stpr using a scatter plot
df_temp = df2[ (df2['fieldData']==True)  & (df2['numSq']!=7) & (df2['numSq']<=15) & (df2['fieldunit']=='mV') & (df2['numSq']==15)   ] 

# use plotly to create a violin plot, cellID is a categorical variable, cell_fpr is a continuous variable
fig = px.violin(df_temp, x="cellID", y="field_fpr_p2p", color="cellID", box=False, points="all", hover_data=df_temp.columns)
fig.update_xaxes(type='category')
fig.show()

In [None]:
figtemp, axtemp = plt.subplots()
axtemp.text(-0.1, 1.1, 'H', transform=axtemp.transAxes, size=20, weight='bold')
df_temp = df2[ (df2['cellID']==3161) ]

# plot a relationship between cell_fpr and field_fpr using a scatter plot (options: box, strip, swarm, violin, boxen)
sns.stripplot(data=df_temp, x="cell_fpr_max",  y='numSq', s=5, alpha=0.8, jitter=True, orient="h", ax=axtemp, hue='numSq', palette=color_squares,  linewidth=0.5)
# set labels
axtemp.set_xlabel('First Peak Depolarization (mV)')
axtemp.set_ylabel('Num Squares')
# legend off
axtemp.legend([],[], frameon=False)
# remove top and right spines
axtemp.spines['top'].set_visible(False)
axtemp.spines['right'].set_visible(False)

# -----------------------------------------------------------------------------------------------------------------------------------------------------------------------

# get the 5sq rows
cell3161_5sq = cell3161new[cell3161new['numSq']==5]
# go over each row, get the patternlist --> get the spot locs for the pattern --> get the patternID for each spotlc --> get the response for that patternID
for i in range(len(cell3161_5sq)):
    patternlist = int(cell3161_5sq.iloc[i]['patternList'])
    print(i,patternlist)
    spotlist = pattern_index.patternID[patternlist]
    # collect the cell_fpr_max for the spots in the spotlist from cell3161 df
    cell_fpr_max = []
    for spot in spotlist:
        cell_fpr_max.append(cell3161new[cell3161new['coordloc']==spot]['cell_fpr_max'].values[0])
    # expected cell_fpr_max
    expected_cell_fpr_max = np.sum(cell_fpr_max) # expected for that patternlist
    sns.scatterplot(x=[expected_cell_fpr_max], y=[0.5], color=color_squares[5], marker='o', s=50, ax=axtemp)
    
# plot the distribution on axtemp
# sns.stripplot(x=bootstrap_distribution,  y=0.5, s=5, alpha=0.5, jitter=True, orient="h", ax=axtemp, color=color_squares[5],  linewidth=0.5)
# get the 15sq rows
cell3161_15sq = cell3161new[cell3161new['numSq']==15]
# go over each row, get the patternlist --> get the spot locs for the pattern --> get the patternID for each spotlc --> get the response for that patternID
for i in range(len(cell3161_15sq)):
    patternlist = int(cell3161_15sq.iloc[i]['patternList'])
    print(i,patternlist)
    spotlist = pattern_index.patternID[patternlist]
    # collect the cell_fpr_max for the spots in the spotlist from cell3161 df
    cell_fpr_max = []
    for spot in spotlist:
        cell_fpr_max.append(cell3161new[cell3161new['coordloc']==spot]['cell_fpr_max'].values[0])
    # expected cell_fpr_max
    expected_cell_fpr_max = np.sum(cell_fpr_max) # expected for that patternlist
    sns.scatterplot(x=[expected_cell_fpr_max], y=[1.5], color=color_squares[15], marker='o', s=50, ax=axtemp)
    
    

In [None]:
fig1B, ax1B = plt.subplots(1,2, figsize=(8,4), sharey=True)
# figure supertitle
fig1B.suptitle('First Peak Response vs PPR and STPR', fontsize=16)

df_temp = df2[ (df2['cell_fpr']<10) & (df2['cell_fpr']>0) & (df2['clampMode']=='CC') & (df2['field_fpr']!=0)] #& (df2['field_fpr']<10)

sns.scatterplot(data=df_temp, x="cell_fpr", y="cell_ppr", hue='stimFreq', style='cellID', palette=color_freq, ax=ax1B[0], alpha=0.5, s=40, legend=True, linewidth=0)
ax1B[0].set_xlabel('First Peak Depolarization (mV)')
ax1B[0].set_ylabel('PPR')
# legend off
ax1B[0].legend([],[], frameon=False)
# remove top and right spines
ax1B[0].spines['top'].set_visible(False)
ax1B[0].spines['right'].set_visible(False)

sns.scatterplot(data=df_temp, x="cell_fpr", y="cell_stpr", hue='stimFreq', style='cellID', palette=color_freq, ax=ax1B[1], alpha=0.5, s=40, legend=True, linewidth=0)
ax1B[1].set_xlabel('First Peak Depolarization (mV)')
ax1B[1].set_ylabel('STPR')
# legend off
ax1B[1].legend()#[],[], frameon=False)
# remove top and right spines
ax1B[1].spines['top'].set_visible(False)
ax1B[1].spines['right'].set_visible(False)

# save fig
figure_name = 'Figure1B'
fig1B.savefig(paper_figure_export_location / (figure_name + '.png'), dpi=300, bbox_inches='tight')
fig1B.savefig(paper_figure_export_location / (figure_name + '.svg'), dpi=300, bbox_inches='tight')

In [None]:
fig1D, ax1D = plt.subplots(1,2, figsize=(8,4), sharey=False)
# figure supertitle
fig1D.suptitle('First Peak Response vs PPR and STPR', fontsize=16)

df_temp = df2[ (df2['cellID']==3131) & (df2['cell_fpr']<10) & (df2['cell_fpr']>0) & (df2['clampMode']=='CC') & (df2['field_fpr']!=0)] #& (df2['field_fpr']<10)

sns.scatterplot(data=df_temp, x="cell_fpr", y="cell_ppr", hue='numSq', style='cellID', palette=color_squares, ax=ax1D[0], alpha=0.5, s=40, legend=True, linewidth=0)
ax1D[0].set_xlabel('First Peak Depolarization (mV)')
ax1D[0].set_ylabel('PPR')
# legend off
ax1D[0].legend([],[], frameon=False)
# remove top and right spines
ax1D[0].spines['top'].set_visible(False)
ax1D[0].spines['right'].set_visible(False)
# draw a line at y=1
ax1D[0].axhline(y=1, color='black', linestyle='--', linewidth=1)

sns.scatterplot(data=df_temp, x="cell_fpr", y="cell_stpr", hue='numSq', style='cellID', palette=color_squares, ax=ax1D[1], alpha=0.5, s=40, legend=True, linewidth=0)
ax1D[1].set_xlabel('First Peak Depolarization (mV)')
ax1D[1].set_ylabel('STPR')
# legend off
ax1D[1].legend()#[],[], frameon=False)
# remove top and right spines
ax1D[1].spines['top'].set_visible(False)
ax1D[1].spines['right'].set_visible(False)
ax1D[1].axhline(y=1, color='black', linestyle='--', linewidth=1)

# save fig
figure_name = 'Figure1D_cell3131'
fig1D.savefig(paper_figure_export_location / (figure_name + '.png'), dpi=300, bbox_inches='tight')
fig1D.savefig(paper_figure_export_location / (figure_name + '.svg'), dpi=300, bbox_inches='tight')

In [None]:
fig1C, ax1C = plt.subplots(1,2, figsize=(8,4), sharey=True)
# figure supertitle
fig1C.suptitle('First Peak Response vs PPR and STPR', fontsize=16)

df_temp = df2[ (df2['cell_fpr']<10) & (df2['cell_fpr']>0) & (df2['clampMode']=='CC') & (df2['field_fpr']!=0)] #& (df2['field_fpr']<10)

sns.scatterplot(data=df_temp, x="cell_fpr", y="cell_ppr", hue='numSq', palette=color_squares, ax=ax1C[0], alpha=0.5, s=40, legend=True, linewidth=0)
ax1C[0].set_xlabel('First Peak Depolarization (mV)')
ax1C[0].set_ylabel('PPR')
# legend off
ax1C[0].legend([],[], frameon=False)
# remove top and right spines
ax1C[0].spines['top'].set_visible(False)
ax1C[0].spines['right'].set_visible(False)

sns.scatterplot(data=df_temp, x="cell_fpr", y="cell_stpr", hue='numSq', palette=color_squares, ax=ax1C[1], alpha=0.5, s=40, legend=True, linewidth=0)
ax1C[1].set_xlabel('First Peak Depolarization (mV)')
ax1C[1].set_ylabel('STPR')
# legend off
ax1C[1].legend()#[],[], frameon=False)
# remove top and right spines
ax1C[1].spines['top'].set_visible(False)
ax1C[1].spines['right'].set_visible(False)

# save fig
figure_name = 'Figure1C'
fig1C.savefig(paper_figure_export_location / (figure_name + '.png'), dpi=300, bbox_inches='tight')
fig1C.savefig(paper_figure_export_location / (figure_name + '.svg'), dpi=300, bbox_inches='tight')

### Figure 2: CA1 and CA3 responses to grid spots
* Deconv to extract of STP from trace: Schematic
* Comparison between deconv fit, P2P, and peak fitting
* Fitting of reference trace
* Example deconv and fit thereof.
* STP traces under different conditions: E/I, #sq and frequency
* Diversity in STP traces.

[<img src="notes_figures/Figure_outline_Fig2.jpg" width="500"/>]

In [None]:
# Load the dataset
freq_sweep_cc_datapath =  r"C:\Users\adity\OneDrive\NCBS\Lab\Projects\EI_Dynamics\Analysis\parsed_data\all_cells_FreqSweep_CC_long.h5" 
df = pd.read_hdf(freq_sweep_cc_datapath, key='data')

# aspect ration of the figure = 1
w, h = [15, 9]
fig2, [[ax2a, ax2b, ax2c],[ax2d_top, ax2e_top, ax2f_top],[ax2d_bottom, ax2e_bottom, ax2f_bottom]] = plt.subplots(3,3, figsize=(w, h), sharey=False)
fig2.subplots_adjust(hspace=0.5, wspace=0.5)
plot_kind = 'violin' # 'line' or 'violin' or 'strip'

column_name_abbreviations = ['pc','pcn','ac','sc','dc','pf','pfn']
cell =  1941
sq   =  15

# -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
# Fig 2A,B,C: Upi's deconvolution fits
ax2a.text(-0.1, 1.1, 'A', transform=ax2a.transAxes, size=20, weight='bold')
ax2b.text(-0.1, 1.1, 'B', transform=ax2b.transAxes, size=20, weight='bold')
ax2c.text(-0.1, 1.1, 'C', transform=ax2c.transAxes, size=20, weight='bold')

sweepnum = 0
df_temp = df[(df['cellID'] == cell) & (df['numSq'] == sq)]
sweep = df_temp.iloc[sweepnum, :]
tracecell, tracestim, stimfreq = sweep[49:20049], sweep[40049:60049], sweep['stimFreq']

ax2a.plot(np.linspace(0, 1, 20000),       tracecell, label='recording') # Raw data
ax2a.plot(np.linspace(0, 1, 20000), 100 * tracestim, label='Stim')      # Stimulus

# find the expected response, add to the axes
fits, _, _ = find_sweep_expected(tracecell, stimfreq, fig2, [ax2a, ax2b, ax2c])

# simplify axes
plot_tools.simplify_axes(ax2a, splines_to_keep=['bottom', 'left'], axis_offset=10, remove_ticks=False, xtick_locs=[0, 0.2, 0.4, 0.6, 0.8, 1.0], xtick_labels=[0, 0.2, 0.4, 0.6, 0.8, 1.0], ytick_locs=[-150, -100, -50, 0, 50], ytick_labels=[-150, -100, -50, 0, 50],)
ax2a.legend(bbox_to_anchor=(0, 1), loc='upper left', borderaxespad=0., frameon=False)
ax2a.set_ylabel('membrane potential (mV)')

plot_tools.simplify_axes(ax2b, splines_to_keep=['bottom', 'left'], axis_offset=10, remove_ticks=False, xtick_locs=[0, 0.2, 0.4, 0.6, 0.8, 1.0], xtick_labels=[0, 0.2, 0.4, 0.6, 0.8, 1.0], ytick_locs=[-50, 0, 50], ytick_labels=[-50, 0, 50],)
ax2b.legend([], frameon=False)
ax2b.set_ylabel('Residual (mV)')

plot_tools.simplify_axes( ax2c, splines_to_keep=['bottom', 'left'], axis_offset=10, remove_ticks=False, xtick_locs=range(9), xtick_labels=freq_sweep_pulses, ytick_locs=[0, 1.0, 2.0], ytick_labels=[0, 1.0, 2.0],)
ax2c.legend([], frameon=False)
ax2c.set_ylabel('Normalized Response')

# -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
# Fig 2D: STP vs numSq
ax2d_top.text(-0.1, 1.1, 'D', transform=ax2d_top.transAxes, size=20, weight='bold')

cp = -70        # clamping potential subset for the plot
f = 20          # stimFreq, Hz
to_plot = [f'pcn{i}' for i in freq_sweep_pulses]
df_temp = df3[ (df3['location'] == 'CA1') & (df3['clampMode'] == 'CC') & (df3['cellID'] == cell) & (df3['stimFreq'] == f) & (df3['spikingFlag'] == False)] 
df_melt = pd.melt( df_temp, id_vars=['cellID', 'stimFreq', 'clampPotential', 'numSq', 'patternList'], value_vars=to_plot, var_name='pulseIndex', value_name='peak_response',)

pairwise_draw_and_annotate_line_plot(   ax2d_top, df_melt, x='pulseIndex', y='peak_response', hue='numSq', draw=True, kind=plot_kind, palette=color_squares, 
                                        stat_across='hue', stat=stats.kruskal, skip_first_xvalue=True, annotate_wrt_data=False, offset_btw_star_n_line=0.1, color='grey', coord_system='data', fontsize=12, zorder=10,)
plot_tools.simplify_axes( ax2d_top, splines_to_keep=['bottom', 'left'], axis_offset=10, remove_ticks=False, xtick_locs=range(9), xtick_labels=freq_sweep_pulses, ytick_locs=[0, 1, 2], ytick_labels=[0, 1, 2],)
ax2d_top.set_xlabel('Pulse Index')
ax2d_top.set_ylabel('Norm. Response')

# get the lgened labels of ax2d_top and add ' Sq' to each one
handles, labels = ax2d_top.get_legend_handles_labels()
labels = [label + ' Sq.' for label in labels]
ax2d_top.legend(handles, labels, loc='upper right', borderaxespad=0., frameon=False)
# add a text in the top left corner of ax2d_top showing '20 Hz' in color = color_freq[20]
ax2d_top.text(0.0, 0.9, f'{f} Hz', transform=ax2d_top.transAxes, size=12, color=color_freq[f], zorder=10)

### Fig 2D_Bottom: STP vs numSq
# field plot in the bottom subplot
to_plot = [f'pfn{i}' for i in range(9)]

df_temp = df3[ (df3['location'] == 'CA1') & (df3['clampMode'] == 'CC') & (df3['cellID'] == cell) & (df3['stimFreq'] == f) & (df3['spikingFlag'] == False)]
df_melt = pd.melt( df_temp, id_vars=['cellID', 'stimFreq', 'clampPotential', 'numSq', 'patternList'], value_vars=to_plot, var_name='pulseIndex', value_name='peak_response',)

pairwise_draw_and_annotate_line_plot(   ax2d_bottom, df_melt, x='pulseIndex', y='peak_response', hue='numSq', draw=True, kind=plot_kind, palette=color_squares,
                                        stat_across='hue', stat=stats.kruskal, skip_first_xvalue=True, annotate_wrt_data=False, offset_btw_star_n_line=0.1, color='grey', coord_system='data', fontsize=12, zorder=10,)
plot_tools.simplify_axes( ax2d_bottom, splines_to_keep=['bottom', 'left'], axis_offset=10, remove_ticks=False, xtick_locs=range(9), xtick_labels=freq_sweep_pulses, ytick_locs=[0, 1, 2], ytick_labels=[0, 1, 2],)

ax2d_bottom.set_xlabel('Pulse Index')
ax2d_bottom.set_ylabel('Norm. Response')
ax2d_bottom.legend([], frameon=False)

# -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
# Fig 2E: STP vs Frequency
ax2e_top.text(-0.1, 1.1, 'E', transform=ax2e_top.transAxes, size=20, weight='bold')

s = 5  # squares
cp = -70
to_plot = [f'pcn{i}' for i in freq_sweep_pulses]
df_temp = df3[ (df3['location'] == 'CA1') & (df3['clampMode'] == 'CC') & (df3['cellID'] == cell) & (df3['numSq'] == s) & (df3['spikingFlag'] == False)]
df_melt = pd.melt( df_temp[df_temp['stimFreq'] < 100], id_vars=['cellID', 'stimFreq', 'clampPotential', 'numSq', 'patternList'], value_vars=to_plot, var_name='pulseIndex', value_name='peak_response',)

pairwise_draw_and_annotate_line_plot(   ax2e_top, df_melt, x='pulseIndex', y='peak_response', hue='stimFreq', draw=True, kind='violin', palette=color_freq, stat_across='hue',
                                        stat=stats.kruskal, skip_first_xvalue=True, annotate_wrt_data=False, offset_btw_star_n_line=0.1, color='grey', coord_system='data', fontsize=12, zorder=10,)
plot_tools.simplify_axes( ax2e_top, splines_to_keep=['bottom', 'left'], axis_offset=10, remove_ticks=False, xtick_locs=range(9), xtick_labels=freq_sweep_pulses, ytick_locs=[0, 1, 2], ytick_labels=[0, 1, 2],)
ax2e_top.set_xlabel('Pulse Index')
ax2e_top.set_ylabel('Norm. Response')

# get the lgened labels of ax2e_top and add ' Hz' to each one
handles, labels = ax2e_top.get_legend_handles_labels()
labels = [label + ' Hz' for label in labels]
ax2e_top.legend(handles, labels, loc='upper right', borderaxespad=0., frameon=False)

# add a text in the top left corner of ax2d_top showing '5 Sq' in color = color_squares[5]
ax2e_top.text(0.0, 0.9, f'{s} Sq', transform=ax2e_top.transAxes, size=12, color=color_squares[s], zorder=10)

### 2E Bottom: Field plot in the bottom
to_plot = [f'pfn{i}' for i in range(9)]

df_temp = df3[ (df3['location'] == 'CA1') & (df3['clampMode'] == 'CC') & (df3['cellID'] == cell) & (df3['numSq'] == s) & (df3['spikingFlag'] == False)]
df_melt = pd.melt( df_temp[df_temp['stimFreq'] < 100], id_vars=['cellID', 'stimFreq', 'clampPotential', 'numSq', 'patternList'], value_vars=to_plot, var_name='pulseIndex', value_name='peak_response',)

pairwise_draw_and_annotate_line_plot(   ax2e_bottom, df_melt, x='pulseIndex', y='peak_response', hue='stimFreq', draw=True, kind='violin', palette=color_freq,
                                        stat_across='hue', stat=stats.kruskal, skip_first_xvalue=True, annotate_wrt_data=False, offset_btw_star_n_line=0.1, color='grey', coord_system='data', fontsize=12, zorder=10,)
plot_tools.simplify_axes( ax2e_bottom, splines_to_keep=['bottom', 'left'], axis_offset=10, remove_ticks=False, xtick_locs=range(9), xtick_labels=freq_sweep_pulses, ytick_locs=[0, 1, 2], ytick_labels=[0, 1, 2],)
ax2e_bottom.set_xlabel('Pulse Index')
ax2e_bottom.set_ylabel('Norm. Response')
ax2e_bottom.legend([], frameon=False)

# -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
# Fig 2F: STP vs E/I
ax2f_top.text(-0.1, 1.1, 'F', transform=ax2f_top.transAxes, size=20, weight='bold')

f = 20  # Hz
s = 5  # squares
to_plot = [f'pcn{i}' for i in freq_sweep_pulses]
df_temp = df3[ (df3['location'] == 'CA1') & (df3['clampMode'] == 'VC') & (df3['cellID'] == cell) & (df3['stimFreq'] == f) & (df3['numSq'] == s)]
df_melt = pd.melt( df_temp, id_vars=['cellID', 'stimFreq', 'clampPotential', 'numSq', 'patternList'], value_vars=to_plot, var_name='pulseIndex', value_name='peak_response',)

pairwise_draw_and_annotate_line_plot(   ax2f_top, df_melt, x='pulseIndex', y='peak_response', hue='clampPotential', draw=True, kind='violin', palette=color_EI,
                                        stat_across='hue', stat=stats.kruskal, skip_first_xvalue=True, annotate_wrt_data=False, offset_btw_star_n_line=0.1, color='grey', coord_system='data', fontsize=12, zorder=10,)

plot_tools.simplify_axes( ax2f_top, splines_to_keep=['bottom', 'left'], axis_offset=10, remove_ticks=False, xtick_locs=range(9), xtick_labels=freq_sweep_pulses, ytick_locs=[0, 1, 2], ytick_labels=[0, 1, 2],)
ax2f_top.set_xlabel('Pulse Index')
ax2f_top.set_ylabel('Norm. Response')

# get the lgened labels of ax2f_top and add ' mV' to each one
handles, labels = ax2f_top.get_legend_handles_labels()
labels = [label + ' mV' for label in labels]
ax2f_top.legend(handles, labels, loc='upper right', borderaxespad=0., frameon=False)

# add a text in the top left corner of ax2d_top showing '20 Hz' in color = color_freq[20]
ax2f_top.text(0.0, 0.9, f'{f} Hz', transform=ax2f_top.transAxes, size=12, color=color_freq[f], zorder=10)
ax2f_top.text(0.2, 0.9, f'{s} Sq', transform=ax2f_top.transAxes, size=12, color=color_squares[s], zorder=10)


### Fig 2F Bottom: Field plot in the bottom
to_plot = [f'pfn{i}' for i in range(9)]
df_temp = df3[ (df3['location'] == 'CA1') & (df3['clampMode'] == 'VC') & (df3['cellID'] == cell) & (df3['stimFreq'] == f) & (df3['numSq'] == s)]
df_melt = pd.melt( df_temp, id_vars=['cellID', 'stimFreq', 'clampPotential', 'numSq', 'patternList'], value_vars=to_plot, var_name='pulseIndex', value_name='peak_response',)

pairwise_draw_and_annotate_line_plot(   ax2f_bottom, df_melt, x='pulseIndex', y='peak_response', hue='clampPotential', draw=True, kind='violin', palette=color_EI,
                                        stat_across='hue', stat=stats.kruskal, skip_first_xvalue=True, annotate_wrt_data=False, offset_btw_star_n_line=0.1, color='grey', coord_system='data', fontsize=12, zorder=10)
plot_tools.simplify_axes(ax2f_bottom, splines_to_keep=['bottom','left'], axis_offset=10, remove_ticks=False, xtick_locs=range(9), xtick_labels=freq_sweep_pulses, ytick_locs=[0,1,2], ytick_labels=[0,1,2])
ax2f_bottom.set_xlabel('Pulse Index')
ax2f_bottom.set_ylabel('Norm. Response')
ax2f_bottom.legend([], frameon=False)

# -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
# # save figure
figure_name = 'Figure2_CC'
fig2.savefig(paper_figure_export_location / (figure_name + '.png'), dpi=300, bbox_inches='tight')
fig2.savefig(paper_figure_export_location / (figure_name + '.svg'), dpi=300, bbox_inches='tight')

### Figure 2 with voltage clamp

In [None]:
# Load the dataset 
# VC cells with field data: 111, 1491, 1522, 1523, 1524, 1531, 1541, 1621, 1931, 2941, 5501, 5502, 6201, 6301
freq_sweep_vc_datapath =  r"C:\Users\adity\OneDrive\NCBS\Lab\Projects\EI_Dynamics\Analysis\parsed_data\all_cells_FreqSweep_VC_long.h5" 
df = pd.read_hdf(freq_sweep_vc_datapath, key='data')


# aspect ration of the figure = 1
w, h = [15, 9]
fig2, [[ax2a, ax2b, ax2c],[ax2d_top, ax2e_top, ax2f_top],[ax2d_bottom, ax2e_bottom, ax2f_bottom]] = plt.subplots(3,3, figsize=(w, h), sharey=False)
fig2.subplots_adjust(hspace=0.5, wspace=0.5)
plot_kind = 'violin' # 'line' or 'violin' or 'strip'

column_name_abbreviations = ['pc','pcn','ac','sc','dc','pf','pfn']
cell =  1491
sq   =  15

# -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
# Fig 2A,B,C: Upi's deconvolution fits
ax2a.text(-0.1, 1.1, 'A', transform=ax2a.transAxes, size=20, weight='bold')
ax2b.text(-0.1, 1.1, 'B', transform=ax2b.transAxes, size=20, weight='bold')
ax2c.text(-0.1, 1.1, 'C', transform=ax2c.transAxes, size=20, weight='bold')

sweepnum = 0
df_temp = df[(df['cellID'] == cell) & (df['numSq'] == sq)]
sweep = df_temp.iloc[sweepnum, :]
tracecell, tracestim, stimfreq = sweep[49:20049], sweep[40049:60049], sweep['stimFreq']

ax2a.plot(np.linspace(0, 1, 20000),       tracecell, label='recording') # Raw data
ax2a.plot(np.linspace(0, 1, 20000), 100 * tracestim, label='Stim')      # Stimulus

# find the expected response, add to the axes
fits, _, _ = find_sweep_expected(tracecell, stimfreq, fig2, [ax2a, ax2b, ax2c])

# simplify axes
plot_tools.simplify_axes(ax2a, splines_to_keep=['bottom', 'left'], axis_offset=10, remove_ticks=False, xtick_locs=[0, 0.2, 0.4, 0.6, 0.8, 1.0], xtick_labels=[0, 0.2, 0.4, 0.6, 0.8, 1.0], ytick_locs=[-150, -100, -50, 0, 50], ytick_labels=[-150, -100, -50, 0, 50],)
ax2a.legend(bbox_to_anchor=(0, 1), loc='upper left', borderaxespad=0., frameon=False)
ax2a.set_ylabel('membrane potential (pA)')

plot_tools.simplify_axes(ax2b, splines_to_keep=['bottom', 'left'], axis_offset=10, remove_ticks=False, xtick_locs=[0, 0.2, 0.4, 0.6, 0.8, 1.0], xtick_labels=[0, 0.2, 0.4, 0.6, 0.8, 1.0], ytick_locs=[-50, 0, 50], ytick_labels=[-50, 0, 50],)
ax2b.legend([], frameon=False)
ax2b.set_ylabel('Residual (pA)')

plot_tools.simplify_axes( ax2c, splines_to_keep=['bottom', 'left'], axis_offset=10, remove_ticks=False, xtick_locs=range(9), xtick_labels=freq_sweep_pulses, ytick_locs=[0, 1.0, 2.0], ytick_labels=[0, 1.0, 2.0],)
ax2c.legend([], frameon=False)
ax2c.set_ylabel('Normalized Response')

# -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
# Fig 2D: STP vs numSq
ax2d_top.text(-0.1, 1.1, 'D', transform=ax2d_top.transAxes, size=20, weight='bold')

cp = -70        # clamping potential subset for the plot
f = 20          # stimFreq, Hz
to_plot = [f'pcn{i}' for i in freq_sweep_pulses]
df_temp = df3[ (df3['location'] == 'CA1') & (df3['clampMode'] == 'VC') & (df3['cellID'] == cell) & (df3['stimFreq'] == f) & (df3['clampPotential'] == cp) & (df3['spikingFlag'] == False)] 
df_melt = pd.melt( df_temp, id_vars=['cellID', 'stimFreq', 'clampPotential', 'numSq', 'patternList'], value_vars=to_plot, var_name='pulseIndex', value_name='peak_response',)

pairwise_draw_and_annotate_line_plot(   ax2d_top, df_melt, x='pulseIndex', y='peak_response', hue='numSq', draw=True, kind=plot_kind, palette=color_squares, 
                                        stat_across='hue', stat=stats.kruskal, skip_first_xvalue=True, annotate_wrt_data=False, offset_btw_star_n_line=0.1, color='grey', coord_system='data', fontsize=12, zorder=10,)
plot_tools.simplify_axes( ax2d_top, splines_to_keep=['bottom', 'left'], axis_offset=10, remove_ticks=False, xtick_locs=range(9), xtick_labels=freq_sweep_pulses, ytick_locs=[0, 1, 2], ytick_labels=[0, 1, 2],)
ax2d_top.set_xlabel('Pulse Index')
ax2d_top.set_ylabel('Norm. Response')

# get the lgened labels of ax2d_top and add ' Sq' to each one
handles, labels = ax2d_top.get_legend_handles_labels()
labels = [label + ' Sq.' for label in labels]
ax2d_top.legend(handles, labels, loc='upper right', borderaxespad=0., frameon=False)
# add a text in the top left corner of ax2d_top showing '20 Hz' in color = color_freq[20]
ax2d_top.text(0.0, 0.9, f'{f} Hz', transform=ax2d_top.transAxes, size=12, color=color_freq[f], zorder=10)
ax2d_top.text(0.2, 0.9, f'{cp} mV', transform=ax2d_top.transAxes, size=12, color=color_EI[cp], zorder=10)

### Fig 2D_Bottom: STP vs numSq
# field plot in the bottom subplot
to_plot = [f'pfn{i}' for i in range(9)]

df_temp = df3[ (df3['location'] == 'CA1') & (df3['clampMode'] == 'VC') & (df3['cellID'] == cell) & (df3['stimFreq'] == f) & (df3['clampPotential'] == cp) & (df3['spikingFlag'] == False)]
df_melt = pd.melt( df_temp, id_vars=['cellID', 'stimFreq', 'clampPotential', 'numSq', 'patternList'], value_vars=to_plot, var_name='pulseIndex', value_name='peak_response',)

pairwise_draw_and_annotate_line_plot(   ax2d_bottom, df_melt, x='pulseIndex', y='peak_response', hue='numSq', draw=True, kind=plot_kind, palette=color_squares,
                                        stat_across='hue', stat=stats.kruskal, skip_first_xvalue=True, annotate_wrt_data=False, offset_btw_star_n_line=0.1, color='grey', coord_system='data', fontsize=12, zorder=10,)
plot_tools.simplify_axes( ax2d_bottom, splines_to_keep=['bottom', 'left'], axis_offset=10, remove_ticks=False, xtick_locs=range(9), xtick_labels=freq_sweep_pulses, ytick_locs=[0, 1, 2], ytick_labels=[0, 1, 2],)

ax2d_bottom.set_xlabel('Pulse Index')
ax2d_bottom.set_ylabel('Norm. Response')
ax2d_bottom.legend([], frameon=False)

# -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
# Fig 2E: STP vs Frequency
ax2e_top.text(-0.1, 1.1, 'E', transform=ax2e_top.transAxes, size=20, weight='bold')

s = 5  # squares
cp = -70
to_plot = [f'pcn{i}' for i in freq_sweep_pulses]
df_temp = df3[ (df3['location'] == 'CA1') & (df3['clampMode'] == 'VC') & (df3['cellID'] == cell) & (df3['clampPotential'] == cp) & (df3['numSq'] == s) & (df3['spikingFlag'] == False)]
df_melt = pd.melt( df_temp[df_temp['stimFreq'] < 100], id_vars=['cellID', 'stimFreq', 'clampPotential', 'numSq', 'patternList'], value_vars=to_plot, var_name='pulseIndex', value_name='peak_response',)

pairwise_draw_and_annotate_line_plot(   ax2e_top, df_melt, x='pulseIndex', y='peak_response', hue='stimFreq', draw=True, kind='violin', palette=color_freq, stat_across='hue',
                                        stat=stats.kruskal, skip_first_xvalue=True, annotate_wrt_data=False, offset_btw_star_n_line=0.1, color='grey', coord_system='data', fontsize=12, zorder=10,)
plot_tools.simplify_axes( ax2e_top, splines_to_keep=['bottom', 'left'], axis_offset=10, remove_ticks=False, xtick_locs=range(9), xtick_labels=freq_sweep_pulses, ytick_locs=[0, 1, 2], ytick_labels=[0, 1, 2],)
ax2e_top.set_xlabel('Pulse Index')
ax2e_top.set_ylabel('Norm. Response')

# get the lgened labels of ax2e_top and add ' Hz' to each one
handles, labels = ax2e_top.get_legend_handles_labels()
labels = [label + ' Hz' for label in labels]
ax2e_top.legend(handles, labels, loc='upper right', borderaxespad=0., frameon=False)

# add a text in the top left corner of ax2d_top showing '5 Sq' in color = color_squares[5]
ax2e_top.text(0.0, 0.9, f'{s} Sq', transform=ax2e_top.transAxes, size=12, color=color_squares[s], zorder=10)
ax2e_top.text(0.2, 0.9, f'{cp} mV', transform=ax2e_top.transAxes, size=12, color=color_EI[cp], zorder=10)


### 2E Bottom: Field plot in the bottom
to_plot = [f'pfn{i}' for i in range(9)]

df_temp = df3[ (df3['location'] == 'CA1') & (df3['clampMode'] == 'VC') & (df3['cellID'] == cell) & (df3['clampPotential'] == cp) & (df3['numSq'] == s) & (df3['spikingFlag'] == False)]
df_melt = pd.melt( df_temp[df_temp['stimFreq'] < 100], id_vars=['cellID', 'stimFreq', 'clampPotential', 'numSq', 'patternList'], value_vars=to_plot, var_name='pulseIndex', value_name='peak_response',)

pairwise_draw_and_annotate_line_plot(   ax2e_bottom, df_melt, x='pulseIndex', y='peak_response', hue='stimFreq', draw=True, kind='violin', palette=color_freq,
                                        stat_across='hue', stat=stats.kruskal, skip_first_xvalue=True, annotate_wrt_data=False, offset_btw_star_n_line=0.1, color='grey', coord_system='data', fontsize=12, zorder=10,)
plot_tools.simplify_axes( ax2e_bottom, splines_to_keep=['bottom', 'left'], axis_offset=10, remove_ticks=False, xtick_locs=range(9), xtick_labels=freq_sweep_pulses, ytick_locs=[0, 1, 2], ytick_labels=[0, 1, 2],)
ax2e_bottom.set_xlabel('Pulse Index')
ax2e_bottom.set_ylabel('Norm. Response')
ax2e_bottom.legend([], frameon=False)

# -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
# Fig 2F: STP vs E/I
ax2f_top.text(-0.1, 1.1, 'F', transform=ax2f_top.transAxes, size=20, weight='bold')

f = 20  # Hz
s = 5  # squares
to_plot = [f'pcn{i}' for i in freq_sweep_pulses]
df_temp = df3[ (df3['location'] == 'CA1') & (df3['clampMode'] == 'VC') & (df3['cellID'] == cell) & (df3['stimFreq'] == f) & (df3['numSq'] == s)]
df_melt = pd.melt( df_temp, id_vars=['cellID', 'stimFreq', 'clampPotential', 'numSq', 'patternList'], value_vars=to_plot, var_name='pulseIndex', value_name='peak_response',)

pairwise_draw_and_annotate_line_plot(   ax2f_top, df_melt, x='pulseIndex', y='peak_response', hue='clampPotential', draw=True, kind='violin', palette=color_EI,
                                        stat_across='hue', stat=stats.kruskal, skip_first_xvalue=True, annotate_wrt_data=False, offset_btw_star_n_line=0.1, color='grey', coord_system='data', fontsize=12, zorder=10,)

plot_tools.simplify_axes( ax2f_top, splines_to_keep=['bottom', 'left'], axis_offset=10, remove_ticks=False, xtick_locs=range(9), xtick_labels=freq_sweep_pulses, ytick_locs=[0, 1, 2], ytick_labels=[0, 1, 2],)
ax2f_top.set_xlabel('Pulse Index')
ax2f_top.set_ylabel('Norm. Response')

# get the lgened labels of ax2f_top and add ' mV' to each one
handles, labels = ax2f_top.get_legend_handles_labels()
labels = [label + ' mV' for label in labels]
ax2f_top.legend(handles, labels, loc='upper right', borderaxespad=0., frameon=False)

# add a text in the top left corner of ax2d_top showing '20 Hz' in color = color_freq[20]
ax2f_top.text(0.0, 0.9, f'{f} Hz', transform=ax2f_top.transAxes, size=12, color=color_freq[f], zorder=10)
ax2f_top.text(0.2, 0.9, f'{s} Sq', transform=ax2f_top.transAxes, size=12, color=color_squares[s], zorder=10)


### Fig 2F Bottom: Field plot in the bottom
to_plot = [f'pfn{i}' for i in range(9)]
df_temp = df3[ (df3['location'] == 'CA1') & (df3['clampMode'] == 'VC') & (df3['cellID'] == cell) & (df3['stimFreq'] == f) & (df3['numSq'] == s)]
df_melt = pd.melt( df_temp, id_vars=['cellID', 'stimFreq', 'clampPotential', 'numSq', 'patternList'], value_vars=to_plot, var_name='pulseIndex', value_name='peak_response',)

pairwise_draw_and_annotate_line_plot(   ax2f_bottom, df_melt, x='pulseIndex', y='peak_response', hue='clampPotential', draw=True, kind='violin', palette=color_EI,
                                        stat_across='hue', stat=stats.kruskal, skip_first_xvalue=True, annotate_wrt_data=False, offset_btw_star_n_line=0.1, color='grey', coord_system='data', fontsize=12, zorder=10)
plot_tools.simplify_axes(ax2f_bottom, splines_to_keep=['bottom','left'], axis_offset=10, remove_ticks=False, xtick_locs=range(9), xtick_labels=freq_sweep_pulses, ytick_locs=[0,1,2], ytick_labels=[0,1,2])
ax2f_bottom.set_xlabel('Pulse Index')
ax2f_bottom.set_ylabel('Norm. Response')
ax2f_bottom.legend([], frameon=False)

# -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
# # save figure
figure_name = 'Figure2_VC'
fig2.savefig(paper_figure_export_location / (figure_name + '.png'), dpi=300, bbox_inches='tight')
fig2.savefig(paper_figure_export_location / (figure_name + '.svg'), dpi=300, bbox_inches='tight')

In [None]:
df_temp = df3[(df3['location']=='CA1') & (df3['clampMode']=='VC') & (df3['cellID']==cell) & (df3['stimFreq']==f) & (df3['numSq']==s) ] # 
df_melt = pd.melt(df_temp, id_vars=['cellID', 'stimFreq', 'clampPotential', 'numSq', 'patternList'], value_vars=to_plot, var_name='pulseIndex', value_name='peak_response')

In [None]:
def pairwise_draw_and_annotate_line_plot(ax, df, x='', y='', hue='', draw=True, kind='violin', palette='viridis', stat_across='hue', stat=stats.kruskal, skip_first_xvalue=True, annotate_wrt_data=False, offset_btw_star_n_line=0.1, color='grey', coord_system='data', fontsize=12, zorder=10):
    ''' This function takes a dataframe, and makes pairwise comparisons between the groups in the hue column
    for each x value. The function then annotates the line plot with the p-values of the comparisons.'''

    if draw:
        # draw the plots
        if kind == 'violin':
            sns.violinplot(data=df_melt, x=x, y=y, hue=hue, palette=palette, ax=ax, alpha=0.8, split=True, inner='quartile', linewidth=1)
        elif kind == 'strip':
            sns.stripplot(data=df_melt, x=x, y=y, hue=hue, palette=palette, ax=ax, alpha=0.8, dodge=1,)
        elif kind == 'line':
            sns.lineplot(data=df_melt, x=x, y=y, hue=hue, palette=palette, ax=ax, alpha=0.5, errorbar=('sd', 1), err_style='bars', linewidth=3,err_kws={"elinewidth": 3, 'capsize':5})


    hue_values = df[hue].unique() # group labels for each x-axis categorical value
    x_values = df[x].unique() # x-axis categorical value labels

    # get the xticks and xticklabels
    xticks = ax.get_xticks()
    xticklabels = ax.get_xticklabels()

    # get the max value of data across all x and all hue groups
    max_ydata = df[y].max()
    # set ypos to be 0.9*ylim
    ypos = 0.9*ax.get_ylim()[1]


    # for each x-value, get the ygroup values for hue1 and hue2
    for ix, x_val in enumerate(x_values):
        if skip_first_xvalue:
            if ix==0:
                continue
                    
        group_data = df_melt[(df_melt[x]==x_val)].groupby(hue)[y].apply(list)
        # convert all the group data into a list of lists
        group_data = group_data.values.tolist()
        kruskal_statistic, kruskal_pval = stats.kruskal(*group_data)

        # get the location of x_val on the x-axis of ax
        # get x-ticks and x-tick-labels
        xpos = xticks[ix]

        # get the maximum value of y for the given x_val across all the groups, add the offset to get the ypos for annotation
        if annotate_wrt_data:
            ypos = 1.1* np.max(group_data)

        # convert xpos and ypos into axes coordinate system if coord_system=='axes'
        if coord_system=='axes':
            xpos = ax.transAxes.inverted().transform(ax.transData.transform([xpos, ypos]))[0]
            ypos = ax.transAxes.inverted().transform(ax.transData.transform([xpos, ypos]))[1]
        

        
        annotate_stat_stars(ax, kruskal_pval, star_loc=[xpos, ypos], add_line=False, color=color, coord_system=coord_system, fontsize=12, zorder=10)

        # print(ix, x_val, kruskal_statistic, kruskal_pval, xpos, ypos)
    



In [None]:
figtemp, ax2ftemp = plt.subplots(1,1, figsize=(6,4), sharey=False)
ax2ftemp.text(-0.1, 1.1, 'F', transform=ax2ftemp.transAxes, size=20, weight='bold')
to_plot = [ f'pcn{i}' for i in freq_sweep_pulses ]
f = 20 #Hz
s = 5 # squares
df_temp = df3[(df3['location']=='CA1') & (df3['clampMode']=='VC') & (df3['cellID']==cell) & (df3['stimFreq']==f) & (df3['numSq']==s) ] # 
# df_temp = df_temp[ df_temp['cellID'].isin(vc_cells_screened) ]

# melt the dataframe
df_melt = pd.melt(df_temp, id_vars=['cellID', 'stimFreq', 'clampPotential', 'numSq', 'patternList'], value_vars=to_plot, var_name='pulseIndex', value_name='peak_response')
# sns.stripplot(data=df_melt, x='pulseIndex', y='peak_response', hue='clampPotential', palette=color_EI, ax=ax2ftemp, alpha=0.8, dodge=1,)
# sns.violinplot(data=df_melt, x='pulseIndex', y='peak_response', hue='clampPotential', palette=color_EI, ax=ax2ftemp, alpha=0.8, split=True, inner='quartile', linewidth=1)
# sns.lineplot(data=df_melt, x='pulseIndex', y='peak_response', hue='clampPotential', palette=color_EI, ax=ax2ftemp, alpha=0.5, errorbar=('sd', 1), err_style='bars', linewidth=3,err_kws={"elinewidth": 3, 'capsize':5})

plot_tools.simplify_axes(ax2ftemp, splines_to_keep=['bottom','left'], axis_offset=10, remove_ticks=False, xtick_locs=range(9), xtick_labels=freq_sweep_pulses, ytick_locs=[0,1,2], ytick_labels=[0,1,2])
ax2ftemp.set_xlabel('Pulse Index')
ax2ftemp.set_ylabel('Norm. Response')
ax2ftemp.set_xlim([-0.5,8.5])

pairwise_annotate_line_plot(ax2ftemp, df_melt, x='pulseIndex', y='peak_response', hue='clampPotential', stat_across='hue', stat=stats.kruskal, offset_btw_star_n_line=0.1, color='grey', coord_system='data', fontsize=12, zorder=10)

# reset the legend to best position
ax2ftemp.legend(loc='lower left', borderaxespad=0., frameon=False)

# filename = 'Figure2_trendplots_templatetest_lineplot'
# figtemp.savefig(paper_figure_export_location / (filename + '.png'), dpi=300, bbox_inches='tight')
# figtemp.savefig(paper_figure_export_location / (filename + '.svg'), dpi=300, bbox_inches='tight')

In [None]:
# close all figures
plt.close('all')

In [None]:
df_temp = df3[(df3['location']=='CA1') & (df3['clampMode']=='CC')  & (df3['stimFreq']==f) & (df3['spikingFlag']==False) & (np.abs(df3['pcn3'])<10)& (df3['numSq']==1)]
df_temp.shape

In [None]:
# Fig 2E: STP vs Frequency

# melt the dataframe
df_temp = df3[(df3['location']=='CA1') & (df3['clampMode']=='CC') & (df3['condition']=='Control') & (df3['spikingFlag']==False) & (df3['numSq']>1) & (df3['stimFreq']<100)]
df_melt = pd.melt(df_temp, id_vars=['cellID', 'stimFreq', 'clampPotential', 'numSq', 'patternList'], value_vars=to_plot, var_name='pulseIndex', value_name='peak_response')

figx = sns.relplot(data=df_melt, x='pulseIndex', y='peak_response', hue='stimFreq', style='numSq', col='cellID', col_wrap=3, palette=color_freq, kind='line', alpha=1, linewidth=2, err_style="bars", errorbar=("se", 1), )
for ax in figx.axes.flat:
    ax.set(xticks=range(9), xticklabels=freq_sweep_pulses, ylim=(-2, 10), yticks=range(-2,2), yticklabels=range(-2,2))
    ax.set_xlabel('Pulse Index')
    ax.set_ylabel('Peak Response (mV)')
    ax.legend([],[], frameon=False)

# save fig
figure_name = 'Figure2E_supp'
figx.savefig(paper_figure_export_location / (figure_name + '.png'), dpi=300, bbox_inches='tight')

### Figure 3:
* Model schematic for plasticity
* Fitting free params of model for E and I
* Fitting diversity for E and I
* Fitting stochasticity for E and I
* How averaged are E and I respectively? Crucial for escape calculation.

In [None]:
w, h = [15, 9]

fig3, [[ax3a, ax3b, ax3c],[ax3d, ax3e, ax3f]] = plt.subplots(2,3, figsize=(w, h), sharey=False)
# have more space between suplots
fig3.subplots_adjust(hspace=0.5, wspace=0.5)

# -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
# Fig 3A: Model schematic for plasticity
ax3a.text(-0.1, 1.1, 'A', transform=ax3a.transAxes, size=20, weight='bold')
image3_path = all_cells.project_path_root / r"Lab\Projects\EI_Dynamics\Analysis\paper_figure_matter\model_schematic.png"
im3 = Image.open(image3_path)
# get the size of the image
# im3_width, im3_height = im3.size
# get the aspect ratio of the ax1a axes object
# ax3a_ratio = ax3a.get_window_extent().width / ax3a.get_window_extent().height
# change the axis limits so that verticle side of the image fits the axis and horizontal side is aligned left on the axis
# ax3a.set_ylim(0, im3_height)
# ax3a.set_xlim(0, im3_height*ax3a_ratio)
# plot the image
ax3a.imshow(im3)#, extent=[0, im3_width, 0, im3_height], aspect=1)
ax3a.axis('off')

# -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
# Fig 3B: Model schematic for plasticity
# make the ax3b aspect ratio same as ax3a
ax3b.set_aspect(ax3a.get_aspect())

ax3b.text(-0.1, 1.1, 'B', transform=ax3b.transAxes, size=20, weight='bold')
image3_path = all_cells.project_path_root / r"Lab\Projects\EI_Dynamics\Analysis\paper_figure_matter\model73.png"
im3 = Image.open(image3_path)
# get the size of the image
# im3_width, im3_height = im3.size
# get the aspect ratio of the ax1a axes object
# ax3b_ratio = ax3b.get_window_extent().width / ax3b.get_window_extent().height
# change the axis limits so that verticle side of the image fits the axis and horizontal side is aligned left on the axis
# ax3b.set_ylim(0, im3_height)
# ax3b.set_xlim(0, im3_height*ax3a_ratio)
# plot the image
ax3b.imshow(im3)#, extent=[0, im3_width, 0, im3_height], aspect=1)
ax3b.axis('off')


# -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
# # save figure
figure_name = 'Figure3'
fig3.savefig(paper_figure_export_location / (figure_name + '.png'), dpi=300, bbox_inches='tight')
fig3.savefig(paper_figure_export_location / (figure_name + '.svg'), dpi=300, bbox_inches='tight')


### Figure 4: Upi is making

### Figure 5
* Example E-I traces over a burst (train)
* E/I ratio of successive pulses across freq and numSq
* Summary figure across cells

__Note: This figure should explain why we dont see escape from E-I balance__


In [None]:
importlib.reload(plot_tools)

### Fig 5 for 5sq patterns

In [None]:
# aspect ration of the figure = 1
w, h = [15, 9]

fig5, [[ax5a, ax5b],[ax5c,ax5d],[ax5e, ax5f]] = plt.subplots(3,2, figsize=(w, h), sharey=False)
# have more space between suplots
fig5.subplots_adjust(hspace=0.5, wspace=0.5)

# -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
# Fig 5A: Example E-I traces over a train
ax5a.text(-0.1, 1.1, 'A', transform=ax5a.transAxes, size=20, weight='bold')
any_good_vc_cell = vc_cells_screened[0]
df_temp = df4[(df4['cellID'] == any_good_vc_cell) & (df4['stimFreq'] == 20) & (df4['numSq'] == 5) & (df4['channel'] == 'cell')]
num_traces, num_metadata_columns = int(df_temp.shape[0]/2), df_temp.shape[1]-20000
E = df_temp[df_temp['clampPotential']==-70].iloc[:,num_metadata_columns:].to_numpy()
I = df_temp[df_temp['clampPotential']==  0].iloc[:,num_metadata_columns:].to_numpy()
for i in range(num_traces):
    ax5a.plot(np.linspace(0,1,20000), I[i,:], alpha=0.2, linewidth=1, color=color_EI[0], )
    ax5a.plot(np.linspace(0,1,20000), E[i,:], alpha=0.2, linewidth=1, color=color_EI[-70], )
ax5a.plot(np.linspace(0,1,20000), np.mean(I, axis=0), linewidth=1, color=color_EI[0], )
ax5a.plot(np.linspace(0,1,20000), np.mean(E, axis=0), linewidth=1, color=color_EI[-70], )

# simplify the axis
plot_tools.simplify_axes(ax5a, splines_to_keep=['bottom','left'], axis_offset=10, remove_ticks=False, xtick_locs=np.round(np.linspace(0.1,0.5,5),1), xtick_labels=np.round(np.linspace(0.1,0.5,5),1), ytick_locs=np.linspace(-500,1000,4), ytick_labels=np.linspace(-500,1000,4))
# add floating scalebar to the plots if splines are not retained
# plot_tools.add_floating_scalebar(ax5a, scalebar_origin=[0.3,600], xlength=0.1, ylength=100, labelx='100', labely='100', unitx='ms', unity='pA', fontsize=8, color=color_EI[0], linewidth=2, pad=0.1, show_labels=False)
# plot_tools.add_floating_scalebar(ax5a, scalebar_origin=[0.3,-200], xlength=0.1, ylength=100, labelx='100', labely='100', unitx='ms', unity='pA', fontsize=8, color=color_EI[-70], linewidth=2, pad=0.1, show_labels=False)

# add an insert subplot in ax5a using matplotlib
ax5a_inset = ax5a.inset_axes([0.9,0.5,0.1,0.5]) # left, bottom, width, height
t1,t2 = 0.205, 0.225
Fs = 2e4
t = np.linspace(t1,t2, int((t2-t1)*Fs))
# for i in range(num_traces):
#     ax5a_inset.plot(t, I[i,int(t1*Fs):int(t1*Fs)+len(t)], alpha=0.2, linewidth=1, color=color_EI[0], )
#     ax5a_inset.plot(t, E[i,int(t1*Fs):int(t1*Fs)+len(t)], alpha=0.2, linewidth=1, color=color_EI[-70], )
ax5a_inset.plot(t, np.mean(I[:,int(t1*Fs):int(t1*Fs)+len(t)], axis=0), linewidth=2, color=color_EI[0], )
ax5a_inset.plot(t, np.mean(E[:,int(t1*Fs):int(t1*Fs)+len(t)], axis=0), linewidth=2, color=color_EI[-70], )
plot_tools.simplify_axes(ax5a_inset, splines_to_keep=['left'], axis_offset=10, remove_ticks=False, xtick_locs=[t1,t2], xtick_labels=[t1,t2], ytick_locs=np.linspace(-250,750,5), ytick_labels=np.linspace(-250,750,5))

# -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
# Fig 5B: Whole train
ax5b.text(-0.1, 1.1, 'B', transform=ax5b.transAxes, size=20, weight='bold')

for i in range(num_traces):
    ax5b.plot(np.linspace(0,1,20000), I[i,:], alpha=0.2, linewidth=1, color=color_EI[0], )
    ax5b.plot(np.linspace(0,1,20000), E[i,:], alpha=0.2, linewidth=1, color=color_EI[-70], )
ax5b.plot(np.linspace(0,1,20000), np.mean(I, axis=0), linewidth=1, color=color_EI[0], )
ax5b.plot(np.linspace(0,1,20000), np.mean(E, axis=0), linewidth=1, color=color_EI[-70], )

# plot_tools.add_floating_scalebar(ax5b, scalebar_origin=[0.3,600], xlength=0.1, ylength=100, labelx='100', labely='100', 
                                # unitx='ms', unity='pA', fontsize=8, color=color_EI[0], linewidth=2, pad=0.1, show_labels=False)
# plot_tools.add_floating_scalebar(ax5b, scalebar_origin=[0.3,-200], xlength=0.1, ylength=100, labelx='100', labely='100', 
                                # unitx='ms', unity='pA', fontsize=8, color=color_EI[-70], linewidth=2, pad=0.1, show_labels=False)
# simplify the axis
plot_tools.simplify_axes(ax5b, splines_to_keep=['bottom','left'], axis_offset=10, remove_ticks=False, xtick_locs=np.round(np.linspace(0,1.0,5),1), xtick_labels=np.round(np.linspace(0,1.0,5),1), ytick_locs=np.linspace(-500,1000,4), ytick_labels=np.linspace(-500,1000,4))

# -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
# Fig 5C: STP vs Frequency
ax5c.text(-0.1, 1.1, 'C', transform=ax5c.transAxes, size=20, weight='bold')
to_plot = [ f'pcn{i}' for i in range(9) ]

# # melt the dataframe
df_temp = df3[(df3['location']=='CA1') & (df3['clampMode']=='VC') & (df3['numSq']>1) & (np.abs(df3['pcn2'])<500)& (df3['cellID'].isin(vc_cells_screened))]
df_melt = pd.melt(df_temp[(df_temp['stimFreq']<100)& (df_temp['clampPotential']==0) & (df_temp['numSq']==5)], id_vars=['cellID', 'stimFreq', 'clampPotential', 'numSq', 'patternList'], value_vars=to_plot, var_name='pulseIndex', value_name='peak_response')
sns.lineplot(data=df_melt, x='pulseIndex', y='peak_response', hue='stimFreq', palette=color_I, ax=ax5c, alpha=1, linewidth=2, err_style="bars", errorbar=("se", 1),)

plot_tools.simplify_axes(ax5c, splines_to_keep=['bottom','left'], axis_offset=10, remove_ticks=False, xtick_locs=range(9), xtick_labels=freq_sweep_pulses, ytick_locs=np.linspace(0,2,5), ytick_labels=np.linspace(0,2,5))
ax5c.set_xlabel('Pulse Index')
ax5c.set_ylabel('Peak Response (mV)')

# -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
# Fig 5D: STP vs Frequency
ax5d.text(-0.1, 1.1, 'D', transform=ax5d.transAxes, size=20, weight='bold')
to_plot = [ f'pcn{i}' for i in range(9) ]

# # melt the dataframe
df_temp = df3[(df3['location']=='CA1') & (df3['clampMode']=='VC') & (df3['numSq']>1) & (np.abs(df3['pcn2'])<500) & (df3['cellID'].isin(vc_cells_screened))]

df_melt = pd.melt(df_temp[(df_temp['stimFreq']<100)& (df_temp['clampPotential']==-70) & (df_temp['numSq']==5)], id_vars=['cellID', 'stimFreq', 'clampPotential', 'numSq', 'patternList'], value_vars=to_plot, var_name='pulseIndex', value_name='peak_response')
sns.lineplot(data=df_melt, x='pulseIndex', y='peak_response', hue='stimFreq', palette=color_E, ax=ax5d, alpha=1, linewidth=2, err_style="bars", errorbar=("se", 1),)

plot_tools.simplify_axes(ax5d, splines_to_keep=['bottom','left'], axis_offset=10, remove_ticks=False, xtick_locs=range(9), xtick_labels=freq_sweep_pulses, ytick_locs=np.linspace(0,2,5), ytick_labels=np.linspace(0,2,5))
ax5d.set_xlabel('Pulse Index')
ax5d.set_ylabel('Peak Response (mV)')


# -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
# Fig 5E: STP vs Frequency for field data
ax5e.text(-0.1, 1.1, 'E', transform=ax5e.transAxes, size=20, weight='bold')
to_plot = [ f'pfn{i}' for i in range(9) ]

# # melt the dataframe
df_temp = df3[(df3['location']=='CA1') & (df3['clampMode']=='VC') & (df3['numSq']>1) & (np.abs(df3['pfn2'])<500)& (df3['cellID'].isin(vc_cells_screened))]

df_melt = pd.melt(df_temp[(df_temp['stimFreq']<100)& (df_temp['clampPotential']==0) & (df_temp['numSq']==5)], id_vars=['cellID', 'stimFreq', 'clampPotential', 'numSq', 'patternList'], value_vars=to_plot, var_name='pulseIndex', value_name='peak_response')
sns.lineplot(data=df_melt, x='pulseIndex', y='peak_response', hue='stimFreq', palette=color_I, ax=ax5e, alpha=1, linewidth=2, err_style="bars", errorbar=("se", 1),)

# ax5e.set_xticks(range(9), labels=freq_sweep_pulses)
ax5e.set_xlabel('Pulse Index')
ax5e.set_ylabel('Peak Response (mV)')
plot_tools.simplify_axes(ax5e, splines_to_keep=['bottom','left'], axis_offset=10, remove_ticks=False, xtick_locs=range(9), xtick_labels=freq_sweep_pulses, ytick_locs=np.linspace(0,2,5), ytick_labels=np.linspace(0,2,5))

# -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
# Fig 5F: STP vs Frequency for field data
ax5f.text(-0.1, 1.1, 'F', transform=ax5f.transAxes, size=20, weight='bold')
to_plot = [ f'pfn{i}' for i in range(9) ]

# # melt the dataframe
df_temp = df3[(df3['location']=='CA1') & (df3['clampMode']=='VC') & (df3['numSq']>1) & (np.abs(df3['pfn2'])<500) & (df3['cellID'].isin(vc_cells_screened))]

df_melt = pd.melt(df_temp[(df_temp['stimFreq']<100)& (df_temp['clampPotential']==-70) & (df_temp['numSq']==5)], id_vars=['cellID', 'stimFreq', 'clampPotential', 'numSq', 'patternList'], value_vars=to_plot, var_name='pulseIndex', value_name='peak_response')
sns.lineplot(data=df_melt, x='pulseIndex', y='peak_response', hue='stimFreq', palette=color_E, ax=ax5f, alpha=1, linewidth=2, err_style="bars", errorbar=("se", 1),)


ax5f.set_xlabel('Pulse Index')
ax5f.set_ylabel('Peak Response (mV)')
plot_tools.simplify_axes(ax5f, splines_to_keep=['bottom','left'], axis_offset=10, remove_ticks=False, xtick_locs=range(9), xtick_labels=freq_sweep_pulses, ytick_locs=np.linspace(0,2,5), ytick_labels=np.linspace(0,2,5))

# -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
# save figure
figure_name = 'Figure5_5sq'
fig5.savefig(paper_figure_export_location / (figure_name + '.png'), dpi=300, bbox_inches='tight')
fig5.savefig(paper_figure_export_location / (figure_name + '.svg'), dpi=300, bbox_inches='tight')

#### Fig 5 for 15sq patterns

In [None]:
# aspect ration of the figure = 1
w, h = [15, 9]

fig5, [[ax5a, ax5b],[ax5c,ax5d],[ax5e, ax5f]] = plt.subplots(3,2, figsize=(w, h), sharey=False)
# have more space between suplots
fig5.subplots_adjust(hspace=0.5, wspace=0.5)

# -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
# Fig 5A: Example E-I traces over a train
ax5a.text(-0.1, 1.1, 'A', transform=ax5a.transAxes, size=20, weight='bold')
any_good_vc_cell = vc_cells_screened[0]
df_temp = df4[(df4['cellID'] == any_good_vc_cell) & (df4['stimFreq'] == 20) & (df4['numSq'] == 15) & (df4['channel'] == 'cell')]
num_traces, num_metadata_columns = int(df_temp.shape[0]/2), df_temp.shape[1]-20000
E = df_temp[df_temp['clampPotential']==-70].iloc[:,num_metadata_columns:].to_numpy()
I = df_temp[df_temp['clampPotential']==  0].iloc[:,num_metadata_columns:].to_numpy()
for i in range(num_traces):
    ax5a.plot(np.linspace(0,1,20000), I[i,:], alpha=0.2, linewidth=1, color=color_EI[0], )
    ax5a.plot(np.linspace(0,1,20000), E[i,:], alpha=0.2, linewidth=1, color=color_EI[-70], )
ax5a.plot(np.linspace(0,1,20000), np.mean(I, axis=0), linewidth=1, color=color_EI[0], )
ax5a.plot(np.linspace(0,1,20000), np.mean(E, axis=0), linewidth=1, color=color_EI[-70], )

# simplify the axis
plot_tools.simplify_axes(ax5a, splines_to_keep=['bottom','left'], axis_offset=10, remove_ticks=False, xtick_locs=np.round(np.linspace(0.1,0.5,5),1), xtick_labels=np.round(np.linspace(0.1,0.5,5),1), ytick_locs=np.linspace(-500,1000,4), ytick_labels=np.linspace(-500,1000,4))
# add floating scalebar to the plots if splines are not retained
# plot_tools.add_floating_scalebar(ax5a, scalebar_origin=[0.3,600], xlength=0.1, ylength=100, labelx='100', labely='100', unitx='ms', unity='pA', fontsize=8, color=color_EI[0], linewidth=2, pad=0.1, show_labels=False)
# plot_tools.add_floating_scalebar(ax5a, scalebar_origin=[0.3,-200], xlength=0.1, ylength=100, labelx='100', labely='100', unitx='ms', unity='pA', fontsize=8, color=color_EI[-70], linewidth=2, pad=0.1, show_labels=False)

# add an insert subplot in ax5a using matplotlib
ax5a_inset = ax5a.inset_axes([0.9,0.5,0.1,0.5]) # left, bottom, width, height
t1,t2 = 0.205, 0.225
Fs = 2e4
t = np.linspace(t1,t2, int((t2-t1)*Fs))
# for i in range(num_traces):
#     ax5a_inset.plot(t, I[i,int(t1*Fs):int(t1*Fs)+len(t)], alpha=0.2, linewidth=1, color=color_EI[0], )
#     ax5a_inset.plot(t, E[i,int(t1*Fs):int(t1*Fs)+len(t)], alpha=0.2, linewidth=1, color=color_EI[-70], )
ax5a_inset.plot(t, np.mean(I[:,int(t1*Fs):int(t1*Fs)+len(t)], axis=0), linewidth=2, color=color_EI[0], )
ax5a_inset.plot(t, np.mean(E[:,int(t1*Fs):int(t1*Fs)+len(t)], axis=0), linewidth=2, color=color_EI[-70], )
plot_tools.simplify_axes(ax5a_inset, splines_to_keep=['left'], axis_offset=10, remove_ticks=False, xtick_locs=[t1,t2], xtick_labels=[t1,t2], ytick_locs=np.linspace(-250,750,5), ytick_labels=np.linspace(-250,750,5))

# -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
# Fig 5B: Whole train
ax5b.text(-0.1, 1.1, 'B', transform=ax5b.transAxes, size=20, weight='bold')

for i in range(num_traces):
    ax5b.plot(np.linspace(0,1,20000), I[i,:], alpha=0.2, linewidth=1, color=color_EI[0], )
    ax5b.plot(np.linspace(0,1,20000), E[i,:], alpha=0.2, linewidth=1, color=color_EI[-70], )
ax5b.plot(np.linspace(0,1,20000), np.mean(I, axis=0), linewidth=1, color=color_EI[0], )
ax5b.plot(np.linspace(0,1,20000), np.mean(E, axis=0), linewidth=1, color=color_EI[-70], )

# plot_tools.add_floating_scalebar(ax5b, scalebar_origin=[0.3,600], xlength=0.1, ylength=100, labelx='100', labely='100', 
                                # unitx='ms', unity='pA', fontsize=8, color=color_EI[0], linewidth=2, pad=0.1, show_labels=False)
# plot_tools.add_floating_scalebar(ax5b, scalebar_origin=[0.3,-200], xlength=0.1, ylength=100, labelx='100', labely='100', 
                                # unitx='ms', unity='pA', fontsize=8, color=color_EI[-70], linewidth=2, pad=0.1, show_labels=False)
# simplify the axis
plot_tools.simplify_axes(ax5b, splines_to_keep=['bottom','left'], axis_offset=10, remove_ticks=False, xtick_locs=np.round(np.linspace(0,1.0,5),1), xtick_labels=np.round(np.linspace(0,1.0,5),1), ytick_locs=np.linspace(-500,1000,4), ytick_labels=np.linspace(-500,1000,4))

# -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
# Fig 5C: STP vs Frequency
ax5c.text(-0.1, 1.1, 'C', transform=ax5c.transAxes, size=20, weight='bold')
to_plot = [ f'pcn{i}' for i in range(9) ]

# # melt the dataframe
df_temp = df3[(df3['location']=='CA1') & (df3['clampMode']=='VC') & (df3['numSq']>1) & (np.abs(df3['pcn2'])<500)& (df3['cellID'].isin(vc_cells_screened))]
df_melt = pd.melt(df_temp[(df_temp['stimFreq']<100)& (df_temp['clampPotential']==0) & (df_temp['numSq']==15)], id_vars=['cellID', 'stimFreq', 'clampPotential', 'numSq', 'patternList'], value_vars=to_plot, var_name='pulseIndex', value_name='peak_response')
sns.lineplot(data=df_melt, x='pulseIndex', y='peak_response', hue='stimFreq', palette=color_I, ax=ax5c, alpha=1, linewidth=2, err_style="bars", errorbar=("se", 1),)

plot_tools.simplify_axes(ax5c, splines_to_keep=['bottom','left'], axis_offset=10, remove_ticks=False, xtick_locs=range(9), xtick_labels=freq_sweep_pulses, ytick_locs=np.linspace(0,2,5), ytick_labels=np.linspace(0,2,5))
ax5c.set_xlabel('Pulse Index')
ax5c.set_ylabel('Peak Response (mV)')

# -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
# Fig 5D: STP vs Frequency
ax5d.text(-0.1, 1.1, 'D', transform=ax5d.transAxes, size=20, weight='bold')
to_plot = [ f'pcn{i}' for i in range(9) ]

# # melt the dataframe
df_temp = df3[(df3['location']=='CA1') & (df3['clampMode']=='VC') & (df3['numSq']>1) & (np.abs(df3['pcn2'])<500) & (df3['cellID'].isin(vc_cells_screened))]

df_melt = pd.melt(df_temp[(df_temp['stimFreq']<100)& (df_temp['clampPotential']==-70) & (df_temp['numSq']==15)], id_vars=['cellID', 'stimFreq', 'clampPotential', 'numSq', 'patternList'], value_vars=to_plot, var_name='pulseIndex', value_name='peak_response')
sns.lineplot(data=df_melt, x='pulseIndex', y='peak_response', hue='stimFreq', palette=color_E, ax=ax5d, alpha=1, linewidth=2, err_style="bars", errorbar=("se", 1),)

plot_tools.simplify_axes(ax5d, splines_to_keep=['bottom','left'], axis_offset=10, remove_ticks=False, xtick_locs=range(9), xtick_labels=freq_sweep_pulses, ytick_locs=np.linspace(0,2,5), ytick_labels=np.linspace(0,2,5))
ax5d.set_xlabel('Pulse Index')
ax5d.set_ylabel('Peak Response (mV)')


# -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
# Fig 5E: STP vs Frequency for field data
ax5e.text(-0.1, 1.1, 'E', transform=ax5e.transAxes, size=20, weight='bold')
to_plot = [ f'pfn{i}' for i in range(9) ]

# # melt the dataframe
df_temp = df3[(df3['location']=='CA1') & (df3['clampMode']=='VC') & (df3['numSq']>1) & (np.abs(df3['pfn2'])<500)& (df3['cellID'].isin(vc_cells_screened))]

df_melt = pd.melt(df_temp[(df_temp['stimFreq']<100)& (df_temp['clampPotential']==0) & (df_temp['numSq']==15)], id_vars=['cellID', 'stimFreq', 'clampPotential', 'numSq', 'patternList'], value_vars=to_plot, var_name='pulseIndex', value_name='peak_response')
sns.lineplot(data=df_melt, x='pulseIndex', y='peak_response', hue='stimFreq', palette=color_I, ax=ax5e, alpha=1, linewidth=2, err_style="bars", errorbar=("se", 1),)

# ax5e.set_xticks(range(9), labels=freq_sweep_pulses)
ax5e.set_xlabel('Pulse Index')
ax5e.set_ylabel('Peak Response (mV)')
plot_tools.simplify_axes(ax5e, splines_to_keep=['bottom','left'], axis_offset=10, remove_ticks=False, xtick_locs=range(9), xtick_labels=freq_sweep_pulses, ytick_locs=np.linspace(0,2,5), ytick_labels=np.linspace(0,2,5))

# -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
# Fig 5F: STP vs Frequency for field data
ax5f.text(-0.1, 1.1, 'F', transform=ax5f.transAxes, size=20, weight='bold')
to_plot = [ f'pfn{i}' for i in range(9) ]

# # melt the dataframe
df_temp = df3[(df3['location']=='CA1') & (df3['clampMode']=='VC') & (df3['numSq']>1) & (np.abs(df3['pfn2'])<500) & (df3['cellID'].isin(vc_cells_screened))]

df_melt = pd.melt(df_temp[(df_temp['stimFreq']<100)& (df_temp['clampPotential']==-70) & (df_temp['numSq']==15)], id_vars=['cellID', 'stimFreq', 'clampPotential', 'numSq', 'patternList'], value_vars=to_plot, var_name='pulseIndex', value_name='peak_response')
sns.lineplot(data=df_melt, x='pulseIndex', y='peak_response', hue='stimFreq', palette=color_E, ax=ax5f, alpha=1, linewidth=2, err_style="bars", errorbar=("se", 1),)


ax5f.set_xlabel('Pulse Index')
ax5f.set_ylabel('Peak Response (mV)')
plot_tools.simplify_axes(ax5f, splines_to_keep=['bottom','left'], axis_offset=10, remove_ticks=False, xtick_locs=range(9), xtick_labels=freq_sweep_pulses, ytick_locs=np.linspace(0,2,5), ytick_labels=np.linspace(0,2,5))

# -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
# save figure
figure_name = 'Figure5_15sq'
fig5.savefig(paper_figure_export_location / (figure_name + '.png'), dpi=300, bbox_inches='tight')
fig5.savefig(paper_figure_export_location / (figure_name + '.svg'), dpi=300, bbox_inches='tight')

### Fig 5 supplementary: All screened VC cells

In [None]:
fig5B, axs5B = plt.subplots(3,3, figsize=(w, h), sharex=True, sharey=True)
# have more space between suplots
fig5B.subplots_adjust(hspace=0.5, wspace=0.5)
#linearize the axs
axs5B = axs5B.flatten()
to_plot = [ f'pcn{i}' for i in range(9) ]
for i,selected_cell in enumerate(vc_cells_screened):
    ax = axs5B[i]
    # # melt the dataframe
    df_temp = df3[(df3['clampMode']=='VC') & (df3['cellID']==selected_cell)]

    df_melt = pd.melt(df_temp[(df_temp['stimFreq']<100)& (df_temp['clampPotential']==-70) & (df_temp['numSq']==5)], id_vars=['cellID', 'stimFreq', 'clampPotential', 'numSq', 'patternList'], value_vars=to_plot, var_name='pulseIndex', value_name='peak_response')
    sns.lineplot(data=df_melt, x='pulseIndex', y='peak_response', hue='stimFreq', palette=color_E, ax=ax, alpha=1, linewidth=2, err_style="bars", errorbar=("se", 1),)

    plot_tools.simplify_axes(ax, splines_to_keep=['bottom','left'], axis_offset=10, remove_ticks=False, xtick_locs=range(9), xtick_labels=freq_sweep_pulses, ytick_locs=np.linspace(0,2,5), ytick_labels=np.linspace(0,2,5))
    ax.set_xlabel('Pulse Index')
    ax.set_ylabel('Peak Response (mV)')
    ax.set_title(f'VC - {selected_cell},5sq')

# save figure
figure_name = 'Figure5_supp_5sq_E_all_screenedVCcells.png'
fig5B.savefig(paper_figure_export_location / (figure_name), dpi=300, bbox_inches='tight')

In [None]:
fig5C, axs5C = plt.subplots(3,3, figsize=(w, h), sharex=True, sharey=True)
# have more space between suplots
fig5C.subplots_adjust(hspace=0.5, wspace=0.5)
#linearize the axs
axs5C = axs5C.flatten()
to_plot = [ f'pcn{i}' for i in range(9) ]
for i,selected_cell in enumerate(vc_cells_screened):
    ax = axs5C[i]
    # # melt the dataframe
    df_temp = df3[(df3['clampMode']=='VC') & (df3['cellID']==selected_cell)]

    df_melt = pd.melt(df_temp[(df_temp['stimFreq']<100)& (df_temp['clampPotential']==-70) & (df_temp['numSq']==15)], id_vars=['cellID', 'stimFreq', 'clampPotential', 'numSq', 'patternList'], value_vars=to_plot, var_name='pulseIndex', value_name='peak_response')
    sns.lineplot(data=df_melt, x='pulseIndex', y='peak_response', hue='stimFreq', palette=color_E, ax=ax, alpha=1, linewidth=2, err_style="bars", errorbar=("se", 1),)

    plot_tools.simplify_axes(ax, splines_to_keep=['bottom','left'], axis_offset=10, remove_ticks=False, xtick_locs=range(9), xtick_labels=freq_sweep_pulses, ytick_locs=np.linspace(0,2,5), ytick_labels=np.linspace(0,2,5))
    ax.set_xlabel('Pulse Index')
    ax.set_ylabel('Peak Response (mV)')
    ax.set_title(f'VC - {selected_cell},15sq')

# save figure
figure_name = 'Figure5_supp_15sq_E_all_screenedVCcells.png'
fig5C.savefig(paper_figure_export_location / (figure_name), dpi=300, bbox_inches='tight')

In [None]:
fig5Binh, axs5Binh = plt.subplots(3,3, figsize=(w, h), sharex=True, sharey=True)
# have more space between suplots
fig5Binh.subplots_adjust(hspace=0.5, wspace=0.5)
#linearize the axs
axs5Binh = axs5Binh.flatten()
to_plot = [ f'pcn{i}' for i in range(9) ]
for i,selected_cell in enumerate(vc_cells_screened):
    ax = axs5Binh[i]
    # # melt the dataframe
    df_temp = df3[(df3['clampMode']=='VC') & (df3['cellID']==selected_cell)]

    df_melt = pd.melt(df_temp[(df_temp['stimFreq']<100)& (df_temp['clampPotential']==0) & (df_temp['numSq']==5)], id_vars=['cellID', 'stimFreq', 'clampPotential', 'numSq', 'patternList'], value_vars=to_plot, var_name='pulseIndex', value_name='peak_response')
    if df_melt.shape[0]>0:
        sns.lineplot(data=df_melt, x='pulseIndex', y='peak_response', hue='stimFreq', palette=color_I, ax=ax, alpha=1, linewidth=2, err_style="bars", errorbar=("se", 1),)
    # sns.lineplot(data=df_melt, x='pulseIndex', y='peak_response', hue='stimFreq', palette=color_I, ax=ax, alpha=1, linewidth=2, err_style="bars", errorbar=("se", 1),)

    plot_tools.simplify_axes(ax, splines_to_keep=['bottom','left'], axis_offset=10, remove_ticks=False, xtick_locs=range(9), xtick_labels=freq_sweep_pulses, ytick_locs=np.linspace(0,2,5), ytick_labels=np.linspace(0,2,5))
    ax.set_xlabel('Pulse Index')
    ax.set_ylabel('Peak Response (mV)')
    ax.set_title(f'VC - {selected_cell},5sq')

# save figure
figure_name = 'Figure5_supp_5sq_I_all_screenedVCcells.png'
fig5Binh.savefig(paper_figure_export_location / (figure_name), dpi=300, bbox_inches='tight')

In [None]:
fig5Binh, axs5Binh = plt.subplots(3,3, figsize=(w, h), sharex=True, sharey=True)
# have more space between suplots
fig5Binh.subplots_adjust(hspace=0.5, wspace=0.5)
#linearize the axs
axs5Binh = axs5Binh.flatten()
to_plot = [ f'pcn{i}' for i in range(9) ]
for i,selected_cell in enumerate(vc_cells_screened):
    ax = axs5Binh[i]
    # # melt the dataframe
    df_temp = df3[(df3['clampMode']=='VC') & (df3['cellID']==selected_cell)]

    df_melt = pd.melt(df_temp[(df_temp['stimFreq']<100)& (df_temp['clampPotential']==0) & (df_temp['numSq']==15)], id_vars=['cellID', 'stimFreq', 'clampPotential', 'numSq', 'patternList'], value_vars=to_plot, var_name='pulseIndex', value_name='peak_response')
    if df_melt.shape[0]>0:
        sns.lineplot(data=df_melt, x='pulseIndex', y='peak_response', hue='stimFreq', palette=color_I, ax=ax, alpha=1, linewidth=2, err_style="bars", errorbar=("se", 1),)
    # sns.lineplot(data=df_melt, x='pulseIndex', y='peak_response', hue='stimFreq', palette=color_I, ax=ax, alpha=1, linewidth=2, err_style="bars", errorbar=("se", 1),)

    plot_tools.simplify_axes(ax, splines_to_keep=['bottom','left'], axis_offset=10, remove_ticks=False, xtick_locs=range(9), xtick_labels=freq_sweep_pulses, ytick_locs=np.linspace(0,2,5), ytick_labels=np.linspace(0,2,5))
    ax.set_xlabel('Pulse Index')
    ax.set_ylabel('Peak Response (mV)')
    ax.set_title(f'VC - {selected_cell},15sq')

# save figure
figure_name = 'Figure5_supp_15sq_I_all_screenedVCcells.png'
fig5Binh.savefig(paper_figure_export_location / (figure_name), dpi=300, bbox_inches='tight')

### Fig 6: Why no escape?

In [None]:
# Fig 6: Show that in current clamp the cells do not fire, there is no escape from EI balance, and when inhibition is blocked by Gabazine, the cells fire
# in the pandas dataframe df3, create a new column 'spike' which stores value True if in the same row from columns 40:20049 has value greater than 30
# df['spike'] = df.iloc[:,40:20049].gt(30).any(axis=1)


cells_that_fire = df[(df['clampMode']=='CC') & (df['AP']==True)]['cellID'].unique()
all_cc_cells = df[(df['clampMode']=='CC')]['cellID'].unique()
print(len(cells_that_fire), len(all_cc_cells))

In [None]:
all_cc_cells

In [None]:
for i,selected_cell in enumerate(all_cc_cells):
    for f in [20,30,40,50]:
        for s in [1,5,15]:
            figure_name = f'raw_data_CC_{selected_cell}_numsq_{s}_freq_{f}'
            print(figure_name)
            fig_cc, ax_cc = plt.subplots(figsize=(20, 10), )
            df_temp = df[(df['clampMode']=='CC') & (df['cellID']==selected_cell) & (df['condition']=='Control') & (df['numSq']==s) & (df['stimFreq']==f)]
            if df_temp.shape[0]>0:
                _, ax_cc = plot_tools.plot_data_from_df(df_temp, data_start_column = 49, combine=False, fig=fig_cc, ax=ax_cc)
                ax_cc[0].set_title(str(selected_cell))
                # save fig
                fig_cc.savefig(paper_figure_export_location / (figure_name + '.png'), dpi=300, bbox_inches='tight')

### Fig 6 Supplementary: All gabazine cells

all gabazine cells = [3131, 5291, 3882, 3872, 3791]

screened gabazine cells = [3131, 3882, 3872]

In [None]:
# aspect ration of the figure = 1
w, h = [15, 9]

fig6_supp1, [[ax6supp1_a, ax6supp1_b],[ax6supp1_c,ax6supp1_d]] = plt.subplots(2,2, figsize=(w, h), sharey=False)
# have more space between suplots
fig6_supp1.subplots_adjust(hspace=0.5, wspace=0.5)

gabazine_cells = df[df['condition']=='Gabazine']['cellID'].unique()
print(f'Gabazine Cells: {gabazine_cells}')
selected_cell = gabazine_cells[0]

# -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
# Fig 6A: Example E-I traces over a train
ax6supp1_a.text(-0.1, 1.1, 'A', transform=ax6supp1_a.transAxes, size=20, weight='bold')

data = df[(df['condition'] == 'Gabazine') & (df['stimFreq'] == 50) & (df['numSq'] == 5) & (df['cellID'] == selected_cell)]
fig6_supp1, ax6supp1_a, signal_mapping = plot_tools.plot_data_from_df(data, data_start_column = 49 , combine=True, fig=fig6_supp1, ax=ax6supp1_a)

data = df[(df['condition'] == 'Control') & (df['stimFreq'] == 50) & (df['numSq'] == 5) & (df['cellID'] == selected_cell)]
fig6_supp1, ax6supp1_a, _ = plot_tools.plot_data_from_df(data, data_start_column = 49 , signals_to_plot=['Cell',], signal_colors=['green',], combine=True, fig=fig6_supp1, ax=ax6supp1_a, signal_mapping=signal_mapping)

# simplify
plot_tools.simplify_axes(ax6supp1_a, splines_to_keep=[], )
ax6supp1_a.set_title(str(selected_cell)+' 50Hz 5sq')

#remove legend
ax6supp1_a.legend([],[], frameon=False)

# -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
# Fig 6B: Example E-I traces over a train
ax6supp1_b.text(-0.1, 1.1, 'B', transform=ax6supp1_b.transAxes, size=20, weight='bold')

data = df[(df['condition'] == 'Gabazine') & (df['stimFreq'] == 50) & (df['numSq'] == 15) & (df['cellID'] == selected_cell)]
fig6_supp1, ax6supp1_b, signal_mapping = plot_tools.plot_data_from_df(data, data_start_column = 49 , combine=True, fig=fig6_supp1, ax=ax6supp1_b)

data = df[(df['condition'] == 'Control') & (df['stimFreq'] == 50) & (df['numSq'] == 15) & (df['cellID'] == selected_cell)]
fig6_supp1, ax6supp1_b, _ = plot_tools.plot_data_from_df(data, data_start_column = 49 , signals_to_plot=['Cell',], signal_colors=['green',], combine=True, fig=fig6_supp1, ax=ax6supp1_b, signal_mapping=signal_mapping)

# simplify
plot_tools.simplify_axes(ax6supp1_b, splines_to_keep=[], )
ax6supp1_b.set_title(str(selected_cell)+' 50Hz 15sq')
ax6supp1_b.legend([],[], frameon=False)
# -----------------------------------------------------------------------------------------------------------------------------------------------------------------------

# Fig 6C: All gabazine cells
ax6supp1_c.text(-0.1, 1.1, 'C', transform=ax6supp1_c.transAxes, size=20, weight='bold')
# cellID is in gabazine cells
data = df[(df['condition'] == 'Gabazine') & (df['stimFreq'] == 50) & (df['numSq'] == 5) & (df['cellID'].isin(gabazine_cells))]
fig6_supp1, ax6supp1_c, signal_mapping = plot_tools.plot_data_from_df(data, data_start_column = 49 , combine=True, fig=fig6_supp1, ax=ax6supp1_c)

data = df[(df['condition'] == 'Control') & (df['stimFreq'] == 50) & (df['numSq'] == 5) & (df['cellID'].isin(gabazine_cells))]
fig6_supp1, ax6supp1_c, _ = plot_tools.plot_data_from_df(data, data_start_column = 49 , signals_to_plot=['Cell',], signal_colors=['green',], combine=True, fig=fig6_supp1, ax=ax6supp1_c, signal_mapping=signal_mapping)



# simplify
plot_tools.simplify_axes(ax6supp1_c, splines_to_keep=[], )
ax6supp1_c.set_title('All cells 50Hz 5sq')
#remove legend
ax6supp1_c.legend([],[], frameon=False)

# -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
# Fig 6D: All gabazine cells
ax6supp1_d.text(-0.1, 1.1, 'D', transform=ax6supp1_d.transAxes, size=20, weight='bold')
# cellID is in gabazine cells
data = df[(df['condition'] == 'Gabazine') & (df['stimFreq'] == 50) & (df['numSq'] == 15) & (df['cellID'].isin(gabazine_cells))]
fig6_supp1, ax6supp1_d, signal_mapping = plot_tools.plot_data_from_df(data, data_start_column = 49 , combine=True, fig=fig6_supp1, ax=ax6supp1_d)

data = df[(df['condition'] == 'Control') & (df['stimFreq'] == 50) & (df['numSq'] == 15) & (df['cellID'].isin(gabazine_cells))]
fig6_supp1, ax6supp1_d, _ = plot_tools.plot_data_from_df(data, data_start_column = 49 , signals_to_plot=['Cell',], signal_colors=['green',], combine=True, fig=fig6_supp1, ax=ax6supp1_d, signal_mapping=signal_mapping)



# simplify
plot_tools.simplify_axes(ax6supp1_d, splines_to_keep=[], )
ax6supp1_d.set_title('All cells 50Hz 15sq')
#remove legend
ax6supp1_d.legend([],[], frameon=False)
# -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
# # save figure
figure_name = 'Figure6_supp1_cell3131.png'
fig6_supp1.savefig(paper_figure_export_location / (figure_name), dpi=300, bbox_inches='tight')b


In [None]:
# aspect ration of the figure = 1
w, h = [15, 9]

fig6_supp2, ax6_supp2 = plt.subplots(5,4, figsize=(w, h), sharey=False)
#flatten ax6_supp2
# ax6_supp2 = ax6_supp2.flatten()
# have more space between suplots
fig6_supp2.subplots_adjust(hspace=0.5, wspace=0.5)

gabazine_cells = df[df['condition']=='Gabazine']['cellID'].unique()
print(f'Gabazine Cells: {gabazine_cells}')


# -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
for i,selected_cell in enumerate(gabazine_cells):
    for j,f in enumerate([20,30,40,50]):

        

        data = df[(df['condition'] == 'Gabazine') & (df['stimFreq'] == f) & (df['numSq'] == 5) & (df['cellID'] == selected_cell)]
        fig6_supp2, ax6_supp2[i,j], signal_mapping = plot_tools.plot_data_from_df(data, data_start_column = 49 , combine=True, fig=fig6_supp2, ax=ax6_supp2[i,j])

        data = df[(df['condition'] == 'Control') & (df['stimFreq'] == f) & (df['numSq'] == 5) & (df['cellID'] == selected_cell)]
        fig6_supp2, ax6_supp2[i,j], _ = plot_tools.plot_data_from_df(data, data_start_column = 49 , signals_to_plot=['Cell',], signal_colors=['green',], combine=True, fig=fig6_supp2, ax=ax6_supp2[i,j], signal_mapping=signal_mapping)

        # simplify
        plot_tools.simplify_axes(ax6_supp2[i,j], splines_to_keep=[], )
        ax6_supp2[i,j].set_title(f'{selected_cell} {f}Hz 5sq')
        #remove legend
        ax6_supp2[i,j].legend([],[], frameon=False)

# -----------------------------------------------------------------------------------------------------------------------------------------------------------------------

# # save figure
figure_name = 'Figure6_supp2_all_gabazine_cells_5sq.png'
fig6_supp2.savefig(paper_figure_export_location / (figure_name), dpi=300, bbox_inches='tight')


In [None]:
# aspect ration of the figure = 1
w, h = [15, 9]

fig6_supp2, ax6_supp2 = plt.subplots(5,4, figsize=(w, h), sharey=False)
#flatten ax6_supp2
# ax6_supp2 = ax6_supp2.flatten()
# have more space between suplots
fig6_supp2.subplots_adjust(hspace=0.5, wspace=0.5)

gabazine_cells = df[df['condition']=='Gabazine']['cellID'].unique()
print(f'Gabazine Cells: {gabazine_cells}')


# -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
for i,selected_cell in enumerate(gabazine_cells):
    for j,f in enumerate([20,30,40,50]):

        

        data = df[(df['condition'] == 'Gabazine') & (df['stimFreq'] == f) & (df['numSq'] == 15) & (df['cellID'] == selected_cell)]
        fig6_supp2, ax6_supp2[i,j], signal_mapping = plot_tools.plot_data_from_df(data, data_start_column = 49 , combine=True, fig=fig6_supp2, ax=ax6_supp2[i,j])

        data = df[(df['condition'] == 'Control') & (df['stimFreq'] == f) & (df['numSq'] == 15) & (df['cellID'] == selected_cell)]
        fig6_supp2, ax6_supp2[i,j], _ = plot_tools.plot_data_from_df(data, data_start_column = 49 , signals_to_plot=['Cell',], signal_colors=['green',], combine=True, fig=fig6_supp2, ax=ax6_supp2[i,j], signal_mapping=signal_mapping)

        # simplify
        plot_tools.simplify_axes(ax6_supp2[i,j], splines_to_keep=[], )
        ax6_supp2[i,j].set_title(f'{selected_cell} {f}Hz 15sq')
        #remove legend
        ax6_supp2[i,j].legend([],[], frameon=False)

# -----------------------------------------------------------------------------------------------------------------------------------------------------------------------

# # save figure
figure_name = 'Figure6_supp2_all_gabazine_cells_15sq.png'
fig6_supp2.savefig(paper_figure_export_location / (figure_name), dpi=300, bbox_inches='tight')


### SDN

In [None]:
# cells with 1sq experiments
freq_sweep_datapath =  r"C:\Users\adity\OneDrive\NCBS\Lab\Projects\EI_Dynamics\Analysis\parsed_data\all_cells_FreqSweep_CC_long.h5" 
df = pd.read_hdf(freq_sweep_datapath, key='data')
# df = df[ pd.notnull(df['peaks_cell'])] # remove all sweeps that had NaNs in analysed params (mostly due to bad pulse detection)
# reset index
# df = df.reset_index(drop=True)
print(f'All cells freqSweep data has {df.shape[0]} sweeps or {df.shape[0]/3} presentations')

cell_with_1sq_data = df[df['numSq']==1]['cellID'].unique()
print(f'There are {len(cell_with_1sq_data)} cells with 1sq data')

# only keep cells that have 1sq data and remove led and field data
df = df[(df['cellID'].isin(cell_with_1sq_data))]
print(f'Cells that have 1sq freqSweep data have total {df.shape[0]} sweeps or {df.shape[0]/3} presentations')

# remove unwanted metadata and signals in columns 40049 to 80049
columns_to_keep = ['cellID','exptID','condition','numSq','pulseWidth','clampPotential','stimFreq','patternList','probePulseStart','pulseTrainStart','sweep']
data1 = df[columns_to_keep]
data2 = df.iloc[:,49:40049]
data = pd.concat([data1,data2], axis=1)
del data1, data2, df
# print data type for the first 10 columns
for i in range(10):
    if data.columns[i] == 'condition':
        continue
    # else make the datatype float
    data[data.columns[i]] = data[data.columns[i]].astype('float32')

# data2 = data.groupby(columns_to_keep[:-1]).mean().reset_index()
print(f'final data shape = {data.shape}')



In [None]:
# filter those rows where row['probePulseStart'] == row['pulseTrainStart']
# data['probe_pulse'] should be equal to 0 if row['probePulseStart'] == row['pulseTrainStart'], and 1 otherwise
data['probe_pulse'] = data['probePulseStart'] != data['pulseTrainStart']

# count how many rows does data['probe_pulse'] == 1
print(f'Number of rows where probe_pulse == 1 is {data[data["probe_pulse"]==1].shape[0]}')

In [None]:
value, freq = np.unique(data['pulseTrainStart'], return_counts=True)
# print value and freq
print(f'value = {value}')
print(f'freq = {freq}')

In [None]:
fig7_supp1, ax7_supp1 = plt.subplots(4,3, figsize=(10,8), sharey=True, layout='constrained')
fig7_supp2, ax7_supp2 = plt.subplots(4,3, figsize=(10,8), sharey=True, layout='constrained')
ax7_supp1 = ax7_supp1.flatten()
ax7_supp2 = ax7_supp2.flatten()

for i, cell in enumerate(cell_with_1sq_data):
    # cell=cell_with_1sq_data[4]
    df_temp_1sq = data[(data['cellID']==cell) & (data['numSq']==1)]
    # print(df_temp_1sq[df_temp_1sq["patternList"]==3.0].iloc[:,:14])
    # average across sweeps for the same patternList value
    df_temp_1sqx = df_temp_1sq.groupby(['cellID','condition','numSq','pulseWidth','clampPotential','stimFreq','probePulseStart','pulseTrainStart','patternList']).mean().reset_index()
    # print(df_temp_1sqx[df_temp_1sqx["patternList"]==3.0].iloc[:,:14])
    # plot the data
    t0 = df_temp_1sqx['probePulseStart'][0]
    t1 = t0+0.05
    time = np.linspace(t0,t1,1000)
    frame = np.zeros((24,24))
    for sweep in range(df_temp_1sqx.shape[0]):
        pattern = df_temp_1sqx.iloc[sweep]['patternList']
        spotloc = pattern_index.patternID[pattern][0]
        # convert frame location to x,y coordinate on a 24,24 grid
        x, y = spotloc//24, spotloc%24
        trace = df_temp_1sqx.iloc[sweep,11+int(t0*2e4):11+int(t0*2e4)+1000]
        tracemax = np.max(trace)
        frame[x,y] = tracemax
        ax7_supp1[i].plot(time, trace, color='k', alpha=0.4)

    ax7_supp1[i].set_title(f'{cell}')
    ax7_supp2[i].set_title(f'{cell}')
    # plot heatmap of the frame on ax7_supp2 axis
    sns.heatmap(frame, ax=ax7_supp2[i], cmap='viridis', vmin=0, vmax=2)

# save figures
figure1_name = 'Figure7_supp1_SDN_CC_1sq_traces.png'
figure2_name = 'Figure7_supp2_SDN_CC_1sq_max_heatmap.png'

fig7_supp1.savefig(paper_figure_export_location / (figure1_name), dpi=300, bbox_inches='tight')
fig7_supp2.savefig(paper_figure_export_location / (figure2_name), dpi=300, bbox_inches='tight')

    
    

Divisive inhibition is given by following equation:
$$\theta = {\epsilon} - {\beta}*{\epsilon} $$


SDN is defined by gamma, given by following equation:
$$\theta = \frac{{\gamma}{\epsilon}}{{\gamma}+{\epsilon}}$$
where $\epsilon$ is expected and $\theta$ is observed response,

$\beta$ is strength of Divisive Inhibition (DI) and $\gamma$ is strength of Divisive Normalization (DN)

In [None]:
# add SDN gamma fucntion
def sdn_observed(expected, gamma):
    observed = (gamma*expected)/(gamma+expected)
    return observed

def di_observed(expected, beta): #divisive inhibition
    observed = expected - beta*expected
    return observed

In [None]:
def plot_first_pulse_SDN(df, cell, ax, remove_spikes_from_fitting=True):
    dfcell = df[(df['cellID']==cell)]
    # make a spotmax dict that stores tracemax value for each spot
    spotmax = {}
    expected = np.zeros(dfcell.shape[0])
    observed = np.zeros(dfcell.shape[0])
    for i,row in dfcell.iterrows():
        pattern = row['patternList']
        if row['numSq'] >1:
            continue
        spotloc = pattern_index.patternID[pattern][0]
        start = 11+int(row['probePulseStart']*2e4)+200 # 0.1s from probePulseStart
        trace = row[start:start+600] # 30ms from the start
        tracemax = np.max(trace)
        spotmax[pattern] = tracemax

    for i in range(dfcell.shape[0]):
        row = dfcell.iloc[i,:]
        numSq = row['numSq']
        patternList = int(row['patternList']) #48	:[197,255,347,401,439]
        spotloclist = pattern_index.patternID[patternList] #[197,255,347,401,439]
        # exp = 0
        for spotloc in spotloclist:
            pattern = pattern_index.get_patternID([spotloc])
            expected[i] += spotmax[pattern]
        # expected[i] = exp/len(spotloclist)

        # calculate observed (for now only for the probe pulse)
        start = 11+int(row['probePulseStart']*2e4)+200 # 0.1s from probePulseStart
        trace = row[start:start+600] # 30ms from the start
        observed[i] = np.max(trace)

    #add expected and observed to the dfcell_short
    dfcell_short = dfcell.iloc[:,:11]
    dfcell_short['expected'] = expected
    dfcell_short['observed'] = observed

    # scatterplot between expected and observed
    sns.scatterplot(data=dfcell_short[dfcell_short['numSq']>1], x='expected', y='observed', hue='patternList', size='numSq', sizes=[200,200], palette='deep', ax=ax, edgecolor='black', alpha=0.5, legend=False)
    # delete legend
    ax.legend([],[], frameon=False)

    # fit various equations
    if remove_spikes_from_fitting:
        x = dfcell_short[(dfcell_short['numSq']>1) & (dfcell_short['observed']<25) ]['expected']
        y = dfcell_short[(dfcell_short['numSq']>1) & (dfcell_short['observed']<25) ]['observed']
    else:
        x = dfcell_short[(dfcell_short['numSq']>1) ]['expected']
        y = dfcell_short[(dfcell_short['numSq']>1) ]['observed']

    # linear fit
    slope, intercept, r_value, p_value, std_err = stats.linregress(x,y)
    # plot the line
    ax.plot(x, intercept + slope*x, 'cornflowerblue', label='Linear fit',)

    # fit di_observed to the data
    popt, _ = curve_fit(di_observed, x, y)
    ydi = di_observed(x, *popt)
    # plot the line
    ax.plot(x, ydi, 'mediumblue', label=r'Divisive Inhibition fit: $\beta$ ='+f'{popt[0]:.2f}')

    # fit sdn_observed to the data
    popt, _ = curve_fit(sdn_observed, x, y)
    ysdn = sdn_observed(x, *popt)
    # plot the line
    ax.plot(x, ysdn, 'darkviolet', label=r'SDN fit: $\gamma$ ='+f'{popt[0]:.2f}')

    # show labels
    ax.legend()

    return dfcell_short, ax


In [None]:
fig7, ax7 = plt.subplots(4,3, figsize=(20,15), layout='constrained')
ax7 = ax7.flatten()
remove_spikes_from_fitting = True
for i,cell in enumerate(cell_with_1sq_data):
    try:
        dfcell_short, ax7[i] = plot_first_pulse_SDN(data, cell, ax7[i],remove_spikes_from_fitting=remove_spikes_from_fitting)
        ax7[i].set_title(f'Cell {cell}')
        ax7[i].set_xlabel('Expected (mV)')
        ax7[i].set_ylabel('Observed (mV)')
        # ax7[i].set_xlim([0,25])
        # ax7[i].set_ylim([0,25])
    except:
        print(f'Cell {cell} failed')
# save fig7
if remove_spikes_from_fitting:
    figure_name = 'Figure7_CC_SDN_first_pulse_withoutfittingspikes.png'
else:
    figure_name = 'Figure7_CC_SDN_first_pulse.png'
fig7.savefig(paper_figure_export_location / (figure_name), dpi=300, bbox_inches='tight')

In [None]:
# For each cell, for each pattern, combine the mean trace for each constituent 1sq spot
# get peak and other responses for each trace
# compare expected vs observed

def plot_raw_expected_observed(data, cell, ax, mode='raw'):
    dfcell = data[(data['cellID']==cell) & (data['stimFreq']==20)]
    print(dfcell['exptID'])
    dfcell2 = dfcell[dfcell['numSq']!=1]
    # color
    clr = {1:viridis(0),2:viridis(0.1),3:viridis(0.2),4:viridis(0.3),5:viridis(0.4),6:viridis(0.5),7:viridis(0.6),8:viridis(0.7),}
    # clr = 
    for j in range(dfcell2.shape[0]):
        row = dfcell2.iloc[j,:]
        numSq = row['numSq']
        observed = row[11:20011].to_numpy()
        time = np.linspace(0,1,20000)
        patternList = int(row['patternList']) #48	:[197,255,347,401,439]
        spotloclist = pattern_index.patternID[patternList] #[197,255,347,401,439]
        
        
        p0 =row['probePulseStart']
        pt = row['pulseTrainStart']
        pulse_times = (2e4* utils.get_pulse_times(8,pt,20)).astype(int)
        
        expected = np.linspace(0,0,20000) 
        for spotloc in spotloclist:
            pattern = pattern_index.get_patternID([spotloc])
            trace = dfcell[(dfcell['patternList']==pattern)].iloc[:,11:20011].to_numpy()
            expected += np.mean(trace, axis=0)
        # subtract baseline from expected
        expected = expected - np.mean(expected[:4000])

        if mode=='raw':
            # save expected to the dfcell2
            # dfcell2.iloc[j, 20011:40011] = expected   
            # plot expected and observed on the ax
            ax.plot(time, 10*j+observed, color='g', alpha=0.5)
            ax.plot(time, 10*j+expected, color='r', alpha=0.5)
        
        elif mode=='p1-pn':
            # expected = np.linspace(0,0,20000)
            # get the first pulse response for expected
            first_pulse_response_trace = expected[pulse_times[0]:pulse_times[1]]
            e1 = np.max(first_pulse_response_trace )

            constructed = np.linspace(0,0,20000)
            for i,p in enumerate(pulse_times):
                c = clr[i+1]
                t0 = int(p)
                t1 = t0 + 1000
                constructed[t0:t1] += first_pulse_response_trace
                pmax_exp = np.max(constructed[t0:t1])
                pmax_obs = np.max(observed[t0:t1])
                ax.scatter(pmax_exp, pmax_obs, 30, c )

        elif mode=='pn-pn':
            # expected = np.linspace(0,0,20000)
            # get the first pulse response for expected
            # first_pulse_response_trace = expected[pulse_times[0]:pulse_times[1]]
            # e1 = np.max(first_pulse_response_trace )

            # constructed = np.linspace(0,0,20000)
            for i,p in enumerate(pulse_times):
                c = clr[i+1]
                t0 = int(p)
                t1 = t0 + 1000
                # print(t0,t1, expected.shape)
                # constructed[t0:t1] += first_pulse_response_trace
                pmax_exp = np.max(expected[t0:t1])
                pmax_obs = np.max(observed[t0:t1])
                ax.scatter(pmax_exp, pmax_obs, 30, c )

In [None]:
fig7_4, ax7_4 = plt.subplots(figsize=(10,8), layout='constrained')
plot_raw_expected_observed(data, 3402, ax7_4, mode='raw')

In [None]:
fig7_5, ax7_5 = plt.subplots(figsize=(10,8), layout='constrained')
plot_raw_expected_observed(data, 3402, ax7_5, mode='pn-pn')

In [None]:
fig7_6, ax7_6 = plt.subplots(figsize=(10,8), layout='constrained')
for cell in cell_with_1sq_data:
    plot_raw_expected_observed(data, cell, ax7_6, mode='p1-pn')

In [None]:
# 1st pulse vs nth pulse responses
freq_sweep_datapath =  r"C:\Users\adity\OneDrive\NCBS\Lab\Projects\EI_Dynamics\Analysis\parsed_data\all_cells_FreqSweep_combined_short.h5" 
df2 = pd.read_hdf(freq_sweep_datapath, key='data')
df2 = df2[ pd.notnull(df2['peaks_cell'])] # remove all sweeps that had NaNs in analysed params (mostly due to bad pulse detection)
df2 = df2.reset_index(drop=True)

In [None]:
# selected cell with 1sq data
df2 = df2[(df2['cellID'].isin(cell_with_1sq_data))]
# pick every cell and separate into 1sq and >1sq data, plot for each pattern, expected constructed from 1sq data, observed from >1sq data

def plot_expected_observed(data, cell, ax):
    dfcell = df2[(df2['cellID']==cell) & (df2['stimFreq']==20)]
    dfcell1sq = dfcell[dfcell['numSq']==1]

    for i in range(dfcell.shape[0]):
        row = dfcell.iloc[i,:]
        numSq = row['numSq']
        peaks = row['peaks_cell'].tolist()
        patternList = int(row['patternList']) #48	:[197,255,347,401,439]
        spotloclist = pattern_index.patternID[patternList] #[197,255,347,401,439]
        expected = np.zeros((9,))
        for spotloc in spotloclist:
            pattern = pattern_index.get_patternID([spotloc])
            pattern_peaks = dfcell1sq[(dfcell1sq['patternList']==pattern)]['peaks_cell'].to_numpy()
            pattern_peaks = np.mean( np.concatenate(pattern_peaks, axis=0).reshape(3,-1), axis=0)
            expected += pattern_peaks
        


        


In [None]:
csv_path = r"C:\Users\adity\OneDrive\NCBS\Lab\Projects\EI_Dynamics\Data\23-11-24_GrikAA375\data001.txt"
df = pd.read_csv(csv_path, sep='\t')
# make the coord column as index
df = df.set_index('coord')

In [None]:
# make a fig with two subpanels
fig8, ax8 = plt.subplots(1,2, figsize=(8,8), sharey=True, layout='constrained', gridspec_kw={'width_ratios': [1, 1]})

'''
vals = df_CA3_heatmap[pulse].values
# get index values
idx = (df_CA3_heatmap[pulse].index.get_level_values(0).values) - 1
'''
vals1 = df['field1']
idx1 = (df['field1'].index.get_level_values(0).values) 

vals2 = df['field2']
idx2 = (df['field2'].index.get_level_values(0).values)
# idx2 = (df['coord'].values) #index.get_level_values(0).values)

plot_tools.plot_grid(spot_locs=idx1, spot_values=vals1, ax=ax8[0], cmap='grey', vmin=-2.0, vmax=0)
plot_tools.plot_grid(spot_locs=idx2, spot_values=vals2, ax=ax8[1], cmap='grey', vmin=-2.0, vmax=0)
# # plot field1
# sns.heatmap(field2grid, ax=ax8[0], cmap='grey')
# # plot field2
# sns.heatmap(field1grid, ax=ax8[1], cmap='grey')
# #remove ticks
# ax8[0].set_xticks([])
# ax8[0].set_yticks([])
# ax8[1].set_xticks([])
# ax8[1].set_yticks([])

# # only 1 colorbar
# fig8.colorbar(ax8[0].get_children()[0], ax=ax8, location="top", use_gridspec=False, pad=0.1, shrink=0.5)
# # remove the colorbar from the subplots
# ax8[0].collections[0].colorbar.remove()
# ax8[1].collections[0].colorbar.remove()


In [None]:
vals1

Fig 5 comment from Upi:

show that variance in the field is smaller than the VC currents: multivariate anova # this as a substitute for CA3 recordings