In [1]:
# from whole_brain_model import WholeBrainModel, ModelParams
# from model_fitting import ModelFitting
# from costs import Costs

In [2]:
import os
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np

import matplotlib.pyplot as plt
import seaborn as sns
import scipy.io

In [3]:
class Utils:

    def check_for_nans(tensor, name="Tensor"):
        """
        Checks if a tensor contains NaNs
        Error message or Min / Max if not
        """
        if torch.isnan(tensor).any():
            print(f"[ERROR] {name} contains NaNs!")
        else:
            print(f"[DEBUG] {name} OK. Min: {tensor.min().item():.4f}, Max: {tensor.max().item():.4f}")


### Plotter

In [4]:
class Plotter:
    @staticmethod
    def plot_loss_curve(epoch_indices, loss_values):
        plt.figure()
        plt.plot(epoch_indices, loss_values, marker='o')
        plt.title("Training Loss")
        plt.xlabel("Epoch")
        plt.ylabel("Loss")
        plt.show()

    @staticmethod
    def plot_rmse_curve(epoch_indices, rmse_values):
        plt.figure()
        plt.plot(epoch_indices, rmse_values, marker='o')
        plt.title("RMSE over Epochs")
        plt.xlabel("Epoch")
        plt.ylabel("RMSE")
        plt.show()

    @staticmethod
    def plot_roi_correlation_curve(epoch_indices, roi_corr_values):
        plt.figure()
        plt.plot(epoch_indices, roi_corr_values, marker='o')
        plt.title("Average ROI Correlation over Epochs")
        plt.xlabel("Epoch")
        plt.ylabel("Average ROI Pearson r")
        plt.show()

    @staticmethod
    def plot_fc_correlation_curve(epoch_indices, fc_corr_values):
        plt.figure()
        plt.plot(epoch_indices, fc_corr_values, marker='o')
        plt.title("Average FC Correlation over Epochs")
        plt.xlabel("Epoch")
        plt.ylabel("Average FC Pearson r")
        plt.show()

    @staticmethod
    def plot_functional_connectivity_heatmaps(simulated_fc: np.ndarray, empirical_fc: np.ndarray):
        """
            Plots both simulated and empirical Functional Connectivity (heatmap) on horizontal axis
            sim_fc, emp_fc: np.ndarray
        """
        fig, axes = plt.subplots(1, 2, figsize=(10, 4))
        sns.heatmap(simulated_fc, vmin=-1, vmax=1, cmap='coolwarm', ax=axes[0])
        axes[0].set_title("Simulated FC")
        sns.heatmap(empirical_fc, vmin=-1, vmax=1, cmap='coolwarm', ax=axes[1])
        axes[1].set_title("Empirical FC")
        plt.tight_layout()
        plt.show()

    @staticmethod
    def plot_time_series(time_series: torch.Tensor, title: str, max_nodes: int = 6):
        """
            Plots batched BOLD time series on single plot:
            time_series: torch.Tensor, shape (N, T, B)
        """
        data = time_series.detach().cpu().numpy() if isinstance(time_series, torch.Tensor) else time_series
        N, T, B = data.shape
        for batch_idx in range(B):
            plt.figure(figsize=(10, 4))
            for node_idx in range(min(N, max_nodes)):
                plt.plot(np.arange(T), data[node_idx, :, batch_idx], label=f"Node {node_idx}")
            plt.title(f"{title} (Batch {batch_idx})")
            plt.xlabel("TR")
            plt.ylabel("BOLD signal")
            plt.legend()
            plt.show()

    @staticmethod
    def plot_hidden_states(hidden_state_logs: np.ndarray, state_names = ['E', 'I', 'x', 'f', 'v', 'q']):
        """
            hidden_state_logs: list of `epoch` elements, each (time_points, state_size = 6)
            Plots each of the six state variables (E, I, x, f, v, q) over TRs, with one line per epoch
        """
        time_points, state_size = hidden_state_logs[0].shape

        for dim in range(state_size):
            plt.figure(figsize=(8, 3))
            for epoch_idx, epoch_log in enumerate(hidden_state_logs):
                plt.plot( np.arange(time_points), epoch_log[:, dim], label=f'Epoch {epoch_idx+1}' )
            plt.title(f"Hidden-state '{state_names[dim]}' over TRs")
            plt.xlabel("Time")
            plt.ylabel(state_names[dim])
            plt.legend(loc='upper right', fontsize='small')
            plt.tight_layout()
            plt.show()


    @staticmethod
    def plot_coupling_parameters(**parameter_history):
        """
            parameter_history: dictionary of `epoch` elements each (time_points, parameter_size = 4)
            Plots each of the four core coupling parameters (g, g_EE, g_EI, g_IE)
        """
        epochs = range(1, len(parameter_history) + 1)
        plt.figure(figsize=(6, 4))
        for param_name, values in parameter_history.items():
            plt.plot(epochs, values, marker='o', label=param_name)
        plt.title("Coupling parameters over epochs")
        plt.xlabel("Epoch")
        plt.ylabel("Parameter Value")
        plt.legend()
        plt.tight_layout()
        plt.show()
        

    @staticmethod
    def plot_node_comparison(empirical_bold: torch.Tensor, simulated_bold: torch.Tensor, node_indices=None):
        if node_indices is None:
            node_indices = list(range(min(6, empirical_bold.shape[0])))
        emp = empirical_bold.detach().cpu().numpy()
        sim = simulated_bold.detach().cpu().numpy()
        T = emp.shape[1]
        fig, axes = plt.subplots(len(node_indices), 1, figsize=(10, 2*len(node_indices)), sharex=True)
        for i, node in enumerate(node_indices):
            axes[i].plot(np.arange(T), emp[node, :, 0], label="Empirical")
            axes[i].plot(np.arange(T), sim[node, :, 0], label="Simulated")
            axes[i].set_ylabel(f"Node {node}")
            if i == 0:
                axes[i].legend()
        axes[-1].set_xlabel("TR")
        plt.tight_layout()
        plt.show()

    @staticmethod
    def plot_laplacian(subject_index: int, laplacian_matrix: np.ndarray):
        plt.figure(figsize=(6, 5))
        sns.heatmap(laplacian_matrix, cmap='viridis')
        plt.title(f"Laplacian Heatmap (Subject {subject_index})")
        plt.xlabel("Node")
        plt.ylabel("Node")
        plt.show()

    @staticmethod
    def plot_distance_matrix(subject_index: int, distance_matrix: np.ndarray):
        plt.figure(figsize=(6, 5))
        sns.heatmap(distance_matrix, cmap='magma')
        plt.title(f"Distance Matrix (Subject {subject_index})")
        plt.xlabel("Node")
        plt.ylabel("Node")
        plt.show()

