In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec #Useful tool to arrange multiple plots in one figure (https://matplotlib.org/stable/api/_as_gen/matplotlib.gridspec.GridSpec.html)

%matplotlib inline

import platform
platstring = platform.platform()

data_dirname = 'visual-behavior-neuropixels'
use_static = False
if 'Darwin' in platstring or 'macOS' in platstring:
    # macOS 
    data_root = "/Volumes/Brain2021/"
elif 'Windows'  in platstring:
    # Windows (replace with the drive letter of USB drive)
    data_root = "E:/"
elif ('amzn' in platstring):
    # then on AWS
    data_root = "/data/"
    data_dirname = 'visual-behavior-neuropixels-data'
    use_static = True
else:
    # then your own linux platform
    # EDIT location where you mounted hard drive
    data_root = "/home/andrew/Documents/tmp/5-1-24-change-detection"

import os

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from allensdk.brain_observatory.ecephys.ecephys_project_cache import EcephysProjectCache

# this path determines where downloaded data will be stored
manifest_path = os.path.join(data_root, "manifest.json")
cache = EcephysProjectCache.from_warehouse(manifest=manifest_path)

print(cache.get_all_session_types())




# from allensdk.brain_observatory.behavior.behavior_project_cache.\
#     behavior_neuropixels_project_cache \
#     import VisualBehaviorNeuropixelsProjectCache

# from allensdk.brain_observatory.behavior.

# # this path should point to the location of the dataset on your platform
# cache_dir = os.path.join(data_root, data_dirname)
# cache = VisualCodingNeuropixelsProjectCache.from_s3_cache(cache_dir=cache_dir)

# # cache = VisualBehaviorNeuropixelsProjectCache.from_local_cache(
# #             cache_dir=cache_dir, use_static_cache=use_static)




In [None]:
sessions = cache.get_session_table()
brain_observatory_type_sessions = sessions[sessions["session_type"] == "brain_observatory_1.1"]
brain_observatory_type_sessions.columns

In [None]:
rvatory_type_sessions = sessions[sessions["session_type"] == "brain_observatory_1.1"]
brain_observatory_type_sessions.columns

In [None]:
rvatory_type_sessions = sessions[sessions["session_type"] == "brain_observatory_1.1"]
brain_observatory_type_sessions.columns

In [None]:
session_id = 791319847
session = cache.get_session_data(session_id)

In [None]:
# pd.set_option('display.max_rows', None)
# pd.set_option('display.max_columns', None)
# pd.set_option('display.max_colwidth', None)
# pd.set_option('display.width', None)
print(session.metadata)


In [None]:
session.structurewise_unit_counts


In [None]:
session.get_stimulus_table('flashes')

In [None]:
presentations = session.get_stimulus_table("flashes")
units = session.units[session.units["ecephys_structure_acronym"] == 'VISp']

time_step = 0.01
time_bins = np.arange(-0.1, 0.5 + time_step, time_step)

histograms = session.presentationwise_spike_counts(
    stimulus_presentation_ids=presentations.index.values,  
    bin_edges=time_bins,
    unit_ids=units.index.values
)

histograms.coords

In [None]:
mean_histograms = histograms.mean(dim="stimulus_presentation_id")

fig, ax = plt.subplots(figsize=(8, 8))
ax.pcolormesh(
    mean_histograms["time_relative_to_stimulus_onset"], 
    np.arange(mean_histograms["unit_id"].size),
    mean_histograms.T, 
    vmin=0,
    vmax=1
)

ax.set_ylabel("unit", fontsize=24)
ax.set_xlabel("time relative to stimulus onset (s)", fontsize=24)
ax.set_title("peristimulus time histograms for VISp units on flash presentations", fontsize=24)

plt.show()

In [None]:
scene_presentations = session.get_stimulus_table("natural_scenes")
visp_units = session.units[session.units["ecephys_structure_acronym"] == "VISp"]

spikes = session.presentationwise_spike_times(
    stimulus_presentation_ids=scene_presentations.index.values,
    unit_ids=visp_units.index.values[:]
)

spikes

In [None]:
spikes["count"] = np.zeros(spikes.shape[0])
spikes = spikes.groupby(["stimulus_presentation_id", "unit_id"]).count()

design = pd.pivot_table(
    spikes, 
    values="count", 
    index="stimulus_presentation_id", 
    columns="unit_id", 
    fill_value=0.0,
    aggfunc=np.sum
)

design

In [None]:
targets = scene_presentations.loc[design.index.values, "frame"]
targets

In [None]:
# session.get_stimulus_table("flashes").head(3)
session.units.columns


In [None]:
visual_units = session.units[session.units['ecephys_structure_acronym'].isin(regions)]
# visual_units

In [None]:

regions = ['VISp', 'VISl', 'VISrl', 'VISal', 'VISpm', 'VISam']  # Initial focus on VIS areas
start_time = 1280
end_time = 1400
# start_time = 1288
# end_time = 1292

visual_units = session.units[session.units['ecephys_structure_acronym'].isin(regions)]

for unit_id, times in session.spike_times.items():
    if unit_id in visual_units.index:
        condition = (session.spike_times[unit_id] > start_time) & (session.spike_times[unit_id] < end_time)
        print(session.spike_times[unit_id][condition])

# session.spike_times

# for key in session.spike_times.keys():
#     print(session.spike_times[key].shape)

In [None]:
def extract_detailed_spike_data(session, start_time, end_time, regions):
    """
    Extract detailed spike timing data for specific regions within a defined time window.

    Parameters:
        session (EcephysSession): The session object containing the unit and spike data.
        start_time (float): Start time of the window.
        end_time (float): End time of the window.
        regions (list): List of regions to include in the analysis.

    Returns:
        DataFrame: Detailed spike timing data for each unit, including the unit's region.
    """
    # Filter units based on specified regions
    visual_units = session.units[session.units['ecephys_structure_acronym'].isin(regions)]

    # Prepare a DataFrame to store results
    spikes_list = []
    for unit_id, times in session.spike_times.items():
        if unit_id in visual_units.index:
            # Filter spike times within the time window
            relevant_spikes = times[(times >= start_time) & (times <= end_time)]
            for spike_time in relevant_spikes:
                spikes_list.append({
                    'unit_id': unit_id,
                    'spike_time': spike_time,
                    'region': visual_units.loc[unit_id, 'ecephys_structure_acronym']
                })

    return pd.DataFrame(spikes_list)

