In [None]:
'''
Author: Conor Lane, February 2024
Contact: conor.lane1995@gmail.com

Analysis code for calculating sensitivity of both all individual cells, and changes in sensitivity between pre- and post- for the randomized stim cohort. 
Sensitivity = the lowest intensity at which a cell shows a significant response to sound-stimulation. 

INPUTS: filepath to the evoked cohort megadicts (collected recordings for each condition from Evoked Cohort)
        z_thresh - minimum z-score threshold over which we declare a significant response (default is 4).
'''

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pickle
from sklearn.metrics import pairwise_distances
import pandas as pd
import scipy.stats as ss
import seaborn as sns
from mpl_toolkits.axes_grid1 import make_axes_locatable
from scipy.stats import kstest
import os

z_thresh = 4
filepath = "F:/Two-Photon/Psilocybin Project/Evoked Cohort Mice/megadicts"

LOAD RANDOMIZED STIM COHORT DICTS:

In [None]:
# Dictionary to map filenames to variable names
file_variable_mapping = {
    'saline_pre_dict.pkl': 'saline_pre',
    'saline_post_dict.pkl': 'saline_post',
    'psilo_pre_dict.pkl': 'psilo_pre',
    'psilo_post_dict.pkl': 'psilo_post'
}

# Initialize empty dictionaries
saline_pre = {}
saline_post = {}
psilo_pre = {}
psilo_post = {}

# Iterate through files in megadict folder
for filename in os.listdir(filepath):
    if filename in file_variable_mapping:
        file_path = os.path.join(filepath, filename)
        with open(file_path, 'rb') as file:
            # Load pkl file and assign to respective dictionary variable
            globals()[file_variable_mapping[filename]] = pickle.load(file)

GENERAL FUNCTIONS:

In [None]:
# Plots a double-bar graph for two chosen sensitivity arrays. 
# INPUTS: values_set_1 and 2 - the two sets of sensitivities to compare
#         title - title of graph as string
#         label_1 - first bar label e.g. Pre-Saline
#         label_2 - second bar label

def calculate_relative_frequencies(values, unique_values):
    total_values = len(values)
    frequencies = np.array([np.sum(values == value) / total_values for value in unique_values])
    return frequencies

def plot_comparison(values_set1, values_set2,title,label_1,label_2):
    # Calculate unique values for each set
    unique_values_set1 = np.unique(values_set1)
    unique_values_set2 = np.unique(values_set2)
    
    # Combine unique values from both sets
    unique_values = np.unique(np.concatenate((unique_values_set1, unique_values_set2)))

    # Halve the unique values except for 0
    halved_unique_values = [value / 2 if value != 0 else 0 for value in unique_values]

    # Calculate relative frequencies for each set based on the unique values
    rel_freq_set1 = calculate_relative_frequencies(values_set1, unique_values)
    rel_freq_set2 = calculate_relative_frequencies(values_set2, unique_values)

    # Set the width of the bars
    bar_width = 0.35

    # Set the positions of the bars on the x-axis
    r1 = np.arange(len(unique_values))
    r2 = [x + bar_width for x in r1]

    # Create the bar plot
    plt.bar(r1, rel_freq_set1, color='blue', width=bar_width, edgecolor='black', label=label_1)
    plt.bar(r2, rel_freq_set2, color='orange', width=bar_width, edgecolor='black', label=label_2)

    # Add labels and title
    plt.xlabel('Bandwidth (Octaves)')
    plt.ylabel('Probability')
    plt.title(title)

    # Set the modified x tick labels
    plt.xticks([r + bar_width / 2 for r in range(len(unique_values))], halved_unique_values)

    # Add legend
    plt.legend()
    plt.tight_layout()

    # Show plot
    plt.show()

