In [1]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = 'retina'

In [2]:
import logging
import os

import matplotlib
import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt

from src.parameters import FIGURE_DIR

# Set background and fontsize
rc_params = {
    'pdf.fonttype': 42, # Make fonts editable in Adobe Illustrator
    'ps.fonttype': 42, # Make fonts editable in Adobe Illustrator
    'axes.labelcolor': '#222222',
    'axes.labelsize': 9,
    'text.color': '#222222',
    'font.sans-serif' : 'Helvetica',
    'text.usetex': False,
    'figure.figsize': (7.2, 4.45),
    'xtick.major.size': 0.00,
    'ytick.major.size': 0.00,
    'axes.labelpad': 0.1,
    }
sns.set(style='white', context='paper', rc=rc_params,
        font_scale=1.0)

# Enable logging
logging.basicConfig(level=logging.INFO)

In [20]:
import pandas as pd
import xarray as xr
from tqdm.auto import tqdm

from replay_trajectory_classification import (ClusterlessClassifier,
                                              SortedSpikesClassifier)
from src.analysis import (get_linear_position_order, get_place_field_max,
                          reshape_to_segments)
from src.load_data import load_data
from src.parameters import (ANIMALS, FIGURE_DIR, PROBABILITY_THRESHOLD,
                            PROCESSED_DATA_DIR, SAMPLING_FREQUENCY,
                            TRANSITION_TO_CATEGORY,
                            continuous_transition_types, discrete_diag,
                            knot_spacing, model, model_kwargs, movement_var,
                            place_bin_size, replay_speed, spike_model_penalty)
from src.visualization import plot_ripple_decode_1D, plot_ripple_decode_2D


def replot_clusterless_ripple_decode_1D(epoch_key, ripple_list, figure_format='png'):
    animal, day, epoch = epoch_key
    data_type, dim = 'clusterless', '1D'

    logging.info('Loading data...')
    data = load_data(epoch_key)

    is_training = data['position_info'].speed > 4
    position = data['position_info'].loc[:, 'linear_position']
    track_labels = data['position_info'].arm_name

    results = xr.open_dataset(
        os.path.join(
            PROCESSED_DATA_DIR, f'{animal}_{day:02}_{epoch:02}.nc'),
        group=f'/{data_type}/{dim}/classifier/ripples/')
    ripple_times = data['ripple_times'].loc[:, ['start_time', 'end_time']]
    spikes = (((data['multiunit'].sum('features') > 0) * 1.0)
              .to_dataframe(name='spikes').unstack())
    spikes.columns = data['tetrode_info'].tetrode_id
    ripple_spikes = reshape_to_segments(spikes, ripple_times)
    classifier = ClusterlessClassifier(
        place_bin_size=place_bin_size, movement_var=movement_var,
        replay_speed=replay_speed,
        discrete_transition_diag=discrete_diag,
        continuous_transition_types=continuous_transition_types,
        model=model, model_kwargs=model_kwargs).fit(
            position, data['multiunit'], is_training=is_training,
            track_labels=track_labels)
    logging.info(classifier)

    place_field_max = get_place_field_max(classifier)
    linear_position_order = place_field_max.argsort(axis=0).squeeze()
    ripple_position = reshape_to_segments(position, ripple_times)
    for ripple_number in tqdm(ripple_list, desc='ripple figures'):
        posterior = (
            results
            .acausal_posterior
            .sel(ripple_number=ripple_number)
            .dropna('time')
            .assign_coords(
                time=lambda ds: 1000 * ds.time / np.timedelta64(1, 's')))
        plot_ripple_decode_1D(
            posterior, ripple_position.loc[ripple_number],
            ripple_spikes.loc[ripple_number], linear_position_order,
            data['position_info'], spike_label='Tetrodes')
        plt.suptitle(
            f'ripple number = {animal}_{day:02d}_{epoch:02d}_'
            f'{ripple_number:04d}')
        fig_name = (f'{animal}_{day:02d}_{epoch:02d}_{ripple_number:04d}_'
                    f'{data_type}_{dim}_acasual_classification.{figure_format}')
        fig_name = os.path.join(
            FIGURE_DIR, 'ripple_classifications', fig_name)
        plt.savefig(fig_name, transparent=True, dpi=300, bbox_inches='tight')
        plt.close(plt.gcf())

