In [1]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = 'retina'
%qtconsole

In [2]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import logging
logging.basicConfig(level=logging.INFO)

## Load Data

In [3]:
from src.simulation import load_simulated_spikes_with_real_position

(position, spikes, is_training,
 place_field_centers, position_info) = load_simulated_spikes_with_real_position()

  data = yaml.load(f.read()) or {}
  defaults = yaml.load(f)


## Fit Classifier

In [None]:
from replay_trajectory_classification import SortedSpikesClassifier
from src.parameters import SAMPLING_FREQUENCY

classifier = SortedSpikesClassifier(
        movement_var=np.sqrt(15/SAMPLING_FREQUENCY), replay_speed=80,
        continuous_transition_types=['empirical_movement', 'uniform', 'identity']).fit(
    position, spikes, is_training=is_training)

INFO:replay_trajectory_classification.classifier:Fitting initial conditions...
INFO:replay_trajectory_classification.classifier:Fitting state transition...
INFO:replay_trajectory_classification.classifier:Fitting place fields...


In [None]:
g = classifier.plot_place_fields(
    spikes, position, SAMPLING_FREQUENCY)

In [None]:
from src.analysis import get_linear_position_order, get_place_field_max
from src.visualization import plot_neuron_place_field_2D_1D_position

place_field_max = get_place_field_max(classifier)
linear_position_order, linear_place_field_max = get_linear_position_order(
    position_info, place_field_max)
plot_neuron_place_field_2D_1D_position(
    position_info, place_field_max, linear_place_field_max, 
    linear_position_order)

## Test Cases

### Continuous

In [None]:
from src.simulation import continuous_replay
from src.visualization import plot_ripple_decode

test_spikes, time = continuous_replay(place_field_centers)
result = classifier.predict(test_spikes, time)


ripple_position = np.zeros((time.size, 2))
plot_ripple_decode(result.acausal_posterior, ripple_position,
                   test_spikes, position, linear_position_order)


### Hover

In [None]:
from src.simulation import hover_replay

test_spikes, time = hover_replay(place_field_centers)
result = classifier.predict(test_spikes, time)

ripple_position = np.zeros((time.size, 2))
plot_ripple_decode(result.acausal_posterior, ripple_position,
                   test_spikes, position, linear_position_order)

### Fragmented

In [None]:
from src.simulation import fragmented_replay

test_spikes, time = fragmented_replay(place_field_centers)
result = classifier.predict(test_spikes, time)

ripple_position = np.zeros((time.size, 2))
plot_ripple_decode(result.acausal_posterior, ripple_position,
                   test_spikes, position, linear_position_order)

### Hover-Continuous-Hover

In [None]:
from src.simulation import hover_continuous_hover_replay

test_spikes, time = hover_continuous_hover_replay(place_field_centers)
result = classifier.predict(test_spikes, time)

ripple_position = np.zeros((time.size, 2))
plot_ripple_decode(result.acausal_posterior, ripple_position,
                   test_spikes, position, linear_position_order)

### Continuous-Fragmented-Continuous

In [None]:
from src.simulation import continuous_fragmented_continuous_replay

test_spikes, time = continuous_fragmented_continuous_replay(place_field_centers)
result = classifier.predict(test_spikes, time)

ripple_position = np.zeros((time.size, 2))
plot_ripple_decode(result.acausal_posterior, ripple_position,
                   test_spikes, position, linear_position_order)

### Hover-Fragmented-Hover

In [None]:
from src.simulation import hover_fragmented_hover_replay

test_spikes, time = hover_fragmented_hover_replay(place_field_centers)
result = classifier.predict(test_spikes, time)

ripple_position = np.zeros((time.size, 2))
plot_ripple_decode(result.acausal_posterior, ripple_position,
                   test_spikes, position, linear_position_order)