<font size=7 face="courier">Cart-Pole Analysis Source Code

This is the code used to create the diagrams in the notebook, `cartpole_analysis.ipynb`. I recommend you at least go through some of these functions to understand how to implement your own analysis functions.

In [None]:
print("Loading: final_analysis_source_code.ipynb...")

hello


Import packages

In [None]:
import pandas as pd
import numpy as np
import pickle
import pytz
import sys
import json
import zipfile
import glob
import argparse  # Import argparse for command-line arguments
from datetime import datetime

import matplotlib.pyplot as plt
import matplotlib.colors as mcolors # For color normalization

from pathlib import Path
from braingeneers.analysis.analysis import SpikeData

import scipy.io as sio
import scipy
from scipy.ndimage import gaussian_filter1d

import spikedata
from spikedata.spikedata import SpikeData
#from braindance.analysis.mapping import Mapping # Moved import higher

Here is some minor stuff to help set up the notebook :

In [None]:
Timezone = pytz.timezone("America/Los_Angeles")

# <font color="blue"> Loading Data

In [None]:
def load_pickle(p):
    with open(p, 'rb') as f:
        data = pickle.load(f)
    return data

def get_save_path(filename_base):
    """Generates a unique save path, appending a number if the file exists."""
    save_path = output_dir / f"{filename_base}.png"
    if save_path.exists():
        n_existing = len(list(output_dir.glob(f"{filename_base}_*.png")))
        save_path = output_dir / f"{filename_base}_{n_existing+1}.png"
    return save_path


In [None]:
dataset_dir = Path("data")
metadata_path = dataset_dir / 'metadata.pkl'
baseline_sd_path = dataset_dir / 'baseline_spike_data.pkl'
adaptive_sd_path = dataset_dir / 'exp1_cartpole_long_6_spike_data.pkl'
random_sd_path = dataset_dir / 'exp1_cartpole_long_7_spike_data.pkl'
none_sd_path = dataset_dir / 'exp1_cartpole_long_8_spike_data.pkl'
adaptive_logs_path = dataset_dir / 'exp1_cartpole_long_6_logs.pkl'
random_logs_path = dataset_dir / 'exp1_cartpole_long_7_logs.pkl'
none_logs_path = dataset_dir / 'exp1_cartpole_long_8_logs.pkl'
metadata = load_pickle(metadata_path)
baseline_sd = load_pickle(baseline_sd_path)
adaptive_sd = load_pickle(adaptive_sd_path)
adaptive_logs = load_pickle(adaptive_logs_path) # Logs might be needed later
random_sd = load_pickle(random_sd_path)
random_logs = load_pickle(random_logs_path)
none_sd = load_pickle(none_sd_path)
none_logs = load_pickle(none_logs_path)



# <font color="Green">Analysis Plots

Here is the source code for all the figures used in the main notebook. Feel free to use these as reference/inspiration when working on your own analysis methods!

In [None]:
def plotRaster( sd, title, ax):
    idces, times = sd.idces_times()
    ax.scatter(times,idces,marker='|',s=1)   # Creates spike raster
    ax.set_title( title )
    ax.set_xlabel("Time(s)")
    ax.set_ylabel('Unit #')

In [None]:
def get_sttc(sd):
    """
    Returns the spike time tiling coefficient of a spike data object
        sd : spike data object from braingeneers
        returns : N x N matrix of STTC values
    """

    sttc = sd.spike_time_tilings()
    return sttc

def correlation(sd):
    """
    Returns the correlation matrix of a spike data object
        sd : spike data object from braingeneers
        returns : N x N matrix of correlation values
    """
    
    corr = np.zeros((sd.N,sd.N)) #inds by inds
    
    dense_raster = sd.raster(bin_size=1) # in ms
    sigma = 5                            # Blur it
    dense_raster = gaussian_filter1d(dense_raster.astype(float),sigma=sigma)
    corr=np.corrcoef( dense_raster )
        
    return corr;

def eigenvalues_eigenvectors(sd): # gets the eigenvalues and eigenvectors of a matrix
    """
    returns the eigenvalues and eigenvectors of a matrix
        sd : spike data object from braingeneers
        returns : eigenvalues, eigenvectors
    """

    W, U = np.linalg.eigh(sd)
    # The rank of A can be no greater than the smaller of its
    # dimensions, so cut off the returned values there.
    rank = min(*sd.shape)
    U = U[:,-rank:]
    sgn = (-1)**(U[0,:] < 0)
    # Also reverse the order of the eigenvalues because eigh()
    # returns them in ascending order but descending makes more sense.
    return W[-rank:][::-1], (U*sgn[np.newaxis,:])[:, ::-1]


In [None]:
def correlation_plot(sd):
    # Correlation
    corr = np.zeros((sd.N,sd.N)) #inds by inds

    dense_raster = sd.raster(bin_size=1) # in ms
    sigma = 5                            # Blur it
    dense_raster = gaussian_filter1d(dense_raster.astype(float),sigma=sigma) 
    corr=np.corrcoef( dense_raster )
    
    plt.imshow(corr)
    plt.xlabel("Neuron index")
    plt.ylabel("Neuron index")
    plt.show() 

def STTC_plot(sd):
    # STTC

    sttc = get_sttc(sd)
    
    
    plt.imshow(sttc)
    plt.xlabel("Neuron index")
    plt.ylabel("Neuron index")
    plt.show() 

def plot_evectmatrix(sd):
    """
    Plots the eigenvectors of the correlation and STTC matrices for a given spike data object.
        sd : spike data object from braingeneers
    """
    
    fig, plot = plt.subplot_mosaic("AB", figsize=(14,7))
    
    corr = correlation(sd)
    sttc = get_sttc(sd)

    Wcorr, Ucorr = eigenvalues_eigenvectors(corr)
    Wsttc, Usttc = eigenvalues_eigenvectors(sttc)

    # Plot Correlation Matrix
    pltA = plot["A"].imshow(Ucorr[:,:2*len(Ucorr)].T, interpolation='none', cmap="magma")
    #plot["A"].gca().set_aspect('auto')
    plot["A"].set_ylabel('Eigenvector Number')
    plot["A"].set_xlabel('Observation Dimension')
    plot["A"].set_title('Eigenvectors of Correlation')
    fig.colorbar(pltA, ax=plot["A"], shrink=0.7)
    
    # Plot STTC matrix
    pltA = plot["B"].imshow(Usttc[:,:2*len(Usttc)].T, interpolation='none', cmap="magma")
    #plot["B"].gca().set_aspect('auto')
    plot["B"].set_ylabel('Eigenvector Number')
    plot["B"].set_xlabel('Observation Dimension')
    plot["B"].set_title('Eigenvectors of STTC')
    fig.colorbar(pltA, ax=plot["B"], shrink=0.7)

# <font color="red">Bookend

In [None]:
now = datetime.now(Timezone)
printNow = now.strftime("%Y/%m/%d %H:%M:%S")

print(f"Done at: {printNow}")

Done at: 2024/06/03 16:27:28
