# DFINE Tutorial
## Overview

DFINE, which stands for **Dynamical Flexible Inference for Nonlinear Embeddings**, is a neural network model that is developed to enable flexible inference, whether causally, non-causally, or even in the presence of missing neural observations. To enable flexible inference, a model must achieve all the following operations simultaneously, without the need to retrain a new model or change the inference structure:

1\) Causal inference (filtering) <br>
2\) Non-causal inference (smoothing) <br>
3\) Account for missing observations, which can occur in wireless neural interfaces

DFINE achieves flexible inference. Also, DFINE’s inference is recursive and thus computationally efficient. Flexible inference is essential for developing neurotechnology, such as brain-machine interfaces (BMIs).

### Model Architecture
To achieve flexible inference, DFINE separates the model into jointly trained manifold and dynamic latent factors such that nonlinearity is captured through the manifold factors and the dynamics can be modeled in tractable linear form on this nonlinear manifold. Also, as its training loss, DFINE can use the future-step-ahead neural prediction error because of its flexible inference capability that allows it to efficiently and recursively compute this loss during training.

Specifically, we define the two sets of latent factors as follows: 1) Manifold latent factors ${a}_t \in \mathbb{R}^{n_a \times 1}$ and 2) Dynamic latent factors ${x}_t \in \mathbb{R}^{n_x \times 1}$.

First, the dynamic latent factors evolve in time with a linear Gaussian model: $\begin{equation}{x}_{t+1} = A{x}_t + {w}_t\tag{1}\end{equation}$ where $A \in \mathbb{R}^{n_x \times n_x}$ is the state transition matrix and ${w}_t \in \mathbb{R}^{n_x \times 1}$ is zero-mean Gaussian noise with covariance matrix $W \in \mathbb{R}^{n_x \times n_x}$. The manifold latent factors ${a}_t$ are related to the dynamic latent factors ${x}_t$ as: $\begin{equation}{a}_t = C{x}_t + {r}_t\tag{2}\end{equation}$ where $C \in \mathbb{R}^{n_a \times n_x}$ is the emission matrix and ${r}_t \times \mathbb{R}^{n_a \times 1}$ is white Gaussian noise with covariance matrix $R \in \mathbb{R}^{n_a \times n_a}$. Equations (1) and (2) form an LDM with learnable parameters $\psi = \{ A, C, W, R, {\mu}_0, \Lambda_0 \}$ where ${\mu}_0$ and $\Lambda_0$ are the initial estimate and covariance of dynamic latent factors, respectively.

Second, to model nonlinear mappings, we used MLP autoencoders to learn the mapping between neural observations ${y}_t$ and manifold latent factors ${a}_t$. We model the decoder part as a nonlinear mapping $f_\theta(\cdot)$ from manifold latent factors to neural observations: $\begin{equation}{y}_t = f_\theta({a}_t) + {v}_t\tag{3}\end{equation}$ where $\theta$ are parameters and ${v}_t \in \mathbb{R}^{n_y \times 1}$ is a white Gaussian noise with covariance $V \in \mathbb{R}^{n_y \times n_y}$. Equations (1)-(3) together form the generative model.

For inference, we also need the mapping from ${y}_t$ to ${a}_t$, which we characterize as: $\begin{equation}{a}_t = f_\phi ({y}_t)\tag{4}\end{equation}$ where $f_\phi(\cdot)$ represents the encoder in the autoencoder structure and is parameterized by another MLP. All equations above are trained together end-to-end, rather than separately. Further, the middle manifold layer in equation (2) explicitly incorporates a
stochastic noise variable $r_t$, whose covariance is learned during training, allowing the nonlinearity with respect to the dynamic latent factors to be stochastic in DFINE. To help with robustness to noise and stochasticity during inference, DFINE learns all the stochastic noise distribution parameters during training, which are then explicitly accounted for at inference.