# Example usage
start_time = 1288
end_time = 1292
detailed_spike_data = extract_detailed_spike_data(session, start_time, end_time, regions)
detailed_spike_data.head(5)


In [None]:
import numpy as np
import matplotlib.pyplot as plt

def extract_detailed_spike_data(session, start_time, end_time, regions):
    """
    Extract detailed spike timing data for specific regions within a defined time window.

    Parameters:
        session (EcephysSession): The session object containing the unit and spike data.
        start_time (float): Start time of the window.
        end_time (float): End time of the window.
        regions (list): List of regions to include in the analysis.

    Returns:
        DataFrame: Detailed spike timing data for each unit, including the unit's region.
    """
    # Filter units based on specified regions
    visual_units = session.units[session.units['ecephys_structure_acronym'].isin(regions)]

    # Prepare a DataFrame to store results
    spikes_list = []
    for unit_id, times in session.spike_times.items():
        if unit_id in visual_units.index:
            # Filter spike times within the time window
            relevant_spikes = times[(times >= start_time) & (times <= end_time)]
            for spike_time in relevant_spikes:
                spikes_list.append({
                    'unit_id': unit_id,
                    'spike_time': spike_time,
                    'region': visual_units.loc[unit_id, 'ecephys_structure_acronym']
                })

    return pd.DataFrame(spikes_list)


def plot_firing_rates_over_time(spike_data, bin_size, start_time, end_time):
    """
    Plot firing rates over time for each region in the spike data.

    Parameters:
        spike_data (DataFrame): DataFrame containing 'spike_time' and 'region' columns.
        bin_size (float): Size of each time bin in seconds.
        start_time (float): Start time of the window for plotting.
        end_time (float): End time of the window for plotting.
    """
    # Creating the bins
    bins = np.arange(start_time, end_time, bin_size)
    regions = spike_data['region'].unique()
    
    plt.figure(figsize=(12, 8))
    
    # Binning and plotting for each region
    for region in regions:
        region_spikes = spike_data[spike_data['region'] == region]['spike_time']
        spike_counts, _ = np.histogram(region_spikes, bins=bins)
        firing_rates = spike_counts / bin_size  # Convert counts to rates
        
        # Plotting
        plt.plot(bins[:-1] + bin_size/2, firing_rates, label=f'Region: {region}')
    
    plt.xlabel('Time (s)')
    plt.ylabel('Firing Rate (spikes/s)')
    plt.title('Firing Rates Over Time by Region')
    plt.legend()
    plt.show()

# Example usage
# start_time = 1288
# end_time = 1292
start_time = 1291
end_time = 1293
detailed_spike_data = extract_detailed_spike_data(session, start_time, end_time, regions)

# Assuming you've defined bin_size, start_time, and end_time
bin_size = 0.05  # Example bin size of 100ms
plot_firing_rates_over_time(detailed_spike_data, bin_size, start_time, end_time)


In [None]:
session.units.columns
count_structure_acronyms = session.units['structure_acronym'].value_counts()



In [None]:
print(session.stimulus_presentations.columns)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

def get_time_window(session, buffer=1):
    """
    Get the time window from slightly before the first flash to slightly after the third flash.
    
    Parameters:
        session (EcephysSession): The session object.
        buffer (float): Buffer time in seconds to extend before the first and after the last flash.
    
    Returns:
        tuple: start_time, end_time
    """
    num_flashes = 50
    flashes = session.get_stimulus_table("flashes").head(num_flashes)
    print("number of flashes!!!-------------")
    print(len(flashes))
    num_flashes = min(num_flashes, len(flashes))
    start_time = flashes.iloc[0]['start_time'] - buffer
    end_time = flashes.iloc[num_flashes-1]['stop_time'] + buffer

    flash_start_times = flashes['start_time'].values
    flash_end_times = flashes['stop_time'].values
    print(flash_start_times)

    return start_time, end_time, flash_start_times, flash_end_times

def extract_and_bin_spikes(session, start_time, end_time, bin_size, regions, flash_start_times, flash_end_times):
    """
    Extract spikes for each region within a defined time window and bin them.
    
    Parameters:
        session (EcephysSession): The session object.
        start_time (float): Start time of the window.
        end_time (float): End time of the window.
        bin_size (float): Bin size in seconds.
        regions (list): List of regions to include in the analysis.
    
    Returns:
        dict: A dictionary containing binned firing rates for each region.
    """
    binned_flash_starts, _ = np.histogram(flash_start_times, bins=np.arange(start_time, end_time, bin_size))
    binned_flash_ends, _ = np.histogram(flash_end_times, bins=np.arange(start_time, end_time, bin_size))

    region_data = {region: [] for region in regions}

    # print("spike times")
    # print(session.spike_times)
    # print("units")
    # print(session.units.columns)
    for unit_id, spikes in session.spike_times.items():
        region = session.units.loc[unit_id, 'ecephys_structure_acronym']
        if region in regions:
            binned_spikes, times = np.histogram(spikes, bins=np.arange(start_time, end_time, bin_size))
            # region_data[region].append(pd.Series(binned_spikes / bin_size, name=unit_id, index=times[:-1]))
            region_data[region].append(pd.Series(binned_spikes / bin_size, name=unit_id))

    # Average across units in each region or provide zeros where no data exists
    for region in regions:
        if region_data[region]:
            region_data[region] = pd.concat(region_data[region], axis=1).mean(axis=1)
        else:
            # Initialize the series with zeros for each time binS
            print("------------------------alert missing region data: ", region)
            # print(region)
            region_data[region] = pd.Series(np.zeros(len(np.arange(start_time, end_time, bin_size)[:-1])))

    return region_data, binned_flash_starts, binned_flash_ends


# def plot_firing_rates(region_data, bin_size, average_flash_starts, average_flash_ends):
#     """
#     Plot the firing rates over time for each region, ensuring data alignment.

#     Parameters:
#         region_data (dict): Dictionary containing firing rates binned over time for each region.
#         bin_size (float): Bin size used in seconds.
#     """
#     plt.figure(figsize=(12, 8))