In [None]:
# Create an array of all the matched cells that are sound-responsive in both recordings.  Each row is a matched cell pair. 
# INPUTS:  pre- and post- megadicts for a given drug condition. 
#          The specific recording to get matched cells for in the sub-dictionaries of pre- and post.
#          Code is written to be used with the matched cells bandwidth functions. 
# OUTPUTS: (npairs x 2) array containing the matched cell pairs that were responsive in both recordings. 

def get_consistently_responsive_cells(dict_pre,dict_post,sub_dict_pre,sub_dict_post):

    matched_responsive_1 = []
    matched_responsive_2 = []

    # Get the array of matched cell pairs stored under the dictionary's first cell key. 
    matched_cells = dict_post[sub_dict_post][next(iter(dict_post[sub_dict_post]))]['matched_cells']

    # iterate through each cell in the first dict and check if it is a matched cell pair.  Append the matched cells to a list.
    for cell in dict_pre[sub_dict_pre]:
            if cell in matched_cells[:,0] and dict_pre[sub_dict_pre][cell]['active'] == True:
                matched_responsive_1.append(cell)

    # Same operation but with the second dictionary.
    for cell in dict_post[sub_dict_post]:
            if cell in matched_cells[:,1] and dict_post[sub_dict_post][cell]['active'] == True:
                matched_responsive_2.append(cell)

    indices = np.where(np.isin(matched_cells[:, 0], matched_responsive_1))

    # Find the indices where the values in column 0 appear in the first match list.
    indices_col1 = np.isin(matched_cells[:, 0], matched_responsive_1)

    # Find the indices where the values in column 1 appear in the second match list. 
    indices_col2 = np.isin(matched_cells[:, 1], matched_responsive_2)

    # Combine the two conditions using logical AND
    combined_indices = np.logical_and(indices_col1, indices_col2)

    # Extract the rows where both conditions are true
    coactive = matched_cells[combined_indices]

    return coactive


In [None]:
def calculate_cdf(data):
    # Define your series
    s = pd.Series(data, name = 'value')
    df = pd.DataFrame(s)
    # Get the frequency, PDF and CDF for each value in the series

    # Frequency
    stats_df = df \
    .groupby('value') \
    ['value'] \
    .agg('count') \
    .pipe(pd.DataFrame) \
    .rename(columns = {'value': 'frequency'})

    # PDF
    stats_df['pdf'] = stats_df['frequency'] / sum(stats_df['frequency'])

    # CDF
    stats_df['cdf'] = stats_df['pdf'].cumsum()
    stats_df = stats_df.reset_index()
    stats_df

    return stats_df

In [None]:
def plot_cdf(pre,post,title,label_1,label_2):
    pre_cdf = calculate_cdf(pre)
    post_cdf = calculate_cdf(post)

    label = [label_1,label_2]

    zipped = zip([pre_cdf,post_cdf], label)

    fig = plt.figure()

    for frame,label in zipped:
        plt.plot(frame['value'], frame['cdf'],label = label )
    plt.title(title,pad=10)
    plt.xticks(range(4), [35,50,65,80])
    plt.xlabel("Lowest Response Intensity (dB)")
    plt.ylabel("Cumulative Probability")
    plt.grid()
    plt.legend()
    plt.tight_layout()
    plt.show()

FUNCTIONS ALL CELLS:

In [None]:
def get_sensitivity_all_cells(dict):
    
    sensitivity_all = []

    sub_dict_keys = dict.keys()

    for sub_dict in sub_dict_keys:
        for cell in dict[sub_dict]:

            if dict[sub_dict][cell]['active'] == True:
                tuning_array = np.array(dict[sub_dict][cell]['peak_tuning'])
                for i in range(len(tuning_array[0,:])):
                    if any(y > z_thresh for y in tuning_array[0:,i]):
                        sensitivity_all.append(float(i))
                        break
                    else:
                        continue

    return sensitivity_all

FUNCTIONS MATCHED CELLS

