In [1]:
# from whole_brain_model import WholeBrainModel, ModelParams
# from model_fitting import ModelFitting
# from costs import Costs

In [1]:
import random
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

import matplotlib.pyplot as plt
import seaborn as sns
import scipy.io

In [3]:
class Utils:

    def check_for_nans(tensor, name="Tensor"):
        """
        Checks if a tensor contains NaNs
        Error message or Min / Max if not
        """
        if torch.isnan(tensor).any():
            print(f"[ERROR] {name} contains NaNs!")
        else:
            print(f"[DEBUG] {name} OK. Min: {tensor.min().item():.4f}, Max: {tensor.max().item():.4f}")


### Data Loader

In [None]:
class DataLoader:
    def __init__(self, fmri_filename: str, dti_filename: str, chunk_length: int = 50):
        """
        Loads fMRI and DTI (SC) data, splits BOLD time series into chunks
        """
        self.fmri_filename = fmri_filename
        self.dti_filename = dti_filename
        self.chunk_length = chunk_length
        self.all_bold = []  # list of BOLD arrays, each shape (node_size, num_TRs)
        self.all_SC = []    # list of SC matrices, each shape (node_size, node_size)
        self.bold_chunks = []  # list of dicts: {'subject': int, 'bold': array (node_size, chunk_length)}
        
        self.load_data()
        self.split_into_chunks()

    def load_data(self):
        fmri_mat = scipy.io.loadmat(self.fmri_filename)
        dti_mat = scipy.io.loadmat(self.dti_filename)
        bold_data = fmri_mat["BOLD_timeseries_HCP"]    # shape (100, 1)
        dti_data = dti_mat["DTI_fibers_HCP"]           # shape (100, 1)
        num_subjects = bold_data.shape[0]
        
        for subject in range(num_subjects):
            bold_subject = bold_data[subject, 0]  # shape (100, 1189)
            dti_subject = dti_data[subject, 0]    # shape (100, 100)
            self.all_bold.append(bold_subject)
            
            # Process SC: symmetric, log-transform, normalise
            SC = 0.5 * (dti_subject.T + dti_subject)
            SC = np.log1p(SC) / np.linalg.norm(np.log1p(SC))
            self.all_SC.append(SC)
        print(f"[DataLoader] Loaded {num_subjects} subjects.")

    def split_into_chunks(self):
        self.bold_chunks = []
        for subject, bold_subject in enumerate(self.all_bold):
            num_TRs = bold_subject.shape[1]
            num_chunks = num_TRs // self.chunk_length
            for i in range(num_chunks):
                chunk = bold_subject[:, i*self.chunk_length:(i+1)*self.chunk_length]
                self.bold_chunks.append({"subject": subject, "bold": chunk})
        print(f"[DataLoader] Created {len(self.bold_chunks)} chunks (chunk length = {self.chunk_length}).")

    def sample_minibatch(self, batch_size: int):
        sampled = random.sample(self.bold_chunks, batch_size)
        batched_bold = []
        batched_SC = []
        for batch_element in sampled:
            batched_bold.append(batch_element["bold"]) # (node_size, chunk_length)
            subject = batch_element["subject"]
            SC = self.all_SC[subject]

            # NOTE: Test with non-Laplacian SC
            D = np.diag(np.sum(SC, axis=1))
            L = D - SC
            batched_SC.append(L)

            # self.plot_laplacian(subject, L)

        # Stack BOLD
        batched_bold = np.stack(batched_bold, axis=-1) # (node_size, chunk_length, batch_size)
        batched_bold = torch.tensor(batched_bold, dtype=torch.float32)

        # Stack batched SC
        batched_SC = np.stack(batched_SC, axis=0)
        batched_SC = torch.tensor(batched_SC, dtype=torch.float32)

        return batched_bold, batched_SC, sampled

    def plot_laplacian(self, subject: int, L: np.ndarray):
        plt.figure(figsize=(6,5))
        sns.heatmap(L, cmap='viridis')
        plt.title(f"Laplacian Heatmap (Subject {subject})")
        plt.xlabel("ROI index")
        plt.ylabel("ROI index")
        plt.show()



### Costs