#     for region, data in region_data.items():
#         if not data.empty:
#             # Generate time bins based on the data index
#             time_bins = np.linspace(data.index[0], data.index[-1], num=len(data))
            
#             # Plotting each region's data
#             plt.plot(time_bins, data.values, label=f'Region: {region}')
#         else:
#             # Handle cases where there is no data for a region
#             print(f"---------------No data available for region {region}")

#     plt.xlabel('Time (s)')
#     plt.ylabel('Firing Rate (spikes/s)')
#     plt.title('Average Firing Rate Over Time by Region Across Selected Sessions')
#     plt.legend()
#     plt.show()

# Main analysis for two selected sessions
sessions = cache.get_session_table()
brain_observatory_type_sessions = sessions[sessions["session_type"] == "brain_observatory_1.1"]
# selected_sessions = brain_observatory_type_sessions.head(600)  # Select only the first two sessions
selected_sessions = sessions.head(600)  # Select only the first two sessions
print(len(selected_sessions))

regions = ["LGd", "LGv", "LP", "VISp", "VISl", "VISal", "VISrl", "VISpm", "VISam"]
bin_size = 0.005  # Example bin size of 100ms

# Collect data from all selected sessions
per_unit_counts = {region: [] for region in regions}
all_region_data = {region: [] for region in regions}
blank_regions = {region: 0 for region in regions}
all_binned_flash_starts = []
all_binned_flash_ends = []
count = -1
for session_id in selected_sessions.index:
    count += 1

    # # skip known bad sessions that don't have all regions
    # if count in [0, 3, 4, 5, 6, 7]:
    #     continue



    session = cache.get_session_data(session_id)

    count_structure_acronyms = session.units['structure_acronym'].value_counts()
    for region in regions:
        per_unit_counts[region].append(count_structure_acronyms.get(region, 0))
        # print(f"Region: {region}, Count: {count_structure_acronyms.get(region, 0)}")

    start_time, end_time, flash_start_times, flash_end_times = get_time_window(session)
    print("start_time, end_time")
    print(start_time, end_time)
    region_data, binned_flash_starts, binned_flash_ends = extract_and_bin_spikes(session, start_time, end_time, bin_size, regions, flash_start_times, flash_end_times)

    if not all(len(df) == 20066 for df in region_data.values()):
        print("BIG PROBLEM BINS DON't MATCH: ", len(region_data["VISp"]))
        break

    for region in regions:
        if not region_data[region].any():
            print(f"{region} bumping empty count for region", region_data[region])
            blank_regions[region] += 1

    # should_continue = False
    # for region in regions:
    #     if not region_data[region].any():
    #         print(f"{region} region_data[region].any(): ", region_data[region])
    #         print("------skipping where no data: ", region)
    #         should_continue = True
    #         break
    # if should_continue:
    #     continue

    for region in regions:
        all_region_data[region].append(region_data[region])

    all_binned_flash_starts.append(binned_flash_starts)
    all_binned_flash_ends.append(binned_flash_ends)

    # # TODO: delete this  
    # if count == 2:
    #     break

# print(all_region_data["VISp"])

# print(f"region_data: { region_data }")


In [None]:
sessions["session_type"].value_counts()
connectivity_sessions = sessions[sessions["session_type"] == "functional_connectivity"]
connectivity_sessions.columns

In [None]:

# # Assuming all_region_data is correctly populated
# average_region_data = {region: pd.concat(all_region_data[region], axis=1).mean(axis=1) if all_region_data[region] else pd.Series() for region in regions}
# plot_firing_rates(average_region_data, bin_size)

# Assuming all_region_data is correctly populated
# print(all_region_data["VISp"])

def plot_firing_rates(region_data, bin_size, average_flash_starts, average_flash_ends, flash_start_offset):
    """
    Plot the firing rates over time for each region, ensuring data alignment,
    and add vertical lines for average flash starts based on a binary indicator.

    Parameters:
        region_data (dict): Dictionary containing firing rates binned over time for each region.
        bin_size (float): Bin size used in seconds.
        flash_indicator (array): Array of binary values indicating flash starts.
    """
    plt.figure(figsize=(12, 8))

    for region, data in region_data.items():
        if not data.empty:
            # Generate time bins based on the data index
            time_bins = np.linspace(data.index[0], data.index[-1], num=len(data))
            
            # Plotting each region's data
            plt.plot(time_bins, data.values, label=f'Region: {region}')
        else:
            # Handle cases where there is no data for a region
            print(f"No data available for region {region}")

    # Calculate the actual times of flash starts from the indicator
    # flash_starts = np.where(average_flash_starts == 1)[0] * bin_size

    # Plot vertical lines for average flash starts
    for bin_num, flash_start in enumerate(average_flash_starts):
        if flash_start == 1:
            plt.axvline(x=flash_start_offset, color='r', linestyle='--', label='Flash Start' if 'Flash Start' not in plt.gca().get_legend_handles_labels()[1] else "")
        
        flash_start_offset += 1

    plt.xlabel('Time (s)')
    plt.ylabel('Firing Rate (spikes/s)')
    plt.title('Average Firing Rate Over Time by Region Across Selected Sessions')
    plt.legend()
    plt.show()

truncate_pre = 0
truncate = 10000000000

all_binned_flash_starts_cp = []
for i in range(len(all_binned_flash_starts)):
    all_binned_flash_starts_cp.append(all_binned_flash_starts[i][truncate_pre:truncate])

average_flash_starts = np.mean(all_binned_flash_starts_cp, axis=0)
average_flash_ends = np.mean(all_binned_flash_ends, axis=0)
# print(average_flash_starts)

print(len(all_region_data["VISp"]))
print(len(all_region_data["VISp"][0]))
print()
flash_indices = np.where(average_flash_starts == 1)
print(flash_indices)
print(len(flash_indices))
print(len(flash_indices[0]))
num_flashes = len(flash_indices)
all_region_data_cp = {}
for region in regions:
    all_region_data_cp[region] = []
    for series in all_region_data[region]:
        all_region_data_cp[region].append(series[truncate_pre:truncate])

average_region_data_cp = {region: pd.concat(all_region_data_cp[region], axis=1).mean(axis=1) for region in regions}
# print(average_region_data["VISam"].head(100))
plot_firing_rates(average_region_data_cp, bin_size, average_flash_starts, average_flash_ends, flash_start_offset=truncate_pre)

