In [None]:
from foundational_ssm.utils import h5_to_dict 
import os
import multiprocessing as mp
import logging
import hydra
from dotenv import load_dotenv

# Third-party imports
import numpy as np
import wandb
from jax import random as jr
import equinox as eqx

# Foundational SSM imports
from omegaconf import OmegaConf

from foundational_ssm.utils.downstream_utils import (
    mse_loss_downstream,
    train_one_epoch,
    validate_one_epoch,
    create_model_and_state,
    create_dataloader,
    log_predictions_and_activations,
    get_rtt_datasets
)
from foundational_ssm.utils.training_utils import (
    create_optimizer_and_state
)
from foundational_ssm.utils.wandb_utils_jax import (
    save_checkpoint_wandb,
    add_alias_to_checkpoint
    )
from foundational_ssm.models import SSMDownstreamDecoder
from foundational_ssm.transform import smooth_spikes
import multiprocessing as mp
import joblib

def create_sliding_window_features(input_array, history_length=56):
    """
    Transform neural input data by applying a sliding window approach that concatenates
    historical timesteps for each output timestep.
    
    Parameters:
    -----------
    input_array : numpy.ndarray
        Input array of shape (batch_size, total_timesteps, features)
    history_length : int, optional
        Number of historical timesteps to include for each output, default is 56
        
    Returns:
    --------
    numpy.ndarray
        Transformed array with shape (batch_size, total_timesteps - history_length, features * history_length)
        where each timestep contains the flattened history window
    """
    batch_size, total_timesteps, features = input_array.shape
    output_timesteps = total_timesteps - history_length
    output_array = np.zeros((batch_size, output_timesteps, features * history_length))
    
    for t in range(output_timesteps):
        # Extract the history window for each timestep
        history_window = input_array[:, t:t+history_length, :] 
        # Reshape to flatten the history dimension
        flattened_history = history_window.reshape(batch_size, -1)
        # Store the flattened history in the output array
        output_array[:, t, :] = flattened_history
    
    return output_array

In [None]:
cfg = OmegaConf.load("../configs/rtt.yaml")
train_data, val_data, data = get_rtt_datasets(cfg.dataset, jr.PRNGKey(0))

history_length=56
train_inputs = smooth_spikes(train_data['neural_input'])
train_inputs = create_sliding_window_features(train_inputs, history_length)
train_targets = train_data['behavior_input'][:, history_length:, :]  # Shape (378, 708, 1)
val_inputs = smooth_spikes(val_data['neural_input'])
val_inputs = create_sliding_window_features(val_inputs, history_length)
val_targets = val_data['behavior_input'][:, history_length:, :]  # Shape (378, 708, 1)

print("train_inputs shape:", train_inputs.shape)
print("train_targets shape:", train_targets.shape)
print("val_inputs shape:", val_inputs.shape)
print("val_targets shape:", val_targets.shape)

In [None]:
from sklearn.linear_model import SGDRegressor
from sklearn.multioutput import MultiOutputRegressor
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.model_selection import KFold
import numpy as np
from tqdm import tqdm

DO_CROSS_VALIDATION = False 
RETRAIN_FINAL_MODEL = False 