In [4]:
class Costs:
    
    def plot_time_series(self, time_series, title="Time Series"):
        """
        Plots batched BOLD time series on single plot:
        
        time_series: torch.Tensor, shape (N, T, B)
        """
        if isinstance(time_series, torch.Tensor):
            time_series = time_series.detach().cpu().numpy()
        N, T, B = time_series.shape
        for batch in range(B):
            plt.figure(figsize=(12, 6))
            for i in range(min(6, N)):  # plot first 6 nodes for clarity
                plt.plot(np.arange(T), time_series[i, :, batch], label=f'Node {i}')
            plt.title(f'{title} - Batch Element {batch}')
            plt.xlabel('Time Step')
            plt.ylabel('BOLD signal')
            plt.legend()
            plt.show()

    def plot_fc(self, sim_fc, emp_fc):
        """
        Plots both simulated and empirical Functional Connectivity (heatmap) on horizontal axis

        sim_fc, emp_fc: np.ndarray
        """
        ## Visualisations
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 4))

        # Plot simulated FC matrix
        im1 = ax1.imshow(sim_fc, vmin=-1, vmax=1, cmap='coolwarm')
        ax1.set_title("Simulated FC")
        plt.colorbar(im1, ax=ax1)

        im2 = ax2.imshow(emp_fc, vmin=-1, vmax=1, cmap='coolwarm')
        ax2.set_title("Empirical FC")
        plt.colorbar(im2, ax=ax2)

        plt.tight_layout()
        plt.show()


    def compare_bold(self, simulated_bold, empirical_bold, plot=True, verbose=True):
        """
        Compare two BOLD time series and calcuate Pearson correlation between FC matrices

        Parameters:
            simulated_bold: torch.Tensor shape (N, T, B)
            empirical_bold: torch.Tensor shape (N, T, B)

        Returns:
            correlation_loss: torch scalar, Pearson's correlation loss between FC matrices
                calculated as -log(0.5 + 0.5 * global_corr)
        """
        if not isinstance(simulated_bold, torch.Tensor):
            simulated_bold = torch.tensor(simulated_bold, dtype=torch.float32)
        if not isinstance(empirical_bold, torch.Tensor):
            empirical_bold = torch.tensor(empirical_bold, dtype=torch.float32)
        
        assert simulated_bold.shape == empirical_bold.shape, f"Simulated and Empirical BOLD time series must have the same dimensions. Found EMP: {empirical_bold.shape}, SIM: {simulated_bold.shape}"
        print(f"Simulated BOLD shape: ({simulated_bold.shape})")
        N, T, B = simulated_bold.shape

        rmse = torch.sqrt(torch.mean((simulated_bold - empirical_bold) ** 2))

        # Compute Pearon's correlation between per node
        rois_correlation = []
        for b in range(B):
            batch_correlation = []
            for node in range(N):
                sim = simulated_bold[node, :, b]
                emp = empirical_bold[node, :, b]

                # Zero mean
                s_centered = sim - torch.mean(sim)
                e_centered = emp - torch.mean(emp)
                corr = torch.dot(s_centered, e_centered) / (torch.sqrt(torch.sum(s_centered**2)) * torch.sqrt(torch.sum(e_centered**2)) + 1e-8)
                batch_correlation.append(corr)

            rois_corr_b = torch.stack(batch_correlation)
            rois_correlation.append(torch.mean(rois_corr_b))
            
        average_rois_correlation = torch.mean(torch.stack(rois_correlation))

        global_corrs = []
        for b in range(B):
            sim_b = simulated_bold[:, :, b]
            emp_b = empirical_bold[:, :, b]
        
            # Compute global FC matrices
            sim_n = sim_b - torch.mean(sim_b, dim=1, keepdim=True)
            emp_n = emp_b - torch.mean(emp_b, dim=1, keepdim=True)
            cov_sim = sim_n @ sim_n.t()  # (N, N)
            cov_emp = emp_n @ emp_n.t()  # (N, N)
            std_sim = torch.sqrt(torch.diag(cov_sim) + 1e-8)
            std_emp = torch.sqrt(torch.diag(cov_emp) + 1e-8)
            FC_sim = cov_sim / (std_sim.unsqueeze(1) * std_sim.unsqueeze(0) + 1e-8)
            FC_emp = cov_emp / (std_emp.unsqueeze(1) * std_emp.unsqueeze(0) + 1e-8)
            
            # Extract lower triangular parts (excluding the diagonal)
            mask = torch.tril(torch.ones_like(FC_sim), diagonal=-1).bool()
            sim_vec = FC_sim[mask]
            emp_vec = FC_emp[mask]
            sim_vec = sim_vec - torch.mean(sim_vec)
            emp_vec = emp_vec - torch.mean(emp_vec)
            
            global_corr_b = torch.sum(sim_vec * emp_vec) / (torch.sqrt(torch.sum(sim_vec**2)) * torch.sqrt(torch.sum(emp_vec**2)) + 1e-8)
            global_corrs.append(global_corr_b)
        
        global_corr = torch.mean(torch.stack(global_corrs))

        if plot:
            FC_sim_np = FC_sim.detach().cpu().numpy()
            FC_emp_np = FC_emp.detach().cpu().numpy()

            self.plot_fc(FC_sim_np, FC_emp_np)

        if verbose:

            print(f"RMSE between BOLD time series: {rmse:.4f}")
            print(f"Average per-ROI Pearson correlation: {average_rois_correlation:.4f}")
            print(f"FC Pearson's correlation: {global_corr:.4f}")

        correlation_loss = -torch.log(0.5 + 0.5 * global_corr + 1e-8)
        return correlation_loss
        