### Data Loader

In [5]:
class DataLoader:
    def __init__(self, fmri_filename: str, dti_filename: str, distance_matrices_path: str, chunk_length: int = 50):
        """
        Loads fMRI (BOLD) time series, Structural Connectivity matrices, and distance (delay) matrices, and splits BOLD time series into chunks
        """
        self.fmri_filename = fmri_filename
        self.dti_filename = dti_filename
        self.distance_matrices_path = distance_matrices_path
        self.chunk_length = chunk_length
        self.all_bold = []      # list of BOLD arrays, each shape (node_size, num_TRs)
        self.all_SC = []        # list of SC matrices, each shape (node_size, node_size)
        self.all_distances = [] # list of dist_matrix, each shape (node_size, node_size)
        self.bold_chunks = []   # list of dicts: {'subject': int, 'bold': array (node_size, chunk_length)}
        
        self._load_data()
        self._split_into_chunks()

    def _load_data(self):
        fmri_mat = scipy.io.loadmat(self.fmri_filename)
        bold_data = fmri_mat["BOLD_timeseries_HCP"]    # shape (100, 1)
        # dti_mat = scipy.io.loadmat(self.dti_filename)
        # dti_data = dti_mat["DTI_fibers_HCP"]           # shape (100, 1)
        num_subjects = bold_data.shape[0]
        
        for subject in range(num_subjects):
            bold_subject = bold_data[subject, 0]  # shape (100, 1189)
            # dti_subject = dti_data[subject, 0]    # shape (100, 100)
            self.all_bold.append(bold_subject)
            
            # SC pre-processed: symmetric, log-transform, normalise
            sc_path = os.path.join(self.distance_matrices_path, f"sc_norm_subj{subject}.npy")
            sc_norm = np.load(sc_path)
            self.all_SC.append(sc_norm)

            dist_path = os.path.join(self.distance_matrices_path, f"subj{subject}.npy")
            dist_matrix = np.load(dist_path)
            self.all_distances.append(dist_matrix)
            
        print(f"[DataLoader] Loaded {num_subjects} subjects.")

    def _split_into_chunks(self):
        self.bold_chunks = []
        for subject, bold_subject in enumerate(self.all_bold):
            num_TRs = bold_subject.shape[1]
            num_chunks = num_TRs // self.chunk_length
            for i in range(num_chunks):
                chunk = bold_subject[:, i*self.chunk_length:(i+1)*self.chunk_length]
                self.bold_chunks.append({"subject": subject, "bold": chunk})
        print(f"[DataLoader] Created {len(self.bold_chunks)} chunks (chunk length = {self.chunk_length}).")

    def sample_minibatch(self, batch_size: int):
        sampled = random.sample(self.bold_chunks, batch_size)
        batched_bold = []
        batched_SC = []
        batched_dist = []
        batch_subjects = []

        for batch_element in sampled:
            batched_bold.append(batch_element["bold"]) # (node_size, chunk_length)
            subject = batch_element["subject"]
            batch_subjects.append(subject)

            # NOTE: Test with non-Laplacian SC
            sc_norm = self.all_SC[subject]
            degree_matrix = np.diag(np.sum(sc_norm, axis=1))
            laplacian = degree_matrix - sc_norm
            batched_SC.append(laplacian)

            distance_matrix = self.all_distances[subject]
            
            batched_dist.append(distance_matrix)

            # Plotter.plot_laplacian(subject, laplacian)
            # Plotter.plot_distance_matrix(subject, distance_matrix)

        # Stack BOLD
        batched_bold = np.stack(batched_bold, axis=-1) # (node_size, chunk_length, batch_size)
        batched_bold = torch.tensor(batched_bold, dtype=torch.float32)

        # Stack batched SC
        batched_SC = np.stack(batched_SC, axis=0)
        batched_SC = torch.tensor(batched_SC, dtype=torch.float32)

        # Stack distance matrices
        batched_dist = np.stack(batched_dist, axis=0)
        batched_dist = torch.tensor(batched_dist, dtype=torch.float32)


        return batched_bold, batched_SC, batched_dist, batch_subjects