if DO_CROSS_VALIDATION:
    # CV Config
    alphas = [1e-5, 1e-4]
    n_epochs = 100
    batch_size = 64

    # Prepare data 
    n_trials = train_inputs.shape[0]
    cv_folds = 3
    kf = KFold(n_splits=cv_folds, shuffle=True, random_state=42)

    # Init results
    alpha_scores = {alpha: [] for alpha in alphas}
    alpha_r2_scores = {alpha: [] for alpha in alphas}

    # Cross-validation loop
    for train_idx, val_idx in tqdm(kf.split(range(n_trials))):
        train_trials = train_idx
        val_trials = val_idx
        
        print(f"Cross-validation fold with {len(train_trials)} training trials and {len(val_trials)} validation trials")
        
        # For each alpha value
        for alpha in alphas:
            print(f"Training model with alpha={alpha}")
            model = MultiOutputRegressor(SGDRegressor(loss='squared_error', alpha=alpha, 
                                            learning_rate='invscaling', max_iter=1000, tol=1e-3))
            
            best_train_mse = float('inf')
            epochs_no_improve = 0
            for epoch in range(n_epochs):
                np.random.shuffle(train_trials)
                
                for trial_batch_idx in range(0, len(train_trials), batch_size):
                    end_idx = min(trial_batch_idx + batch_size, len(train_trials))
                    current_trials = train_trials[trial_batch_idx:end_idx]
                    
                    X_batch = np.vstack([train_inputs[i].reshape(-1, train_inputs.shape[-1]) for i in current_trials])
                    y_batch = np.vstack([train_targets[i].reshape(-1, train_targets.shape[-1]) for i in current_trials])                
                    model.partial_fit(
                        X_batch, y_batch
                    )

                train_pred = model.predict(train_inputs[train_trials].reshape(-1, train_inputs.shape[-1]))
                train_mse = mean_squared_error(train_targets[train_trials].reshape(-1, train_targets.shape[-1]), train_pred)
                print(f"    Training MSE after epoch {epoch+1}: {train_mse:.4f}")

                if best_train_mse > train_mse:
                    epochs_no_improve = 0
                    best_train_mse = train_mse
                else:
                    epochs_no_improve += 1

                if epochs_no_improve > 10:
                    print(f"    Training stopped early after {epoch+1} epochs")
                    break

            # Evaluate on validation trials
            val_pred = model.predict(train_inputs[val_trials].reshape(-1, train_inputs.shape[-1]))
            val_mse = mean_squared_error(train_targets[val_trials].reshape(-1, train_targets.shape[-1]), val_pred)
            val_r2 = r2_score(val_pred, train_targets[val_trials].reshape(-1, train_targets.shape[-1]))
            print(f"alpha={alpha}| Validation MSE : {val_mse:.4f} | R2 Score: {val_r2:.4f}")
            alpha_scores[alpha].append(val_mse)
            alpha_r2_scores[alpha].append(val_r2)

    # Calculate average score across folds for each alpha
    avg_scores = {alpha: np.mean(scores) for alpha, scores in alpha_scores.items()}
    best_alpha = min(avg_scores.items(), key=lambda x: x[1])[0]

    print("\nCross-validation results:")
    for alpha, score in sorted(avg_scores.items()):
        print(f"Alpha: {alpha}, Average MSE: {score:.4f}")
    print(f"\nBest alpha: {best_alpha}, MSE: {avg_scores[best_alpha]:.4f}")
else:
    best_alpha = 1e-5


if RETRAIN_FINAL_MODEL:
    print("\nTraining final model with best alpha...")
    final_model = MultiOutputRegressor(SGDRegressor(loss='squared_error', alpha=best_alpha,
                                                    learning_rate='invscaling', max_iter=1000, tol=1e-3))

    # Train on all trials in batches
    for epoch in range(n_epochs):
        all_trials = np.random.permutation(n_trials)

        for trial_batch_idx in tqdm(range(0, n_trials, batch_size)):
            end_idx = min(trial_batch_idx + batch_size, n_trials)
            current_trials = all_trials[trial_batch_idx:end_idx]
            
            # Collect data from these trials
            X_batch = np.vstack([train_inputs[i].reshape(-1, train_inputs.shape[-1]) for i in current_trials])
            y_batch = np.vstack([train_targets[i].reshape(-1, train_targets.shape[-1]) for i in current_trials])
            final_model.partial_fit(
                X_batch, y_batch,
                classes=None  # For regression
            )

    # Final evaluation on validation set
    y_pred = final_model.predict(val_inputs.reshape(-1, val_inputs.shape[-1]))

    final_val_mse = mean_squared_error(val_targets.reshape(-1, val_targets.shape[-1]), y_pred)
    final_val_r2 = r2_score(val_targets.reshape(-1, val_targets.shape[-1]), y_pred)
    print(f"\nFinal model validation MSE: {final_val_mse:.4f} | R2 Score: {final_val_r2:.4f}")
