In [1]:
# from whole_brain_model import WholeBrainModel, ModelParams
# from model_fitting import ModelFitting
# from costs import Costs
from wbm.data_loader import BOLDDataLoader
from wbm.plotter import Plotter
from wbm.utils import load_encoder, load_discriminator, DEVICE
from discriminator.contrastive_discriminator import GraphEncoder, Discriminator
from discriminator.graph_builder import GraphBuilder

In [2]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch_geometric.data import Batch

import os, random
from glob import glob
import matplotlib.pyplot as plt

### Costs

In [3]:
class Costs:

    @staticmethod
    def ledoit_wolf_shrinkage_torch(X: torch.Tensor, shrinkage_value: float = 0.1) -> torch.Tensor:
        """
        Compute Ledoit-Wolf shrunk covariance matrix in PyTorch (with manual shrinkage)
        
        Parameters:
            X: torch.Tensor of shape (T, N) or (B, T, N)
            shrinkage_value: float in [0, 1], amount of shrinkage

        Returns:
            shrunk_cov: torch.Tensor of shape (N, N) or (B, N, N)
        """
        X = X.to(DEVICE)
        if X.dim() == 2:
            X = X.unsqueeze(0)  # (1, T, N)
        B, T, N = X.shape

        X_mean = X.mean(dim=1, keepdim=True)
        X_centered = X - X_mean  # (B, T, N)
        empirical_cov = torch.matmul(X_centered.transpose(1, 2), X_centered) / (T - 1)  # (B, N, N)

        # Target: identity scaled by average variance
        avg_var = torch.mean(torch.diagonal(empirical_cov, dim1=1, dim2=2), dim=1)  # (B,)
        target = torch.stack([torch.eye(N, device=X.device) * avg_var[i] for i in range(B)], dim=0)  # (B, N, N)

        shrunk_cov = (1 - shrinkage_value) * empirical_cov + shrinkage_value * target
        return shrunk_cov.squeeze(0) if shrunk_cov.shape[0] == 1 else shrunk_cov


    @staticmethod
    def compute(simulated_bold, empirical_bold):
        """
        Compare two BOLD time series and calcuate Pearson correlation between FC matrices

        Parameters:
            simulated_bold: torch.Tensor shape (N, T, B)
            empirical_bold: torch.Tensor shape (N, T, B)

        Returns:
            loss: torch scalar, Pearson's correlation loss between FC matrices
                calculated as -log(0.5 + 0.5 * global_corr)
            root mean squared error
            average node-wise Pearson correlation
            average functional connectivity Pearson correlation
        """
        if not isinstance(simulated_bold, torch.Tensor):
            simulated_bold = torch.tensor(simulated_bold, dtype=torch.float32)
        if not isinstance(empirical_bold, torch.Tensor):
            empirical_bold = torch.tensor(empirical_bold, dtype=torch.float32)
        
        assert simulated_bold.shape == empirical_bold.shape, f"Simulated and Empirical BOLD time series must have the same dimensions. Found EMP: {empirical_bold.shape}, SIM: {simulated_bold.shape}"
        # print(f"Simulated BOLD shape: ({simulated_bold.shape})")
        N, T, B = simulated_bold.shape

        rmse = torch.sqrt(torch.mean((simulated_bold - empirical_bold) ** 2))

        # Compute Pearon's correlation between per node
        rois_correlation = []
        for b in range(B):
            sim_batch = simulated_bold[:, :, b]
            emp_batch = empirical_bold[:, :, b]

            # Zero mean
            s_centered = sim_batch - torch.mean(sim_batch)
            e_centered = emp_batch - torch.mean(emp_batch)
            
            dot_product = (s_centered * e_centered).sum(dim=1)
            product = (s_centered.norm(dim=1) * e_centered.norm(dim=1) + 1e-8)
            rois_correlation.append((dot_product / product).mean().detach().item())

        average_rois_correlation = float(np.mean(rois_correlation))

        global_corrs = []
        
        for b in range(B):
            sim_b = simulated_bold[:, :, b].permute(1, 0) # (T, N)
            emp_b = empirical_bold[:, :, b].permute(1, 0) # (T, N)

            if torch.allclose(sim_b.std(dim=0), torch.zeros_like(sim_b[0]), atol=1e-5) or \
            torch.allclose(emp_b.std(dim=0), torch.zeros_like(emp_b[0]), atol=1e-5):
                print(f"[WARNING] Batch {b}: constant or near-zero signal in sim or emp BOLD.")
                global_corr_b = torch.tensor(0.0, device=sim_b.device, requires_grad=True)
                global_corrs.append(global_corr_b)
                continue


            # MARK: Ledoit-Wolf shrunk covariance estimation
            # cov_sim = Costs.ledoit_wolf_shrinkage_torch(sim_b, shrinkage_value=0.1)  # (N, N)
            # cov_emp = Costs.ledoit_wolf_shrinkage_torch(emp_b, shrinkage_value=0.1)
            # Compute global FC matrices
            sim_n = sim_b - torch.mean(sim_b, dim=1, keepdim=True)
            emp_n = emp_b - torch.mean(emp_b, dim=1, keepdim=True)
            cov_sim = sim_n @ sim_n.t()  # (N, N)
            cov_emp = emp_n @ emp_n.t()  # (N, N)
            std_sim = torch.sqrt(torch.diag(cov_sim) + 1e-8)
            std_emp = torch.sqrt(torch.diag(cov_emp) + 1e-8)
            FC_sim = cov_sim / (std_sim.unsqueeze(1) * std_sim.unsqueeze(0) + 1e-8)
            FC_emp = cov_emp / (std_emp.unsqueeze(1) * std_emp.unsqueeze(0) + 1e-8)
            
            # Extract lower triangular parts (excluding the diagonal)
            mask = torch.tril(torch.ones_like(FC_sim), diagonal=-1).bool()

            sim_vec = FC_sim[mask]
            emp_vec = FC_emp[mask]
            sim_vec = sim_vec - torch.mean(sim_vec)
            emp_vec = emp_vec - torch.mean(emp_vec)

            dot = torch.sum(sim_vec * emp_vec)
            norm_sim = torch.sqrt(torch.sum(sim_vec ** 2)).clamp(min=1e-6)
            norm_emp = torch.sqrt(torch.sum(emp_vec ** 2)).clamp(min=1e-6)
            l2_product = norm_sim * norm_emp

            global_corr_b = dot / l2_product

            # Skip NaNs or Infs
            if not torch.isfinite(global_corr_b):
                print(f"[WARNING] NaN or Inf in global_corr_b at batch {b}")
                global_corr_b = torch.tensor(0.0, device=sim_vec.device, requires_grad=True)

            global_corrs.append(global_corr_b)

        global_corr = torch.mean(torch.stack(global_corrs)).clamp(min=1e-8)
        global_corr = torch.nan_to_num(global_corr, nan=0.0, posinf=0.0, neginf=0.0) # NAN -> 0
        global_corr = torch.clamp(global_corr, min=-0.999, max=0.999) # Clamp to (-1, 1)

        correlation_loss = -torch.log(0.5 + 0.5 * global_corr + 1e-8)

        return {
            "loss": correlation_loss,
            "rmse": rmse.detach().cpu().numpy(),
            "average_rois_correlation": average_rois_correlation,
            "average_fc_correlation": global_corr.detach().cpu().numpy()
        }
        