### Costs

In [6]:
class Costs:

    @staticmethod
    def compute(simulated_bold, empirical_bold):
        """
        Compare two BOLD time series and calcuate Pearson correlation between FC matrices

        Parameters:
            simulated_bold: torch.Tensor shape (N, T, B)
            empirical_bold: torch.Tensor shape (N, T, B)

        Returns:
            loss: torch scalar, Pearson's correlation loss between FC matrices
                calculated as -log(0.5 + 0.5 * global_corr)
            root mean squared error
            average node-wise Pearson correlation
            average functional connectivity Pearson correlation
        """
        if not isinstance(simulated_bold, torch.Tensor):
            simulated_bold = torch.tensor(simulated_bold, dtype=torch.float32)
        if not isinstance(empirical_bold, torch.Tensor):
            empirical_bold = torch.tensor(empirical_bold, dtype=torch.float32)
        
        assert simulated_bold.shape == empirical_bold.shape, f"Simulated and Empirical BOLD time series must have the same dimensions. Found EMP: {empirical_bold.shape}, SIM: {simulated_bold.shape}"
        # print(f"Simulated BOLD shape: ({simulated_bold.shape})")
        N, T, B = simulated_bold.shape

        rmse = torch.sqrt(torch.mean((simulated_bold - empirical_bold) ** 2)).item()

        # Compute Pearon's correlation between per node
        rois_correlation = []
        for b in range(B):
            sim_batch = simulated_bold[:, :, b]
            emp_batch = empirical_bold[:, :, b]

            # Zero mean
            s_centered = sim_batch - torch.mean(sim_batch)
            e_centered = emp_batch - torch.mean(emp_batch)
            
            dot_product = (s_centered * e_centered).sum(dim=1)
            product = (s_centered.norm(dim=1) * e_centered.norm(dim=1) + 1e-8)
            rois_correlation.append((dot_product / product).mean().item())

        average_rois_correlation = float(np.mean(rois_correlation))

        global_corrs = []
        
        for b in range(B):
            sim_b = simulated_bold[:, :, b]
            emp_b = empirical_bold[:, :, b]
        
            # Compute global FC matrices
            sim_n = sim_b - torch.mean(sim_b, dim=1, keepdim=True)
            emp_n = emp_b - torch.mean(emp_b, dim=1, keepdim=True)
            cov_sim = sim_n @ sim_n.t()  # (N, N)
            cov_emp = emp_n @ emp_n.t()  # (N, N)
            std_sim = torch.sqrt(torch.diag(cov_sim) + 1e-8)
            std_emp = torch.sqrt(torch.diag(cov_emp) + 1e-8)
            FC_sim = cov_sim / (std_sim.unsqueeze(1) * std_sim.unsqueeze(0) + 1e-8)
            FC_emp = cov_emp / (std_emp.unsqueeze(1) * std_emp.unsqueeze(0) + 1e-8)
            
            # Extract lower triangular parts (excluding the diagonal)
            mask = torch.tril(torch.ones_like(FC_sim), diagonal=-1).bool()

            sim_vec = FC_sim[mask]
            emp_vec = FC_emp[mask]
            sim_vec = sim_vec - torch.mean(sim_vec)
            emp_vec = emp_vec - torch.mean(emp_vec)
            
            global_corr_b = torch.sum(sim_vec * emp_vec) / (torch.sqrt(torch.sum(sim_vec**2)) * torch.sqrt(torch.sum(emp_vec**2)) + 1e-8)
            global_corrs.append(global_corr_b)
        
        global_corr = torch.mean(torch.stack(global_corrs))

        correlation_loss = -torch.log(0.5 + 0.5 * global_corr + 1e-8)
        return {
            "loss": correlation_loss,
            "rmse": rmse,
            "average_rois_correlation": average_rois_correlation,
            "average_fc_correlation": global_corr.detach().cpu().numpy()
        }
        