else:
    final_model = joblib.load("final_sgd_model.joblib")

0it [00:00, ?it/s]

Cross-validation fold with 252 training trials and 126 validation trials
Training model with alpha=1e-05
    Training MSE after epoch 1: 702.1558
    Training MSE after epoch 2: 656.4853
    Training MSE after epoch 3: 641.6966
    Training MSE after epoch 4: 632.8040
    Training MSE after epoch 5: 629.6564
    Training MSE after epoch 6: 625.1412
    Training MSE after epoch 7: 622.5171
    Training MSE after epoch 8: 622.6033
    Training MSE after epoch 9: 618.5331
    Training MSE after epoch 10: 613.5893
    Training MSE after epoch 11: 607.3859
    Training MSE after epoch 12: 606.9892
    Training MSE after epoch 13: 604.1966
    Training MSE after epoch 14: 604.2250
    Training MSE after epoch 15: 599.3320
    Training MSE after epoch 16: 599.5498
    Training MSE after epoch 17: 603.8016
    Training MSE after epoch 18: 597.8777
    Training MSE after epoch 19: 594.6839
    Training MSE after epoch 20: 595.8588
    Training MSE after epoch 21: 599.1164
    Training MSE after

1it [22:13, 1333.28s/it]

alpha=0.0001| Validation MSE : 740.0912 | R2 Score: 0.2066
Cross-validation fold with 252 training trials and 126 validation trials
Training model with alpha=1e-05
    Training MSE after epoch 1: 711.6710
    Training MSE after epoch 2: 675.7316
    Training MSE after epoch 3: 656.0654
    Training MSE after epoch 4: 647.5854
    Training MSE after epoch 5: 645.8840
    Training MSE after epoch 6: 634.1346
    Training MSE after epoch 7: 635.2355
    Training MSE after epoch 8: 634.9375
    Training MSE after epoch 9: 622.0853
    Training MSE after epoch 10: 626.2802
    Training MSE after epoch 11: 620.3481
    Training MSE after epoch 12: 623.9066
    Training MSE after epoch 13: 615.8639
    Training MSE after epoch 14: 623.1905
    Training MSE after epoch 15: 611.9682
    Training MSE after epoch 16: 622.1644
    Training MSE after epoch 17: 608.8086
    Training MSE after epoch 18: 609.0225
    Training MSE after epoch 19: 610.4157
    Training MSE after epoch 20: 609.1578
    T

2it [44:38, 1340.37s/it]

alpha=0.0001| Validation MSE : 748.9065 | R2 Score: 0.2395
Cross-validation fold with 252 training trials and 126 validation trials
Training model with alpha=1e-05
    Training MSE after epoch 1: 650.3189
    Training MSE after epoch 2: 618.9737
    Training MSE after epoch 3: 607.5296
    Training MSE after epoch 4: 597.3661
    Training MSE after epoch 5: 595.5118
    Training MSE after epoch 6: 591.8590
    Training MSE after epoch 7: 582.8819
    Training MSE after epoch 8: 580.0395
    Training MSE after epoch 9: 580.9488
    Training MSE after epoch 10: 583.3421
    Training MSE after epoch 11: 574.4964
    Training MSE after epoch 12: 573.1606
    Training MSE after epoch 13: 581.9182
    Training MSE after epoch 14: 569.5336
    Training MSE after epoch 15: 574.2349
    Training MSE after epoch 16: 567.5599
    Training MSE after epoch 17: 567.3766
    Training MSE after epoch 18: 564.8442
    Training MSE after epoch 19: 567.1135
    Training MSE after epoch 20: 565.9270
    T

3it [1:05:44, 1314.99s/it]


alpha=0.0001| Validation MSE : 838.6598 | R2 Score: 0.0797

Cross-validation results:
Alpha: 1e-05, Average MSE: 785.8174
Alpha: 0.0001, Average MSE: 775.8858

Best alpha: 0.0001, MSE: 775.8858

Training final model with best alpha...


100%|██████████| 6/6 [00:06<00:00,  1.02s/it]