In [None]:


from matplotlib.colors import LinearSegmentedColormap

def plot_firing_rates(region_data, bin_size, average_flash_starts, average_flash_ends, flash_offset, std_err_data):
    """
    Plot the firing rates over time for each region, ensuring data alignment,
    and add vertical lines for average flash starts and ends based on a binary indicator.

    Parameters:
        region_data (dict): Dictionary containing firing rates binned over time for each region.
        bin_size (float): Bin size used in seconds.
        average_flash_starts (array): Array of binary values indicating flash starts.
        average_flash_ends (array): Array of binary values indicating flash ends.
        flash_offset (float): Offset to align the flash indicators correctly with the time bins.
    """
    print(average_flash_starts - average_flash_ends)
    plt.figure(figsize=(12, 8))

    # Define a blue colormap that scales from dark to light
    colors = LinearSegmentedColormap.from_list("blue_grad", ["lightblue", "darkblue"], N=9)
    color_map = {region: colors(i) for i, region in enumerate(regions)}

    # Define the subset range you are interested in for debugging
    # start_bin = 230  # Change this to your start bin
    # end_bin = 280    # Change this to your end bin

    start_bin = 257 * 1
    end_bin = 280 * 1

    # start_bin = 0
    # end_bin = 1000

    regions_for_fn = ["VISp", "VISl", "VISal", "VISrl", "VISpm", "VISam"]

    for region_0 in regions_for_fn:

        for i, (region, data) in enumerate(region_data.items()):
            if region not in [region_0]:
            # if region not in ["VISal"]:
                continue

            if not data.empty:
                # Generate time bins based on the data index
                time_bins = np.linspace(data.index[0], data.index[-1], num=len(data))
                
                # Slice time bins and data for the range you are interested in
                sliced_time_bins = time_bins[start_bin:end_bin]
                sliced_data = data.values[start_bin:end_bin]

                # Get the standard error for the sliced data
                std_err = std_err_data[region].values[start_bin:end_bin]

                # TODO: assign variable that is std_err to `std_err`

                # Plotting each region's data with hierarchical color
                plt.plot(sliced_time_bins, sliced_data, label=f'Region: {region}', color=color_map[region])
                plt.fill_between(sliced_time_bins, sliced_data - std_err, sliced_data + std_err, alpha=0.3, color=color_map[region])
            else:
                # Handle cases where there is no data for a region
                print(f"No data available for region {region}")

        # Plot vertical lines for average flash starts and ends within the subset
        for bin_num, (flash_start, flash_end) in enumerate(zip(average_flash_starts[start_bin:end_bin], average_flash_ends[start_bin:end_bin])):
            if flash_start == 1:
                plt.axvline(x=flash_offset + bin_num + start_bin, color='r', linestyle='--', label='Flash Start' if 'Flash Start' not in plt.gca().get_legend_handles_labels()[1] else "")
            if flash_end == 1:
                plt.axvline(x=flash_offset + bin_num + start_bin, color='b', linestyle='--', label='Flash End' if 'Flash End' not in plt.gca().get_legend_handles_labels()[1] else "")

        plt.xlabel('Time (bins)')
        plt.ylabel('Firing Rate (spikes/s)')
        plt.title('Average Firing Rate Over Time by Region Across Selected Sessions')
        plt.legend()
        plt.show()

    
    for i, (region, data) in enumerate(region_data.items()):
        if region not in regions_for_fn:
            continue

        if not data.empty:
            # Generate time bins based on the data index
            time_bins = np.linspace(data.index[0], data.index[-1], num=len(data))
            
            # Slice time bins and data for the range you are interested in
            sliced_time_bins = time_bins[start_bin:end_bin]
            sliced_data = data.values[start_bin:end_bin]

            # Get the standard error for the sliced data
            std_err = std_err_data[region].values[start_bin:end_bin]

            # TODO: assign variable that is std_err to `std_err`

            # Plotting each region's data with hierarchical color
            plt.plot(sliced_time_bins, sliced_data, label=f'Region: {region}', color=color_map[region])
            plt.fill_between(sliced_time_bins, sliced_data - std_err, sliced_data + std_err, alpha=0.3, color=color_map[region])
        else:
            # Handle cases where there is no data for a region
            print(f"No data available for region {region}")

    # Plot vertical lines for average flash starts and ends within the subset
    for bin_num, (flash_start, flash_end) in enumerate(zip(average_flash_starts[start_bin:end_bin], average_flash_ends[start_bin:end_bin])):
        if flash_start == 1:
            plt.axvline(x=flash_offset + bin_num + start_bin, color='r', linestyle='--', label='Flash Start' if 'Flash Start' not in plt.gca().get_legend_handles_labels()[1] else "")
        if flash_end == 1:
            plt.axvline(x=flash_offset + bin_num + start_bin, color='b', linestyle='--', label='Flash End' if 'Flash End' not in plt.gca().get_legend_handles_labels()[1] else "")

    plt.xlabel('Time (bins)')
    plt.ylabel('Firing Rate (spikes/s)')
    plt.title('Average Firing Rate Over Time by Region Across Selected Sessions')
    plt.legend()
    plt.show()

pd.set_option('display.max_rows', None)
pd.set_option('display.max_columns', None)

regions = ["LGd", "LGv", "LP", "VISp", "VISl", "VISal", "VISrl", "VISpm", "VISam"]

# for region in regions:
# # blank_regions[region]) / count_structure_acronyms[region]
#     if region not in blank_regions:
#         blank_regions[region] = 0
#     if region not in count_structure_acronyms:
#         print("WARN: region not in count_structure_acronyms: ", region)
#         count_structure_acronyms[region] = 0

# determine from flash starts how many timesteps before first flash
num_time_steps_before_first_flash = np.where(average_flash_starts == 1)[0][0]
# print(num_time_steps_before_first_flash)

# determine from flash starts how many timesteps after last flash
num_time_steps_after_last_flash = len(average_flash_starts) - np.where(average_flash_starts == 1)[0][-1]
# print(num_time_steps_after_last_flash)


truncate_pre = 0
truncate = 100000

all_binned_flash_starts_cp = []
for i in range(len(all_binned_flash_starts)):
    all_binned_flash_starts_cp.append(all_binned_flash_starts[i][truncate_pre:truncate])