In [None]:
def get_lowest_response_intensity_matched(dict_pre,dict_post):
        
        lowest_response_intensity_all = []

        for sub_dict_pre, sub_dict_post in zip(dict_pre.keys(),dict_post.keys()):

        

                # Get the array of consistently responsive matched cell pairs for the pre- and post-conditions. 
                coactive = get_consistently_responsive_cells(dict_pre,dict_post,sub_dict_pre,sub_dict_post)
                
                lowest_response_intensity = np.zeros_like(coactive)

                for i,cell_1, cell_2 in zip(range(len(coactive[:,0])),coactive[:,0],coactive[:,1]):
                        tuning_array_1 = dict_pre[sub_dict_pre][cell_1]['peak_tuning']
                        tuning_array_2 = dict_post[sub_dict_post][cell_2]['peak_tuning']

                        nInt = len(tuning_array_1[0,:])         
                        
                        for j in range(nInt):
                                if np.any(tuning_array_1[0:,j] >= z_thresh): 
                                        lowest_response_intensity[i,0] = float(j)
                                        break
                                else:
                                        continue
                        
                        for j in range(nInt):
                                if np.any(tuning_array_2[0:,j] >= z_thresh): 
                                        lowest_response_intensity[i,1] = float(j)
                                        break
                                else:
                                        continue
                                
                lowest_response_intensity_all.append(lowest_response_intensity)
        lowest_response_intensity_all = np.concatenate(lowest_response_intensity_all)

        return lowest_response_intensity_all

In [None]:
low = get_lowest_response_intensity_matched(saline_pre,saline_post)

ANALYSIS: ALL CELLS

In [None]:
# Calculate lowest response intensity Saline
sensitivity_pre_saline = get_sensitivity_all_cells(saline_pre)
sensitivity_post_saline = get_sensitivity_all_cells(saline_post)

# Calculate lowest response intensity Psilocybin
sensitivity_pre_psilo = get_sensitivity_all_cells(psilo_pre)
sensitivity_post_psilo = get_sensitivity_all_cells(psilo_post)

In [None]:
from scipy.stats import mannwhitneyu

statistic, p_value = mannwhitneyu(sensitivity_post_saline, sensitivity_post_psilo)

print(f"Mann-Whitney U statistic: {statistic}")
print(f"P-value: {p_value}")

In [None]:
title = 'Cell Sensitivity Post-Saline and Post-Psilocybin'
label_1 = 'Post-Saline'
label_2 = 'Post-Psilocybin'
plot_cdf(sensitivity_post_saline,sensitivity_post_psilo,title,label_1,label_2)

In [None]:
title = 'Cell Sensitivity Post-Saline and Post-Psilocybin'
label_1 = 'Pre-Saline'
label_2 = 'Post-Saline'
plot_cdf(sensitivity_pre_saline,sensitivity_post_saline,title,label_1,label_2)

In [None]:
title = 'Cell Sensitivity Post-Saline and Post-Psilocybin'
label_1 = 'Pre-Psilocybin'
label_2 = 'Post-Psilocybin'
plot_cdf(sensitivity_pre_saline,sensitivity_pre_psilo,title,label_1,label_2)

In [None]:
from scipy.stats import mannwhitneyu

statistic, p_value = mannwhitneyu(sensitivity_pre_saline, sensitivity_pre_psilo)

print(f"Mann-Whitney U statistic: {statistic}")
print(f"P-value: {p_value}")

In [None]:
lowest_saline = get_lowest_response_intensity_matched(saline_pre,saline_post)
lowest_psilo = get_lowest_response_intensity_matched(psilo_pre,psilo_post)

title = 'Cell Sensitivity Post-Saline and Post-Psilocybin'
label_1 = 'Pre-Saline'
label_2 = 'Post-Saline'
plot_cdf(lowest_saline[:,1],lowest_psilo[:,1],title,label_1,label_2)

In [None]:
#look at changes in tuning specifically for the non-coactive cells