### Model

In [7]:
class ModelParams:
    def __init__(self):
        ## DMF parameters
        #  starting states taken from Griffiths et al. 2022
        self.W_E    = 1.0               # Scale for external input to excitatory population
        self.W_I    = 0.7               # Scale for external input to inhibitory population
        self.I_0     = 0.32             # Constant external input
        self.tau_E  = 100.0             # Decay time (ms) for excitatory synapses
        self.tau_I  = 10.0              # Decay time for inhibitory synapses
        self.gamma_E = 0.641 / 1000.0   # Kinetic parameter for excitatory dynamics

        # Sigmoid parameters for conversion of current to firing rate:
        self.aE     = 310.0
        self.bE     = 125.0
        self.dE     = 0.16
        self.aI     = 615.0
        self.bI     = 177.0
        self.dI     = 0.087

        # Connectivity parameters
        self.g      = 20.0              # Global coupling (long-range)
        self.g_EE   = 0.1               # Local excitatory self-feedback
        self.g_IE   = 0.1               # Inhibitory-to-excitatory coupling
        self.g_EI   = 0.1               # Excitatory-to-inhibitory coupling

        ## Balloon (haemodynamic) parameters
        self.tau_s  = 0.65
        self.tau_f  = 0.41
        self.tau_0  = 0.98
        self.alpha  = 0.32
        self.rho    = 0.34
        self.k1     = 2.38
        self.k2     = 2.0
        self.k3     = 0.48
        self.V      = 0.02      # V0 in the BOLD equation
        self.E0     = 0.34

    def __getitem__(self, key):
        return getattr(self, key)

In [19]:

# Whole Brain Model integrating both DMF and Balloon
class WholeBrainModel(nn.Module):
    def __init__(self, params: ModelParams, input_size: int, node_size: int, batch_size: int, 
                 step_size: float, tr: float):
        """
        Parameters:
            params: ModelParams container (attributes W_E, tau_E, gamma_E, ...)
            input_size: Number of inptu channels (e.g. noise channels) per integration step
            node_size: Number of nodes (ROIs)
            batch_size: Batch size (number of parallel simulations)
            step_size: Integration time steps (0.05s)
            tr: TR duation (0.75s); hidden_size = tr / step_size

        """
        super(WholeBrainModel, self).__init__()
        self.node_size = node_size
        self.batch_size = batch_size
        self.step_size = step_size
        self.tr = tr
        self.hidden_size = int(tr / step_size)  # number of integration steps per TR
        self.input_size = input_size  # noise input dimension

        self.state_size = 6  # [E, I, x, f, v, q]

        # DMF parameters
        self.W_E     = nn.Parameter(torch.tensor(params["W_E"], dtype=torch.float32))
        self.W_I     = nn.Parameter(torch.tensor(params["W_I"], dtype=torch.float32))
        self.I_0     = nn.Parameter(torch.tensor(params["I_0"], dtype=torch.float32))
        self.tau_E   = nn.Parameter(torch.tensor(params["tau_E"], dtype=torch.float32))
        self.tau_I   = nn.Parameter(torch.tensor(params["tau_I"], dtype=torch.float32))
        self.gamma_E = nn.Parameter(torch.tensor(params["gamma_E"], dtype=torch.float32))
        self.aE      = nn.Parameter(torch.tensor(params["aE"], dtype=torch.float32))
        self.bE      = nn.Parameter(torch.tensor(params["bE"], dtype=torch.float32))
        self.dE      = nn.Parameter(torch.tensor(params["dE"], dtype=torch.float32))
        self.aI      = nn.Parameter(torch.tensor(params["aI"], dtype=torch.float32))
        self.bI      = nn.Parameter(torch.tensor(params["bI"], dtype=torch.float32))
        self.dI      = nn.Parameter(torch.tensor(params["dI"], dtype=torch.float32))
        self.g       = nn.Parameter(torch.tensor(params["g"], dtype=torch.float32))
        self.g_EE    = nn.Parameter(torch.tensor(params["g_EE"], dtype=torch.float32))
        self.g_IE    = nn.Parameter(torch.tensor(params["g_IE"], dtype=torch.float32))
        self.g_EI    = nn.Parameter(torch.tensor(params["g_EI"], dtype=torch.float32))

        # Balloon (hemodynamic) parameters
        self.tau_s   = nn.Parameter(torch.tensor(params["tau_s"], dtype=torch.float32))
        self.tau_f   = nn.Parameter(torch.tensor(params["tau_f"], dtype=torch.float32))
        self.tau_0   = nn.Parameter(torch.tensor(params["tau_0"], dtype=torch.float32))
        self.alpha   = nn.Parameter(torch.tensor(params["alpha"], dtype=torch.float32))
        self.rho     = nn.Parameter(torch.tensor(params["rho"], dtype=torch.float32))
        self.k1      = nn.Parameter(torch.tensor(params["k1"], dtype=torch.float32))
        self.k2      = nn.Parameter(torch.tensor(params["k2"], dtype=torch.float32))
        self.k3      = nn.Parameter(torch.tensor(params["k3"], dtype=torch.float32))
        self.V       = nn.Parameter(torch.tensor(params["V"], dtype=torch.float32))
        self.E0      = nn.Parameter(torch.tensor(params["E0"], dtype=torch.float32))

        print(f"[DEBUG] Model initialized with {self.state_size} states, {len(params.__dict__)} learnable parameters, and {self.hidden_size} hidden step size")

    def generate_initial_states(self):
        """
        Generates the initial state for RWW (DMF) foward function. Uses same initial states as in the Griffiths et al. code

        Returns:
            initial_state: torch.Tensor of shape (node_size, input_size, batch_size)
        """
        initial_state = 0.45 * np.random.uniform(0, 1, (self.node_size, self.input_size, self.batch_size))
        baseline = np.array([0, 0, 0, 1.0, 1.0, 1.0]).reshape(1, self.input_size, 1)
        initial_state = initial_state + baseline
        return torch.tensor(initial_state, dtype=torch.float32)

    def h_tf(self, a, b, d, current):
        """
        Transformation for firing rates of excitatory and inhibitory pools
        Takes variables a, b, current and convert into a linear equation (a * current - b) while adding a small
        amount of noise (1e-5) while dividing that term to an exponential of itself multiplied by constant d for
        the appropriate dimensions
        """
        num = 1e-5 + torch.abs(a * current - b)
        den = 1e-5 * d + torch.abs(1.000 - torch.exp(-d * (a * current - b)))
        return torch.divide(num, den + 1e-8)
    

    def forward(self, hx: torch.Tensor, noise_in: torch.Tensor, noise_out: torch.Tensor, delays: torch.Tensor, \
                 batched_laplacian: torch.Tensor, dist_matrices: torch.Tensor):
        """
        Simulate on TR chunk
        
        Parameters:
            hx: Current state input, shape (node_size, 6, batch_size)
            noise_in: Noise tensor for state updates, shape (node_size, hidden_size, batch_size, input_size)
            noise_out: Noise tensor for BOLD output, shape (node_size, batch_size)
            delays: Delay buffer for E, shape (node_size, delays_max, batch_size)
            batched_laplacian: batched Laplacian tensor, shape (batch_size, node_size, node_size)
            dist_matrices: batched distance tensor, representing tract lengths for excitatory delays, shape (batch_size, node_size, node_size)
        
        Returns:
            state: Updated state (node_size, 6, batch_size)
            bold: Simulated BOLD signal (node_size, batch_size)
            delays: Updated delay buffer (node_size, delays_max, batch_size)
        """
        dt = self.step_size
        state = hx

        # Loop over hidden integration steps (one TR)
        for i in range(self.hidden_size):
            noise_step = noise_in[:, i, :, :] # (node_size, input_size, batch_size)

            # --- DMF update ---
            E = state[:, 0:1, :] # (node_size, 1, batch_size)
            I = state[:, 1:2, :] # (node_size, 1, batch_size)

            # Delayed excitatory input
            # Compute delay indices in integration steps
            speed = 1.5 # m/s
            delay_seconds = dist_matrices * 0.001 / speed
            delay_steps = (delay_seconds / self.step_size).floor().long()

            # Gather from delay buffer
            hE = delays.permute(2, 0, 1) # (batch_size, node_size, delays_max)
            E_delayed = hE.gather(dim=2, index=delay_steps) # gather along delays axis, (batch_size, node_size, node_size)
            E_delayed_post = E_delayed.permute(0, 2, 1) # transpose to j->i indexing

            # Apply Laplacian
            connectivity_b = torch.bmm(batched_laplacian, E_delayed_post) # (batch_size, node_size, 1)
            connectivity_effect = connectivity_b.permute(1, 2, 0)

            # for b, g in enumerate(connectivity_b):
            #     Plotter.plot_distance_matrix(b, g)


            I_E = torch.nn.functional.relu(self.W_E * self.I_0 + self.g_EE * E + self.g * connectivity_effect - self.g_IE * I)
            I_I = torch.nn.functional.relu(self.W_I * self.I_0 + self.g_EI * E - I)

            R_E = self.h_tf(self.aE, self.bE, self.dE, I_E)
            R_I = self.h_tf(self.aI, self.bI, self.dI, I_I)

            dE = -E / self.tau_E + (1 - E) * self.gamma_E * R_E
            dI = -I / self.tau_I + R_I

            E_new = torch.tanh(torch.nn.functional.relu(E + dt * dE + noise_step[:, 0:1, :]))  # use first channel of noise for E
            I_new = torch.tanh(torch.nn.functional.relu(I + dt * dI + noise_step[:, 1:2, :]))  # second channel for I

            # --- Balloon Update ---
            x = state[:, 2:3, :]
            f = state[:, 3:4, :]
            v = state[:, 4:5, :]
            q = state[:, 5:6, :]

            dx = E_new - torch.reciprocal(self.tau_s) * x - torch.reciprocal(self.tau_f) * (f - 1)
            df = x
            dv = (f - torch.pow(v, torch.reciprocal(self.alpha))) * torch.reciprocal(self.tau_0)
            dq = (f * (1 - torch.pow(1 - self.rho, torch.reciprocal(f))) * torch.reciprocal(self.rho) \
                   - q * torch.pow(v, torch.reciprocal(self.alpha)) * torch.reciprocal(v+1e-8)) \
                     * torch.reciprocal(self.tau_0)
            
            x_new = x + dt * dx + noise_step[:, 2:3, :]
            f_new = f + dt * df + noise_step[:, 3:4, :]
            v_new = v + dt * dv + noise_step[:, 4:5, :]
            q_new = q + dt * dq + noise_step[:, 5:6, :]

            state = torch.cat([E_new, I_new, x_new, f_new, v_new, q_new], dim=1)

            # Discard oldest delay value. Shape (node_size, delays_max, batch_size)
            delays = torch.cat([E_new, delays[:, :-1, :]], dim=1)

        BOLD = self.V * (self.k1 * (1 - q_new) + \
                        (self.k2 * (1 - q_new * torch.reciprocal(v_new))) + \
                        (self.k3 * (1 - v_new)))
        BOLD = BOLD.squeeze(1)
        BOLD = BOLD + noise_out

        # print(f"[DEBUG] BOLD shape: {BOLD.shape} (expected ({self.node_size}, {self.batch_size}))")

        return state, BOLD, delays