average_flash_starts = np.mean(all_binned_flash_starts_cp, axis=0)
average_flash_ends = np.mean(all_binned_flash_ends, axis=0)
# print(average_flash_starts)

# print(len(all_region_data["VISp"]))
# print(len(all_region_data["VISp"][0]))
# print()
all_region_data_cp = {}
for region in regions:
    all_region_data_cp[region] = []
    for series in all_region_data[region]:
        all_region_data_cp[region].append(series[truncate_pre:truncate])

# average_region_data_cp = {region: pd.concat(all_region_data_cp[region], axis=1).mean(axis=1) for region in regions}

    # iterate through all region data cp
        # for each region data cp[region][session][index] divide by number of units

# normalize based on number of units
normalized_region_data_cp = {}
std_err_data = {}
for region in all_region_data_cp:
    region_data_normalized_by_nuits = []
    for i, session_data in enumerate(all_region_data_cp[region]):
        build = []
        for measurement in session_data:
            if per_unit_counts[region][i] == 0:
                assert measurement == 0
                build.append(0)
            else:
                build.append(measurement/per_unit_counts[region][i])
        region_data_normalized_by_nuits.append(pd.Series(build))
    normalized_region_data_cp[region] = region_data_normalized_by_nuits

    # TODO: perform this calculation without zeros in the dataset
    # Calculate the standard error after normalization and before averaging
    concatenated_normalized_data = pd.concat(normalized_region_data_cp[region], axis=1)
    non_zero_columns = concatenated_normalized_data.loc[:, (concatenated_normalized_data != 0).any(axis=0)]
    # print(concatenated_normalized_data.shape)
    # print(non_zero_columns.shape)

    # std_err_data[region] = non_zero_columns

    # std_err_data[region] = non_zero_columns.std(axis=1) / np.sqrt(non_zero_columns.shape[1])
    std_err_data[region] = non_zero_columns

    # 
    # print("-----------")
    # print(std_err_data[region].shape)
    # # print(std_err_data[region][0])
    # # print(std_err_data[region][1])
    # print(concatenated_normalized_data.shape[1])
    # print(concatenated_normalized_data.std(axis=1).shape)
    # print("++++++++")
    # print(normalized_region_data_cp[region][0].shape)
    # print(std_err_data[region][0])
    # print(normalized_region_data_cp[region][0].shape)


# account for blank regions
average_region_data_cp = {}
for region in normalized_region_data_cp:
    num_sessions = len(normalized_region_data_cp[region])
    # average_region_data_cp[region] = pd.concat(normalized_region_data_cp[region], axis=1).sum(axis=1)
    average_region_data_cp[region] = pd.concat(normalized_region_data_cp[region], axis=1)
    non_zero_columns_blank_calcs = average_region_data_cp[region].loc[:, (average_region_data_cp[region] != 0).any(axis=0)]
    average_region_data_cp[region] = non_zero_columns_blank_calcs.sum(axis=1) / non_zero_columns_blank_calcs.shape[1]
    # print(f"num sessions / blank regions: {num_sessions} / {blank_regions[region]}")
    # for i in range(0, len(average_region_data_cp[region])):
    #     average_region_data_cp[region][i] = average_region_data_cp[region][i] / (num_sessions - blank_regions[region])

flash_indices = np.where(average_flash_starts == 1)[0]
num_flashes = len(flash_indices)
print("num flashes: ", num_flashes)
chunk_size = num_time_steps_before_first_flash + num_time_steps_after_last_flash

print("Number of time steps before first flash:", num_time_steps_before_first_flash)
print("Number of time steps after last flash:", num_time_steps_after_last_flash)
print("Chunk size:", chunk_size)

print("Length of average_region_data_cp['VISam']:", len(average_region_data_cp["VISam"]))



overlaid_region_data = {}
for region in regions:
    overlaid_chunks = []
    # IMPORTANT: how flashes to ignore
    for i in range(3, num_flashes):
        start_index = flash_indices[i] - num_time_steps_before_first_flash
        end_index = start_index + chunk_size
        chunk = average_region_data_cp[region].iloc[start_index:end_index].reset_index(drop=True)  # Reset index here
        overlaid_chunks.append(chunk)

    overlaid_chunks_concat = pd.concat(overlaid_chunks, axis=1)
    num_chunks = overlaid_chunks_concat.shape[1]
    overlaid_chunks_sum = overlaid_chunks_concat.sum(axis=1)
    for i in range(len(overlaid_chunks_sum)):
        overlaid_chunks_sum[i] = overlaid_chunks_sum[i] / (num_chunks)

    overlaid_region_data[region] = overlaid_chunks_sum

    # trim the last few elements off of std_err_data for region
    std_err_data[region] = std_err_data[region].iloc[0:chunk_size * num_flashes]

    # calculate std err by overlapping chunks
    # print(std_err_data[region].shape)
    # print(std_err_data[region].shape[0] % chunk_size)
    # print(std_err_data[region].shape[1])
    std_err_data[region] = std_err_data[region].values.reshape(-1, chunk_size, std_err_data[region].shape[1])
    std_err_data[region] = std_err_data[region].transpose(2, 0, 1)
    shape_for_std_err_calc = std_err_data[region].shape
    std_err_data[region] = std_err_data[region].reshape(shape_for_std_err_calc[0] * shape_for_std_err_calc[1], shape_for_std_err_calc[2])
    std_err_data[region] = pd.DataFrame(std_err_data[region])
    print("---")
    print(std_err_data[region].shape)
    std_err_data[region] = std_err_data[region].std(axis=0) / np.sqrt(std_err_data[region].shape[0])
    print(std_err_data[region].shape)
    print("---")

# print("VISpm: ", overlaid_region_data["VISpm"])

# TODO: This needs to move on top in beginning before we calculate std_err!!!!!!!!
# # by region, average overlaid region data by the average over timesteps 0-190
# region_averages = {}
# end_average_timestep = 180
# for region in regions:
#     running_sum = 0
#     for i in range(0, end_average_timestep):
#         running_sum += overlaid_region_data[region][i]
#     average = running_sum / end_average_timestep
#     region_averages[region] = average

