# Neural Decoding Analysis with LSTM Networks

## Overview

This notebook demonstrates how to perform neural decoding analysis using Long Short-Term Memory (LSTM) networks to predict behavioral variables from neural activity. The analysis includes:

1. **Data Preprocessing**: Load and prepare spike train data and behavioral variables
2. **Cross-Validation Setup**: Create training/testing splits with proper temporal structure
3. **LSTM Model Training**: Train recurrent neural networks for decoding
4. **Statistical Validation**: Compare model performance against null distributions
5. **Results Visualization**: Generate plots and performance metrics

## Key Concepts

- **Neural Decoding**: The process of inferring behavioral or cognitive states from neural activity patterns
- **LSTM Networks**: Recurrent neural networks capable of learning long-term temporal dependencies
- **Cross-Validation**: Statistical method to assess model generalizability
- **Null Distribution**: Baseline comparison using temporally shuffled data to test statistical significance

## Expected Outputs

- Model performance metrics (RMSE) for each behavioral variable
- Visualization of predicted vs. actual behavioral trajectories  
- Statistical significance tests comparing true vs. shuffled data performance

## Requirements and Setup

### Dependencies
- Python 3.7+
- PyTorch
- NumPy
- Pandas
- Matplotlib
- Seaborn
- SciPy

### Data Structure Requirements
Your data directory should contain preprocessed spike train data and behavioral measurements organized by mouse and session. The `core.py` module must contain:
- `preprocess()` function for data loading
- `LSTMDecoder` class
- `SequenceDataset` class  
- `train_LSTM()` function
- `cv_split()` function
- `plot_training()` function

### Important Notes
- **Update file paths** in the configuration section to match your local directory structure
- Ensure GPU availability for faster training (optional but recommended)
- This notebook generates substantial output files - ensure adequate disk space

In [2]:
# =============================================================================
# IMPORTS AND SETUP
# =============================================================================

# Standard scientific computing libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# PyTorch for neural network implementation
import torch
from torch.utils.data import DataLoader

# Custom modules for neural decoding
from core import *  # Contains preprocessing functions and LSTM model definitions

# System and utility libraries
import time
import os

# Fix for OpenMP duplicate library issue on some systems
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

# Set plotting style for better visualization
plt.style.use('default')
sns.set_palette("husl")

# Configure data directories
# NOTE: Update these paths to match your local directory structure
dir = r"D:\clickbait-mmz"                    # Raw data directory
save_dir = r"C:\Users\smearlab\analysis_code\EncodingModels\outputs"  # Output directory

print("Libraries imported successfully!")
print(f"PyTorch version: {torch.__version__}")
print(f"GPU available: {torch.cuda.is_available()}")

Libraries imported successfully!
PyTorch version: 2.6.0+cu124
GPU available: False


---

# Setup and Initialization

## Library Imports and Environment Configuration

This section imports all necessary libraries and configures the computational environment.

# Configuration

## Experimental Parameters

The following parameters define the experimental session and analysis settings. These should be adjusted based on your specific dataset and research objectives.

In [None]:
# =============================================================================
# EXPERIMENTAL PARAMETERS
# =============================================================================

# Subject and session identification
mouse = '6002'        # Subject ID
session = '6'         # Session number

# Time window parameters for binning neural data
window_size = 2.0     # Time window size in seconds for spike binning
step_size = 2         # Step size in seconds (2 = non-overlapping windows)

# Sampling frequencies
fs = 30000            # Sampling frequency for spike data (Hz)
sfs = 1000            # Sampling frequency for behavioral/sniff data (Hz)

# Unit selection criteria
use_units = 'good'    # Options: 'all', 'good', or specific unit number

# Cross-validation parameters
k_CV = 3              # Number of cross-validation folds
n_blocks = 3          # Number of temporal blocks for cross-validation

# Null distribution parameters
shift = 2             # Temporal shift for null distribution (0 = no shift, >0 = shuffled)

# Model configuration
model_input = 'neural'  # Input type: 'neural' (spikes->behavior) or 'behavioral' (behavior->spikes)

# Behavioral variables to decode
use_behaviors = [
    'position_x',     # X-coordinate position
    'position_y',     # Y-coordinate position  
    'velocity_x',     # X-component of velocity
    'velocity_y',     # Y-component of velocity
    'sns'             # Sniff rate
]