### ModelFitting

In [9]:
# Anish Kochhar, Imperial College London, March 2025

from tqdm import tqdm

class ModelFitting:
    def __init__(self, model: WholeBrainModel, data_loader: DataLoader, num_epochs: int, lr: float, cost_function: Costs, smoothing_window: int = 1, log_state: bool = False):
        """
        Parameters:
            model: WholeBrainModel instance
            data_loader: DataLoader instance providing sample_minibatch()
            num_epochs: Number of training epochs
            lr: Learning rate
            cost_function: compute() function for metrics comparision between simulated and empirical BOLD
            smoothing_window: size of moving-average window (1 = no smoothing)
            log_state: If True, logs the evolution of state variables over TR chunks
        """
        self.model = model
        self.loader = data_loader
        self.num_epochs = num_epochs
        self.optimizer = optim.Adam(self.model.parameters(), lr=lr)
        self.cost_function = cost_function # Costs.compute
        self.smoothing_window = smoothing_window
        self.log_state = log_state

        self.logs = { "losses": [], "fc_corr": [], "rmse": [], "roi_corr": [], "hidden_states": [] }
        self.parameters_history = { name: [] for name in ["g", "g_EE", "g_EI", "g_IE"] }
        # self.hidden_state_logs = [] # List per epoch (time x state_means)

    def smooth(self, bold: torch.Tensor):
        """ Applies moving average along time dimension (dim = 1) """
        if self.smoothing_window <= 1:
            return bold
        # kernel = torch.ones(self.smoothing_window, dtype=bold.dtype, device=bold.device) / self.smoothing_window
        N, T, B = bold.shape
        x = bold.permute(0, 2, 1).reshape(-1, 1, T) # (N * B, T)
        # Asymmetric Pad
        left_pad = (self.smoothing_window - 1) // 2
        right_pad = self.smoothing_window - 1 - left_pad
        x = F.pad(x, (left_pad, right_pad), mode='replicate')
        smoothed = F.avg_pool1d(x, kernel_size=self.smoothing_window, stride=1)
        smoothed = smoothed.reshape(N, B, T).permute(0, 2, 1)
        return smoothed

    def compute_fc(self, matrix: torch.Tensor):
        """ Builds the FC matrix  """
        zero_centered = matrix - matrix.mean(dim=1, keepdim=True)
        covariance = zero_centered @ zero_centered.T
        std = torch.sqrt(torch.diag(covariance)).unsqueeze(0)
        return (covariance / (std.T * std + 1e-8)).detach().cpu().numpy()

    def train(self, delays_max: int = 500, batch_size: int = 20):
        """
        Train the model over multiple minibatches, iterating over TR chunks for each sample

        Parameters:
            delays_max: Maximum delay stored for residual connections
            batch_size: Minibatch size
        """

        state = self.model.generate_initial_states()
        delays = torch.zeros(self.model.node_size, delays_max, batch_size)

        for epoch in range(1, self.num_epochs + 1):
            self.optimizer.zero_grad()
            
            empirical_bold, laplacians, dist_matrices, _sampled = self.loader.sample_minibatch(batch_size)
            num_TRs = empirical_bold.shape[1]

            simulated_bold_chunks = []
            epoch_state_log = []

            for _tr in range(num_TRs):
                # -- Generate noise for integration steps --
                # noise_in shape: (node_size, hidden_size, batch_size, input_size) with input_size = 6
                noise_in = torch.randn(self.model.node_size, self.model.hidden_size, self.model.input_size, batch_size, device=empirical_bold.device) * 0.02
                noise_in = torch.zeros_like(noise_in)
                # noise_out shape: (node_size, batch_size)
                noise_out = torch.randn(self.model.node_size, batch_size, device=empirical_bold.device) * 0.02
                noise_out = torch.zeros_like(noise_out)

                state, bold_chunk, delays = self.model(state, noise_in, noise_out, delays, laplacians, dist_matrices)
                simulated_bold_chunks.append(bold_chunk)

                if self.log_state:
                    epoch_state_log.append(state.mean(dim=(0, 2)).detach().cpu().numpy())

            # Stack TR chunks to form a time series: (node_size, num_TRs, batch_size)
            simulated_bold_epoch = torch.stack(simulated_bold_chunks, dim=1)
            smoothed_simulated_bold_epoch = self.smooth(simulated_bold_epoch)

            # Compute cost 
            metrics = self.cost_function.compute(smoothed_simulated_bold_epoch, empirical_bold)
            metrics["loss"].backward()

            for name, param in self.model.named_parameters():
                if param.grad is not None:
                    print(name, param.grad.norm().item())

            self.optimizer.step()

            for parameter_name in self.parameters_history:
                self.parameters_history[parameter_name].append(getattr(self.model, parameter_name).item())
            
            self.logs["losses"].append(metrics["loss"].detach().cpu().numpy())
            self.logs["fc_corr"].append(metrics["average_fc_correlation"])
            self.logs["roi_corr"].append(metrics["average_rois_correlation"])
            self.logs["rmse"].append(metrics["rmse"])

            if self.log_state:
                self.logs["hidden_states"].append(np.stack(epoch_state_log, axis=0))

            print(
                f"Epoch {epoch}/{self.num_epochs} | "
                f"Loss: {metrics['loss'].item():.4f} | "
                f"RMSE: {metrics['rmse']:.4f} | "
                f"Avg ROI Corr: {metrics['average_rois_correlation']:.4f} | "
                f"Avg FC Corr: {metrics['average_fc_correlation']:.4f}"
            )

        # Final epoch visualisations
        epochs = list(range(1, self.num_epochs + 1))
        Plotter.plot_loss_curve(epochs, self.logs["losses"])
        Plotter.plot_fc_correlation_curve(epochs, self.logs["fc_corr"])
        Plotter.plot_roi_correlation_curve(epochs, self.logs["roi_corr"])
        Plotter.plot_rmse_curve(epochs, self.logs["rmse"])

        if self.log_state:
            Plotter.plot_hidden_state_evolution(self.logs["hidden_states"])

        Plotter.plot_coupling_parameters(self.parameters_history)

        # Plot FC matrix heatmaps for final epoch for batch 0
        simulated_fc = self.compute_fc(smoothed_simulated_bold_epoch[:, :, 0])
        empirical_fc = self.compute_fc(empirical_bold[:, :, 0])
        
        Plotter.plot_functional_connectivity_heatmaps(simulated_fc, empirical_fc)

        Plotter.plot_node_comparison(
            empirical_bold[:, :, 0].unsqueeze(-1),
            smoothed_simulated_bold_epoch[:, :, 0].unsqueeze(-1),
            node_indices=list(range(6))
        )

