In [38]:
import matplotlib
import numpy as np
import pandas as pd
from matplotlib.path import Path
import os
import spikeinterface.extractors as se
from utilities.platforms_utils import get_platform_center, calculate_occupancy_plats, get_hd_distr_allplats, get_firing_rate_platforms, get_norm_hd_distr
from utilities.restrict_spiketrain_specialbehav import restrict_spiketrain_specialbehav
from calculate_occupancy import get_direction_bins
from population_sink.get_relDirDist import calculate_relDirDist
from population_sink.calculate_MRLval import mrlData
from population_sink.plot_plat_info import plot_plat_info
from utilities.mrl_func import resultant_vector_length
from utilities.utils import get_unit_ids
from astropy.stats import circmean
from utilities.load_and_save_data import load_pickle, save_pickle
from utilities.trials_utils import get_limits_from_json, get_goal_numbers, get_coords
import matplotlib.pyplot as plt
from matplotlib.patches import RegularPolygon
from tqdm import tqdm
matplotlib.use("TkAgg")
import pickle
import datetime


In [43]:
derivatives_base = r"S:\Honeycomb_maze_task\derivatives\sub-002_id-1R\ses-02_date-11092025\all_trials"
data_folder = os.path.join(derivatives_base, 'analysis', 'cell_characteristics', 'spatial_features', 'popsink_data_newmethod')
name = 'platform_occupancy_allgoals.pickle'
platform_occupancy_allgoals = load_pickle(name, data_folder)
limits = get_limits_from_json(derivatives_base)
sink_bins = load_pickle('sink_bins', data_folder)
direction_bins = load_pickle('direction_bins', data_folder)
goal_numbers = get_goal_numbers(derivatives_base)
hcoord, vcoord = get_coords(derivatives_base)
g1_info = load_pickle('session_vars_goal1_pyramidal.pkl', data_folder)
g2_info = load_pickle('session_vars_goal2_pyramidal.pkl', data_folder)
g3_info = load_pickle('session_vars_goal3_pyramidal.pkl', data_folder)

In [40]:
def plot_platform_occupancy(platform_occupancy_allgoals, hcoord, vcoord, limits, run_zero = True, plot_name = 'Platform Occupancy all goals', frame_rate=25):
    """ Plots occupancy for all platforms for each goal """
    x_min, x_max, y_min, y_max = limits
    
    if not run_zero:
        fig, axs = plt.subplots(1, 3, figsize=(24, 8))
    else:
        fig, axs = plt.subplots(1,4, figsize=(24,6))
    axs = axs.flatten()
    
    cmap = plt.get_cmap('RdYlGn')
    
    for j in range(4): # j = 0, rat going to g2 during g1. j = 1, goal 1. j = 2, goal 2, j = 3 full trial
        if not run_zero and j == 0:
            continue
        ax = axs[j - (0 if run_zero else 1)]
        occupancy = platform_occupancy_allgoals[j  - (0 if run_zero else 1)]
        occupancy = [el/frame_rate for el in occupancy]
        occupancy_normalized = occupancy / np.nanmax(occupancy)
        for i, (x, y) in enumerate(zip(hcoord, vcoord)):
            if occupancy_normalized[i] == 0:
                colour = 'grey'
                text = ''
            else:
                colour = cmap(occupancy_normalized[i])
                text = f'{np.int32(occupancy[i])}'

            hex = RegularPolygon((x, y), numVertices=6, radius=87.,
                                orientation=np.radians(28),  # Rotate hexagons to align with grid
                                facecolor=colour, alpha=0.2, edgecolor='k')
            ax.text(x, y, text, ha='center', va='center', size=15)  # Start numbering from 1
            ax.add_patch(hex)

        # Also add scatter points in hexagon centres
        ax.scatter(hcoord, vcoord, alpha=0, c = 'grey')
        # plot the goal positions
        ax.set_xlim([x_min, x_max])
        ax.set_ylim([y_max, y_min])
        ax.set_aspect('equal')
        
        # Add small text with MRL and angle on bottom
        if j == 3:
            title = 'Occupancy all trials (s)'
        elif j != 0:
            title = f'Occupancy goal {j} (s)'
        else:
            title = 'Occupancy going to G2 during G1 (s)'
        ax.set_title(title)
    plt.show()

In [41]:
plot_platform_occupancy(platform_occupancy_allgoals, hcoord, vcoord, limits, run_zero = False, plot_name = 'Platform Occupancy all goals')