# Create output directory structure
save_dir = os.path.join(save_dir, mouse, session, f'shift_{shift}')
os.makedirs(save_dir, exist_ok=True)

print(f"Configuration complete for Mouse {mouse}, Session {session}")
print(f"Output directory: {save_dir}")
print(f"Behavioral variables: {use_behaviors}")
print(f"Cross-validation: {k_CV} folds, {n_blocks} blocks per fold")

# Data Preprocessing

## Loading and Preparing Neural Data

This section preprocesses the raw data to extract:
- Spike counts binned at specified temporal resolution
- Behavioral variables and their names
- Trial identifications for temporal organization
- Neuron names and metadata from manual curation

The `preprocess()` function handles the complex data loading and alignment procedures.

In [None]:
# =============================================================================
# DATA PREPROCESSING
# =============================================================================

# Preprocessing the data to get spike counts, behavioral variables and names, 
# trial IDs, neuron names, and neuron info from manual curation
# 
# Returns:
# - counts: Neural spike counts binned in time windows [time_bins x neurons]
# - variables: List of behavioral variables aligned to neural data
# - variable_names: Names of behavioral variables
# - trial_ids: Trial identification for temporal organization
# - neu_names: Names of all recorded neurons
# - neu_info: Metadata for each neuron (including brain region)

counts, variables, variable_names, trial_ids, neu_names, neu_info = preprocess(
    dir, save_dir, mouse, session, window_size, step_size, use_units
)

print(f"Data preprocessing complete:")
print(f"  - Neural data shape: {counts.shape} (time_bins x neurons)")
print(f"  - Behavioral variables: {len(variable_names)}")
print(f"  - Total neurons: {len(neu_names)}")

## Brain Region Selection

**Choose which brain region to analyze from 'OB' (Olfactory Bulb) or 'HC' (Hippocampus)**

This code filters neurons based on their recorded brain region and applies temporal shuffling for null distribution construction when `shift > 0`. The circular shift of spike rates breaks the temporal relationship between neural activity and behavior, creating a null distribution for statistical hypothesis testing.

In [None]:
# =============================================================================
# BRAIN REGION SELECTION AND TEMPORAL SHUFFLING
# =============================================================================

# Filter neurons based on brain region (HC = Hippocampus, OB = Olfactory Bulb)
use_units = [key for key, value in neu_info.items() if value['area'] == 'HC']
spike_rates = counts[:, np.isin(neu_names, use_units)]

# Apply temporal shift for null distribution if shift > 0
# This breaks the temporal relationship between neural activity and behavior
if shift > 0:
    random_roll = np.random.randint(0, spike_rates.shape[0])
    print(f'Rolling spike rates by {random_roll} time bins for null distribution')
    spike_rates = np.roll(spike_rates, random_roll, axis=0)
    
print(f'Final spike rates shape (time bins × neurons): {spike_rates.shape}')
print(f'Using {len(use_units)} neurons from hippocampus (HC)')