In [26]:
epoch_key = 'remy', 35, 4
ripple_list = [303, 269, 57, 267]

replot_clusterless_ripple_decode_1D(epoch_key, ripple_list, figure_format='pdf')

INFO:root:Loading data...
INFO:src.load_data:Finding ripple times...
INFO:replay_trajectory_classification.classifier:Fitting initial conditions...
INFO:replay_trajectory_classification.classifier:Fitting state transition...
INFO:replay_trajectory_classification.classifier:Fitting multiunits...
INFO:root:ClusterlessClassifier(continuous_transition_types=[['w_track_1D_random_walk_minus_identity',
                                                    'w_track_1D_inverse_random_walk',
                                                    'identity'],
                                                   ['uniform',
                                                    'w_track_1D_inverse_random_walk',
                                                    'uniform'],
                                                   ['w_track_1D_random_walk_minus_identity',
                                                    'w_track_1D_inverse_random_walk',
                                                    'ident

HBox(children=(IntProgress(value=0, description='ripple figures', max=3, style=ProgressStyle(description_width…




In [25]:
epoch_key = 'remy', 36, 2
ripple_list = [20, 137, 4, 33, 151, 163, 252, 269, 122, 262, 97, 200, 142]

replot_clusterless_ripple_decode_1D(epoch_key, ripple_list, figure_format='pdf')

INFO:root:Loading data...
INFO:src.load_data:Finding ripple times...
INFO:replay_trajectory_classification.classifier:Fitting initial conditions...
INFO:replay_trajectory_classification.classifier:Fitting state transition...
INFO:replay_trajectory_classification.classifier:Fitting multiunits...
INFO:root:ClusterlessClassifier(continuous_transition_types=[['w_track_1D_random_walk_minus_identity',
                                                    'w_track_1D_inverse_random_walk',
                                                    'identity'],
                                                   ['uniform',
                                                    'w_track_1D_inverse_random_walk',
                                                    'uniform'],
                                                   ['w_track_1D_random_walk_minus_identity',
                                                    'w_track_1D_inverse_random_walk',
                                                    'ident

HBox(children=(IntProgress(value=0, description='ripple figures', max=13, style=ProgressStyle(description_widt…




In [27]:
epoch_key = 'remy', 37, 2
ripple_list = [314, 7, 9, 278]

replot_clusterless_ripple_decode_1D(epoch_key, ripple_list, figure_format='pdf')

INFO:root:Loading data...
INFO:src.load_data:Finding ripple times...
INFO:replay_trajectory_classification.classifier:Fitting initial conditions...
INFO:replay_trajectory_classification.classifier:Fitting state transition...
INFO:replay_trajectory_classification.classifier:Fitting multiunits...
INFO:root:ClusterlessClassifier(continuous_transition_types=[['w_track_1D_random_walk_minus_identity',
                                                    'w_track_1D_inverse_random_walk',
                                                    'identity'],
                                                   ['uniform',
                                                    'w_track_1D_inverse_random_walk',
                                                    'uniform'],
                                                   ['w_track_1D_random_walk_minus_identity',
                                                    'w_track_1D_inverse_random_walk',
                                                    'ident

HBox(children=(IntProgress(value=0, description='ripple figures', max=4, style=ProgressStyle(description_width…




In [28]:
epoch_key = 'remy', 35, 2
ripple_list = [157, 240]

replot_clusterless_ripple_decode_1D(epoch_key, ripple_list, figure_format='pdf')

INFO:root:Loading data...
INFO:src.load_data:Finding ripple times...
INFO:replay_trajectory_classification.classifier:Fitting initial conditions...
INFO:replay_trajectory_classification.classifier:Fitting state transition...
INFO:replay_trajectory_classification.classifier:Fitting multiunits...
INFO:root:ClusterlessClassifier(continuous_transition_types=[['w_track_1D_random_walk_minus_identity',
                                                    'w_track_1D_inverse_random_walk',
                                                    'identity'],
                                                   ['uniform',
                                                    'w_track_1D_inverse_random_walk',
                                                    'uniform'],
                                                   ['w_track_1D_random_walk_minus_identity',
                                                    'w_track_1D_inverse_random_walk',
                                                    'ident

HBox(children=(IntProgress(value=0, description='ripple figures', max=2, style=ProgressStyle(description_width…




In [29]:
epoch_key = 'remy', 36, 2
ripple_list = [122]

replot_clusterless_ripple_decode_1D(epoch_key, ripple_list, figure_format='pdf')

INFO:root:Loading data...
INFO:src.load_data:Finding ripple times...
INFO:replay_trajectory_classification.classifier:Fitting initial conditions...
INFO:replay_trajectory_classification.classifier:Fitting state transition...
INFO:replay_trajectory_classification.classifier:Fitting multiunits...
INFO:root:ClusterlessClassifier(continuous_transition_types=[['w_track_1D_random_walk_minus_identity',
                                                    'w_track_1D_inverse_random_walk',
                                                    'identity'],
                                                   ['uniform',
                                                    'w_track_1D_inverse_random_walk',
                                                    'uniform'],
                                                   ['w_track_1D_random_walk_minus_identity',
                                                    'w_track_1D_inverse_random_walk',
                                                    'ident

HBox(children=(IntProgress(value=0, description='ripple figures', max=1, style=ProgressStyle(description_width…




In [30]:
epoch_key = 'bon', 3, 2
ripple_list = [250, 21, 25, 61, 111, 87]

replot_clusterless_ripple_decode_1D(epoch_key, ripple_list, figure_format='pdf')

INFO:root:Loading data...
INFO:src.load_data:Finding ripple times...
INFO:replay_trajectory_classification.classifier:Fitting initial conditions...
INFO:replay_trajectory_classification.classifier:Fitting state transition...
INFO:replay_trajectory_classification.classifier:Fitting multiunits...
INFO:root:ClusterlessClassifier(continuous_transition_types=[['w_track_1D_random_walk_minus_identity',
                                                    'w_track_1D_inverse_random_walk',
                                                    'identity'],
                                                   ['uniform',
                                                    'w_track_1D_inverse_random_walk',
                                                    'uniform'],
                                                   ['w_track_1D_random_walk_minus_identity',
                                                    'w_track_1D_inverse_random_walk',
                                                    'ident

HBox(children=(IntProgress(value=0, description='ripple figures', max=6, style=ProgressStyle(description_width…




In [31]:
epoch_key = 'remy', 37, 4
ripple_list = [50, 51, 45]

replot_clusterless_ripple_decode_1D(epoch_key, ripple_list, figure_format='pdf')

INFO:root:Loading data...
INFO:src.load_data:Finding ripple times...
INFO:replay_trajectory_classification.classifier:Fitting initial conditions...
INFO:replay_trajectory_classification.classifier:Fitting state transition...
INFO:replay_trajectory_classification.classifier:Fitting multiunits...
INFO:root:ClusterlessClassifier(continuous_transition_types=[['w_track_1D_random_walk_minus_identity',
                                                    'w_track_1D_inverse_random_walk',
                                                    'identity'],
                                                   ['uniform',
                                                    'w_track_1D_inverse_random_walk',
                                                    'uniform'],
                                                   ['w_track_1D_random_walk_minus_identity',
                                                    'w_track_1D_inverse_random_walk',
                                                    'ident

HBox(children=(IntProgress(value=0, description='ripple figures', max=3, style=ProgressStyle(description_width…




In [32]:
epoch_key = 'remy', 35, 4
ripple_list = [267]

replot_clusterless_ripple_decode_1D(epoch_key, ripple_list, figure_format='pdf')

INFO:root:Loading data...
INFO:src.load_data:Finding ripple times...
INFO:replay_trajectory_classification.classifier:Fitting initial conditions...
INFO:replay_trajectory_classification.classifier:Fitting state transition...
INFO:replay_trajectory_classification.classifier:Fitting multiunits...
INFO:root:ClusterlessClassifier(continuous_transition_types=[['w_track_1D_random_walk_minus_identity',
                                                    'w_track_1D_inverse_random_walk',
                                                    'identity'],
                                                   ['uniform',
                                                    'w_track_1D_inverse_random_walk',
                                                    'uniform'],
                                                   ['w_track_1D_random_walk_minus_identity',
                                                    'w_track_1D_inverse_random_walk',
                                                    'ident

HBox(children=(IntProgress(value=0, description='ripple figures', max=1, style=ProgressStyle(description_width…




In [36]:
epoch_key = 'bon', 3, 6
ripple_list = [93, 144, 169, 181, 186]

replot_clusterless_ripple_decode_1D(epoch_key, ripple_list, figure_format='pdf')

INFO:root:Loading data...
INFO:src.load_data:Finding ripple times...
INFO:replay_trajectory_classification.classifier:Fitting initial conditions...
INFO:replay_trajectory_classification.classifier:Fitting state transition...
INFO:replay_trajectory_classification.classifier:Fitting multiunits...
INFO:root:ClusterlessClassifier(continuous_transition_types=[['w_track_1D_random_walk_minus_identity',
                                                    'w_track_1D_inverse_random_walk',
                                                    'identity'],
                                                   ['uniform',
                                                    'w_track_1D_inverse_random_walk',
                                                    'uniform'],
                                                   ['w_track_1D_random_walk_minus_identity',
                                                    'w_track_1D_inverse_random_walk',
                                                    'ident

HBox(children=(IntProgress(value=0, description='ripple figures', max=5, style=ProgressStyle(description_width…




In [46]:
epoch_key = 'bon', 5, 4
ripple_list = [75]

replot_clusterless_ripple_decode_1D(epoch_key, ripple_list, figure_format='pdf')

INFO:root:Loading data...
INFO:src.load_data:Finding ripple times...
INFO:replay_trajectory_classification.classifier:Fitting initial conditions...
INFO:replay_trajectory_classification.classifier:Fitting state transition...
INFO:replay_trajectory_classification.classifier:Fitting multiunits...
INFO:root:ClusterlessClassifier(continuous_transition_types=[['w_track_1D_random_walk_minus_identity',
                                                    'w_track_1D_inverse_random_walk',
                                                    'identity'],
                                                   ['uniform',
                                                    'w_track_1D_inverse_random_walk',
                                                    'uniform'],
                                                   ['w_track_1D_random_walk_minus_identity',
                                                    'w_track_1D_inverse_random_walk',
                                                    'ident

HBox(children=(IntProgress(value=0, description='ripple figures', max=1, style=ProgressStyle(description_width…




In [19]:
def replot_sorted_spikes_ripple_decode_1D(epoch_key, ripple_list, figure_format='png'):
    animal, day, epoch = epoch_key
    data_type, dim = 'sorted_spikes', '1D'

    logging.info('Loading data...')
    data = load_data(epoch_key)

    is_training = data['position_info'].speed > 4
    position = data['position_info'].loc[:, 'linear_position']
    track_labels = data['position_info'].arm_name

    results = xr.open_dataset(
        os.path.join(
            PROCESSED_DATA_DIR, f'{animal}_{day:02}_{epoch:02}.nc'),
        group=f'/{data_type}/{dim}/classifier/ripples/')
    ripple_times = data['ripple_times'].loc[:, ['start_time', 'end_time']]
    ripple_spikes = reshape_to_segments(data['spikes'], ripple_times)
    classifier = SortedSpikesClassifier(
        place_bin_size=place_bin_size, movement_var=movement_var,
        replay_speed=replay_speed,
        discrete_transition_diag=discrete_diag,
        spike_model_penalty=spike_model_penalty, knot_spacing=knot_spacing,
        continuous_transition_types=continuous_transition_types).fit(
            position, data['spikes'], is_training=is_training,
            track_labels=track_labels)
    logging.info(classifier)

    place_field_max = get_place_field_max(classifier)
    linear_position_order = place_field_max.argsort(axis=0).squeeze()
    ripple_position = reshape_to_segments(position, ripple_times)
    for ripple_number in tqdm(ripple_list, desc='ripple figures'):
        posterior = (
            results
            .acausal_posterior
            .sel(ripple_number=ripple_number)
            .dropna('time')
            .assign_coords(
                time=lambda ds: 1000 * ds.time / np.timedelta64(1, 's')))
        plot_ripple_decode_1D(
            posterior, ripple_position.loc[ripple_number],
            ripple_spikes.loc[ripple_number], linear_position_order,
            data['position_info'])
        plt.suptitle(
            f'ripple number = {animal}_{day:02d}_{epoch:02d}_'
            f'{ripple_number:04d}')
        fig_name = (f'{animal}_{day:02d}_{epoch:02d}_{ripple_number:04d}_'
                    f'{data_type}_{dim}_acasual_classification.{figure_format}')
        fig_name = os.path.join(
            FIGURE_DIR, 'ripple_classifications', fig_name)
        plt.savefig(fig_name, transparent=True, dpi=300, bbox_inches='tight')
        plt.close(plt.gcf())

In [35]:
epoch_key = 'bon', 3, 2
ripple_list = [61, 111, 87]

replot_sorted_spikes_ripple_decode_1D(epoch_key, ripple_list, figure_format='pdf')

INFO:root:Loading data...
INFO:src.load_data:Finding ripple times...
INFO:replay_trajectory_classification.classifier:Fitting initial conditions...
INFO:replay_trajectory_classification.classifier:Fitting state transition...
INFO:replay_trajectory_classification.classifier:Fitting place fields...
INFO:root:SortedSpikesClassifier(continuous_transition_types=[['w_track_1D_random_walk_minus_identity',
                                                     'w_track_1D_inverse_random_walk',
                                                     'identity'],
                                                    ['uniform',
                                                     'w_track_1D_inverse_random_walk',
                                                     'uniform'],
                                                    ['w_track_1D_random_walk_minus_identity',
                                                     'w_track_1D_inverse_random_walk',
                                                

HBox(children=(IntProgress(value=0, description='ripple figures', max=3, style=ProgressStyle(description_width…




In [18]:
def clusterless_analysis_2D(epoch_key, ripple_list, figure_format='png'):
    continuous_transition_types = (
        [['random_walk_minus_identity', 'inverse_random_walk', 'identity'],  # noqa
         ['uniform',                     'inverse_random_walk', 'uniform'],   # noqa
         ['random_walk_minus_identity', 'inverse_random_walk', 'identity']])  # noqa
    animal, day, epoch = epoch_key
    data_type, dim = 'clusterless', '2D'

    logging.info('Loading data...')
    data = load_data(epoch_key)
    position = data['position_info'].loc[:, ['x_position', 'y_position']]
    is_training = data['position_info'].speed > 4
    logging.info('Fitting classifier...')
    classifier = ClusterlessClassifier(
        place_bin_size=place_bin_size, movement_var=movement_var,
        replay_speed=replay_speed,
        discrete_transition_diag=discrete_diag,
        continuous_transition_types=continuous_transition_types,
        model=model, model_kwargs=model_kwargs).fit(
        position, data['multiunit'], is_training=is_training)
    logging.info(classifier)
    # Decode
    ripple_times = data['ripple_times'].loc[:, ['start_time', 'end_time']]
    spikes = (((data['multiunit'].sum('features') > 0) * 1.0)
              .to_dataframe(name='spikes').unstack())
    spikes.columns = data['tetrode_info'].tetrode_id
    ripple_spikes = reshape_to_segments(spikes, ripple_times)

    results = []
    for ripple_number in tqdm(ripple_list, desc='ripple'):
        time_slice = slice(*data['ripple_times'].loc[
            ripple_number, ['start_time', 'end_time']])
        m = data['multiunit'].sel(time=time_slice)
        results.append(classifier.predict(m, m.time - m.time[0]))
    results = xr.concat(results, dim=data['ripple_times'].loc[ripple_list].index)
    results = results.assign_coords(
        state=lambda ds: ds.state.to_index()
        .map(TRANSITION_TO_CATEGORY))

    logging.info('Plotting ripple figures...')

    place_field_max = get_place_field_max(classifier)
    linear_position_order, linear_place_field_max = get_linear_position_order(
        data['position_info'], place_field_max)

    ripple_position = reshape_to_segments(position, ripple_times)

    for ripple_number in tqdm(ripple_list, desc='ripple figures'):
        posterior = (
            results
            .acausal_posterior
            .sel(ripple_number=ripple_number)
            .dropna('time')
            .assign_coords(
                time=lambda ds: 1000 * ds.time / np.timedelta64(1, 's')))
        plot_ripple_decode_2D(
            posterior, ripple_position.loc[ripple_number],
            ripple_spikes.loc[ripple_number], position, linear_position_order,
            spike_label='Tetrodes')
        plt.suptitle(
            f'ripple number = {animal}_{day:02d}_{epoch:02d}_'
            f'{ripple_number:04d}')
        fig_name = (f'{animal}_{day:02d}_{epoch:02d}_{ripple_number:04d}_'
                    f'{data_type}_{dim}_acasual_classification.{figure_format}')
        fig_name = os.path.join(
            FIGURE_DIR, 'ripple_classifications', fig_name)
        plt.savefig(fig_name, transparent=True, dpi=300, bbox_inches='tight')
        plt.close(plt.gcf())

    logging.info('Done...')

In [65]:
epoch_key = 'bon', 4, 2
ripple_list = [106, 121, 112, 265]

clusterless_analysis_2D(epoch_key, ripple_list, figure_format='pdf')

INFO:root:Loading data...
INFO:src.load_data:Finding ripple times...
INFO:root:Fitting classifier...
INFO:replay_trajectory_classification.classifier:Fitting initial conditions...
INFO:replay_trajectory_classification.classifier:Fitting state transition...
INFO:replay_trajectory_classification.classifier:Fitting multiunits...
INFO:root:ClusterlessClassifier(continuous_transition_types=[['random_walk_minus_identity',
                                                    'inverse_random_walk',
                                                    'identity'],
                                                   ['uniform',
                                                    'inverse_random_walk',
                                                    'uniform'],
                                                   ['random_walk_minus_identity',
                                                    'inverse_random_walk',
                                                    'identity']],
               

HBox(children=(IntProgress(value=0, description='ripple', max=4, style=ProgressStyle(description_width='initia…

INFO:root:Plotting ripple figures...





HBox(children=(IntProgress(value=0, description='ripple figures', max=4, style=ProgressStyle(description_width…

INFO:root:Done...





In [61]:
epoch_key = 'bon', 4, 2
ripple_list = [106, 121, 112, 265]

replot_clusterless_ripple_decode_1D(epoch_key, ripple_list, figure_format='pdf')

INFO:root:Loading data...
INFO:src.load_data:Finding ripple times...
INFO:replay_trajectory_classification.classifier:Fitting initial conditions...
INFO:replay_trajectory_classification.classifier:Fitting state transition...
INFO:replay_trajectory_classification.classifier:Fitting multiunits...
INFO:root:ClusterlessClassifier(continuous_transition_types=[['w_track_1D_random_walk_minus_identity',
                                                    'w_track_1D_inverse_random_walk',
                                                    'identity'],
                                                   ['uniform',
                                                    'w_track_1D_inverse_random_walk',
                                                    'uniform'],
                                                   ['w_track_1D_random_walk_minus_identity',
                                                    'w_track_1D_inverse_random_walk',
                                                    'ident

HBox(children=(IntProgress(value=0, description='ripple figures', max=4, style=ProgressStyle(description_width…




In [66]:
epoch_key = 'bon', 4, 2
ripple_list = [121]

replot_sorted_spikes_ripple_decode_1D(epoch_key, ripple_list, figure_format='pdf')

INFO:root:Loading data...
INFO:src.load_data:Finding ripple times...
INFO:replay_trajectory_classification.classifier:Fitting initial conditions...
INFO:replay_trajectory_classification.classifier:Fitting state transition...
INFO:replay_trajectory_classification.classifier:Fitting place fields...
INFO:root:SortedSpikesClassifier(continuous_transition_types=[['w_track_1D_random_walk_minus_identity',
                                                     'w_track_1D_inverse_random_walk',
                                                     'identity'],
                                                    ['uniform',
                                                     'w_track_1D_inverse_random_walk',
                                                     'uniform'],
                                                    ['w_track_1D_random_walk_minus_identity',
                                                     'w_track_1D_inverse_random_walk',
                                                

HBox(children=(IntProgress(value=0, description='ripple figures', max=1, style=ProgressStyle(description_width…




In [17]:
from replay_trajectory_classification import SortedSpikesClassifier

def sorted_spikes_analysis_2D(epoch_key, ripple_list, figure_format='png'):
    continuous_transition_types = (
        [['random_walk_minus_identity', 'inverse_random_walk', 'identity'],  # noqa
         ['uniform',                     'inverse_random_walk', 'uniform'],   # noqa
         ['random_walk_minus_identity', 'inverse_random_walk', 'identity']])  # noqa
    animal, day, epoch = epoch_key
    data_type, dim = 'sorted_spikes', '2D'

    logging.info('Loading data...')
    data = load_data(epoch_key)
    position = data['position_info'].loc[:, ['x_position', 'y_position']]
    is_training = data['position_info'].speed > 4
    logging.info('Fitting classifier...')
    classifier = SortedSpikesClassifier(
        place_bin_size=place_bin_size, movement_var=movement_var,
        replay_speed=replay_speed,
        discrete_transition_diag=discrete_diag,
        spike_model_penalty=spike_model_penalty, knot_spacing=knot_spacing,
        continuous_transition_types=continuous_transition_types).fit(
        position, data['spikes'], is_training=is_training)
    logging.info(classifier)
    
    # Decode
    ripple_times = data['ripple_times'].loc[:, ['start_time', 'end_time']]
    ripple_spikes = reshape_to_segments(data['spikes'], ripple_times)

    results = []
    for ripple_number in tqdm(ripple_list, desc='ripple'):
        ripple_time = (ripple_spikes.loc[ripple_number].index -
                       ripple_spikes.loc[ripple_number].index[0])
        results.append(
            classifier.predict(ripple_spikes.loc[ripple_number],
                               time=ripple_time))
    results = (xr.concat(results, dim=data['ripple_times'].loc[ripple_list].index)
               .assign_coords(state=lambda ds: ds.state.to_index()
                              .map(TRANSITION_TO_CATEGORY)))

    logging.info('Plotting ripple figures...')

    place_field_max = get_place_field_max(classifier)
    linear_position_order, linear_place_field_max = get_linear_position_order(
        data['position_info'], place_field_max)

    ripple_position = reshape_to_segments(position, ripple_times)

    for ripple_number in tqdm(ripple_list, desc='ripple figures'):
        posterior = (
            results
            .acausal_posterior
            .sel(ripple_number=ripple_number)
            .dropna('time')
            .assign_coords(
                time=lambda ds: 1000 * ds.time / np.timedelta64(1, 's')))
        plot_ripple_decode_2D(
            posterior, ripple_position.loc[ripple_number],
            ripple_spikes.loc[ripple_number], position, linear_position_order,
            spike_label='Cells')
        plt.suptitle(
            f'ripple number = {animal}_{day:02d}_{epoch:02d}_'
            f'{ripple_number:04d}')
        fig_name = (f'{animal}_{day:02d}_{epoch:02d}_{ripple_number:04d}_'
                    f'{data_type}_{dim}_acasual_classification.{figure_format}')
        fig_name = os.path.join(
            FIGURE_DIR, 'ripple_classifications', fig_name)
        plt.savefig(fig_name, transparent=True, dpi=300, bbox_inches='tight')
        plt.close(plt.gcf())

    logging.info('Done...')

In [5]:
from dask.distributed import Client

client_params = dict(n_workers=7,
                     threads_per_worker=8,
                     processes=True,
                     memory_limit='25GB')

client = Client(**client_params)

In [6]:
client

0,1
Client  Scheduler: tcp://127.0.0.1:45025  Dashboard: http://127.0.0.1:39318/status,Cluster  Workers: 7  Cores: 56  Memory: 175.00 GB


In [15]:
epoch_key = 'bon', 4, 2
ripple_list = [121]

sorted_spikes_analysis_2D(epoch_key, ripple_list, figure_format='pdf')

INFO:root:Loading data...
INFO:src.load_data:Finding ripple times...
INFO:root:Fitting classifier...
INFO:replay_trajectory_classification.classifier:Fitting initial conditions...
INFO:replay_trajectory_classification.classifier:Fitting state transition...
INFO:replay_trajectory_classification.classifier:Fitting place fields...
INFO:root:SortedSpikesClassifier(continuous_transition_types=[['random_walk_minus_identity',
                                                     'inverse_random_walk',
                                                     'identity'],
                                                    ['uniform',
                                                     'inverse_random_walk',
                                                     'uniform'],
                                                    ['random_walk_minus_identity',
                                                     'inverse_random_walk',
                                                     'identity']],
    

HBox(children=(IntProgress(value=0, description='ripple', max=1, style=ProgressStyle(description_width='initia…

INFO:root:Plotting ripple figures...





HBox(children=(IntProgress(value=0, description='ripple figures', max=1, style=ProgressStyle(description_width…

INFO:root:Done...





In [22]:
epoch_key = 'bon', 3, 2
ripple_list = [145, 190, 192, 204]

replot_sorted_spikes_ripple_decode_1D(epoch_key, ripple_list, figure_format='pdf')
replot_clusterless_ripple_decode_1D(epoch_key, ripple_list, figure_format='pdf')
sorted_spikes_analysis_2D(epoch_key, ripple_list, figure_format='pdf')
clusterless_analysis_2D(epoch_key, ripple_list, figure_format='pdf')

INFO:root:Loading data...
INFO:src.load_data:Finding ripple times...
INFO:replay_trajectory_classification.classifier:Fitting initial conditions...
INFO:replay_trajectory_classification.classifier:Fitting state transition...
INFO:replay_trajectory_classification.classifier:Fitting place fields...
INFO:root:SortedSpikesClassifier(continuous_transition_types=[['w_track_1D_random_walk_minus_identity',
                                                     'w_track_1D_inverse_random_walk',
                                                     'identity'],
                                                    ['uniform',
                                                     'w_track_1D_inverse_random_walk',
                                                     'uniform'],
                                                    ['w_track_1D_random_walk_minus_identity',
                                                     'w_track_1D_inverse_random_walk',
                                                

HBox(children=(IntProgress(value=0, description='ripple figures', max=4, style=ProgressStyle(description_width…

INFO:root:Loading data...





INFO:src.load_data:Finding ripple times...
INFO:replay_trajectory_classification.classifier:Fitting initial conditions...
INFO:replay_trajectory_classification.classifier:Fitting state transition...
INFO:replay_trajectory_classification.classifier:Fitting multiunits...
INFO:root:ClusterlessClassifier(continuous_transition_types=[['w_track_1D_random_walk_minus_identity',
                                                    'w_track_1D_inverse_random_walk',
                                                    'identity'],
                                                   ['uniform',
                                                    'w_track_1D_inverse_random_walk',
                                                    'uniform'],
                                                   ['w_track_1D_random_walk_minus_identity',
                                                    'w_track_1D_inverse_random_walk',
                                                    'identity']],
                  

HBox(children=(IntProgress(value=0, description='ripple figures', max=4, style=ProgressStyle(description_width…

INFO:root:Loading data...





INFO:src.load_data:Finding ripple times...
INFO:root:Fitting classifier...
INFO:replay_trajectory_classification.classifier:Fitting initial conditions...
INFO:replay_trajectory_classification.classifier:Fitting state transition...
INFO:replay_trajectory_classification.classifier:Fitting place fields...
INFO:root:SortedSpikesClassifier(continuous_transition_types=[['random_walk_minus_identity',
                                                     'inverse_random_walk',
                                                     'identity'],
                                                    ['uniform',
                                                     'inverse_random_walk',
                                                     'uniform'],
                                                    ['random_walk_minus_identity',
                                                     'inverse_random_walk',
                                                     'identity']],
                       discret

HBox(children=(IntProgress(value=0, description='ripple', max=4, style=ProgressStyle(description_width='initia…

INFO:root:Plotting ripple figures...





HBox(children=(IntProgress(value=0, description='ripple figures', max=4, style=ProgressStyle(description_width…

INFO:root:Done...
INFO:root:Loading data...





INFO:src.load_data:Finding ripple times...
INFO:root:Fitting classifier...
INFO:replay_trajectory_classification.classifier:Fitting initial conditions...
INFO:replay_trajectory_classification.classifier:Fitting state transition...
INFO:replay_trajectory_classification.classifier:Fitting multiunits...
INFO:root:ClusterlessClassifier(continuous_transition_types=[['random_walk_minus_identity',
                                                    'inverse_random_walk',
                                                    'identity'],
                                                   ['uniform',
                                                    'inverse_random_walk',
                                                    'uniform'],
                                                   ['random_walk_minus_identity',
                                                    'inverse_random_walk',
                                                    'identity']],
                      discrete_transition

HBox(children=(IntProgress(value=0, description='ripple', max=4, style=ProgressStyle(description_width='initia…

INFO:root:Plotting ripple figures...





HBox(children=(IntProgress(value=0, description='ripple figures', max=4, style=ProgressStyle(description_width…

INFO:root:Done...



