In [None]:
'''
Author: Conor Lane, March 2024
contact: conor.lane1995@gmail.com

Analysis code for calculating the bandwidths and shifts in bandwidths, both for all the active cells at a population level,
and a subset of matched cells that were sound-responsive in both recordings. 

INPUTS: filepath to the evoked cohort megadicts (collected recordings for each condition from Evoked Cohort)
        z_thresh - minimum z-score threshold over which we declare a significant response (default is 4)
'''

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import os
import pickle

In [None]:
# INPUTS:

filepath = "F:/Two-Photon/Psilocybin Project/Evoked Cohort Mice/megadicts"
z_thresh = 4

LOAD EVOKED COHORT DICTS:

In [None]:
# Dictionary to map filenames to variable names
file_variable_mapping = {
    'saline_pre_dict.pkl': 'saline_pre',
    'saline_post_dict.pkl': 'saline_post',
    'psilo_pre_dict.pkl': 'psilo_pre',
    'psilo_post_dict.pkl': 'psilo_post'
}

# Initialize empty dictionaries
saline_pre = {}
saline_post = {}
psilo_pre = {}
psilo_post = {}

# Iterate through files in megadict folder
for filename in os.listdir(filepath):
    if filename in file_variable_mapping:
        file_path = os.path.join(filepath, filename)
        with open(file_path, 'rb') as file:
            # Load pkl file and assign to respective dictionary variable
            globals()[file_variable_mapping[filename]] = pickle.load(file)

GENERAL FUNCTIONS:

In [None]:
# Plots a double-bar graph for two chosen bandwidth arrays. 
# INPUTS: values_set_1 and 2 - the two sets of bandwidths to compare
#         title - title of graph as string
#         label_1 - first bar label e.g. Pre-Saline
#         label_2 - second bar label

def calculate_relative_frequencies(values, unique_values):
    total_values = len(values)
    frequencies = np.array([np.sum(values == value) / total_values for value in unique_values])
    return frequencies

def plot_comparison(values_set1, values_set2,title,label_1,label_2):
    # Calculate unique values for each set
    unique_values_set1 = np.unique(values_set1)
    unique_values_set2 = np.unique(values_set2)
    
    # Combine unique values from both sets
    unique_values = np.unique(np.concatenate((unique_values_set1, unique_values_set2)))

    # Halve the unique values except for 0
    halved_unique_values = [value / 2 if value != 0 else 0 for value in unique_values]

    # Calculate relative frequencies for each set based on the unique values
    rel_freq_set1 = calculate_relative_frequencies(values_set1, unique_values)
    rel_freq_set2 = calculate_relative_frequencies(values_set2, unique_values)

    # Set the width of the bars
    bar_width = 0.35

    # Set the positions of the bars on the x-axis
    r1 = np.arange(len(unique_values))
    r2 = [x + bar_width for x in r1]

    # Create the bar plot
    plt.bar(r1, rel_freq_set1, color='blue', width=bar_width, edgecolor='black', label=label_1)
    plt.bar(r2, rel_freq_set2, color='orange', width=bar_width, edgecolor='black', label=label_2)

    # Add labels and title
    plt.xlabel('Bandwidth (Octaves)')
    plt.ylabel('Probability')
    plt.title(title)

    # Set the modified x tick labels
    plt.xticks([r + bar_width / 2 for r in range(len(unique_values))], halved_unique_values)

    # Add legend
    plt.legend()
    plt.tight_layout()

    # Show plot
    plt.show()


In [None]:
# Calculate the bandwidth of the cell using the half-max criterion, i.e. the continous range of responsive frequencies from the max response
# that show a response above 50% of the maximum response.  
# INPUTS:  Tuning array for the specified intensity (e.g. BF_column_1) in get_bandwidth_all_cells func.

def count_above_half_max(array):
    max_value = max(array)
    count = 0
    
    # Find the index of the maximum value in the array
    max_index = array.argmax()
    
    # Start from the index of the maximum value and iterate downwards
    index = max_index
    while index >= 0 and array[index] >= max_value / 2:
        count += 1
        index -= 1
    
    # Start from the index of the maximum value and iterate upwards
    index = max_index + 1  # Start from the next index
    while index < len(array) and array[index] >= max_value / 2:
        count += 1
        index += 1
    
    return count