# # subtract by the average over timesteps 0-190
# for region in regions:
#     for i in range(0, len(overlaid_region_data[region])):
#         overlaid_region_data[region][i] = overlaid_region_data[region][i] - region_averages[region]


average_flash_starts = average_flash_starts[0:chunk_size]
average_flash_ends = average_flash_ends[0:chunk_size]
plot_firing_rates(overlaid_region_data, bin_size, average_flash_starts, average_flash_ends, flash_offset=truncate_pre, std_err_data=std_err_data)

def find_peak_amplitude(data, margin):
    max_value = data.max()
    peak_amplitude = max_value
    for i in range(len(data) - 1, -1, -1):
        if data.iloc[i] >= max_value - margin:
            peak_amplitude = data.iloc[i]
            break
    return peak_amplitude

def find_time_to_percent_peak(data, peak_value, percent, start_bin, end_bin):
    target_value = peak_value * percent
    peak_reached = False
    for i in range(len(data)):
        if not peak_reached and data.iloc[i] == peak_value:
            peak_reached = True
        if peak_reached and data.iloc[i] <= target_value:
            return data.index[i]
    return -1

start_bin = 257
end_bin = 280
margin = 0.0125  # Adjust this value according to your needs

plt.figure(figsize=(12, 8))

percentages = np.arange(0.50, 0.96, 0.05)
bar_width = 0.15
opacity = 0.8

regions = ["VISp", "VISl", "VISal", "VISrl", "VISpm", "VISam"]
colors = plt.cm.viridis(np.linspace(0, 1, len(regions)))

for i, region in enumerate(regions):
    data = overlaid_region_data[region]
    sliced_data = data[start_bin:end_bin]
    peak_amplitude = find_peak_amplitude(sliced_data, margin)
    time_to_percent_peak = []

    for percent in percentages:
        time = find_time_to_percent_peak(sliced_data, peak_amplitude, percent, start_bin, end_bin)
        if time != -1:
            time_to_percent_peak.append((time - start_bin) * bin_size)
        else:
            time_to_percent_peak.append(np.nan)

    index = np.arange(len(percentages))
    plt.bar(index + i * bar_width, time_to_percent_peak, bar_width, alpha=opacity, color=colors[i], label=f'Region: {region}')

plt.xlabel('Percentage of Peak Amplitude (%)')
plt.ylabel('Time After Stimulus (s)')
plt.title('Time to Reach Percentage of Peak Amplitude by Region')
plt.xticks(index + bar_width * (len(regions) - 1) / 2, (percentages * 100).astype(int))
plt.legend()
plt.tight_layout()
plt.show()


In [None]:


from matplotlib.colors import LinearSegmentedColormap

def plot_firing_rates(region_data, bin_size, average_flash_starts, average_flash_ends, flash_offset, std_err_data):
    """
    Plot the firing rates over time for each region, ensuring data alignment,
    and add vertical lines for average flash starts and ends based on a binary indicator.

    Parameters:
        region_data (dict): Dictionary containing firing rates binned over time for each region.
        bin_size (float): Bin size used in seconds.
        average_flash_starts (array): Array of binary values indicating flash starts.
        average_flash_ends (array): Array of binary values indicating flash ends.
        flash_offset (float): Offset to align the flash indicators correctly with the time bins.
    """
    print(average_flash_starts - average_flash_ends)
    plt.figure(figsize=(12, 8))

    # Define a blue colormap that scales from dark to light
    colors = LinearSegmentedColormap.from_list("blue_grad", ["lightblue", "darkblue"], N=9)
    color_map = {region: colors(i) for i, region in enumerate(regions)}

    # Define the subset range you are interested in for debugging
    # start_bin = 230  # Change this to your start bin
    # end_bin = 280    # Change this to your end bin

    start_bin = 257 * 1
    end_bin = 280 * 1

    # start_bin = 0
    # end_bin = 1000

    regions_for_fn = ["VISp", "VISl", "VISal", "VISrl", "VISpm", "VISam"]

    for region_0 in regions_for_fn:

        for i, (region, data) in enumerate(region_data.items()):
            if region not in [region_0]:
            # if region not in ["VISal"]:
                continue

            if not data.empty:
                # Generate time bins based on the data index
                time_bins = np.linspace(data.index[0], data.index[-1], num=len(data))
                
                # Slice time bins and data for the range you are interested in
                sliced_time_bins = time_bins[start_bin:end_bin]
                sliced_data = data.values[start_bin:end_bin]

                # Get the standard error for the sliced data
                std_err = std_err_data[region].values[start_bin:end_bin]

                # TODO: assign variable that is std_err to `std_err`

                # Plotting each region's data with hierarchical color
                plt.plot(sliced_time_bins, sliced_data, label=f'Region: {region}', color=color_map[region])
                plt.fill_between(sliced_time_bins, sliced_data - std_err, sliced_data + std_err, alpha=0.3, color=color_map[region])
            else:
                # Handle cases where there is no data for a region
                print(f"No data available for region {region}")

        # Plot vertical lines for average flash starts and ends within the subset
        for bin_num, (flash_start, flash_end) in enumerate(zip(average_flash_starts[start_bin:end_bin], average_flash_ends[start_bin:end_bin])):
            if flash_start == 1:
                plt.axvline(x=flash_offset + bin_num + start_bin, color='r', linestyle='--', label='Flash Start' if 'Flash Start' not in plt.gca().get_legend_handles_labels()[1] else "")
            if flash_end == 1:
                plt.axvline(x=flash_offset + bin_num + start_bin, color='b', linestyle='--', label='Flash End' if 'Flash End' not in plt.gca().get_legend_handles_labels()[1] else "")

        plt.xlabel('Time (bins)')
        plt.ylabel('Firing Rate (spikes/s)')
        plt.title('Average Firing Rate Over Time by Region Across Selected Sessions')
        plt.legend()
        plt.show()

    
    for i, (region, data) in enumerate(region_data.items()):
        if region not in regions_for_fn:
            continue

        if not data.empty:
            # Generate time bins based on the data index
            time_bins = np.linspace(data.index[0], data.index[-1], num=len(data))
            
            # Slice time bins and data for the range you are interested in
            sliced_time_bins = time_bins[start_bin:end_bin]
            sliced_data = data.values[start_bin:end_bin]

            # Get the standard error for the sliced data
            std_err = std_err_data[region].values[start_bin:end_bin]

            # TODO: assign variable that is std_err to `std_err`

            # Plotting each region's data with hierarchical color
            plt.plot(sliced_time_bins, sliced_data, label=f'Region: {region}', color=color_map[region])
            plt.fill_between(sliced_time_bins, sliced_data - std_err, sliced_data + std_err, alpha=0.3, color=color_map[region])
        else:
            # Handle cases where there is no data for a region
            print(f"No data available for region {region}")

    # Plot vertical lines for average flash starts and ends within the subset
    for bin_num, (flash_start, flash_end) in enumerate(zip(average_flash_starts[start_bin:end_bin], average_flash_ends[start_bin:end_bin])):
        if flash_start == 1:
            plt.axvline(x=flash_offset + bin_num + start_bin, color='r', linestyle='--', label='Flash Start' if 'Flash Start' not in plt.gca().get_legend_handles_labels()[1] else "")
        if flash_end == 1:
            plt.axvline(x=flash_offset + bin_num + start_bin, color='b', linestyle='--', label='Flash End' if 'Flash End' not in plt.gca().get_legend_handles_labels()[1] else "")

    plt.xlabel('Time (bins)')
    plt.ylabel('Firing Rate (spikes/s)')
    plt.title('Average Firing Rate Over Time by Region Across Selected Sessions')
    plt.legend()
    plt.show()