Completed epoch 1/100 for final model


100%|██████████| 6/6 [00:06<00:00,  1.02s/it]


Completed epoch 2/100 for final model


100%|██████████| 6/6 [00:06<00:00,  1.02s/it]


Completed epoch 3/100 for final model


100%|██████████| 6/6 [00:06<00:00,  1.02s/it]


Completed epoch 4/100 for final model


100%|██████████| 6/6 [00:06<00:00,  1.02s/it]


Completed epoch 5/100 for final model


100%|██████████| 6/6 [00:06<00:00,  1.02s/it]


Completed epoch 6/100 for final model


100%|██████████| 6/6 [00:06<00:00,  1.02s/it]


Completed epoch 7/100 for final model


100%|██████████| 6/6 [00:06<00:00,  1.02s/it]


Completed epoch 8/100 for final model


100%|██████████| 6/6 [00:06<00:00,  1.02s/it]


Completed epoch 9/100 for final model


100%|██████████| 6/6 [00:06<00:00,  1.02s/it]


Completed epoch 10/100 for final model


100%|██████████| 6/6 [00:06<00:00,  1.02s/it]


Completed epoch 11/100 for final model


100%|██████████| 6/6 [00:06<00:00,  1.02s/it]


Completed epoch 12/100 for final model


100%|██████████| 6/6 [00:06<00:00,  1.01s/it]


Completed epoch 13/100 for final model


100%|██████████| 6/6 [00:06<00:00,  1.02s/it]


Completed epoch 14/100 for final model


100%|██████████| 6/6 [00:06<00:00,  1.02s/it]


Completed epoch 15/100 for final model


100%|██████████| 6/6 [00:06<00:00,  1.02s/it]


Completed epoch 16/100 for final model


100%|██████████| 6/6 [00:06<00:00,  1.02s/it]


Completed epoch 17/100 for final model


100%|██████████| 6/6 [00:06<00:00,  1.02s/it]


Completed epoch 18/100 for final model


100%|██████████| 6/6 [00:06<00:00,  1.02s/it]


Completed epoch 19/100 for final model


100%|██████████| 6/6 [00:06<00:00,  1.02s/it]


Completed epoch 20/100 for final model


100%|██████████| 6/6 [00:06<00:00,  1.02s/it]


Completed epoch 21/100 for final model


100%|██████████| 6/6 [00:06<00:00,  1.02s/it]


Completed epoch 22/100 for final model


100%|██████████| 6/6 [00:06<00:00,  1.02s/it]


Completed epoch 23/100 for final model


100%|██████████| 6/6 [00:06<00:00,  1.02s/it]


Completed epoch 24/100 for final model


100%|██████████| 6/6 [00:06<00:00,  1.02s/it]


Completed epoch 25/100 for final model


100%|██████████| 6/6 [00:06<00:00,  1.02s/it]


Completed epoch 26/100 for final model


100%|██████████| 6/6 [00:06<00:00,  1.02s/it]


Completed epoch 27/100 for final model


100%|██████████| 6/6 [00:06<00:00,  1.02s/it]


Completed epoch 28/100 for final model


100%|██████████| 6/6 [00:06<00:00,  1.01s/it]


Completed epoch 29/100 for final model


100%|██████████| 6/6 [00:06<00:00,  1.02s/it]


Completed epoch 30/100 for final model


100%|██████████| 6/6 [00:06<00:00,  1.02s/it]


Completed epoch 31/100 for final model


100%|██████████| 6/6 [00:06<00:00,  1.02s/it]


Completed epoch 32/100 for final model


100%|██████████| 6/6 [00:06<00:00,  1.02s/it]


Completed epoch 33/100 for final model


100%|██████████| 6/6 [00:06<00:00,  1.02s/it]


Completed epoch 34/100 for final model


100%|██████████| 6/6 [00:06<00:00,  1.02s/it]


Completed epoch 35/100 for final model


100%|██████████| 6/6 [00:06<00:00,  1.02s/it]


Completed epoch 36/100 for final model