### The Inference Problem
Using the equations above, we can infer both the manifold and dynamic latent factors from neural observations ${y}_{1:T}$, where $T$ is the total number of time steps for the observations. We use subscript $t|k$ to denote the inferred latent factors at time $t$ given observations up to time $k$, ${y}_{1:k}$. Thus, $t|t$ denotes filtering (causal) inference given ${y}_{1:t}$, $t+k|t$ denotes the $k$-step-ahead prediction given $y_{1:t}$, and $t|T$ denotes smoothing (non-causal) inference given ${y}_{1:T}$.

The inference method is shown in Figure 1b in the paper and is as follows. We first directly but statically obtain an initial estimate of ${a}_t$ based on ${y}_t$ with ${\hat{a}}_t = f_\phi({y}_t)$ in equation (4), to provide the noisy observations of the dynamical model, that is, ${\hat{a}}_t$. Having obtained ${\hat{a}}_t$, we can now use the dynamical part of the model in equations (1) and (2) to infer ${x}_{t|t}$ with Kalman filtering from ${\hat{a}}_{1:t}$, and infer ${x}_{t|T}$ with Kalman smoothing from ${\hat{a}}_{1∶T}$. We can then infer the manifold latent factor as ${a}_{t|t} = C{x}_{t|t}$ and ${a}_{t|T} = C{x}_{t|T}$ on the basis of equation (2). Similarly, we can obtain the filtered neural activity ${y}_{t|t}$ and smoothed neural activity ${y}_{t|T}$ using equation (3) as ${y}_{t|t} = f_{\theta}({a}_{t|t})$ and ${y}_{t|T} = f_{\theta}({a}_{t|T})$, respectively.

To obtain the $k$-step-ahead predicted neural activity ${y}_{t+k|t}$, we first recursively forward predict the dynamic latent factors $k$ time-steps with the Kalman predictor, and obtain ${x}_{t+k|t}$. Then, we can compute the $k$-step-ahead predictions of manifold latent factors and neural observations with ${a}_{t+k|t} = C{x}_{t+k|t}$ and ${y}_{t+k|t} = f_{\theta}({a}_{t+k|t})$, respectively.

### Training Loss Function
Having established the DFINE model and its inference, we can learn the model parameters $\psi, \theta, \phi$ by minimizing:  $\begin{equation}L(\psi, \theta, \phi) = \sum_{k=1}^K \sum_{t=1}^{T-k} e({y}_{t+k|t}, {y}_{t+k}) + \lambda_{reg} L_2 (\theta, \phi)\tag{5}\end{equation}$ where $K$ denotes the maximum horizon for future-step-ahead prediction, $e(\cdot, \cdot)$ denotes the error measure which is taken as mean-squared error (MSE) loss, $L_2(\cdot)$ is L2 regularization for the autoencoder parameters $\{\theta, \phi\}$ to prevent overfitting, and $\lambda_{reg}$ is the L2 regularization loss scale (see config_dfine.py).

### Training tips and hyperparameters
DFINE does not have many hyperparameters to tune. Yet **it may be necessary to search over a grid of the following hyperparameters to find the best performing ones for a given dataset (especially for L2 regularization loss scale)**:
- L2 regularization loss scale, config.loss.scale_l2. For hyperparameter search, you can use a small grid such as [1e-4, 5e-4, 1e-3, 2e-3] after z-scoring the signals (see below).
- $K$, future-step-ahead prediction horizon provided as a list, config.loss.steps_ahead
- Encoder/decoder architecture, i.e., number of hidden layers and units in each layer, config.model.hidden_layer_list
- Setting $n_a$ higher than $n_y$ may lead to overfitting, it's recommended that $n_a \leq n_y$
- As we show in Extended Data Fig. 8, it's recommended to set $n_a = n_x$ to reduce the hyperparameter search complexity

For default values of all hyperparameters, please see config_dfine.py. For the future-step-ahead prediction horizon, we used $K=4$, or config.loss.steps_ahead = [1,2,3,4] for DFINE.

It is important to note that for neural signals, we performed **z-scoring** which is highly recommended, please see below and time_series_utils.z_score_tensor. Recommendations above for L2 regularization scale are with z-scoring, which can affect the choice of L2 regularization scale.


**DFINE is currently implemented for continuous-valued signals**.

In [1]:
%matplotlib inline
from itertools import product

import random
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader

from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