In [54]:
def plot_firing_rate(firing_rates, hcoord, vcoord, limits, run_zero = True, plot_name = 'Firing Rate all goals'):
    """ Plots firing rate for all platforms for each goal """
    x_min, x_max, y_min, y_max = limits
    if not run_zero:
        fig, axs = plt.subplots(1, 3, figsize=(24, 8))
    else:
        fig, axs = plt.subplots(1,4, figsize=(24,6))
    axs = axs.flatten()
    
    cmap = plt.get_cmap('RdYlGn')
    
    for j in range(4): # j = 0, rat going to g2 during g1. j = 1, goal 1. j = 2, goal 2, j = 3 full trial
        if not run_zero and j == 0:
            continue
        ax = axs[j - (0 if run_zero else 1)]
        frates = firing_rates[j  - (0 if run_zero else 1)]
        mean_frates = np.nanmean(frates, axis = 0)
        frates_normalized = mean_frates / np.nanmax(mean_frates)
        print(len(mean_frates))
        for i, (x, y) in enumerate(zip(hcoord, vcoord)):
            if frates_normalized[i] == 0:
                colour = 'grey'
                text = ''
            else:
                colour = cmap(frates_normalized[i])
                text = f'{mean_frates[i]:.2f}'
            hex = RegularPolygon((x, y), numVertices=6, radius=87.,
                                orientation=np.radians(28),  # Rotate hexagons to align with grid
                                facecolor=colour, alpha=0.2, edgecolor='k')
            ax.text(x, y, text, ha='center', va='center', size=15)  # Start numbering from 1
            ax.add_patch(hex)

        # Also add scatter points in hexagon centres
        ax.scatter(hcoord, vcoord, alpha=0, c = 'grey')
        # plot the goal positions
        ax.set_xlim([x_min, x_max])
        ax.set_ylim([y_max, y_min])
        ax.set_aspect('equal')
        
        # Add small text with MRL and angle on bottom
        if j == 3:
            title = 'Mean firing rate all trials (Hz)'
        elif j != 0:
            title = f'Mean firing rate goal {j} (Hz)'
        else:
            title = 'Mean firing rate going to G2 during G1 (Hz)'
        ax.set_title(title)
    plt.show()

In [55]:
firing_rates = [g1_info['allfiring_rates'], g2_info['allfiring_rates'], g3_info['allfiring_rates']]
plot_firing_rate(firing_rates, hcoord, vcoord, limits, run_zero = False, plot_name = 'Firing Rate all goals')

61
61
61


In [70]:
run_zero = False
x_min, x_max, y_min, y_max = limits
if not run_zero:
    fig, axs = plt.subplots(1, 3, figsize=(24, 8))
else:
    fig, axs = plt.subplots(1,4, figsize=(24,6))
axs = axs.flatten()

cmap = plt.get_cmap('RdYlGn')

for j in range(4): # j = 0, rat going to g2 during g1. j = 1, goal 1. j = 2, goal 2, j = 3 full trial
    if not run_zero and j == 0:
        continue
    ax = axs[j - (0 if run_zero else 1)]
    frates = firing_rates[j  - (0 if run_zero else 1)]
    frates = np.array(frates)
    frates_T = frates.T
    frates_T  = [el[el != 0] for el in frates_T]
    # remove nans
    frates_T = [el[~np.isnan(el)] for el in frates_T]
    mean_frates = [np.mean(el) if len(el) > 0 else 0 for el in frates_T]
    frates_normalized = mean_frates / np.nanmax(mean_frates)
    
    
    for i, (x, y) in enumerate(zip(hcoord, vcoord)):
        if frates_normalized[i] == 0:
            colour = 'grey'
            text = ''
        else:
            colour = cmap(frates_normalized[i])
            text = f'{np.nanmax(frates_T[i]):.1f}'
        hex = RegularPolygon((x, y), numVertices=6, radius=87.,
                            orientation=np.radians(28),  # Rotate hexagons to align with grid
                            facecolor=colour, alpha=0.2, edgecolor='k')
        ax.text(x, y, text, ha='center', va='center', size=15)  # Start numbering from 1
        ax.add_patch(hex)

    # Also add scatter points in hexagon centres
    ax.scatter(hcoord, vcoord, alpha=0, c = 'grey')
    # plot the goal positions
    ax.set_xlim([x_min, x_max])
    ax.set_ylim([y_max, y_min])
    ax.set_aspect('equal')
    
    # Add small text with MRL and angle on bottom
    if j == 3:
        title = 'Mean firing rate all trials (Hz)'
    elif j != 0:
        title = f'Mean firing rate goal {j} (Hz)'
    else:
        title = 'Mean firing rate going to G2 during G1 (Hz)'
    ax.set_title(title)
plt.show()

In [66]:
frates_T

[array([0.64285714, 0.07142857, 0.14285714, 0.03571429, 0.07142857,
        0.60714286, 1.17857143, 0.28571429, 0.14285714, 0.03571429,
        0.5       , 0.03571429, 0.07142857, 0.10714286, 0.03571429,
        0.03571429, 0.57142857, 0.14285714, 0.17857143, 0.03571429,
        0.03571429, 0.07142857, 0.28571429, 0.53571429, 0.14285714,
        0.67857143, 0.07142857, 0.10714286, 0.07142857]),
 array([0.0021692 , 0.14533623, 0.02819957, 0.01735358, 0.00650759,
        0.02386117, 0.00433839, 0.164859  , 0.29934924, 0.06724512,
        0.00433839, 0.01952278, 0.0021692 , 0.02169197, 0.36008677,
        0.00650759, 0.0021692 , 0.0867679 , 0.0021692 , 0.02169197,
        0.05856833, 0.01952278, 0.19739696, 0.0867679 , 0.10629067,
        0.01518438, 0.04555315, 0.06073753, 0.01518438, 0.1626898 ,
        0.164859  , 0.22993492, 0.01952278, 0.00867679, 0.0021692 ,
        0.01084599, 0.10412148, 0.03904555, 0.06290672, 0.01301518,
        0.01735358, 0.28850325, 0.01301518, 0.59219089, 0.