In [10]:

fmri_filename = "../HCP Data/BOLD Timeseries HCP.mat"
dti_filename = "../HCP Data/DTI Fibers HCP.mat"
distance_matrices_path = "../HCP Data/distance_matrices/"


data_loader = DataLoader(fmri_filename, dti_filename, distance_matrices_path, chunk_length=50)

[DataLoader] Loaded 100 subjects.
[DataLoader] Created 2300 chunks (chunk length = 50).


In [11]:
batch_size = 8                         # Minibatch size
_, sc, dist, _ = data_loader.sample_minibatch(batch_size)

In [12]:
## Model Settings
node_size = sc.shape[1]                 # 100
step_size = 0.05                        # Integration time step
tr = 0.75                               # TR duration
input_size = 6                          # Number of noise channels
delays_max = 20                         # Maximum delay time

In [20]:
params = ModelParams()

model = WholeBrainModel(params, input_size, node_size, batch_size, step_size, tr)

costs = Costs()

trainer = ModelFitting(model, data_loader, num_epochs=10, lr=0.001, cost_function=costs, smoothing_window=10, log_state=True)

[DEBUG] Model initialized with 6 states, 26 learnable parameters, and 15 hidden step size


In [21]:
trainer.train(delays_max, batch_size)

print("[Main] Training complete")


: 