from config_dfine import get_default_config
from trainers.TrainerDFINE import TrainerDFINE
from datasets import DFINEDataset
from time_series_utils import z_score_tensor, get_nrmse_error

import h5py
import numpy as np
import os

def set_seed(seed=0):
    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)
    
def train_test_split(y, train_ratio, batch_size = 4, shuffleTrain = True):
# Split data into training and test datasets
    num_trials = y.shape[0]
    num_train_trials = int(train_ratio * num_trials)
    num_test_trials = num_trials - num_train_trials
    train_y = y[:num_train_trials, ...]
    test_y = y[num_train_trials:, ...]

    # Z-score the observation tensors
    train_y_zsc, mean_y, std_y = z_score_tensor(train_y, fit=True)
    test_y_zsc, _, _ = z_score_tensor(test_y, mean=mean_y, std=std_y, fit=False)

    # Create DFINE dataset objects and torch dataloaders
    train_dataset = DFINEDataset(y=train_y_zsc)
    test_dataset = DFINEDataset(y=test_y_zsc)

    train_loader = DataLoader(train_dataset, batch_size = batch_size, shuffle= shuffleTrain)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    
    return train_dataset, test_dataset, train_loader, test_loader



In [None]:
from itertools import product

# Define the parameter grid
param_grid = {
    "scale_l2": [0.0001, 0.001],
    "latent_factors" :[7,15]
}

param_combinations = list(product(*param_grid.values()))
param_names = list(param_grid.keys())

# Store the best configuration
best_params = None
best_score = float("inf")


for subj in range(1, 21):
    subj_num = f"{subj:02}" 
    f = h5py.File('GTH_data\GTH_s' + str(subj_num) + '_decision_power_struct_nobs.mat', 'r')
    dataset = list(f["power_struct"]["highgamma"]["powspctrm"])
    labelsset = np.array(f["power_struct"]["beh"]["gambles"])
    
    gen_data = f["power_struct"]["highgamma"]["powspctrm"]
    highgamma = np.array(gen_data) # (time, electrodes, trials)
    y= np.moveaxis(highgamma,2,0) # (trials, time, channels)
    print("dim are: " , y.shape) #(num_seq, num_steps, dim_y)
    seed=0
    set_seed(seed)
    for params in param_combinations:
        # Map parameters to config
        param_dict = dict(zip(param_names, params))
        latent_factors = param_dict["latent_factors"]
        scale_l2 = param_dict["scale_l2"]
        config = get_default_config()
        config.device = 'cuda'
        config.model.activation ='tanh'
        config.train.num_epochs = 20
        config.train.batch_size = 4
        config.lr.init = 0.01
        config.model.supervise_behv = False
        config.seed = seed
        config.model.dim_y = y.shape[2]    
        config.model.dim_a = latent_factors     #manifold latent factors
        config.model.dim_x = latent_factors     #dynamic latent factors (should be same as dim a)
        config.loss.scale_l2 = scale_l2
        config.model.save_dir = f'./results/neural/subj_{subj_num}_l2_{scale_l2}_nlatent_{latent_factors}'
        trainer_load = TrainerDFINE(config=config)
    
    
        labelsset = labelsset.T
        behv_mask = torch.tensor(labelsset)
        behv_mask = behv_mask.squeeze() 
    
        # Filter trials based on the mask
        gamble = y[behv_mask == 1]  # Trials where the mask is 1
        no_gamble = y[behv_mask == 0]  # Trials where the mask is 0
    
        # Output shapes
        print("Gamble shape:", gamble.shape)  # (num_seq_1, 5000, 32)
        print("No Gamble shape:", no_gamble.shape)  # (num_seq_0, 5000, 32)
    
        train_dataset, test_dataset, train_loader, test_loader = train_test_split(y, 0.8)
    
        trainer_load.train(train_loader=train_loader, valid_loader=test_loader)
    #CONTINUE TRAINING