rolling spike rates by 761 time bins
spike rates shape (# time bins, # neurons): (917, 29)


## Extracting Behavioral Variables

This section extracts and organizes the behavioral variables specified in the configuration. The variables are stacked into a single matrix for model training.

In [None]:
# =============================================================================
# BEHAVIORAL VARIABLE EXTRACTION
# =============================================================================

print(f'Available behavioral variables: {variable_names}')

# Extract each behavioral component as specified in use_behaviors
behavior_components = []

if 'position_x' in use_behaviors:
    pos_x = np.array(variables[variable_names.index('position_x')])
    behavior_components.append(pos_x)
    print(f"Added position_x with shape: {pos_x.shape}")

if 'position_y' in use_behaviors:
    pos_y = np.array(variables[variable_names.index('position_y')])
    behavior_components.append(pos_y)
    print(f"Added position_y with shape: {pos_y.shape}")

if 'velocity_x' in use_behaviors:
    vel_x = np.array(variables[variable_names.index('velocity_x')])
    behavior_components.append(vel_x)
    print(f"Added velocity_x with shape: {vel_x.shape}")

if 'velocity_y' in use_behaviors:
    vel_y = np.array(variables[variable_names.index('velocity_y')])
    behavior_components.append(vel_y)
    print(f"Added velocity_y with shape: {vel_y.shape}")

if 'sns' in use_behaviors:
    sniff_rate = np.array(variables[variable_names.index('sns')])
    behavior_components.append(sniff_rate)
    print(f"Added sniff rate with shape: {sniff_rate.shape}")

# Stack all behavioral components into a single matrix
behavior = np.stack(behavior_components, axis=1)
print(f'Final behavior matrix shape (time bins × behavioral dimensions): {behavior.shape}')

variables: ['position_x', 'position_y', 'velocity_x', 'velocity_y', 'sns', 'latency', 'phase', 'speed', 'click']
behavior shape (# time bins, # dims): (917, 5)


## Model Input/Output Configuration

**Building list of arguments to pass to the model**

This is where the input $\mathbf{X}$ and output $y$ variables for the model are defined.

$$y = f(\mathbf{X})$$

where $f(\cdot)$ is a neural network learned from the data to minimize some cost function.

**For Neural Decoding:**
- Input $\mathbf{X}$: Neural spike rates (time bins × neurons)
- Output $y$: Behavioral variables (time bins × behavioral dimensions)

The model learns to predict behavioral states from neural activity patterns.

In [None]:
# =============================================================================
# CROSS-VALIDATION SETUP AND MODEL INPUT/OUTPUT CONFIGURATION
# =============================================================================

# Initialize list to store arguments for each cross-validation fold
arg_list = []

# Loop through each cross-validation fold to build datasets
for k in range(k_CV):
    # Split spike rates data using temporal cross-validation
    rates_train, rates_test, train_switch_ind, test_switch_ind = cv_split(
        spike_rates, k, k_CV, n_blocks
    )
    
    # Split behavioral data using the same temporal structure
    behavior_train, behavior_test, _, _ = cv_split(
        behavior, k, k_CV, n_blocks
    )

    # Create save directory for model outputs
    current_save_path = os.path.join(save_dir, "model fits")
    os.makedirs(current_save_path, exist_ok=True)

    # Configure input (X) and output (y) based on model type
    if model_input == 'neural':
        # Neural decoding: Neural activity -> Behavior
        y_train = behavior_train
        y_test = behavior_test
        X_train = rates_train
        X_test = rates_test
    elif model_input == 'behavioral':
        # Behavioral encoding: Behavior -> Neural activity
        y_train = rates_train
        y_test = rates_test
        X_train = behavior_train
        X_test = behavior_test

    # Store all arguments needed for this fold
    arg_list.append((
        X_train, X_test, y_train, y_test, 
        train_switch_ind, test_switch_ind, 
        current_save_path, None, shift
    ))

print(f"Cross-validation setup complete: {k_CV} folds prepared")

# LSTM Model Functions

## Visualization and Training Functions

This section defines helper functions for:
1. **Visualization**: Plotting model predictions against true values
2. **Training**: Complete pipeline for training and evaluating LSTM models

In [None]:
def plot_preds(targets, predictions, test_switch_ind, sequence_length, save_path, k, shift):
    """
    Plot model predictions against true values for visualization.
    
    Parameters:
    -----------
    targets : numpy.ndarray
        True behavioral values (time_steps x behavioral_dimensions)
    predictions : numpy.ndarray  
        Model predictions (time_steps x behavioral_dimensions)
    test_switch_ind : list
        Indices where test data switches between blocks/trials
    sequence_length : int
        Length of sequences used in LSTM training
    save_path : str
        Directory path to save the plot
    k : int
        Cross-validation fold number
    shift : int
        Temporal shift parameter for null distribution
    """
    # Adjust test switch indices to account for sequence length
    adjusted_test_switch_ind = [ind - sequence_length * k for k, ind in enumerate(test_switch_ind)]
    
    # Get number of behavioral dimensions
    behavior_dim = targets.shape[1]
    
    # Create subplots for each behavioral dimension
    _, ax = plt.subplots(behavior_dim, 1, figsize=(20, 10), sharex=True)
    if behavior_dim == 1:
        ax = [ax]
    
    # Plot true vs predicted for each dimension
    for i in range(behavior_dim):
        ax[i].plot(targets[:, i], label='True', color='crimson')
        ax[i].plot(predictions[:, i], label='Predicted')
        
        # Add vertical lines at test block boundaries
        for ind in adjusted_test_switch_ind:
            ax[i].axvline(ind, color='grey', linestyle='--', alpha=0.5)

    # Remove y-axis ticks if too many dimensions
    if behavior_dim > 4:
        for a in ax:
            a.set_yticks([])
    
    plt.xlabel('Time')
    ax[0].legend(loc='upper right')
    sns.despine()
    
    # Save the plot
    plt.savefig(os.path.join(save_path, f'lstm_predictions_k_{k}_shift_{shift}.png'), dpi=300)
    plt.close()

In [None]:
def process_fold(X_train, X_test, y_train, y_test, train_switch_ind, test_switch_ind, current_save_path,
        device=None, shift=0, hidden_dim=8, num_layers=2, dropout=0.1, sequence_length=3, target_index=-1, 
        batch_size=64, lr=0.01, num_epochs=300, patience=10, min_delta=0.01, factor=0.5, plot_predictions=True):
    """
    Train and evaluate LSTM model for a single cross-validation fold.
    
    Parameters:
    -----------
    X_train, X_test : numpy.ndarray
        Training and testing input data (neural activity or behavior)
    y_train, y_test : numpy.ndarray
        Training and testing target data (behavior or neural activity)
    train_switch_ind, test_switch_ind : list
        Indices marking boundaries between temporal blocks
    current_save_path : str
        Directory to save model outputs and plots
    device : str, optional
        GPU device to use ('0', '1', etc.), None for auto-detection
    shift : int
        Temporal shift for null distribution (0 = no shift)
    hidden_dim : int
        Number of hidden units in LSTM layers
    num_layers : int
        Number of LSTM layers
    dropout : float
        Dropout probability for regularization
    sequence_length : int
        Length of input sequences for LSTM
    target_index : int
        Which time step to predict (-1 = last time step)
    batch_size : int
        Training batch size
    lr : float
        Learning rate for optimizer
    num_epochs : int
        Maximum number of training epochs
    patience : int
        Early stopping patience (epochs without improvement)
    min_delta : float
        Minimum improvement threshold for early stopping
    factor : float
        Learning rate reduction factor
    plot_predictions : bool
        Whether to generate prediction plots
    
    Returns:
    --------
    rmse : numpy.ndarray
        Root mean squared error for each output dimension
    targets : numpy.ndarray
        True target values from test set
    predictions : numpy.ndarray
        Model predictions on test set
    """
    
    # =============================================================================
    # DEVICE SETUP AND MODEL INITIALIZATION
    # =============================================================================
    
    # Set the device for computation (GPU if available, otherwise CPU)
    if device:
        os.environ['CUDA_VISIBLE_DEVICES'] = device
        device = torch.device('cuda:0')
    else:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # Create the LSTM model with specified architecture
    lstm_model = LSTMDecoder(
        input_dim=X_train.shape[1], 
        hidden_dim=hidden_dim, 
        output_dim=y_train.shape[1], 
        num_layers=num_layers, 
        dropout=dropout
    ).to(device)

    # =============================================================================
    # TRAINING DATA PREPARATION
    # =============================================================================
    
    # Prepare the training data for LSTM with proper temporal structure
    blocks = [(train_switch_ind[i], train_switch_ind[i + 1]) for i in range(len(train_switch_ind) - 1)]
    train_dataset = SequenceDataset(X_train, y_train, blocks, sequence_length, target_index)
    train_loader = DataLoader(
        train_dataset, 
        batch_size=min(batch_size, len(train_dataset)), 
        shuffle=False, 
        num_workers=0, 
        pin_memory=True, 
        prefetch_factor=None
    )

    # =============================================================================
    # MODEL TRAINING
    # =============================================================================
    
    # Train the LSTM model
    start_train = time.time()
    trained_lstm_model, lstm_history = train_LSTM(
        lstm_model, train_loader, device, 
        lr=lr, epochs=num_epochs, patience=patience, 
        min_delta=min_delta, factor=factor, verbose=False
    )
    print(f"\nTraining time: {time.time() - start_train:.2f}s fold={k} shift={shift}", flush=True)
    
    # Free up memory
    del lstm_model, train_dataset, train_loader

    # Plot training history if requested
    if plot_predictions:
        plot_training(lstm_history, current_save_path, shift, k)

    # =============================================================================
    # TEST DATA PREPARATION AND EVALUATION
    # =============================================================================
    
    # Prepare the test data for LSTM evaluation
    test_blocks = [(test_switch_ind[i], test_switch_ind[i + 1]) for i in range(len(test_switch_ind) - 1)]
    test_dataset = SequenceDataset(X_test, y_test, test_blocks, sequence_length, target_index)
    test_loader = DataLoader(
        test_dataset, 
        batch_size=min(batch_size, len(test_dataset)), 
        num_workers=0, 
        pin_memory=True
    )

    # Generate predictions on the test set
    trained_lstm_model.eval()
    predictions = []
    targets = []
    
    with torch.no_grad():
        for X_batch, y_batch in test_loader:
            X_batch = X_batch.to(device)
            preds = trained_lstm_model(X_batch)
            predictions.append(preds.cpu().numpy())
            targets.append(y_batch.cpu().numpy())
    
    predictions = np.concatenate(predictions, axis=0)
    targets = np.concatenate(targets, axis=0)

    # Clean up memory
    del trained_lstm_model, test_dataset, test_loader

    # Generate visualization plots
    plot_preds(targets, predictions, test_switch_ind, sequence_length, current_save_path, k, shift)

    # Calculate RMSE for each output dimension
    diffs = predictions - targets
    rmse = np.sqrt(np.mean(diffs ** 2, axis=0))

    # Final cleanup
    torch.cuda.empty_cache()
    print(f"Total time {time.time() - start_train:.2f}s for fold={k} shift={shift}\n\n\n", flush=True)

    return rmse, targets, predictions

# Model Training and Evaluation

## Running the model across all folds and saving results

This section executes the LSTM training pipeline across all cross-validation folds and saves the results.

In [None]:
# =============================================================================
# MODEL TRAINING AND RESULT SAVING
# =============================================================================

# Initialize results storage
results = []
results_save_dir = os.path.join(save_dir, "results")
os.makedirs(results_save_dir, exist_ok=True)

# Process each cross-validation fold
for k in range(k_CV):
    print(f"Processing fold {k + 1}/{k_CV}...")
    
    # Train LSTM model and get results
    rmse, targets, predictions = process_fold(*arg_list[k])

    # Define file paths for saving results
    results_file = f'results_shift{shift}_fold{k}.npz'
    results_path = os.path.join(results_save_dir, results_file)
    
    # Save results as compressed .npz file
    # Contains: RMSE values, true targets, and model predictions
    np.savez_compressed(
        results_path, 
        rmse=rmse, 
        targets=targets, 
        predictions=predictions
    )

    # Store metadata in results table
    results.append({
        'mouse': mouse,
        'session': session,
        'shift': shift,
        'fold': k,
        'results_file': results_file,
    })

# Convert to DataFrame and save as CSV for easy loading
df = pd.DataFrame(results)
df.to_csv(os.path.join(results_save_dir, 'results.csv'), index=False)

print(f"\nAll {k_CV} folds complete!")
print(f"Results saved to: {results_save_dir}")

Processing fold 1/3...
Using device: cuda

Training time: 2.46s fold=0 shift=2
Total time 3.69s for fold=0 shift=2



Processing fold 2/3...
Using device: cuda

Training time: 3.38s fold=1 shift=2
Total time 5.19s for fold=1 shift=2



Processing fold 3/3...
Using device: cuda

Training time: 2.99s fold=2 shift=2
Total time 4.11s for fold=2 shift=2





# Statistical Analysis and Model Comparison

## Comparing True vs. Null Model Performance

This section performs statistical tests to determine if the model's performance on real data is significantly better than chance. We compare RMSE values from the true data (shift=0) against the null distribution (shift>0).

**Note:** RMSE is a list with `n_folds` elements, where each element contains RMSE values for each behavioral dimension. The index corresponds to the target dimension being decoded.

In [None]:
# =============================================================================
# STATISTICAL ANALYSIS: TRUE vs. NULL PERFORMANCE
# =============================================================================

from scipy.stats import ranksums

# Parameters for statistical analysis
n_dims = 5  # Number of behavioral dimensions to analyze

# Path to session results directory (adjust as needed)
session_dir = r"C:\Users\smearlab\analysis_code\EncodingModels\outputs\6002\6"

# Get list of all shift conditions (shift_0, shift_1, etc.)
shifts = os.listdir(session_dir)
n_shifts = len(shifts)

print(f'Number of shifts: {n_shifts - 1}')  # -1 because shift_0 is true data

# Compare performance for each behavioral dimension
for dim in range(n_dims):
    true_rmse = []   # RMSE values for true data (shift=0)
    null_rmse = []   # RMSE values for null distribution (shift>0)
    
    # Loop through all shift conditions
    for shift in range(n_shifts):
        shift_dir = os.path.join(session_dir, f'shift_{shift}', 'results')
        
        # Count number of cross-validation folds
        n_folds = len([f for f in os.listdir(shift_dir) 
                      if f.startswith('results') and f.endswith('.npz')])
        
        # Load RMSE for each fold
        for fold in range(n_folds):
            results_file = os.path.join(shift_dir, f'results_shift{shift}_fold{fold}.npz')
            rmse = np.load(results_file)['rmse'][dim]

            # Separate true vs null distributions
            if shift == 0:
                true_rmse.append(rmse)  # True data
            else:
                null_rmse.append(rmse)  # Null distribution

    # Perform statistical test using Wilcoxon rank-sum test
    # This is a non-parametric test for comparing two independent samples
    stat, p = ranksums(true_rmse, null_rmse)
    
    # Calculate effect size (difference in means)
    mean_true = np.mean(true_rmse)
    mean_null = np.mean(null_rmse)
    effect_size = mean_true - mean_null
    
    print(f'Dimension {dim}: true - null RMSE = {effect_size:.4f}, '
          f'p-value = {p:.4f}, statistic = {stat:.4f}')

Number of shifts: 2
dim 0: true - null RMSE = -14.6328, p-value = 0.0201, statistic = -2.3238
dim 1: true - null RMSE = -6.1701, p-value = 0.1967, statistic = -1.2910
dim 2: true - null RMSE = -0.0024, p-value = 0.7963, statistic = -0.2582
dim 3: true - null RMSE = 0.0127, p-value = 0.4386, statistic = 0.7746
dim 4: true - null RMSE = 0.5347, p-value = 0.1213, statistic = 1.5492


In [None]:
# Summary and Conclusions

## What This Analysis Tells Us

This notebook implements a complete pipeline for neural decoding using LSTM networks. The key insights from this analysis:

### Model Performance
- **RMSE values** quantify how well the model predicts behavioral variables from neural activity
- **Lower RMSE** indicates better prediction accuracy
- **Cross-validation** ensures the model generalizes beyond the training data

### Statistical Significance
- **Null distribution** (temporally shuffled data) provides a baseline for comparison
- **Wilcoxon rank-sum test** determines if true performance is significantly better than chance
- **Effect size** measures the practical significance of the improvement

### Behavioral Variables Decoded
1. **Position (X, Y)**: Spatial location of the animal
2. **Velocity (X, Y)**: Movement speed and direction
3. **Sniff Rate**: Olfactory sampling behavior

## Interpretation Guidelines

### Strong Decoding Performance
- **Low RMSE** and **significant p-values** indicate the neural population contains information about the behavior
- **Large effect sizes** suggest the relationship is not only significant but also meaningful

### Poor Decoding Performance
- **High RMSE** or **non-significant p-values** may indicate:
  - Neural population doesn't encode this behavior
  - Temporal resolution is inappropriate
  - Model architecture needs adjustment
  - Insufficient data for reliable estimation

## Next Steps

1. **Parameter tuning**: Adjust LSTM architecture, learning rate, or sequence length
2. **Feature engineering**: Consider different spike binning windows or normalization
3. **Model comparison**: Test against linear decoders or other architectures
4. **Population analysis**: Examine individual neuron contributions
5. **Temporal dynamics**: Analyze how decoding accuracy changes over time