### Model

In [5]:
class ModelParams:
    def __init__(self):
        ## DMF parameters
        #  starting states taken from Griffiths et al. 2022
        self.W_E    = 1.0               # Scale for external input to excitatory population
        self.W_I    = 0.7               # Scale for external input to inhibitory population
        self.I_0     = 0.32             # Constant external input
        self.tau_E  = 100.0             # Decay time (ms) for excitatory synapses
        self.tau_I  = 10.0              # Decay time for inhibitory synapses
        self.gamma_E = 0.641 / 1000.0   # Kinetic parameter for excitatory dynamics

        # Sigmoid parameters for conversion of current to firing rate:
        self.aE     = 310.0
        self.bE     = 125.0
        self.dE     = 0.16
        self.aI     = 615.0
        self.bI     = 177.0
        self.dI     = 0.087

        # Connectivity parameters
        self.g      = 20.0              # Global coupling (long-range)
        self.g_EE   = 0.1               # Local excitatory self-feedback
        self.g_IE   = 0.1               # Inhibitory-to-excitatory coupling
        self.g_EI   = 0.1               # Excitatory-to-inhibitory coupling

        ## Balloon (haemodynamic) parameters
        self.tau_s  = 0.65
        self.tau_f  = 0.41
        self.tau_0  = 0.98
        self.alpha  = 0.32
        self.rho    = 0.34
        self.k1     = 2.38
        self.k2     = 2.0
        self.k3     = 0.48
        self.V      = 0.02      # V0 in the BOLD equation
        self.E0     = 0.34

    def __getitem__(self, key):
        return getattr(self, key)

In [19]:

# Whole Brain Model integrating both DMF and Balloon
class WholeBrainModel(nn.Module):
    def __init__(self, params: ModelParams, input_size: int, node_size: int, batch_size: int, 
                 step_size: float, tr: float):
        """
        Parameters:
            params: ModelParams container (attributes W_E, tau_E, gamma_E, ...)
            input_size: Number of inptu channels (e.g. noise channels) per integration step
            node_size: Number of nodes (ROIs)
            batch_size: Batch size (number of parallel simulations)
            step_size: Integration time steps (0.05s)
            tr: TR duation (0.75s); hidden_size = tr / step_size

        """
        super(WholeBrainModel, self).__init__()
        self.node_size = node_size
        self.batch_size = batch_size
        self.step_size = step_size
        self.tr = tr
        self.hidden_size = int(tr / step_size)  # number of integration steps per TR
        self.input_size = input_size  # noise input dimension

        self.state_size = 6  # [E, I, x, f, v, q]

        # DMF parameters
        self.W_E    = nn.Parameter(torch.tensor(params["W_E"], dtype=torch.float32))
        self.W_I    = nn.Parameter(torch.tensor(params["W_I"], dtype=torch.float32))
        self.I_0     = nn.Parameter(torch.tensor(params["I_0"], dtype=torch.float32))
        self.tau_E  = nn.Parameter(torch.tensor(params["tau_E"], dtype=torch.float32))
        self.tau_I  = nn.Parameter(torch.tensor(params["tau_I"], dtype=torch.float32))
        self.gamma_E= nn.Parameter(torch.tensor(params["gamma_E"], dtype=torch.float32))
        self.aE     = nn.Parameter(torch.tensor(params["aE"], dtype=torch.float32))
        self.bE     = nn.Parameter(torch.tensor(params["bE"], dtype=torch.float32))
        self.dE     = nn.Parameter(torch.tensor(params["dE"], dtype=torch.float32))
        self.aI     = nn.Parameter(torch.tensor(params["aI"], dtype=torch.float32))
        self.bI     = nn.Parameter(torch.tensor(params["bI"], dtype=torch.float32))
        self.dI     = nn.Parameter(torch.tensor(params["dI"], dtype=torch.float32))
        self.g      = nn.Parameter(torch.tensor(params["g"], dtype=torch.float32))
        self.g_EE   = nn.Parameter(torch.tensor(params["g_EE"], dtype=torch.float32))
        self.g_IE   = nn.Parameter(torch.tensor(params["g_IE"], dtype=torch.float32))
        self.g_EI   = nn.Parameter(torch.tensor(params["g_EI"], dtype=torch.float32))

        # Balloon (hemodynamic) parameters
        self.tau_s  = nn.Parameter(torch.tensor(params["tau_s"], dtype=torch.float32))
        self.tau_f  = nn.Parameter(torch.tensor(params["tau_f"], dtype=torch.float32))
        self.tau_0  = nn.Parameter(torch.tensor(params["tau_0"], dtype=torch.float32))
        self.alpha  = nn.Parameter(torch.tensor(params["alpha"], dtype=torch.float32))
        self.rho    = nn.Parameter(torch.tensor(params["rho"], dtype=torch.float32))
        self.k1     = nn.Parameter(torch.tensor(params["k1"], dtype=torch.float32))
        self.k2     = nn.Parameter(torch.tensor(params["k2"], dtype=torch.float32))
        self.k3     = nn.Parameter(torch.tensor(params["k3"], dtype=torch.float32))
        self.V      = nn.Parameter(torch.tensor(params["V"], dtype=torch.float32))
        self.E0     = nn.Parameter(torch.tensor(params["E0"], dtype=torch.float32))

        print("[DEBUG] Model initialized with state_size:", self.state_size, "hidden_size:", self.hidden_size)

    def generate_initial_states(self):
        """
        Generates the initial state for RWW (DMF) foward function. Uses same initial states as in the Griffiths et al. code

        Returns:
            initial_state: torch.Tensor of shape (node_size, input_size, batch_size)
        """
        initial_state = 0.45 * np.random.uniform(0, 1, (self.node_size, self.input_size, self.batch_size))
        baseline = np.array([0, 0, 0, 1.0, 1.0, 1.0]).reshape(1, self.input_size, 1)
        initial_state = initial_state + baseline
        return torch.tensor(initial_state, dtype=torch.float32)

    def h_tf(self, a, b, d, current):
        """
        Transformation for firing rates of excitatory and inhibitory pools
        Takes variables a, b, current and convert into a linear equation (a * current - b) while adding a small
        amount of noise (1e-5) while dividing that term to an exponential of itself multiplied by constant d for
        the appropriate dimensions
        """
        num = 1e-5 + torch.abs(a * current - b)
        den = 1e-5 * d + torch.abs(1.000 - torch.exp(-d * (a * current - b)))
        return torch.divide(num, den + 1e-8)
    

    def forward(self, hx: torch.Tensor, noise_in: torch.Tensor, noise_out: torch.Tensor, delays: torch.Tensor, batched_laplacian: torch.Tensor):
        """
        Simulate on TR chunk
        
        Parameters:
            hx: Current state input, shape (node_size, 6, batch_size)
            noise_in: Noise tensor for state updates, shape (node_size, hidden_size, batch_size, input_size)
            noise_out: Noise tensor for BOLD output, shape (node_size, batch_size)
            delays: Delay buffer for E, shape (node_size, delays_max)
            batched_laplacian: batched Laplacian tensor, shape (batch_size, node_size, node_size)
        
        Returns:
            state: Updated state (node_size, 6, batch_size)
            bold: Simulated BOLD signal (node_size, batch_size)
            delays: Updated delay buffer
        """
        dt = self.step_size
        state = hx

        # Loop over hidden integration steps (one TR)
        for i in range(self.hidden_size):
            noise_step = noise_in[:, i, :, :] # (node_size, input_size, batch_size)

            # --- DMF update ---
            E = state[:, 0:1, :] # (node_size, 1, batch_size)
            I = state[:, 1:2, :] # (node_size, 1, batch_size)

            # Compute delayed excitatory input
            # delayed_E = torch.gather(delays, 1, torch.zeros_like(delays, dtype=torch.int64)) # (node_size, delays_max, batch_size)
            # assert delayed_E.shape[0] == self.node_size and delayed_E.shape[2] == self.batch_size, \
            #     f"[DEBUG] delayed_E shape mismatch: got {delayed_E.shape}, expected ({self.node_size}, delays_max, {self.batch_size})"
            # delayed_E = delayed_E.mean(dim=1, keepdim=True) # (node_size, 1, batch_size)

            # Permute E from (node_size, 1, batch_size) -> (bach_size, node_size, 1)
            E_b = E.permute(2, 0, 1)
            connectivity_effect_b = torch.bmm(batched_laplacian, E_b) # (batch_size, node_size, 1)
            connectivity_effect = connectivity_effect_b.permute(1, 2, 0)  # (node_size, 1, batch_size)


            I_E = self.W_E * self.I_0 + self.g_EE * E + self.g * connectivity_effect - self.g_IE * I
            I_I = self.W_I * self.I_0 + self.g_EI * E - I

            R_E = self.h_tf(self.aE, self.bE, self.dE, I_E)
            R_I = self.h_tf(self.aI, self.bI, self.dI, I_I)

            dE = -E / self.tau_E + (1 - E) * self.gamma_E * R_E
            dI = -I / self.tau_I + R_I

            E_new = E + dt * dE + noise_step[:, 0:1, :]  # use first channel of noise for E
            I_new = I + dt * dI + noise_step[:, 1:2, :]  # second channel for I

            # --- Balloon Update ---
            x = state[:, 2:3, :]
            f = state[:, 3:4, :]
            v = state[:, 4:5, :]
            q = state[:, 5:6, :]

            dx = E_new - torch.reciprocal(self.tau_s) * x - torch.reciprocal(self.tau_f) * (f - 1)
            df = x
            dv = (f - torch.pow(v, torch.reciprocal(self.alpha))) * torch.reciprocal(self.tau_0)
            dq = (f * (1 - torch.pow(1 - self.rho, torch.reciprocal(f))) * torch.reciprocal(self.rho) \
                   - q * torch.pow(v, torch.reciprocal(self.alpha)) * torch.reciprocal(v+1e-8)) \
                     * torch.reciprocal(self.tau_0)
            
            x_new = x + dt * dx + noise_step[:, 2:3, :]
            f_new = f + dt * df + noise_step[:, 3:4, :]
            v_new = v + dt * dv + noise_step[:, 4:5, :]
            q_new = q + dt * dq + noise_step[:, 5:6, :]

            state = torch.cat([E_new, I_new, x_new, f_new, v_new, q_new], dim=1)

            # Discard oldest delay value. Shape (node_size, delays_max, batch_size)
            delays = torch.cat([E_new, delays[:, :-1, :]], dim=1)

        BOLD = self.V * (self.k1 * (1 - q_new) + \
                        (self.k2 * (1 - q_new * torch.reciprocal(v_new))) + \
                        (self.k3 * (1 - v_new)))
        BOLD = BOLD.squeeze(1)
        BOLD = BOLD + noise_out

        Utils.check_for_nans(BOLD, "BOLD time series")

        # print(f"[DEBUG] BOLD shape: {BOLD.shape} (expected ({self.node_size}, {self.batch_size}))")

        return state, BOLD, delays