100%|██████████| 6/6 [00:06<00:00,  1.02s/it]


Completed epoch 37/100 for final model


100%|██████████| 6/6 [00:06<00:00,  1.02s/it]


Completed epoch 38/100 for final model


100%|██████████| 6/6 [00:06<00:00,  1.01s/it]


Completed epoch 39/100 for final model


100%|██████████| 6/6 [00:06<00:00,  1.02s/it]


Completed epoch 40/100 for final model


100%|██████████| 6/6 [00:06<00:00,  1.02s/it]


Completed epoch 41/100 for final model


100%|██████████| 6/6 [00:06<00:00,  1.02s/it]


Completed epoch 42/100 for final model


100%|██████████| 6/6 [00:06<00:00,  1.02s/it]


Completed epoch 43/100 for final model


100%|██████████| 6/6 [00:06<00:00,  1.02s/it]


Completed epoch 44/100 for final model


100%|██████████| 6/6 [00:06<00:00,  1.02s/it]


Completed epoch 45/100 for final model


100%|██████████| 6/6 [00:06<00:00,  1.02s/it]


Completed epoch 46/100 for final model


100%|██████████| 6/6 [00:06<00:00,  1.02s/it]


Completed epoch 47/100 for final model


100%|██████████| 6/6 [00:06<00:00,  1.02s/it]


Completed epoch 48/100 for final model


100%|██████████| 6/6 [00:06<00:00,  1.02s/it]


Completed epoch 49/100 for final model


100%|██████████| 6/6 [00:06<00:00,  1.02s/it]


Completed epoch 50/100 for final model


100%|██████████| 6/6 [00:06<00:00,  1.02s/it]


Completed epoch 51/100 for final model


100%|██████████| 6/6 [00:06<00:00,  1.02s/it]


Completed epoch 52/100 for final model


100%|██████████| 6/6 [00:06<00:00,  1.02s/it]


Completed epoch 53/100 for final model


100%|██████████| 6/6 [00:06<00:00,  1.02s/it]


Completed epoch 54/100 for final model


100%|██████████| 6/6 [00:06<00:00,  1.02s/it]


Completed epoch 55/100 for final model


100%|██████████| 6/6 [00:06<00:00,  1.02s/it]


Completed epoch 56/100 for final model


100%|██████████| 6/6 [00:06<00:00,  1.02s/it]


Completed epoch 57/100 for final model


100%|██████████| 6/6 [00:06<00:00,  1.02s/it]


Completed epoch 58/100 for final model


100%|██████████| 6/6 [00:06<00:00,  1.02s/it]


Completed epoch 59/100 for final model


100%|██████████| 6/6 [00:06<00:00,  1.02s/it]


Completed epoch 60/100 for final model


100%|██████████| 6/6 [00:06<00:00,  1.02s/it]


Completed epoch 61/100 for final model


100%|██████████| 6/6 [00:06<00:00,  1.02s/it]


Completed epoch 62/100 for final model


100%|██████████| 6/6 [00:06<00:00,  1.02s/it]


Completed epoch 63/100 for final model


100%|██████████| 6/6 [00:06<00:00,  1.02s/it]


Completed epoch 64/100 for final model


100%|██████████| 6/6 [00:06<00:00,  1.02s/it]


Completed epoch 65/100 for final model


100%|██████████| 6/6 [00:06<00:00,  1.02s/it]


Completed epoch 66/100 for final model


100%|██████████| 6/6 [00:06<00:00,  1.02s/it]


Completed epoch 67/100 for final model


100%|██████████| 6/6 [00:06<00:00,  1.02s/it]


Completed epoch 68/100 for final model


100%|██████████| 6/6 [00:06<00:00,  1.02s/it]


Completed epoch 69/100 for final model


100%|██████████| 6/6 [00:06<00:00,  1.02s/it]


Completed epoch 70/100 for final model


100%|██████████| 6/6 [00:06<00:00,  1.02s/it]


Completed epoch 71/100 for final model


100%|██████████| 6/6 [00:06<00:00,  1.02s/it]