ALL CELLS FUNCTIONS:

In [None]:
# Calculates the bandwidth of every responsive cell in the cohort for a given condition. 
# Inputs:  dict - the given megadict you want to extract bandwidths from.
#          Intensity - The intensity of sound stim you want in dB (0 = 35, 1 = 50, 2 = 65, 3 = 80)
# Returns: List containing the bandwidth value (in no. of frequencies so bandwidth of 1 = 0.5 octaves)

def get_bandwidth_all_cells(dict,intensity):    
    
    # Get a list of all the individual recording keys as strings
    sub_dict_keys = dict.keys()

    bandwidth_all = []
    
    # Iterate through each cell, in each individual recording. 
    for sub_dict in sub_dict_keys:
        for cell in dict[sub_dict]:
                
                # Select only the responsive cells
                if dict[sub_dict][cell]['active'] == True:

                    # Extract the tuning information from the dict and isolate only the specified intensity. 
                    tuning_array_1 = dict[sub_dict][cell]['peak_tuning']
                    BF_column_1 = tuning_array_1[:,intensity]

                    # If the cell has a response at this sound intensity that is above z-score threshold,
                    # calculate the bandwidth using half_max function
                    if any(value >= z_thresh for value in BF_column_1):
                        bandwidth = count_above_half_max(BF_column_1)
                        bandwidth_all.append(bandwidth)

    return bandwidth_all

MATCHED CELLS FUNCTIONS:

In [None]:
# Create an array of all the matched cells that are sound-responsive in both recordings.  Each row is a matched cell pair. 
# INPUTS:  pre- and post- megadicts for a given drug condition. 
#          The specific recording to get matched cells for in the sub-dictionaries of pre- and post.
#          Code is written to be used with the matched cells bandwidth functions. 
# OUTPUTS: (npairs x 2) array containing the matched cell pairs that were responsive in both recordings. 

def get_consistently_responsive_cells(dict_pre,dict_post,sub_dict_pre,sub_dict_post):

    matched_responsive_1 = []
    matched_responsive_2 = []

    # Get the array of matched cell pairs stored under the dictionary's first cell key. 
    matched_cells = dict_post[sub_dict_post][next(iter(dict_post[sub_dict_post]))]['matched_cells']

    # iterate through each cell in the first dict and check if it is a matched cell pair.  Append the matched cells to a list.
    for cell in dict_pre[sub_dict_pre]:
            if cell in matched_cells[:,0] and dict_pre[sub_dict_pre][cell]['active'] == True:
                matched_responsive_1.append(cell)

    # Same operation but with the second dictionary.
    for cell in dict_post[sub_dict_post]:
            if cell in matched_cells[:,1] and dict_post[sub_dict_post][cell]['active'] == True:
                matched_responsive_2.append(cell)

    indices = np.where(np.isin(matched_cells[:, 0], matched_responsive_1))

    # Find the indices where the values in column 0 appear in the first match list.
    indices_col1 = np.isin(matched_cells[:, 0], matched_responsive_1)

    # Find the indices where the values in column 1 appear in the second match list. 
    indices_col2 = np.isin(matched_cells[:, 1], matched_responsive_2)

    # Combine the two conditions using logical AND
    combined_indices = np.logical_and(indices_col1, indices_col2)

    # Extract the rows where both conditions are true
    coactive = matched_cells[combined_indices]

    return coactive


In [None]:
# Extract the bandwidths of the matched cell pairs across all recordings of that condition.  
# INPUTS:  pre- and post-drug dictionaries to extract bandwidths.
#          intensity - The intensity of sound stim you want in dB (0 = 35, 1 = 50, 2 = 65, 3 = 80)
# OUTPUTS: (npairs x 2) array containing the pre- and post-bandwidths of every cell pair.  