### ModelFitting

In [7]:
# Anish Kochhar, Imperial College London, March 2025

from tqdm import tqdm

class ModelFitting:
    def __init__(self, model: WholeBrainModel, empirical_bold: torch.Tensor, num_epochs: int, lr: float, cost_function, log_state: bool = False):
        """
        Parameters:
            model: WholeBrainModel instance
            empirical_bold: Empirical BOLD data as torch tensor of float32 of shape (node_size, num_TRs, batch_size)
            num_epochs: Number of training epochs
            lr: Learning rate
            cost_function: Cost function for comparision between simulated and empirical BOLD
            log_state: If True, logs the evolution of state variables over TR chunks
        """
        self.model = model
        self.empirical_bold = empirical_bold
        self.num_epochs = num_epochs
        self.optimizer = optim.Adam(self.model.parameters(), lr=lr)
        self.cost_function = cost_function # Costs.compare_bold
        self.log_state = log_state

        self.logs = { "state_means": [], "losses": [] }

    def train(self, initial_state: torch.Tensor, initial_delays: torch.Tensor, batched_laplacian: torch.Tensor):
        """
        Train the model by iterating of TR chunks

        Parameters:
            initial_state: Initial state tensor of shape (node_size, 6, batch_size)
            initial_delays: Initial delay buffer of shape (node_size, delays_max, batch_size)
            batched_laplacian: Batched Laplacian tensor of shape (batch_size, node_size, node_size)
            
        Returns:
            loss_history: List of loss values per epoch.
            simulated_bold: Simulated BOLD time series (node_size, num_TRs, batch_size)
        """
        num_TRs = self.empirical_bold.shape[1]
        loss_history = []

        state = initial_state
        delays = initial_delays

        for epoch in range(self.num_epochs):
            self.optimizer.zero_grad()
            simulated_bold_chunks = []
            epoch_state_log = []

            for tr_index in range(num_TRs):
                # -- Generate noise for integration steps --
                # noise_in shape: (node_size, hidden_size, batch_size, input_size) with input_size = 6
                noise_in = torch.randn(self.model.node_size, self.model.hidden_size, self.model.input_size, self.model.batch_size) * 0.02
                # noise_out shape: (node_size, batch_size)
                noise_out = torch.randn(self.model.node_size, self.model.batch_size) * 0.02

                state, bold_chunk, delays = self.model(state, noise_in, noise_out, delays, batched_laplacian)
                simulated_bold_chunks.append(bold_chunk)

                if self.log_state:
                    state_mean = torch.mean(state, dim=(0, 2)).detach().cpu().numpy()
                    epoch_state_log.append(state_mean)

            # Stack TR chunks to form a time series: (node_size, num_TRs, batch_size)
            simulated_bold_epoch = torch.stack(simulated_bold_chunks, dim=1)

            state = state.detach()  # Avoid backprop through the entire history
            delays = delays.detach()

            # Compute cost 
            loss = self.cost_function(simulated_bold_epoch, self.empirical_bold)
            loss.backward()
            self.optimizer.step()
            epoch_loss = loss.item()
            loss_history.append(epoch_loss)
            print(f"Epoch {epoch+1}/{self.num_epochs}, Loss: {epoch_loss:.6f}")

            if self.log_state:
                self.logs["state_means"].append(np.array(epoch_state_log))
            self.logs["losses"].append(epoch_loss)

        # Return time series from last epoch
        simulated_bold_all = simulated_bold_epoch.detach().cpu().numpy()

        return loss_history, simulated_bold_all