Completed epoch 72/100 for final model


100%|██████████| 6/6 [00:06<00:00,  1.02s/it]


Completed epoch 73/100 for final model


100%|██████████| 6/6 [00:06<00:00,  1.02s/it]


Completed epoch 74/100 for final model


100%|██████████| 6/6 [00:06<00:00,  1.02s/it]


Completed epoch 75/100 for final model


100%|██████████| 6/6 [00:06<00:00,  1.02s/it]


Completed epoch 76/100 for final model


100%|██████████| 6/6 [00:06<00:00,  1.02s/it]


Completed epoch 77/100 for final model


100%|██████████| 6/6 [00:06<00:00,  1.02s/it]


Completed epoch 78/100 for final model


100%|██████████| 6/6 [00:06<00:00,  1.02s/it]


Completed epoch 79/100 for final model


100%|██████████| 6/6 [00:06<00:00,  1.02s/it]


Completed epoch 80/100 for final model


100%|██████████| 6/6 [00:06<00:00,  1.02s/it]


Completed epoch 81/100 for final model


100%|██████████| 6/6 [00:06<00:00,  1.02s/it]


Completed epoch 82/100 for final model


100%|██████████| 6/6 [00:06<00:00,  1.02s/it]


Completed epoch 83/100 for final model


100%|██████████| 6/6 [00:06<00:00,  1.02s/it]


Completed epoch 84/100 for final model


100%|██████████| 6/6 [00:06<00:00,  1.03s/it]


Completed epoch 85/100 for final model


100%|██████████| 6/6 [00:06<00:00,  1.02s/it]


Completed epoch 86/100 for final model


100%|██████████| 6/6 [00:06<00:00,  1.02s/it]


Completed epoch 87/100 for final model


100%|██████████| 6/6 [00:06<00:00,  1.02s/it]


Completed epoch 88/100 for final model


100%|██████████| 6/6 [00:06<00:00,  1.02s/it]


Completed epoch 89/100 for final model


100%|██████████| 6/6 [00:06<00:00,  1.02s/it]


Completed epoch 90/100 for final model


100%|██████████| 6/6 [00:06<00:00,  1.02s/it]


Completed epoch 91/100 for final model


100%|██████████| 6/6 [00:06<00:00,  1.02s/it]


Completed epoch 92/100 for final model


100%|██████████| 6/6 [00:06<00:00,  1.02s/it]


Completed epoch 93/100 for final model


100%|██████████| 6/6 [00:06<00:00,  1.02s/it]


Completed epoch 94/100 for final model


100%|██████████| 6/6 [00:06<00:00,  1.02s/it]


Completed epoch 95/100 for final model


100%|██████████| 6/6 [00:06<00:00,  1.02s/it]


Completed epoch 96/100 for final model


100%|██████████| 6/6 [00:06<00:00,  1.02s/it]


Completed epoch 97/100 for final model


100%|██████████| 6/6 [00:06<00:00,  1.02s/it]


Completed epoch 98/100 for final model


100%|██████████| 6/6 [00:06<00:00,  1.02s/it]


Completed epoch 99/100 for final model


100%|██████████| 6/6 [00:06<00:00,  1.02s/it]


Completed epoch 100/100 for final model

Final model validation MSE: 740.4038


In [None]:

y_pred = final_model.predict(val_inputs.reshape(-1, val_inputs.shape[-1]))

final_val_mse = mean_squared_error(val_targets.reshape(-1, val_targets.shape[-1]), y_pred)
final_val_r2 = r2_score(val_targets.reshape(-1, val_targets.shape[-1]), y_pred)
print(f"\nFinal model validation MSE: {final_val_mse:.4f} | R2 Score: {final_val_r2:.4f}")

# Save the final model to a file
joblib.dump(final_model, "final_sgd_model.joblib")
print("Final model saved to final_sgd_model.joblib")
final_model = joblib.load("final_sgd_model.joblib")


Final model validation MSE: 740.4038 | R2 Score: 0.5513
Final model saved to final_sgd_model.joblib


In [None]:
final_model