In [2]:
   def avg_latent_factor_plot(f, savedir, prefix='gamble', feat_name='x_smooth'):
        '''
        Creates dynamic latent factor plots during training/validation

        Parameters:
        ------------
        - f: torch.Tensor, shape: (num_seq, num_steps, dim_x/dim_a), Batch of inferred dynamic/manifold latent factors, smoothed/filtered factors can be provided
        - epoch: int, Number of epoch for which to create dynamic latent factor plot
        - trial_num: int, Trial number to plot
        - prefix: str, Plotname prefix to save plots
        - feat_name: str, Feature name of y_hat_batch (e.g. y_hat/y_smooth) used in plotname
        '''
        
        # From feat_name, get whether it's manifold or dynamic latent factors
        if feat_name[0].lower() == 'x':
            feat_name = 'Dynamic'
        else:
            feat_name = 'Manifold' 

        # Create the figure and colormap
        fig = plt.figure(figsize=(10,8))
        num_steps, dim_f = f.shape
        color_index = range(num_steps)
        color_map = plt.cm.get_cmap('viridis')
        
        if dim_f > 2:
            # Scatter first 3 dimensions of dynamic latent factors 
            ax = fig.add_subplot(221, projection='3d')
            ax_m = ax.scatter(f[:, 0], f[:, 1], f[:, 2], c=color_index, vmin=0, vmax=num_steps, s=35, cmap=color_map)
            ax.set_xlabel('Dim 0')
            ax.set_ylabel('Dim 1')
            ax.set_zlabel('Dim 2')
            ax.set_title(f'{feat_name} latent factors in 3D')
            fig.colorbar(ax_m)

            # Scatter first 2 dimensions of dynamic latent factors, top view
            ax = fig.add_subplot(222)
            ax_m = ax.scatter(f[:, 0], f[:, 1], c=color_index, vmin=0, vmax=num_steps, s=35, cmap=color_map)
            ax.set_xlabel('Dim 0')
            ax.set_ylabel('Dim 1')
            ax.set_title(f'{feat_name} latent factors from top')
            fig.colorbar(ax_m)

            # Plot the first dimension of dynamic latent factors
            ax = fig.add_subplot(223)
            ax.plot(range(num_steps), f[:, 0])
            ax.set_xlabel('Time')
            ax.set_ylabel('Dim 0')

            # Plot the second dimension of dynamic latent factors
            ax = fig.add_subplot(224)
            ax.plot(range(num_steps), f[:, 1])
            ax.set_xlabel('Time')
            ax.set_ylabel('Dim 1')

        elif dim_f == 2:
            # Scatter first 2 dimensions of dynamic latent factors, top view
            ax = fig.add_subplot(221)
            ax_m = ax.scatter(f[:, 0], f[:, 1], c=color_index, vmin=0, vmax=num_steps, s=35, cmap=color_map)
            ax.set_xlabel('Dim 0')
            ax.set_ylabel('Dim 1')
            ax.set_title(f'{feat_name} latent factors from top')
            fig.colorbar(ax_m)

            # Plot the first dimension of dynamic latent factors
            ax = fig.add_subplot(222)
            ax.plot(range(num_steps), f[:, 0])
            ax.set_xlabel('Time')
            ax.set_ylabel('Dim 0')

            # Plot the second dimension of dynamic latent factors
            ax = fig.add_subplot(223)
            ax.plot(range(num_steps), f[:, 1])
            ax.set_xlabel('Time')
            ax.set_ylabel('Dim 1')

        else:
            # Plot the first dimension of dynamic latent factors
            ax = fig.add_subplot(111)
            ax.plot(range(num_steps), f[:, 0])
            ax.set_xlabel('Time')
            ax.set_ylabel('Dim 0')
        fig.suptitle(f'{feat_name} latent factors info', fontsize=16)
        
        # Save the plot under plot_save_dir
        plot_name = f'{prefix}_{feat_name}_avg.png'
        plt.savefig(os.path.join(savedir, "plots", plot_name))
        plt.close('all')


In [3]:
 
def findEuclidianDist(series1, series2):

    # Compute Euclidean distances for each time step
    distances = torch.sqrt(torch.sum((series1 - series2) ** 2, dim=1))  #sum over dimension tuples

    # Overall distance (optional)
    total_distance = distances.sum()  # Total distance
    mean_distance = distances.mean()  # Total distance

    return total_distance