pd.set_option('display.max_rows', None)
pd.set_option('display.max_columns', None)

regions = ["LGd", "LGv", "LP", "VISp", "VISl", "VISal", "VISrl", "VISpm", "VISam"]

# for region in regions:
# # blank_regions[region]) / count_structure_acronyms[region]
#     if region not in blank_regions:
#         blank_regions[region] = 0
#     if region not in count_structure_acronyms:
#         print("WARN: region not in count_structure_acronyms: ", region)
#         count_structure_acronyms[region] = 0

# determine from flash starts how many timesteps before first flash
num_time_steps_before_first_flash = np.where(average_flash_starts == 1)[0][0]
# print(num_time_steps_before_first_flash)

# determine from flash starts how many timesteps after last flash
num_time_steps_after_last_flash = len(average_flash_starts) - np.where(average_flash_starts == 1)[0][-1]
# print(num_time_steps_after_last_flash)


truncate_pre = 0
truncate = 100000

all_binned_flash_starts_cp = []
for i in range(len(all_binned_flash_starts)):
    all_binned_flash_starts_cp.append(all_binned_flash_starts[i][truncate_pre:truncate])

average_flash_starts = np.mean(all_binned_flash_starts_cp, axis=0)
average_flash_ends = np.mean(all_binned_flash_ends, axis=0)
# print(average_flash_starts)

# print(len(all_region_data["VISp"]))
# print(len(all_region_data["VISp"][0]))
# print()
all_region_data_cp = {}
for region in regions:
    all_region_data_cp[region] = []
    for series in all_region_data[region]:
        all_region_data_cp[region].append(series[truncate_pre:truncate])

# average_region_data_cp = {region: pd.concat(all_region_data_cp[region], axis=1).mean(axis=1) for region in regions}

    # iterate through all region data cp
        # for each region data cp[region][session][index] divide by number of units

# normalize based on number of units
normalized_region_data_cp = {}
std_err_data = {}
for region in all_region_data_cp:
    region_data_normalized_by_nuits = []
    for i, session_data in enumerate(all_region_data_cp[region]):
        build = []
        for measurement in session_data:
            if per_unit_counts[region][i] == 0:
                assert measurement == 0
                build.append(0)
            else:
                build.append(measurement/per_unit_counts[region][i])
        region_data_normalized_by_nuits.append(pd.Series(build))
    normalized_region_data_cp[region] = region_data_normalized_by_nuits

    # TODO: perform this calculation without zeros in the dataset
    # Calculate the standard error after normalization and before averaging
    concatenated_normalized_data = pd.concat(normalized_region_data_cp[region], axis=1)
    non_zero_columns = concatenated_normalized_data.loc[:, (concatenated_normalized_data != 0).any(axis=0)]
    print(concatenated_normalized_data.shape)
    print(non_zero_columns.shape)
    std_err_data[region] = non_zero_columns.std(axis=1) / np.sqrt(non_zero_columns.shape[1])
    print("-----------")
#     -----------
# (12059,)
# (12059, 58)
# 58
# (12059,)
# ++++++++
    print(std_err_data[region].shape)
    # print(std_err_data[region][0])
    # print(std_err_data[region][1])
    print(concatenated_normalized_data.shape[1])
    print(concatenated_normalized_data.std(axis=1).shape)
    print("++++++++")
    print(normalized_region_data_cp[region][0].shape)
    print(std_err_data[region][0])
    print(normalized_region_data_cp[region][0].shape)


# account for blank regions
average_region_data_cp = {}
for region in normalized_region_data_cp:
    num_sessions = len(normalized_region_data_cp[region])
    average_region_data_cp[region] = pd.concat(normalized_region_data_cp[region], axis=1).sum(axis=1)
    print(f"num sessions / blank regions: {num_sessions} / {blank_regions[region]}")
    for i in range(0, len(average_region_data_cp[region])):
        average_region_data_cp[region][i] = average_region_data_cp[region][i] / (num_sessions - blank_regions[region])

flash_indices = np.where(average_flash_starts == 1)[0]
num_flashes = len(flash_indices)
print("num flashes: ", num_flashes)
chunk_size = num_time_steps_before_first_flash + num_time_steps_after_last_flash

print("Number of time steps before first flash:", num_time_steps_before_first_flash)
print("Number of time steps after last flash:", num_time_steps_after_last_flash)
print("Chunk size:", chunk_size)

print("Length of average_region_data_cp['VISam']:", len(average_region_data_cp["VISam"]))