In [9]:

fmri_filename = "../HCP Data/BOLD Timeseries HCP.mat"
dti_filename = "../HCP Data/DTI Fibers HCP.mat"


data_loader = DataLoader(fmri_filename, dti_filename, chunk_length=50)

[DataLoader] Loaded 100 subjects.
[DataLoader] Created 2300 chunks (chunk length = 50).


In [10]:
batch_size = 4  # mini-batch size
empirical_bold_batch, batched_SC, sampled_entries = data_loader.sample_minibatch(batch_size)
# empirical_bold_batch: shape (node_size, chunk_length, batch_size)
print(f"[Main] Empirical BOLD batch shape: {empirical_bold_batch.shape}")
print(f"[Main] Batched SC shape: {batched_SC.shape}")

[Main] Empirical BOLD batch shape: torch.Size([100, 50, 4])
[Main] Batched SC shape: torch.Size([4, 100, 100])


In [11]:
## Model Settings
node_size = batched_SC.shape[1]         # 100
batch_size = 4                          # Minibatch size
step_size = 0.05                        # Integration time step
tr = 0.75                               # TR duration
input_size = 6                          # Number of noise channels
delays_max = 100                        # Maximum delay time

In [20]:
params = ModelParams()

model = WholeBrainModel(params, input_size, node_size, batch_size, step_size, tr)

costs = Costs()

# Initial states and delays
initial_state = model.generate_initial_states()
initial_delays = torch.zeros(node_size, delays_max, batch_size)


[DEBUG] Model initialized with state_size: 6 hidden_size: 15


In [21]:
trainer = ModelFitting(model, empirical_bold=empirical_bold_batch, num_epochs=10, lr=0.001, cost_function=costs.compare_bold, log_state=True)

In [None]:
loss_history, simulated_bold = trainer.train(initial_state, initial_delays, batched_SC)

print("[Main] Training complete")
print("[Main] Simulated BOLD shape:", simulated_bold.shape)