### Discriminator

In [4]:
class DiscriminatorHook:
    """
    Wraps the frozen discriminator + GraphBuilder. Encoder is frozen, Discriminator updates each epoch
    """
    def __init__(self, discriminator: Discriminator, builder: GraphBuilder, lambda_start: float = 0.1, lambda_end: float = 0.1,
                 warmup_epochs: int = 0, finetune_lr: float = 1e-4, device: str = DEVICE):
        self.discriminator = discriminator.to(device).eval()
        self.builder = builder
        self.lambda_start = lambda_start
        self.lambda_end = lambda_end
        self.warmup = max(1, warmup_epochs)
        self.device = device
        self.bce_target = torch.ones(1, device=device)
        self.bce = nn.BCELoss()
        self.finetune_optimizer = optim.Adam(self.discriminator.parameters(), lr=finetune_lr)

        os.makedirs("fooling_examples", exist_ok=True)

    def _check_prob(self, t, name):
        bad = (t < 0) | (t > 1) | torch.isnan(t)
        if bad.any():
            idx = bad.nonzero(as_tuple=False)[:10]  # first few offenders
            print(f"[DEBUG_BCE] {name} out of range at {idx}")
            print(f"  min={t.min().item():.4g}  max={t.max().item():.4g}")


    def _lambda_now(self, epoch: int):
        if epoch >= self.warmup:
            return self.lambda_end
        alpha = epoch / self.warmup
        return self.lambda_start * (1 - alpha) + self.lambda_end * alpha
    
    @torch.no_grad()
    def __call__(self, bold_chunk: torch.Tensor, sc_matrix: torch.Tensor, epoch: int):
        """ Returns lambda-scaled BCE loss """
        data = self.builder.build_graph(bold_chunk, sc_matrix, label=1.0)
        score = self.discriminator(data).view(-1) # [0, 1]
        self._check_prob(score, "forward")
        loss = self.bce(score, self.bce_target)
        return loss * self._lambda_now(epoch)
    
    def finetune_discriminator(self, steps: int, batch_size: int, bold_batch: torch.Tensor, sc_matrix: torch.Tensor):
        """ Finetunes discriminator network by sampling `batch_size` elements from fooling_examples/, along with latest empirical bold batch """
        fooled_paths = glob("fooling_examples/*.pt")[-100:] # latest 100
        if not fooled_paths:
            return 0.0
        self.discriminator.train()
        total_ft_loss = 0.0
        random.shuffle(fooled_paths)

        for step in range(min(steps, (len(fooled_paths) - 1) // batch_size + 1)):
            batch_paths = fooled_paths[step * batch_size : (step + 1) * batch_size]
            if not batch_paths: break

            fake_graphs = [self.builder.build_graph(torch.load(p)['bold'].to(self.device),
                                                    torch.load(p)['sc'].to(self.device), label=0.0)
                                                    for p in batch_paths]
            
            total_real = bold_batch.shape[2]
            count_real = min(len(fake_graphs), total_real)
            indices    = random.sample(range(total_real), count_real)

            real_graphs = [self.builder.build_graph(bold_batch[:, :, i],
                                                    sc_matrix[i], label=1.0)
                                                    for i in indices]
            
            batch = Batch.from_data_list(fake_graphs + real_graphs).to(DEVICE)
            target = torch.cat([torch.zeros(len(fake_graphs)), torch.ones(len(real_graphs))]).to(DEVICE)

            predictions = self.discriminator(batch)
            self._check_prob(predictions, "finetune")
            loss_d = self.bce(predictions, target)            

            self.finetune_optimizer.zero_grad(set_to_none=True)
            loss_d.backward()
            self.finetune_optimizer.step()

            total_ft_loss += loss_d.item()

        self.discriminator.eval()
        return total_ft_loss / max(1, steps)

### Model

In [5]:
class ModelParams:
    def __init__(self):
        ## DMF parameters
        #  starting states taken from Griffiths et al. 2022
        self.W_E        = 1.0               # Scale for external input to excitatory population
        self.W_I        = 0.7               # Scale for external input to inhibitory population
        self.I_0        = 0.382              # Constant external input
        self.tau_E      = 100.0             # Decay time (ms) for excitatory synapses
        self.tau_I      = 10.0              # Decay time for inhibitory synapses
        self.gamma_E    = 0.641 / 1000.0    # Kinetic parameter for excitatory dynamics
        self.gamma_I    = 1.0 / 1000.0      # Kinetic parameter for inhibitory dynamics
        self.sigma_E    = 0.02              # Std. of Gaussian noise for E
        self.sigma_I    = 0.02              # Std. of Gaussian noise for I
        self.sigma_BOLD = 0.0               # Std. of Gaussian noise for BOLD

        # Sigmoid parameters for conversion of current to firing rate:
        self.aE     = 310.0
        self.bE     = 125.0
        self.dE     = 0.16
        self.aI     = 615.0
        self.bI     = 177.0
        self.dI     = 0.087

        # Connectivity parameters
        self.g      = 40.0               # Global coupling (long-range)
        self.g_EE   = 2.5                # Local excitatory self-feedback
        self.g_IE   = 0.42               # Inhibitory-to-excitatory coupling
        self.g_EI   = 0.42               # Excitatory-to-inhibitory coupling

        ## Balloon (haemodynamic) parameters
        self.tau_s  = 0.65
        self.tau_f  = 0.41
        self.tau_0  = 0.98
        self.alpha  = 0.32
        self.rho    = 0.34
        self.k1     = 2.38
        self.k2     = 2.0
        self.k3     = 0.48
        self.V      = 0.02              # V0 in the BOLD equation
        self.E0     = 0.34

    def __getitem__(self, key):
        return getattr(self, key)

In [6]:

# Whole Brain Model integrating both DMF and Balloon
class WholeBrainModel(nn.Module):
    def __init__(self, params: ModelParams, input_size: int, node_size: int, batch_size: int, 
                 step_size: float, tr: float, delays_max: int):
        """
        Parameters:
            params: ModelParams container (attributes W_E, tau_E, gamma_E, ...)
            input_size: Number of inptu channels (e.g. noise channels) per integration step
            node_size: Number of nodes (ROIs)
            batch_size: Batch size (number of parallel simulations)
            step_size: Integration time steps (0.05s)
            tr: TR duation (0.75s); hidden_size = tr / step_size
            delays_max: Maximum size of delay buffer

        """
        super(WholeBrainModel, self).__init__()
        self.node_size = node_size
        self.batch_size = batch_size
        self.step_size = torch.tensor(step_size, dtype=torch.float32, device=DEVICE)
        self.tr = tr
        self.hidden_size = int(tr / step_size)  # number of integration steps per TR
        self.input_size = input_size  # noise input dimension
        self.delays_max = delays_max

        self.state_size = 6  # [E, I, x, f, v, q]

        # DMF parameters
        self.W_E        = (torch.tensor(params["W_E"], dtype=torch.float32, device=DEVICE))
        self.W_I        = (torch.tensor(params["W_I"], dtype=torch.float32, device=DEVICE))
        self.I_0        = (torch.tensor(params["I_0"], dtype=torch.float32, device=DEVICE))
        self.tau_E      = (torch.tensor(params["tau_E"], dtype=torch.float32, device=DEVICE))
        self.tau_I      = (torch.tensor(params["tau_I"], dtype=torch.float32, device=DEVICE))
        self.gamma_E    = (torch.tensor(params["gamma_E"], dtype=torch.float32, device=DEVICE))
        self.gamma_I    = (torch.tensor(params["gamma_I"], dtype=torch.float32, device=DEVICE))
        self.sigma_E    = (torch.tensor(params["sigma_E"], dtype=torch.float32, device=DEVICE))
        self.sigma_I    = (torch.tensor(params["sigma_I"], dtype=torch.float32, device=DEVICE))
        self.sigma_BOLD = (torch.tensor(params["sigma_BOLD"], dtype=torch.float32, device=DEVICE))
        
        self.aE      = (torch.tensor(params["aE"], dtype=torch.float32, device=DEVICE))
        self.bE      = (torch.tensor(params["bE"], dtype=torch.float32, device=DEVICE))
        self.dE      = (torch.tensor(params["dE"], dtype=torch.float32, device=DEVICE))
        self.aI      = (torch.tensor(params["aI"], dtype=torch.float32, device=DEVICE))
        self.bI      = (torch.tensor(params["bI"], dtype=torch.float32, device=DEVICE))
        self.dI      = (torch.tensor(params["dI"], dtype=torch.float32, device=DEVICE))
        self.g       = nn.Parameter(torch.tensor(params["g"], dtype=torch.float32, device=DEVICE))
        self.g_EE    = (torch.tensor(params["g_EE"], dtype=torch.float32, device=DEVICE))
        self.g_IE    = (torch.tensor(params["g_IE"], dtype=torch.float32, device=DEVICE))
        self.g_EI    = (torch.tensor(params["g_EI"], dtype=torch.float32, device=DEVICE))

        # Balloon (hemodynamic) parameters
        self.tau_s   = (torch.tensor(params["tau_s"], dtype=torch.float32, device=DEVICE))
        self.tau_f   = (torch.tensor(params["tau_f"], dtype=torch.float32, device=DEVICE))
        self.tau_0   = (torch.tensor(params["tau_0"], dtype=torch.float32, device=DEVICE))
        self.alpha   = (torch.tensor(params["alpha"], dtype=torch.float32, device=DEVICE))
        self.rho     = (torch.tensor(params["rho"], dtype=torch.float32, device=DEVICE))
        self.k1      = (torch.tensor(params["k1"], dtype=torch.float32, device=DEVICE))
        self.k2      = (torch.tensor(params["k2"], dtype=torch.float32, device=DEVICE))
        self.k3      = (torch.tensor(params["k3"], dtype=torch.float32, device=DEVICE))
        self.V       = (torch.tensor(params["V"], dtype=torch.float32, device=DEVICE))
        self.E0      = (torch.tensor(params["E0"], dtype=torch.float32, device=DEVICE))

        print(f"[DEBUG] Model initialized with {self.state_size} states, {len(params.__dict__)} learnable parameters, and {self.hidden_size} hidden step size")

    def generate_initial_states(self):
        """
        Generates the initial state for RWW (DMF) foward function. Uses same initial states as in the Griffiths et al. code

        Returns:
            initial_state: torch.Tensor of shape (node_size, input_size, batch_size)
        """
        initial_state = 0.1 * np.random.uniform(0, 1, (self.node_size, self.input_size, self.batch_size))
        baseline = np.array([0, 0, 0, 1.1, 1.0, 1.0]).reshape(1, self.input_size, 1)
        initial_state = initial_state + baseline
        # state_means = initial_state.mean(axis=(0, 2))
        # E_mean, I_mean, x_mean, f_mean, v_mean, q_mean = state_means
        # print(f"BASE | E={E_mean:.4f} I={I_mean:.4f} x={x_mean:.4f} f={f_mean:.4f} v={v_mean:.4f} q={q_mean:.4f}")
        return torch.tensor(initial_state, dtype=torch.float32, device=DEVICE)

    def firing_rate(self, a, b, d, current):
        """
        Transformation for firing rates of excitatory and inhibitory pools
        Takes variables a, b, current and convert into a linear equation (a * current - b) while adding a small
        amount of noise (1e-5) while dividing that term to an exponential of itself multiplied by constant d for
        the appropriate dimensions
        """
        x = a * current - b
        return x / (1.000 - torch.exp(-d * x) + 1e-8)
    

    def forward(self, hx: torch.Tensor, external_current: torch.Tensor, noise_in: torch.Tensor, noise_out: torch.Tensor, \
                delays: torch.Tensor, batched_laplacian: torch.Tensor, dist_matrices: torch.Tensor):
        """
        Simulate on TR chunk
        
        Parameters:
            hx: Current state input, shape (node_size, 6, batch_size)
            external_current: External current input for excitatory nodes (node_size, hidden_size, batch_size)
            noise_in: Noise tensor for state updates, shape (node_size, hidden_size, batch_size, input_size)
            noise_out: Noise tensor for BOLD output, shape (node_size, batch_size)
            delays: Delay buffer for E, shape (node_size, delays_max, batch_size)
            batched_laplacian: batched Laplacian tensor, shape (batch_size, node_size, node_size)
            dist_matrices: batched distance tensor, representing tract lengths for excitatory delays, shape (batch_size, node_size, node_size)
        
        Returns:
            state: Updated state (node_size, 6, batch_size)
            bold: Simulated BOLD signal (node_size, batch_size)
            delays: Updated delay buffer (node_size, delays_max, batch_size)
        """
        state = hx
        dt = self.step_size
        ones_tensor = torch.ones_like(dt, device=DEVICE)
        relu = torch.nn.ReLU() # ReLU module

        # Loop over hidden integration steps (one TR)
        for i in range(self.hidden_size):
            noise_step = noise_in[:, i, :, :] # (node_size, input_size, batch_size)
            input_current = external_current[:, i, :].unsqueeze(1) # (node_size, 1, batch_size)

            # --- DMF update ---
            E = state[:, 0:1, :] # (node_size, 1, batch_size)
            I = state[:, 1:2, :] # (node_size, 1, batch_size)

            # Delayed excitatory input
            # Compute delay indices in integration steps
            speed = 1.5 * ones_tensor # m/s
            delay_seconds = dist_matrices * 0.001 / speed
            delay_steps = (delay_seconds / self.step_size).floor().long().clamp(0, self.delays_max - 1) # (batch_size, node_size, node_size)

            # Gather from delay buffer
            hE = delays.permute(2, 1, 0) # (batch_size, delays_max, node_size)
            E_delayed = hE.gather(dim=1, index=delay_steps) # gather along delays axis, (batch_size, node_size, node_size)

            # Apply Laplacian
            weighted_delays = batched_laplacian * E_delayed
            summed_delays = weighted_delays.sum(dim=2) 
            connectivity_effect = summed_delays.permute(1, 0).unsqueeze(1) # (node_size, 1, batch_size)
            
            if i == self.hidden_size - 1:
                Plotter.plot()
            # print('std(connectivity) / std(E) =',connectivity_effect.std().item() / E.std().item())

            I_E = relu(self.W_E * self.I_0 + self.g_EE * E + self.g * connectivity_effect - self.g_IE * I) + input_current
            I_I = relu(self.W_I * self.I_0 + self.g_EI * E - I)

            R_E = self.firing_rate(self.aE, self.bE, self.dE, I_E)
            R_I = self.firing_rate(self.aI, self.bI, self.dI, I_I)

            print(f"R_E statistics mean {R_E.mean().item():.5f} max {R_E.max().item():.5f} min {R_E.min().item()}")
            print(f"R_I statistics mean {R_I.mean().item():.5f} max {R_I.max().item():.5f} min {R_I.min().item()}")

            dE = -E / self.tau_E + (ones_tensor - E) * self.gamma_E * R_E
            dI = -I / self.tau_I + self.gamma_I * R_I

            E_noise = self.sigma_E * noise_step[:, 0:1, :] * torch.sqrt(dt) # use first channel of noise for E
            I_noise = self.sigma_I * noise_step[:, 1:2, :] * torch.sqrt(dt) # second channel for I

            E_new = torch.tanh(E + dt * dE + E_noise)
            I_new = torch.tanh(I + dt * dI + I_noise)

            # --- Balloon Update ---
            x = state[:, 2:3, :]
            f = state[:, 3:4, :]
            v = state[:, 4:5, :]
            q = state[:, 5:6, :]

            dx = 1 * R_E - torch.reciprocal(self.tau_s) * x - torch.reciprocal(self.tau_f) * (f - ones_tensor)
            df = x
            dv = (f - torch.pow(v, torch.reciprocal(self.alpha))) * torch.reciprocal(self.tau_0)
            dq = (f * (ones_tensor - torch.pow(ones_tensor - self.rho, torch.reciprocal(f))) * torch.reciprocal(self.rho) \
                   - q * torch.pow(v, torch.reciprocal(self.alpha)) * torch.reciprocal(v+1e-8)) \
                     * torch.reciprocal(self.tau_0)
            
            x_new = x + dt * dx #+ noise_step[:, 2:3, :]
            f_new = f + dt * df #+ noise_step[:, 3:4, :]    
            v_new = v + dt * dv #+ noise_step[:, 4:5, :]
            q_new = q + dt * dq #+ noise_step[:, 5:6, :]

            state = torch.cat([E_new, I_new, x_new, f_new, v_new, q_new], dim=1)

            # Discard oldest delay value. Shape (node_size, delays_max, batch_size)
            delays = torch.cat([E_new, delays[:, :-1, :]], dim=1)

        BOLD = 100.0 * self.V * torch.reciprocal(self.E0) * (self.k1 * (ones_tensor - q_new) + \
                        (self.k2 * (ones_tensor - q_new * torch.reciprocal(v_new))) + \
                        (self.k3 * (ones_tensor - v_new)))
        BOLD = BOLD.squeeze(1)
        BOLD = BOLD + self.sigma_BOLD * noise_out # shape (node_size, batch_size)
        # print(f"BOLD avg: {BOLD.mean().item():.4f}")

        return state, BOLD, delays


### ModelFitting

In [7]:
# Anish Kochhar, Imperial College London, March 2025

from tqdm import tqdm

class ModelFitting:
    def __init__(self, model: WholeBrainModel, discriminator: DiscriminatorHook, data_loader: BOLDDataLoader, num_epochs: int, lr: float, cost_function: Costs, \
                 smoothing_window: int = 1, finetune_steps: int = 5, finetune_batch: int = 32, log_state: bool = False, device = DEVICE):
        """
        Parameters:
            model: WholeBrainModel instance
            discriminator: DiscriminatorHook instance, containing all functionality to get classification loss (real vs. fake)
            data_loader: BOLDDataLoader instance providing sample_minibatch()
            num_epochs: Number of training epochs
            lr: Learning rate
            cost_function: compute() function for metrics comparision between simulated and empirical BOLD
            smoothing_window: size of moving-average window (1 = no smoothing)
            log_state: If True, logs the evolution of state variables over TR chunks
        """
        self.model = model
        self.discriminator_hook = discriminator
        self.loader = data_loader
        self.num_epochs = num_epochs
        self.optimizer = optim.Adam(self.model.parameters(), lr=lr)
        self.cost_function = cost_function # Costs.compute
        self.smoothing_window = smoothing_window
        self.log_state = log_state

        self.finetune_steps = finetune_steps # per epoch
        self.finetune_batch = finetune_batch

        self.logs = { "losses": [], "fc_correlation": [], "rmse": [], "roi_correlation": [], "hidden_states": [], "adv_loss": [] }
        self.parameters_history = { name: [] for name in ["g", "g_EE", "g_EI", "g_IE"] }
        self.device = device


    def smooth(self, bold: torch.Tensor):
        """ Applies moving average along time dimension (dim = 1) """
        if self.smoothing_window <= 1:
            return bold

        N, T, B = bold.shape
        x = bold.permute(0, 2, 1).reshape(-1, 1, T) # (N * B, T)
        
        # Asymmetric pad
        left_pad = (self.smoothing_window - 1) // 2
        right_pad = self.smoothing_window - 1 - left_pad
        x = F.pad(x, (left_pad, right_pad), mode='replicate')
        smoothed = F.avg_pool1d(x, kernel_size=self.smoothing_window, stride=1)
        smoothed = smoothed.reshape(N, B, T).permute(0, 2, 1)
        return smoothed

    def compute_fc(self, matrix: torch.Tensor):
        """ Builds the FC matrix  """
        zero_centered = matrix - matrix.mean(dim=1, keepdim=True)
        covariance = zero_centered @ zero_centered.T
        std = torch.sqrt(torch.diag(covariance)).unsqueeze(0)
        return (covariance / (std.T * std + 1e-8)).detach().cpu().numpy()

    def train(self, delays_max: int = 500, batch_size: int = 20):
        """
        Train the model over multiple minibatches, iterating over TR chunks for each sample

        Parameters:
            delays_max: Maximum delay stored for residual connections
            batch_size: Minibatch size
        """
        torch.autograd.set_detect_anomaly(True)

        num_batches = self.loader.batched_dataset_length(batch_size) # Minibatches per epoch)

        for epoch in range(1, self.num_epochs + 1):
            # Initial state
            state = self.model.generate_initial_states().to(self.device)
            delays = torch.zeros(self.model.node_size, delays_max, batch_size, device=self.device)

            batch_losses = []
            batch_fc_corrs = []
            batch_roi_corrs = []
            batch_rmses = []
            batch_adv_losses = []
            epoch_state_log = []

            self.model.sigma_BOLD.data.fill_(0.)

            batch_iter = tqdm(range(num_batches), desc=f"Epochs [{epoch}/{self.num_epochs}]", unit="batch", leave=False)

            for batch_index in batch_iter:
                self.optimizer.zero_grad()
            
                empirical_bold, normalised_sc, laplacians, dist_matrices, sampled = self.loader.sample_minibatch(batch_size)

                num_TRs = empirical_bold.shape[1]   # chunk_length = 50

                simulated_bold_chunks = []

                for tr_index in range(num_TRs):
                    # noise_in shape: (node_size, hidden_size, batch_size, input_size) with input_size = 6
                    noise_in = torch.randn(self.model.node_size, self.model.hidden_size, self.model.input_size, batch_size, device=self.device)
                    # noise_in = torch.zeros_like(noise_in)

                    # noise_out shape: (node_size, batch_size)
                    noise_out = torch.randn(self.model.node_size, batch_size, device=self.device)
                    # noise_out = torch.zeros_like(noise_out)

                    external_current = torch.zeros(self.model.node_size, self.model.hidden_size, batch_size, device=self.device)

                    state, bold_chunk, delays = self.model(state, external_current, noise_in, noise_out, delays, laplacians, dist_matrices)
                    simulated_bold_chunks.append(bold_chunk)

                    if self.log_state and batch_index == 0:
                        state_means = state.mean(dim=(0, 2)).detach().cpu().numpy()
                        if tr_index % 10 == 0:
                            E_mean, I_mean, x_mean, f_mean, v_mean, q_mean = state_means
                            print(f"TR {tr_index:02d} | E={E_mean:.4f}  I={I_mean:.4f}  x={x_mean:.4f}  f={f_mean:.4f}  v={v_mean:.4f}  q={q_mean:.4f}" )
                        epoch_state_log.append(state_means)
                        
                    # if tr_index % 10 == 0:
                    #     print(f"SNR {((bold_chunk - noise_out).std()/noise_out.std()).item():.1f}",
                    #     f"|corr(E)| {torch.corrcoef(state[:,0,:]).abs().mean():.2f}",
                    #     f"|corr(BOLD)| {torch.corrcoef(bold_chunk).abs().mean():.2f}")


                # Stack TR chunks to form a time series: (node_size, num_TRs, batch_size)
                simulated_bold_epoch = torch.stack(simulated_bold_chunks, dim=1)
                smoothed_simulated_bold_epoch = self.smooth(simulated_bold_epoch)

                # Compute cost
                metrics = self.cost_function.compute(smoothed_simulated_bold_epoch, empirical_bold)
                loss = metrics["loss"]

                # Compute discriminator loss
                # total_adversarial_loss = 0.0
                # for b in range(batch_size):
                #     adversarial_loss = self.discriminator_hook(smoothed_simulated_bold_epoch[:, :, b],
                #                                                normalised_sc[b], epoch)
                #     print(f"Subj {sampled[b]:3d} Loss {adversarial_loss.item():.4f}")
                #     if adversarial_loss.item() < 1e-3: # MARK Fool threshold
                #         torch.save({'bold': smoothed_simulated_bold_epoch[:, :, b].cpu(), 'sc': normalised_sc[b].cpu()},
                #                     f"fooling_examples/epoch{epoch:03d}_batch{batch_index:03d}_idx{b}.pt")
                #         print(f"Saved to fooling_examples/epoch{epoch:03d}_batch{batch_index:03d}_idx{b}.pt")
                #         # Plotter.plot_time_series(smoothed_simulated_bold_epoch[:, :, b].unsqueeze(-1), title="FOOLED")
                #     # elif adversarial_loss.item() > 2.25:
                #     #     torch.save({'bold': smoothed_simulated_bold_epoch[:, :, b].cpu(), 'sc': normalised_sc[b].cpu()},
                #     #                 f"fooling_examples_n/epoch{epoch:03d}_batch{batch_index:03d}_idx{b}.pt")
                #     #     print(f"Saved to fooling_examples_n/epoch{epoch:03d}_batch{batch_index:03d}_idx{b}.pt")
                #     #     Plotter.plot_time_series(smoothed_simulated_bold_epoch[:, :, b].unsqueeze(-1), title="NOT FOOLED")
                        
                #     total_adversarial_loss += adversarial_loss.item()
                
                # loss += total_adversarial_loss
                loss.backward()

                torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
                self.optimizer.step()
                
                batch_losses.append(loss.item())
                batch_fc_corrs.append(metrics["average_fc_correlation"].item())
                batch_roi_corrs.append(metrics["average_rois_correlation"])
                batch_rmses.append(metrics["rmse"])
                # batch_adv_losses.append(total_adversarial_loss)

                batch_iter.set_postfix(
                    loss=f"{loss.item():.4f}",
                    # adv=f"{total_adversarial_loss:.4f}",
                    rmse=f"{metrics['rmse']:.4f}",
                    fc_corr=f"{metrics['average_fc_correlation'].item():.4f}"
                )

                delays = delays.detach()
                state = state.detach()
            
            # Small discriminator fine-tune (with latest batch of ground-truth signals)
            # ft_loss = self.discriminator_hook.finetune_discriminator(self.finetune_steps, self.finetune_batch, empirical_bold, normalised_sc)
            # print(f"[Fintune] d-loss = {ft_loss:.4f}")

            for parameter_name in self.parameters_history:
                self.parameters_history[parameter_name].append(getattr(self.model, parameter_name).item())
            
            self.logs["losses"].append(np.mean(batch_losses))
            self.logs["fc_correlation"].append(np.mean(batch_fc_corrs))
            self.logs["roi_correlation"].append(np.mean(batch_roi_corrs))
            self.logs["rmse"].append(np.mean(batch_rmses))
            self.logs["adv_loss"].append(np.mean(batch_adv_losses))

            if self.log_state:
                self.logs["hidden_states"].append(np.stack(epoch_state_log, axis=0))

            print(
                f"Epoch {epoch}/{self.num_epochs} | "
                f"Loss: {self.logs['losses'][-1]:.4f} | "
                f"RMSE: {self.logs['rmse'][-1]:.4f} | "
                f"ROI Corr: {self.logs['roi_correlation'][-1]:.4f} | "
                f"FC Corr: {self.logs['fc_correlation'][-1]:.4f}"
            )

            # Plot FC matrix heatmaps for final epoch for batch 0
            simulated_fc = self.compute_fc(smoothed_simulated_bold_epoch[:, :, 0])
            empirical_fc = self.compute_fc(empirical_bold[:, :, 0])
            sampled_id = sampled[0]
            
            Plotter.plot_functional_connectivity_heatmaps(simulated_fc, empirical_fc, sampled_id)

            Plotter.plot_node_comparison(
                empirical_bold[:, :, 0].unsqueeze(-1),
                smoothed_simulated_bold_epoch[:, :, 0].unsqueeze(-1),
                node_indices=list(np.random.choice(range(self.model.node_size), size=6, replace=False))
            )

            for name, param in self.model.named_parameters():
                print(name, param.item())


        # Final epoch visualisations
        epochs = list(range(1, self.num_epochs + 1))
        Plotter.plot_loss_curve(epochs, self.logs["losses"])
        Plotter.plot_fc_correlation_curve(epochs, self.logs["fc_correlation"])
        Plotter.plot_roi_correlation_curve(epochs, self.logs["roi_correlation"])
        Plotter.plot_rmse_curve(epochs, self.logs["rmse"])

        if self.log_state:
            Plotter.plot_hidden_states(self.logs["hidden_states"])

        Plotter.plot_coupling_parameters(self.parameters_history)

        # Plot FC matrix heatmaps for final epoch for batch 0
        simulated_fc = self.compute_fc(smoothed_simulated_bold_epoch[:, :, 0])
        empirical_fc = self.compute_fc(empirical_bold[:, :, 0])
        sampled_id = sampled[0]
        
        Plotter.plot_functional_connectivity_heatmaps(simulated_fc, empirical_fc, sampled_id)

        Plotter.plot_node_comparison(
            empirical_bold[:, :, 0].unsqueeze(-1),
            smoothed_simulated_bold_epoch[:, :, 0].unsqueeze(-1),
            node_indices=list(np.random.choice(range(self.model.node_size), size=6, replace=False))
        )

        for name, param in self.model.named_parameters():
            print(name, param.item())


In [8]:

fmri_filename = "./HCP Data/BOLD Timeseries HCP.mat"
dti_filename = "./HCP Data/DTI Fibers HCP.mat"
distance_matrices_path = "./HCP Data/distance_matrices/"
distance_matrix_path = "./HCP Data/schaefer100_dist.npy"
encoder_path = "checkpoints/encoder.pt"
discriminator_path = "checkpoints/discriminator.pt"

data_loader = BOLDDataLoader(fmri_filename, dti_filename, distance_matrices_path, chunk_length=50)

## Model Settings
batch_size = 16                         # Minibatch size
node_size = data_loader.get_node_size() # 100
step_size = 0.001                       # Integration time step
tr = 0.75                               # TR duration
input_size = 6                          # Number of noise channels
delays_max = 500                        # Maximum delay time

in_dim, hidden_dim, latent_dim = node_size, 64, 32 # From discriminator.ipynb

trainer_lr               = 1e-1
trainer_epochs           = 10
trainer_smoothing_window = 1
finetune_lr              = 1e-4
finetune_steps           = 5            # Number of steps taken in each finetune
finetune_batch           = 32

[DataLoader] Loaded 100 subjects.
[DataLoader] Created 2300 chunks (chunk length = 50).


In [9]:
params = ModelParams()

model = WholeBrainModel(params, input_size, node_size, batch_size, step_size, tr, delays_max).to(DEVICE)

costs = Costs()

graph_builder = GraphBuilder(node_dim=node_size, use_pca=False, device=DEVICE)

encoder = load_encoder(encoder_path, in_dim, hidden_dim, latent_dim).to(DEVICE).eval()

discriminator = load_discriminator(discriminator_path, encoder, latent_dim).to(DEVICE).eval()

discriminator_hook = DiscriminatorHook(discriminator, graph_builder, finetune_lr=finetune_lr)

trainer = ModelFitting(model, discriminator_hook, data_loader, num_epochs=trainer_epochs, lr=trainer_epochs, cost_function=costs, \
                       smoothing_window=trainer_smoothing_window, finetune_steps=finetune_steps, finetune_batch=finetune_batch, log_state=True)

[DEBUG] Model initialized with 6 states, 30 learnable parameters, and 10 hidden step size
[UTILS] GraphEncoder loaded with 37664 parameters
[UTILS] Discriminator loaded with 51298 parameters


In [None]:
trainer.train(delays_max, batch_size)

print("[Main] Training complete")

### Impulse Test

In [None]:
@torch.no_grad()
def coupling_impulse_test(model, g_val, n_nodes=8, n_TR=100):
    """
    Stimulate node 0 with a single-TR pulse and plot mean E over time
    for the other nodes.  Run with g=0 and a large g to see the effect.
    """

    old_g = model.g.data.clone()
    model.g.data = g_val           # set global coupling
    print(f"g_val: {g_val}")

    B = 1
    state  = model.generate_initial_states()[:, :model.state_size, :B]
    delays = torch.zeros(model.node_size, model.delays_max, B)

    # external pulse: 1 at first hidden step for node 0
    external = torch.zeros(model.node_size, model.hidden_size, B)
    external[0, 0] = 1.0

    E_traj = []
    for _ in range(n_TR):
        noise_in  = torch.zeros(model.node_size, model.hidden_size,
                                model.input_size, B)
        noise_out = torch.zeros(model.node_size, B)
        state, _, delays = model(state, external, noise_in,
                                 noise_out, delays,
                                 batched_laplacian=torch.zeros(B, n_nodes, n_nodes),
                                 dist_matrices=torch.zeros(B, n_nodes, n_nodes))
        E_traj.append(state[:,0,:].squeeze())   # store E only
        external.zero_()   # pulse only on the first TR
    
    model.g.data.copy_(old_g)          # restore original g
    E_traj = torch.stack(E_traj, 1)    # (node, time)
    plt.plot(E_traj[0].cpu(), label='stim node')
    plt.plot(E_traj[1:].mean(0).cpu(), label='mean of others')
    plt.title(f'Impulse response, g={g_val}')
    plt.legend(); plt.xlabel('TR'); plt.ylabel('E')
    plt.show()

model = WholeBrainModel(params, input_size, 8, 1, step_size, tr, 20)

coupling_impulse_test(model, torch.tensor(0, dtype=torch.float32))


### Coupling Experiments

In [31]:
import copy

torch.manual_seed(0)

test_batch_size = 1
test_delays_max = 50
params = ModelParams()
model = WholeBrainModel(params, input_size, node_size, test_batch_size, step_size, tr, test_delays_max)
device = next(model.parameters()).device

_, laplacians, dist_matrices, _ = data_loader.sample_minibatch(test_batch_size)

model_coup = copy.deepcopy(model)
model_coup.g.data.fill_(400.)
model_noc  = copy.deepcopy(model)
model_noc.g.data.fill_(0.)

def run(model, laplacians, dist_matrices, device):
    print(f"[STARTING RUN] - g = {model.g.item()}, g_EE = {model.g_EE.item()}, sigma_E = {model.sigma_E.item()}, sigma_I = {model.sigma_I.item()}, sigma_BOLD = {model.sigma_BOLD.item()}")
    laplacians     = laplacians.to(device)
    dist_matrices  = dist_matrices.to(device)

    
    state = model.generate_initial_states().to(device)
    delays = torch.zeros(model.node_size, test_delays_max, test_batch_size)
    bold = []
    e = []
    for _ in range(100):
        noise_in  = torch.randn(node_size, model.hidden_size, input_size, test_batch_size, device=device)
        noise_out = torch.randn(node_size, test_batch_size, device=device)

        ext = torch.zeros(node_size, model.hidden_size, test_batch_size, device=device)
        state, y, delays = model(state, ext, noise_in, noise_out, delays, laplacians, dist_matrices)
        # print('SNR =', ((y - noise_out).std() / noise_out.std()).item())
        e.append(state[:, 0, :].detach().squeeze(1))
        print(state[:, 0, :].mean().item())
        bold.append(y.mean(1))
    return torch.stack(bold), torch.stack(e)

bold_400, e_400 = run(model_coup, laplacians, dist_matrices, device)
bold_0, e_0 = run(model_noc, laplacians, dist_matrices, device)

print('std bold (g=400):', bold_400.std().item(), 'std bold (g=0):', bold_0.std().item())
print('CORR BOLD       :', torch.corrcoef(bold_400.T).abs().mean().item(),
                             torch.corrcoef(bold_0.T).abs().mean().item())
print('CORR E          :', torch.corrcoef(e_400).abs().mean().item(),
                             torch.corrcoef(e_0).abs().mean().item())




[DEBUG] Model initialized with 6 states, 29 learnable parameters, and 15 hidden step size
[STARTING RUN] - g = 400.0, g_EE = 2.5, sigma_E = 0.019999999552965164, sigma_I = 0.019999999552965164, sigma_BOLD = 0.019999999552965164
0.34676799178123474
0.33873772621154785
0.3310260474681854
0.326227605342865
0.3212602734565735
0.3191368877887726
0.31693708896636963
0.3129013180732727
0.31192463636398315
0.31018397212028503
0.30846157670021057
0.3064529299736023
0.3059980273246765
0.3047659993171692
0.30457109212875366
0.305608332157135
0.304277241230011
0.3027627766132355
0.3024885952472687
0.30187851190567017
0.3003745973110199
0.2987288236618042
0.2978981137275696
0.298061728477478
0.2971906065940857
0.2967502474784851
0.30005285143852234
0.29807692766189575
0.2991199195384979
0.2994721233844757
0.30013370513916016
0.30186471343040466
0.3001541793346405
0.29975783824920654
0.2995337247848511
0.29842454195022583
0.2970705032348633
0.29795894026756287
0.2970106601715088
0.29881367087364197