overlaid_region_data = {}
for region in regions:
    overlaid_chunks = []
    # IMPORTANT: how flashes to ignore
    for i in range(3, num_flashes):
        start_index = flash_indices[i] - num_time_steps_before_first_flash
        end_index = start_index + chunk_size
        chunk = average_region_data_cp[region].iloc[start_index:end_index].reset_index(drop=True)  # Reset index here
        overlaid_chunks.append(chunk)

    overlaid_chunks_concat = pd.concat(overlaid_chunks, axis=1)
    num_chunks = overlaid_chunks_concat.shape[1]
    overlaid_chunks_sum = overlaid_chunks_concat.sum(axis=1)
    for i in range(len(overlaid_chunks_sum)):
        overlaid_chunks_sum[i] = overlaid_chunks_sum[i] / (num_chunks)

    overlaid_region_data[region] = overlaid_chunks_sum

print("VISpm: ", overlaid_region_data["VISpm"])

# TODO: This needs to move on top in beginning before we calculate std_err!!!!!!!!
# # by region, average overlaid region data by the average over timesteps 0-190
# region_averages = {}
# end_average_timestep = 180
# for region in regions:
#     running_sum = 0
#     for i in range(0, end_average_timestep):
#         running_sum += overlaid_region_data[region][i]
#     average = running_sum / end_average_timestep
#     region_averages[region] = average

# # subtract by the average over timesteps 0-190
# for region in regions:
#     for i in range(0, len(overlaid_region_data[region])):
#         overlaid_region_data[region][i] = overlaid_region_data[region][i] - region_averages[region]


average_flash_starts = average_flash_starts[0:chunk_size]
average_flash_ends = average_flash_ends[0:chunk_size]
plot_firing_rates(overlaid_region_data, bin_size, average_flash_starts, average_flash_ends, flash_offset=truncate_pre, std_err_data=std_err_data)

def find_peak_amplitude(data, margin):
    max_value = data.max()
    peak_amplitude = max_value
    for i in range(len(data) - 1, -1, -1):
        if data.iloc[i] >= max_value - margin:
            peak_amplitude = data.iloc[i]
            break
    return peak_amplitude

def find_time_to_percent_peak(data, peak_value, percent, start_bin, end_bin):
    target_value = peak_value * percent
    peak_reached = False
    for i in range(len(data)):
        if not peak_reached and data.iloc[i] == peak_value:
            peak_reached = True
        if peak_reached and data.iloc[i] <= target_value:
            return data.index[i]
    return -1

start_bin = 257
end_bin = 280
margin = 0.0125  # Adjust this value according to your needs

plt.figure(figsize=(12, 8))

percentages = np.arange(0.50, 0.96, 0.05)
bar_width = 0.15
opacity = 0.8

regions = ["VISp", "VISl", "VISal", "VISrl", "VISpm", "VISam"]
colors = plt.cm.viridis(np.linspace(0, 1, len(regions)))

for i, region in enumerate(regions):
    data = overlaid_region_data[region]
    sliced_data = data[start_bin:end_bin]
    peak_amplitude = find_peak_amplitude(sliced_data, margin)
    time_to_percent_peak = []

    for percent in percentages:
        time = find_time_to_percent_peak(sliced_data, peak_amplitude, percent, start_bin, end_bin)
        if time != -1:
            time_to_percent_peak.append((time - start_bin) * bin_size)
        else:
            time_to_percent_peak.append(np.nan)

    index = np.arange(len(percentages))
    plt.bar(index + i * bar_width, time_to_percent_peak, bar_width, alpha=opacity, color=colors[i], label=f'Region: {region}')

plt.xlabel('Percentage of Peak Amplitude (%)')
plt.ylabel('Time After Stimulus (s)')
plt.title('Time to Reach Percentage of Peak Amplitude by Region')
plt.xticks(index + bar_width * (len(regions) - 1) / 2, (percentages * 100).astype(int))
plt.legend()
plt.tight_layout()
plt.show()


In [None]:
import matplotlib.pyplot as plt

# Assuming `flashes_df` is your DataFrame from `session.get_stimulus_table("flashes")`
flashes_df = session.get_stimulus_table("flashes")

# Plotting the start and stop times of flashes
plt.figure(figsize=(10, 5))
plt.plot(flashes_df['start_time'][0:10], np.ones_like(flashes_df['start_time'][0:10]), 'go', label='Start Time')
plt.plot(flashes_df['stop_time'][0:10], np.ones_like(flashes_df['stop_time'][0:10]), 'ro', label='Stop Time')
plt.legend()
plt.xlabel('Time (s)')
plt.title('Start and Stop Times of Flash Stimuli')
plt.yticks([])  # Hide y-axis labels as they are not informative here
plt.show()

In [None]:
session.stimulus_presentations.columns

In [None]:
session.units.head()

In [None]:
from sklearn import svm
from sklearn.model_selection import KFold
from sklearn.metrics import confusion_matrix

In [None]:
design_arr = design.values.astype(float)
targets_arr = targets.values.astype(int)

labels = np.unique(targets_arr)

In [None]:
accuracies = []
confusions = []

for train_indices, test_indices in KFold(n_splits=5).split(design_arr):
    
    clf = svm.SVC(gamma="scale", kernel="rbf")
    clf.fit(design_arr[train_indices], targets_arr[train_indices])
    
    test_targets = targets_arr[test_indices]
    test_predictions = clf.predict(design_arr[test_indices])
    
    accuracy = 1 - (np.count_nonzero(test_predictions - test_targets) / test_predictions.size)
    print(accuracy)
    
    accuracies.append(accuracy)
    confusions.append(confusion_matrix(y_true=test_targets, y_pred=test_predictions, labels=labels))

In [None]:
print(f"mean accuracy: {np.mean(accuracy)}")
print(f"chance: {1/labels.size}")

In [None]:
mean_confusion = np.mean(confusions, axis=0)

fig, ax = plt.subplots(figsize=(8, 8))

img = ax.imshow(mean_confusion)
fig.colorbar(img)

ax.set_ylabel("actual")
ax.set_xlabel("predicted")

plt.show()

In [None]:
best = labels[np.argmax(np.diag(mean_confusion))]
worst = labels[np.argmin(np.diag(mean_confusion))]

fig, ax = plt.subplots(1, 2, figsize=(16, 8))

best_image = cache.get_natural_scene_template(best)
ax[0].imshow(best_image, cmap=plt.cm.gray)
ax[0].set_title("most decodable", fontsize=24)

worst_image = cache.get_natural_scene_template(worst)
ax[1].imshow(worst_image, cmap=plt.cm.gray)
ax[1].set_title("least decodable", fontsize=24)


plt.show()