def get_bandwidth_matched_cells(dict_pre,dict_post,intensity):

    matched_bandwidths = []

    for sub_dict_pre, sub_dict_post in zip(dict_pre.keys(),dict_post.keys()):
        
        # Get the array of consistently responsive matched cell pairs for the pre- and post-conditions. 
        coactive = get_consistently_responsive_cells(dict_pre,dict_post,sub_dict_pre,sub_dict_post)

        # Initialize empty array to place bandwidths in.
        bandwidth = np.zeros_like(coactive)

        # For each pair of cells in the coactive matrix, extract their tuning curves. i keeps track of the pair's row position. 
        for i,cell_1, cell_2 in zip(range(len(coactive[:,0])),coactive[:,0],coactive[:,1]):
                    tuning_array_1 = dict_pre[sub_dict_pre][cell_1]['peak_tuning']
                    tuning_array_2 = dict_post[sub_dict_post][cell_2]['peak_tuning']

                    # Extract the full row of frequencies, at the selected intensity.
                    column_pre = tuning_array_1[:,intensity]
                    column_post = tuning_array_2[:,intensity]

                    # Calculate the bandwidths from each tuning array. 
                    bandwidth[i,0] = count_above_half_max(column_pre)
                    bandwidth[i,1] = count_above_half_max(column_post)

        matched_bandwidths.append(bandwidth)

    matched_bandwidths = np.concatenate(matched_bandwidths)

    return matched_bandwidths

In [None]:
# Calculate the shift in bandwidth in matched cells by subtracting post bandwidths from pre. 

def get_bandwidth_shift(matched_bandwidths):
    bandwidth_change = matched_bandwidths[:,1] - matched_bandwidths[:,0]

    return bandwidth_change

ANALYSIS 1: BANDWIDTHS ALL CELLS

In [None]:
# Calculate the bandwidth for all cells in a particular dictionary at a specific intensity. Change dicts as needed. 

saline_pre_bandwidth_65 = get_bandwidth_all_cells(saline_pre,2)
saline_post_bandwidth_65 = get_bandwidth_all_cells(saline_post,2)

In [None]:
# Plot the comparison of bandwidths, change title and labels as needed. 

title = 'Bandwidth at 65 dB, Pre- and Post-Saline'
label_1 = 'Pre-Saline'
label_2 = 'Post-Saline'
plot_comparison(saline_pre_bandwidth_65,saline_post_bandwidth_65,title,label_1,label_2)

ANALYSIS 2: BANDWIDTH MATCHED CELLS SUBSET

In [None]:
# Get the bandwidths for the matched cells, column 0 is pre, 1 is post. 

saline_matched_65 = get_bandwidth_matched_cells(saline_pre,saline_post,2)

In [None]:
title = 'Bandwidth Matched Cells at 65 dB, Pre- and Post-Saline'
label_1 = 'Pre-Saline'
label_2 = 'Post-Saline'
plot_comparison(saline_matched_65[:,0],saline_matched_65[:,1],title,label_1,label_2)

ANALYSIS 2: MATCHED CELLS CHANGE IN BANDWIDTH

In [None]:
# Get change in bandwidth for both the saline and psilocybin conditions. 

change_bandwidth_saline_65 = get_bandwidth_shift(get_bandwidth_matched_cells(saline_pre,saline_post,2))
change_bandwidth_psilo_65 = get_bandwidth_shift(get_bandwidth_matched_cells(psilo_pre,psilo_post,2))


In [None]:
title = 'Change in Bandwidth at 65 dB, Saline vs Psilocybin'
label_1 = 'Saline'
label_2 = 'Psilocybin'
plot_comparison(change_bandwidth_saline_65,change_bandwidth_psilo_65,title,label_1,label_2)

In [None]:
#  Not normally distributed so Mann Whitney test for difference. 

from scipy.stats import mannwhitneyu

statistic, p_value = mannwhitneyu(change_bandwidth_saline_65, change_bandwidth_psilo_65)

print(f"Mann-Whitney U statistic: {statistic}")
print(f"P-value: {p_value}")