In [24]:
import os

# Define the parameter grid
param_grid = {
    "scale_l2": [0.0001, 0.001],
    "latent_factors" :[15,30]
}

param_combinations = list(product(*param_grid.values()))
param_names = list(param_grid.keys())

# Store the best configuration
best_params = None
best_score = float("inf")


for subj in range(1, 21):
    subj_num = f"{subj:02}" 
    f = h5py.File('GTH_data\GTH_s' + str(subj_num) + '_decision_power_struct_nobs.mat', 'r')
    dataset = list(f["power_struct"]["highgamma"]["powspctrm"])
    labelsset = np.array(f["power_struct"]["beh"]["gambles"])
    
    gen_data = f["power_struct"]["highgamma"]["powspctrm"]
    highgamma = np.array(gen_data) # (time, electrodes, trials)
    y= np.moveaxis(highgamma,2,0) # (trials, time, channels)
    print("dim are: " , y.shape) #(num_seq, num_steps, dim_y)

    labelsset = labelsset.T
    behv_mask = torch.tensor(labelsset)
    behv_mask = behv_mask.squeeze() 

    gamble = y[behv_mask == 1]
    no_gamble = y[behv_mask == 0] 
    seed=0
    set_seed(seed)
    for params in param_combinations:
        # Map parameters to config
        param_dict = dict(zip(param_names, params))
        latent_factors = param_dict["latent_factors"]
        scale_l2 = param_dict["scale_l2"]
        config = get_default_config()
        config.device = 'cuda'
        config.model.activation ='tanh'
        config.train.num_epochs = 40
        config.train.batch_size = 4
        config.lr.init = 0.01
        config.model.supervise_behv = False
        config.seed = seed
        config.model.dim_y = y.shape[2]    
        config.model.dim_a = latent_factors     #manifold latent factors
        config.model.dim_x = latent_factors     #dynamic latent factors (should be same as dim a)
        config.loss.scale_l2 = scale_l2
        config.model.save_dir = f'./results/neural/subj_{subj_num}_l2_{scale_l2}_nlatent_{latent_factors}'
        
        config.load.ckpt = 'best_loss'
        trainer = TrainerDFINE(config=config)
        
        predictions = {
            "x_pred": [],
            "x_filter": [],
            "x_smooth": [],
            "a_hat": [],
            "a_pred": [],
            "a_filter": [],
            "a_smooth": [],
        }
        
        g_file_path = os.path.join(config.model.save_dir, 'g_latents.pt')
        ng_file_path = os.path.join(config.model.save_dir, 'ng_latents.pt')
        all_file_path = os.path.join(config.model.save_dir, 'batchwise_latents.pt')
        if os.path.exists(g_file_path) and  os.path.exists(ng_file_path) and os.path.exists(all_file_path):
            gamble_latents = torch.load(g_file_path)
            no_gamble_latents = torch.load(ng_file_path)
            all_latents = torch.load(all_file_path)
            print("File loaded successfully.")
        else:
            print("File does not exist.")
            gamble_dataset,  _,gamble_loader, _ = train_test_split(gamble, 1, batch_size = 1)
            no_gamble_dataset,_, no_gamble_loader ,_ = train_test_split(no_gamble, 1, batch_size = 1)
    
            g_train_loader = DataLoader(gamble_dataset, batch_size =  1, shuffle= False)
            gamble_latents = trainer.compute_latents(train_loader=gamble_loader)
    
            no_train_loader = DataLoader(no_gamble_dataset, batch_size =  1, shuffle = False) 
            no_gamble_latents = trainer.compute_latents(train_loader= no_gamble_loader)
    
            train_y_zsc, mean_y, std_y = z_score_tensor(y, fit=True)
            train_dataset = DFINEDataset(y=train_y_zsc)
            train_loader = DataLoader(train_dataset, batch_size = 1, shuffle=False)
            all_latents = trainer.compute_latents(train_loader= train_loader)
            torch.save(gamble_latents, g_file_path)
            torch.save(no_gamble_latents, ng_file_path)
            torch.save(all_latents, all_file_path)
    
        save_dir = os.path.join(config.model.save_dir, 'confusion_matrices')
        os.makedirs(save_dir, exist_ok=True)
    
    
        confusion_matrices = {}
        num_g_trials = np.sum(labelsset)
        num_ng_trials = y.shape[0]-num_g_trials
        for key in gamble_latents['train'].keys():
            if key == 'mask':
                continue
            total_g = torch.stack(gamble_latents['train'][key], dim=0).squeeze().sum(dim=0, keepdim = True).squeeze()[2000:3001]
            total_ng = torch.stack(no_gamble_latents['train'][key], dim=0).squeeze().sum(dim=0, keepdim = True).squeeze()[2000:3001]
    
            avg_latent_factor_plot(total_g/num_g_trials, config.model.save_dir, feat_name=key)
            avg_latent_factor_plot(total_ng/num_ng_trials, config.model.save_dir, prefix="no gamble", feat_name=key)
            for trial in range(0,y.shape[0]):
                if labelsset[trial]:
                    gamble_avg = ((total_g - all_latents['train'][key][trial].squeeze()[2000:3001]) / (num_g_trials-1))
                    no_gamble_avg = total_ng/(num_ng_trials)
                    
                else:
                    gamble_avg = total_g/(num_g_trials)
                    no_gamble_avg = ((total_ng - all_latents['train'][key][trial].squeeze()[2000:3001]) /( num_ng_trials-1))
    
                test_data = all_latents['train'][key][trial].squeeze()[2000:3001]
    
                g_distances = findEuclidianDist(gamble_avg, test_data)
                ng_distances =  findEuclidianDist(no_gamble_avg, test_data)
    
                predictions[key].append(1 if ng_distances > g_distances else 0)
    
            cm = confusion_matrix(labelsset, predictions[key])
        save_path = os.path.join(config.model.save_dir, f'{key}_eval_results.json')

        if os.path.exists(save_path):
            with open(save_path, "r") as json_file:
                eval_results = json.load(json_file)
                accuracy = eval_results['accuracy']
                precision = eval_results['precision']
                recall = eval_results['recall']
                F1 = eval_results['F1_score']
                model_loss = eval_results['model_loss']
            
        else:
            print("File does not exist.")
            tn, fp, fn, tp = cm.ravel()
            accuracy = (tn + tp) / np.sum(cm) 
            precision = tp / (tp + fp) 
            recall = tp / (tp + fn)
            F1 = 2 * (precision * recall) / (precision + recall)
    
            train_dataset, test_dataset, train_loader, test_loader = train_test_split(y, 0.8)
            total_loss, loss_dict = trainer.model_eval(test_loader)
            model_loss = loss_dict['model_loss'].item()
            
            eval_results = {
                "accuracy": accuracy,
                "precision": precision,
                "recall": recall,
                "F1_score": F1,
                "model_loss": model_loss
            }
            
            # Ensure directory exists and save JSON
            os.makedirs(os.path.dirname(save_path), exist_ok=True)
            
            with open(save_path, "w") as json_file:
                json.dump(eval_results, json_file, indent=4)
                
        disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=["No Gamble", "Gamble"])
        disp.plot(cmap="Blues")
        plt.title(f"Confusion Matrix for {key}")
        plt.text(1.7, 1.75, f'F1 score: {F1} \n Accuracy: {accuracy} \n Precision: {precision} \n Recall: {recall} \n model loss: {model_loss}', fontsize=5)
        plt.savefig(os.path.join(save_dir, f"{key}_confusion_matrix.png"))
        plt.close()



dim are:  (180, 5001, 57)


02/03/2025 01:10:59 AM - DFINE Logger - INFO - Loading model from: ./results/neural/subj_01_l2_0.0001_nlatent_15\ckpts\best_loss_ckpt.pth...
02/03/2025 01:10:59 AM - DFINE Logger - INFO - Checkpoint succesfully loaded from ./results/neural/subj_01_l2_0.0001_nlatent_15\ckpts\best_loss_ckpt.pth!


TrainerDFINE loaded from: C:\Users\angel\Desktop\Code\torchDFINE\trainers\TrainerDFINE.py
File does not exist.


KeyboardInterrupt: 