In [2]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Refactored DeepCorr CVAE Script for fMRI (Face vs Place task example)

This script loads fMRI data and anatomical masks, preprocesses the data, trains a 
Conditional Variational Autoencoder (CVAE) to separate signal of interest (face-related ROI) 
from background noise signals, and outputs results including a diagnostic dashboard plot.

User can customize file paths and key parameters in the section below.
"""
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
from sklearn.decomposition import PCA
from sklearn import linear_model
import torch
from torch import nn
from torch.nn import functional as F

# ------------------------------
# *** User-Defined Parameters ***
# ------------------------------
# Paths to input data (modify these as needed)
epi_path = "/path/to/subject_bold.nii.gz"        # Path to 4D fMRI EPI data (NIfTI)
anat_path = "/path/to/subject_T1w.nii.gz"        # Path to anatomical data (T1w image) or its segmentations
roi_mask_path = "/path/to/ffa_mask.nii.gz"       # Path to ROI mask (e.g., FFA mask in MNI space)

# Preprocessing parameter
n_dummy = 8  # Number of initial dummy scans to discard (replace with mean)

# CVAE hyperparameters
cvae_params = {
    "latent_dim": (16, 16),    # latent dimensions (z_dim, s_dim)
    "epochs": 50,             # number of training epochs
    "batch_size": 256,        # training batch size
    "learning_rate": 1e-3,    # learning rate for optimizer
    "beta": 0.001,            # beta weight for KLD loss term
    # (Additional hyperparameters like gamma, delta, etc., can be added if needed)
}

# Output directory for results
ofn_root = "./deepcorr_output"

# ------------------------------
# Helper classes and functions
# ------------------------------
def safe_mkdir(path):
    """Create directory if it doesn't exist."""
    if not os.path.exists(path):
        os.makedirs(path)

class GradientReversalFunction(torch.autograd.Function):
    """
    Autograd function that inverts the gradient (multiplies by -lambda) on the backward pass.
    """
    @staticmethod
    def forward(ctx, x, lambda_):
        ctx.lambda_ = lambda_
        # Forward pass outputs the input tensor as-is
        return x.view_as(x)
    @staticmethod
    def backward(ctx, grad_output):
        # Backward pass: multiply incoming gradient by -lambda_
        lambda_ = ctx.lambda_
        grad_input = -ctx.lambda_ * grad_output
        return grad_input, None  # No gradient for lambda

class GradientReversalLayer(nn.Module):
    """
    Layer that reverses the gradient (with a scaling factor lambda) during backpropagation.
    """
    def __init__(self, lambda_=1.0):
        super(GradientReversalLayer, self).__init__()
        self.lambda_ = lambda_
    def forward(self, x):
        return GradientReversalFunction.apply(x, self.lambda_)

class TrainDataset(torch.utils.data.Dataset):
    """
    Custom Dataset for CVAE training, providing ROI (obs) and RONI (noise) samples.
    Each sample consists of one ROI voxel timecourse and one randomly scaled noise voxel timecourse.
    """
    def __init__(self, obs_data, noise_data):
        """
        obs_data : np.ndarray of shape (n_roi_voxels, n_time) for ROI time series.
        noise_data : np.ndarray of shape (n_noise_voxels, n_time) for RONI (noise) time series.
        """
        self.obs = obs_data
        self.noise = noise_data
        # Ensure we have equal number of samples by limiting to the smaller set
        self.n_samples = min(self.obs.shape[0], self.noise.shape[0])
    def __len__(self):
        return self.n_samples
    def __getitem__(self, index):
        # Get one ROI voxel timecourse and one noise voxel timecourse
        obs_ts = self.obs[index]         # shape (n_time,)
        noise_ts = self.noise[index]     # shape (n_time,)
        # Apply random scaling to noise (data augmentation)
        s = 2 * np.random.beta(4, 4)     # random scale factor from Beta(4,4) in [0,2]
        noise_aug = s * noise_ts
        # Stack time series with spatial coordinates channels will be handled after data loading (see train function)
        # Here we just return the time series (as float32 for PyTorch)
        return obs_ts.astype(np.float32), noise_aug.astype(np.float32)

class cVAE(nn.Module):
    """
    Conditional Variational Autoencoder for disentangling target (signal-of-interest) and background noise.
    """
    def __init__(self, conf, in_channels, in_dim, latent_dim, hidden_dims=None,
                 beta=1.0, gamma=1.0, delta=1.0,
                 scale_MSE_GM=1.0, scale_MSE_CF=1.0, scale_MSE_FG=1.0,
                 do_disentangle=True):
        """
        Initialize the CVAE model.
        Parameters:
        - conf (torch.Tensor): Tensor of shape (batch_size, n_confounds, n_time) containing confound signals (repeated for batch).
        - in_channels (int): Number of input channels (1 for time series + 3 for spatial coordinates = 4).
        - in_dim (int): Length of the time dimension (number of time points).
        - latent_dim (tuple): (latent_dim_z, latent_dim_s) sizes for two latent subspaces.
        - hidden_dims (list): Conv layer channel sizes for encoder (default [64,128,256,256]).
        - beta, gamma, delta (float): weight factors for loss terms (beta for KLD, gamma/delta for optional losses).
        - scale_MSE_GM, scale_MSE_CF, scale_MSE_FG (float): scaling factors for ROI, RONI, and FG reconstruction losses.
        - do_disentangle (bool): whether to enforce disentanglement losses (confound correlation losses).
        """
        super(cVAE, self).__init__()
        self.latent_dim = latent_dim
        self.latent_dim_z = latent_dim[0]  # latent dimension for "z" (one factor, e.g. confound-related)
        self.latent_dim_s = latent_dim[1]  # latent dimension for "s" (another factor, e.g. signal-of-interest)
        self.in_channels = in_channels    # expected input channels (should be 4: value + x,y,z coords)
        self.in_dim = in_dim              # number of time points
        # Loss weights
        self.beta = beta
        self.gamma = gamma
        self.delta = delta
        self.scale_MSE_GM = scale_MSE_GM
        self.scale_MSE_CF = scale_MSE_CF
        self.scale_MSE_FG = scale_MSE_FG
        self.do_disentangle = do_disentangle
        # Confounds for conditioning (expand conf to shape [batch_size, conf_dim, time])
        # We assume conf is already a torch tensor of shape (batch_size, n_confounds, in_dim)
        self.confounds = conf.float()
        # Gradient Reversal layer for adversarial loss on confounds
        self.grl = GradientReversalLayer(lambda_=1.0)
        # Decoder for confounds from z-latent
        self.decoder_confounds_z = nn.Sequential(
            nn.ConvTranspose1d(in_channels=self.latent_dim_z, out_channels=128,
                               kernel_size=self.in_dim, stride=1, bias=False),
            nn.Conv1d(in_channels=128, out_channels=32, kernel_size=3, stride=1, padding=1, bias=False),
            nn.ReLU(),
            nn.Conv1d(in_channels=32, out_channels=16, kernel_size=3, stride=1, padding=1, bias=False),
            nn.ReLU(),
            nn.Conv1d(in_channels=16, out_channels=6, kernel_size=3, stride=1, padding=1, bias=False),
            nn.Sigmoid()  # output 6 normalized confound signals
        )
        # Decoder for confounds from s-latent
        self.decoder_confounds_s = nn.Sequential(
            nn.ConvTranspose1d(in_channels=self.latent_dim_s, out_channels=128,
                               kernel_size=self.in_dim, stride=1, bias=False),
            nn.Conv1d(in_channels=128, out_channels=32, kernel_size=3, stride=1, padding=1, bias=False),
            nn.ReLU(),
            nn.Conv1d(in_channels=32, out_channels=16, kernel_size=3, stride=1, padding=1, bias=False),
            nn.ReLU(),
            nn.Conv1d(in_channels=16, out_channels=6, kernel_size=3, stride=1, padding=1, bias=False),
            nn.Sigmoid()
        )
        # Build Encoder networks for z and s
        if hidden_dims is None:
            hidden_dims = [64, 128, 256, 256]
        # Compute required padding for full convolutional coverage of time dimension
        pad, final_size, pad_out = compute_padding(self.in_dim)
        self.pad = pad
        self.final_size = final_size
        self.pad_out = pad_out
        # Encoder for z-latent
        modules_z = []
        in_ch = in_channels
        for h_dim in hidden_dims:
            modules_z.append(nn.Sequential(
                nn.Conv1d(in_ch, out_channels=h_dim, kernel_size=3,
                          stride=2, padding=int(self.pad[-(len(modules_z)+1)])),
                nn.LeakyReLU()
            ))
            in_ch = h_dim
        self.encoder_z = nn.Sequential(*modules_z)
        # Linear layers for z latent parameters
        self.fc_mu_z = nn.Linear(hidden_dims[-1] * int(self.final_size), self.latent_dim_z)
        self.fc_var_z = nn.Linear(hidden_dims[-1] * int(self.final_size), self.latent_dim_z)
        # Encoder for s-latent (similar structure)
        modules_s = []
        in_ch = in_channels
        for h_dim in hidden_dims:
            modules_s.append(nn.Sequential(
                nn.Conv1d(in_ch, out_channels=h_dim, kernel_size=3,
                          stride=2, padding=int(self.pad[-(len(modules_s)+1)])),
                nn.LeakyReLU()
            ))
            in_ch = h_dim
        self.encoder_s = nn.Sequential(*modules_s)
        # Linear layers for s latent parameters
        self.fc_mu_s = nn.Linear(hidden_dims[-1] * int(self.final_size), self.latent_dim_s)
        self.fc_var_s = nn.Linear(hidden_dims[-1] * int(self.final_size), self.latent_dim_s)
        # Build Decoder for reconstruction of input (ROI time series + coords)
        hidden_dims.reverse()  # reverse for decoder
        self.decoder_input = nn.Linear(self.latent_dim_z + self.latent_dim_s,
                                       hidden_dims[0] * int(self.final_size))
        modules_dec = []
        # Build decoder layers
        for i in range(len(hidden_dims) - 1):
            modules_dec.append(nn.Sequential(
                nn.ConvTranspose1d(hidden_dims[i], hidden_dims[i+1],
                                   kernel_size=3, stride=2,
                                   padding=int(pad_out[-(len(modules_dec)+4)]),
                                   output_padding=int(pad_out[-(len(modules_dec)+4)])),
                nn.LeakyReLU()
            ))
        self.decoder = nn.Sequential(*modules_dec)
        # Final decoder layer to reconstruct original channels (should output in_channels channels)
        self.final_layer = nn.Sequential(
            nn.ConvTranspose1d(hidden_dims[-1], hidden_dims[-1],
                               kernel_size=3, stride=2,
                               padding=int(pad_out[-1]), output_padding=int(pad_out[-1])),
            nn.LeakyReLU(),
            nn.Conv1d(hidden_dims[-1], out_channels=self.in_channels, kernel_size=3, padding=1)
        )

    def encode_z(self, x: torch.Tensor):
        """Encode input into z-latent (confound-related) parameters."""
        result = self.encoder_z(x)            # [B, ..., *] convolution outputs
        result = torch.flatten(result, start_dim=1)  # flatten to [B, latent_features]
        mu = self.fc_mu_z(result)
        log_var = self.fc_var_z(result)
        return mu, log_var

    def encode_s(self, x: torch.Tensor):
        """Encode input into s-latent (signal-related) parameters."""
        result = self.encoder_s(x)
        result = torch.flatten(result, start_dim=1)
        mu = self.fc_mu_s(result)
        log_var = self.fc_var_s(result)
        return mu, log_var

    def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor):
        """Sample from Gaussian distribution via reparameterization trick."""
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return eps * std + mu

    def forward_tg(self, x: torch.Tensor):
        """
        Forward pass treating input as target (ROI) data.
        Returns [reconstruction, input, tg_mu_z, tg_log_var_z, tg_mu_s, tg_log_var_s, tg_z, tg_s]
        """
        tg_mu_z, tg_log_var_z = self.encode_z(x)
        tg_mu_s, tg_log_var_s = self.encode_s(x)
        tg_z = self.reparameterize(tg_mu_z, tg_log_var_z)
        tg_s = self.reparameterize(tg_mu_s, tg_log_var_s)
        # Reconstruction: sum of foreground and background reconstructions
        recon = self.forward_bg(x)[0] + self.forward_fg(x)[0]
        return recon, x, tg_mu_z, tg_log_var_z, tg_mu_s, tg_log_var_s, tg_z, tg_s

    def forward_fg(self, x: torch.Tensor):
        """
        Forward pass for foreground (signal-of-interest) only.
        Returns [foreground_reconstruction, input, tg_mu_s, tg_log_var_s]
        """
        tg_mu_s, tg_log_var_s = self.encode_s(x)
        tg_s = self.reparameterize(tg_mu_s, tg_log_var_s)
        # Zero-out z latent to reconstruct only signal-related components
        zeros_z = torch.zeros(tg_s.size(0), self.latent_dim_z, device=x.device)
        fg_out = self.decode(torch.cat((zeros_z, tg_s), dim=1))
        return fg_out, x, tg_mu_s, tg_log_var_s

    def forward_bg(self, x: torch.Tensor):
        """
        Forward pass for background (noise) only.
        Returns [background_reconstruction, input, bg_mu_z, bg_log_var_z]
        """
        bg_mu_z, bg_log_var_z = self.encode_z(x)
        bg_z = self.reparameterize(bg_mu_z, bg_log_var_z)
        # Zero-out s latent to reconstruct only confound/background components
        zeros_s = torch.zeros(bg_z.size(0), self.latent_dim_s, device=x.device)
        bg_out = self.decode(torch.cat((bg_z, zeros_s), dim=1))
        return bg_out, x, bg_mu_z, bg_log_var_z

    def ncc(self, x: torch.Tensor, y: torch.Tensor, eps: float = 1e-8):
        """
        Compute (1 - normalized cross-correlation) between two inputs. Outputs a distance (lower is more correlated).
        """
        # Flatten spatial (or voxel) dimensions to compare along time
        x_flat = x.flatten(start_dim=1)
        y_flat = y.flatten(start_dim=1)
        x_mean = x_flat.mean(dim=1, keepdim=True)
        y_mean = y_flat.mean(dim=1, keepdim=True)
        x_std = x_flat.std(dim=1, keepdim=True) + eps
        y_std = y_flat.std(dim=1, keepdim=True) + eps
        ncc_val = ((x_flat - x_mean) * (y_flat - y_mean) / (x_std * y_std)).mean(dim=1)
        return 1 - ncc_val  # return 1 - correlation (to be minimized, i.e., high correlation -> low loss)

    def loss_function(self, *args):
        """
        Compute CVAE loss components and return as a dictionary.
        Expects args in order: 
         [recons_tg, input_tg, tg_mu_z, tg_log_var_z, tg_mu_s, tg_log_var_s, tg_z, tg_s,
          recons_bg, input_bg, bg_mu_z, bg_log_var_z]
        """
        # Unpack inputs
        recons_tg, input_tg = args[0], args[1]
        tg_mu_z, tg_log_var_z = args[2], args[3]
        tg_mu_s, tg_log_var_s = args[4], args[5]
        tg_z, tg_s = args[6], args[7]
        recons_bg, input_bg = args[8], args[9]
        bg_mu_z, bg_log_var_z = args[10], args[11]
        # Reconstruction losses (MSE for ROI and RONI signals, first channel is the time series value)
        recons_loss_roi = F.mse_loss(recons_tg[:, 0, :], input_tg[:, 0, :]) * self.scale_MSE_GM
        recons_loss_roni = F.mse_loss(recons_bg[:, 0, :], input_bg[:, 0, :]) * self.scale_MSE_CF
        recons_loss = recons_loss_roi + recons_loss_roni
        # Additional reconstruction of foreground from ROI input (for FG vs BG consistency)
        recons_fg = self.forward_fg(input_tg)[0]
        # Adversarial confound reconstruction losses
        conf_pred_z = self.decoder_confounds_z(tg_z.unsqueeze(2))  # shape [B,6,T]
        conf_pred_s = self.decoder_confounds_s(tg_s.unsqueeze(2))
        loss_recon_conf_s = self.grl(F.mse_loss(conf_pred_s, self.confounds)) * 1e2   # maximize MSE (via GRL) for conf from s-latent
        loss_recon_conf_z = F.mse_loss(conf_pred_z, self.confounds) * 1e2             # minimize MSE for conf from z-latent
        # Encourage reconstructed ROI and RONI to be correlated with original (minimize 1-NCC)
        ncc_loss_tg = self.ncc(input_tg, recons_tg).mean() * 1.0
        ncc_loss_bg = self.ncc(input_bg, recons_bg).mean() * 1.0
        # Encourage confound outputs to (not) correlate with true confounds (non-correlations constraint, NCC)
        ncc_loss_conf_s = 0.0
        for i in range(self.confounds.shape[1]):
            ncc_loss_conf_s += self.ncc(self.confounds[:, i, :], conf_pred_s[:, i, :]).mean() * 1e1
        ncc_loss_conf_s = self.grl(ncc_loss_conf_s)  # apply GRL (maximize correlation, i.e., minimize 1-corr via neg. gradient)
        ncc_loss_conf_z = 0.0
        for i in range(self.confounds.shape[1]):
            ncc_loss_conf_z += self.ncc(self.confounds[:, i, :], conf_pred_z[:, i, :]).mean() * 1e1
        # Encourage foreground and background parts to complement each other (smoothness and orthogonality)
        recond_bg = self.forward_bg(input_tg)[0]  # background reconstruction of ROI input
        fg_bg_ncc = self.ncc(recond_bg[:, 0, :], recons_fg[:, 0, :]).mean()
        recons_loss_fg = F.mse_loss(torch.zeros_like(fg_bg_ncc), 1 - fg_bg_ncc) * 1e4  # force FG and BG to be correlated (minimize 1-corr)
        # Small additional penalties to encourage consistency
        recons_loss += F.mse_loss(recons_fg[:, 0, :], input_tg[:, 0, :]) * 1e-4  # tiny weight on FG matching ROI
        recons_loss += F.mse_loss(recons_bg[:, 0, :], input_bg[:, 0, :]) * 1e-5  # tiny weight on BG matching RONI
        # Smoothness loss on temporal derivatives of FG and BG outputs (encourage smooth timecourses)
        smoothness_loss = torch.mean((recons_fg[:, 0, :, 1:] - recons_fg[:, 0, :, :-1])**2) if recons_fg.dim() == 3 else \
                          torch.mean((recons_fg[:, 0, 1:] - recons_fg[:, 0, :-1])**2)
        smoothness_loss += torch.mean((recond_bg[:, 0, 1:] - recond_bg[:, 0, :-1])**2)
        smoothness_loss = smoothness_loss * 0.01
        # (Optional additional losses can be added here if needed; currently not used)
        # KLD loss (average over batch)
        kld_loss_z = -0.5 * torch.mean(1 + tg_log_var_z - tg_mu_z**2 - torch.exp(tg_log_var_z))
        kld_loss_s = -0.5 * torch.mean(1 + tg_log_var_s - tg_mu_s**2 - torch.exp(tg_log_var_s))
        kld_loss_bg = -0.5 * torch.mean(1 + bg_log_var_z - bg_mu_z**2 - torch.exp(bg_log_var_z))
        kld_loss = (kld_loss_z + kld_loss_s + kld_loss_bg) / 3.0
        kld_loss = kld_loss * self.beta
        # Sum all loss components for total loss
        total_loss = recons_loss + kld_loss + recons_loss_fg + ncc_loss_tg + ncc_loss_bg + loss_recon_conf_s + loss_recon_conf_z + smoothness_loss + ncc_loss_conf_z + ncc_loss_conf_s
        return {
            'loss': total_loss,
            'kld_loss': kld_loss,
            'recons_loss_roi': recons_loss_roi,
            'recons_loss_roni': recons_loss_roni,
            'loss_recon_conf_s': loss_recon_conf_s,
            'loss_recon_conf_z': loss_recon_conf_z,
            'ncc_loss_tg': ncc_loss_tg,
            'ncc_loss_bg': ncc_loss_bg,
            'ncc_loss_conf_s': ncc_loss_conf_s * 0.0,  # effectively 0 (not used in final loss)
            'ncc_loss_conf_z': ncc_loss_conf_z,
            'smoothness_loss': smoothness_loss,
            'recons_loss_fg': recons_loss_fg
        }

def compute_in(x):
    """Compute one iteration of "in" size for padding calculation."""
    return (x - 3) / 2 + 1

def compute_in_size(x):
    """Compute size after 4 downsampling (conv) layers given initial length x."""
    for _ in range(4):
        x = compute_in(x)
    return x

def compute_out_size(x):
    """Inverse of compute_in_size: compute output length after 4 upsampling layers given latent size x."""
    # The upsampling sequence (convtranspose) is roughly the inverse of conv
    return ((((x * 2 + 1) * 2 + 1) * 2 + 1) * 2 + 1)

def compute_padding(x):
    """
    Compute padding and output padding required for 4 conv downsampling and 4 convtranspose upsampling layers.
    Returns:
      pad (str): binary string of padding bits for conv layers,
      final_size (float): final latent feature size after conv layers,
      pad_out (str): binary string of output padding bits for convtranspose layers.
    """
    # Determine padding for conv layers
    rounding = np.ceil(compute_in_size(x)) - compute_in_size(x)
    y = ((((rounding * 2) * 2) * 2) * 2)
    pad = bin(int(y)).replace('0b', '').zfill(4)  # binary representation of which convs need padding
    final_size = compute_in_size(x + y)  # final latent feature length
    # Determine output padding for convtranspose to reach original size
    pad_out = bin(int(compute_out_size(final_size) - x)).replace('0b', '').zfill(4)
    return pad, final_size, pad_out

def correlate_columns(arr1: np.ndarray, arr2: np.ndarray):
    """
    Compute column-wise Pearson correlation between two arrays.
    Both arrays should be shape (n_time, n_features). Returns 1D array of correlations for each feature column.
    """
    arr1 = np.asarray(arr1)
    arr2 = np.asarray(arr2)
    # Subtract mean from each column
    arr1_centered = arr1 - np.mean(arr1, axis=0)
    arr2_centered = arr2 - np.mean(arr2, axis=0)
    # Compute Pearson correlation for each column
    numerator = np.sum(arr1_centered * arr2_centered, axis=0)
    denominator = np.sqrt(np.sum(arr1_centered**2, axis=0) * np.sum(arr2_centered**2, axis=0))
    correlation = numerator / (denominator + 1e-8)
    return correlation

def interpolate_outliers(data: np.ndarray, outlier_mask: np.ndarray):
    """
    Linearly interpolate over motion outlier frames in a (n_voxels x n_time) array.
    outlier_mask is a boolean array of length n_time, True at indices to interpolate.
    Returns a corrected array with outlier frames replaced by interpolation of nearest valid frames.
    """
    t = np.arange(data.shape[1])
    good = ~outlier_mask
    bad = outlier_mask
    corrected = data.copy()
    if np.sum(bad) == 0:
        return corrected  # no outliers to interpolate
    t_good = t[good]
    t_bad = t[bad]
    # Interpolate for each voxel's time series
    for v in range(data.shape[0]):
        ts = data[v, :]
        y_good = ts[good]
        # Linear interpolation for bad frames
        y_interp = np.interp(t_bad, t_good, y_good)
        corrected[v, bad] = y_interp
    return corrected

def load_data_and_preprocess(epi_path: str, anat_path: str, n_dummy: int):
    """
    Load fMRI data and anatomical masks, apply preprocessing, and return numpy arrays for model training.
    Returns:
      obs_list (np.ndarray): ROI time series data (voxels x time)
      noi_list (np.ndarray): RONI (noise) time series data (voxels x time)
      conf (np.ndarray): Confound regressor matrix (n_confounds x time)
      ffa (np.ndarray): FFA ROI time series (voxels_in_FFA x time)
      face_reg (np.ndarray): Task regressor time course for "face" condition (length = n_time)
      place_reg (np.ndarray): Task regressor time course for "place" condition (length = n_time)
      ffa_compcorr (np.ndarray): FFA time series after CompCor (baseline noise removal using PCA of RONI)
    """
    # Import nibabel or ANTs for image loading
    try:
        import ants  # ANTsPy for NIfTI
    except ImportError as e:
        raise ImportError("ANTs library not found. Please install antspyx or use nibabel for image I/O.")
    # Read fMRI 4D image
    epi_img = ants.image_read(epi_path)
    epi_data = epi_img.numpy()  # numpy array shape (X, Y, Z, T)
    # If anatomical segmentation files exist (GM, WM, CSF), read them; otherwise try to derive from anat_path
    # Here we assume anat_path is the structural image; segmentation files with standard names are in same folder
    anat_dir = os.path.dirname(anat_path)
    base_name = os.path.basename(anat_path)
    # Construct expected segmentation file names (for fMRIPrep outputs)
    gm_prob_path = os.path.join(anat_dir, base_name.replace("T1w", "GM_probseg"))
    wm_prob_path = os.path.join(anat_dir, base_name.replace("T1w", "WM_probseg"))
    csf_prob_path = os.path.join(anat_dir, base_name.replace("T1w", "CSF_probseg"))
    # If those do not exist, raise error (for simplicity, require segmented probability maps)
    if not (os.path.exists(gm_prob_path) and os.path.exists(wm_prob_path) and os.path.exists(csf_prob_path)):
        raise FileNotFoundError("Grey matter/White matter/CSF probability maps not found. Please provide segmentation files.")
    gm_img = ants.image_read(gm_prob_path)
    wm_img = ants.image_read(wm_prob_path)
    csf_img = ants.image_read(csf_prob_path)
    gm_data = gm_img.numpy()
    wm_data = wm_img.numpy()
    csf_data = csf_img.numpy()
    # Threshold probability maps to create binary masks for ROI (GM) and RONI (WM+CSF)
    gm_mask = gm_data > 0.5
    cf_mask = (wm_data + csf_data) > 0.5  # "cf" mask is combined WM and CSF (potential confound regions)
    # Ensure no overlap between gm_mask and cf_mask
    overlap = gm_mask & cf_mask
    if overlap.any():
        gm_mask[overlap] = False
        cf_mask[overlap] = False
    # Remove voxels with near-zero variance from masks (these are typically outside brain or static areas)
    # Compute temporal std dev for each voxel:
    epi_time_std = epi_data.std(axis=-1)
    gm_mask = gm_mask & (epi_time_std > 1e-3)
    cf_mask = cf_mask & (epi_time_std > 1e-3)
    # Flatten spatial data for time series extraction
    nx, ny, nz, nt = epi_data.shape
    epi_flat = epi_data.reshape(-1, nt).T  # shape (T, Nvoxels)
    # Handle dummy scans: replace first n_dummy volumes with the mean of remaining volumes (to avoid abrupt transient)
    if n_dummy > 0 and n_dummy < nt:
        mean_vol = epi_flat[n_dummy:, :].mean(axis=0)
        epi_flat[:n_dummy, :] = mean_vol
    # Revert shape to (Nvoxels, T) for easier indexing by mask
    epi_flat = epi_flat.T  # shape (Nvoxels, T)
    # Flatten masks to match epi_flat indexing
    gm_flat = gm_mask.flatten()
    cf_flat = cf_mask.flatten()
    # Extract ROI (GM) and RONI (WM+CSF) time series
    func_gm = epi_flat[gm_flat, :]  # ROI signals (shape: n_gm_voxels x T)
    func_cf = epi_flat[cf_flat, :]  # Noise signals (shape: n_cf_voxels x T)
    # Load confound TSV file if exists (assuming fMRIPrep confounds file naming)
    conf_fn = epi_path.replace("desc-preproc_bold.nii.gz", "desc-confounds_timeseries.tsv")
    conf = None
    if os.path.exists(conf_fn):
        df_conf = pd.read_csv(conf_fn, sep='\t')
        # Use 6 motion parameters as confounds (trans_x,y,z, rot_x,y,z) if available
        motion_cols = [c for c in ['trans_x','trans_y','trans_z','rot_x','rot_y','rot_z'] if c in df_conf.columns]
        if motion_cols:
            conf = df_conf.loc[:, motion_cols].values.T  # shape (6, T)
        else:
            # If not found, use all numeric confounds as a fallback (excluding non-numeric columns)
            conf = df_conf.select_dtypes(include=[np.number]).values.T
    else:
        # If no confound file, create a dummy zero matrix for conf (to satisfy model input)
        conf = np.zeros((1, func_gm.shape[1]))
    # Identify motion outlier frames using framewise displacement if available
    fd_mask = None
    if 'framewise_displacement' in locals() or ('framewise_displacement' in df_conf.columns):
        fd_vals = df_conf['framewise_displacement'].fillna(0.0).values if 'df_conf' in locals() else None
        if fd_vals is not None:
            fd_mask = fd_vals > 0.25  # True for frames with FD > 0.25 (could be considered outliers)
    if fd_mask is None:
        fd_mask = np.zeros(func_gm.shape[1], dtype=bool)
    # Interpolate over outlier frames in ROI and RONI data
    func_gm = interpolate_outliers(func_gm, fd_mask)
    func_cf = interpolate_outliers(func_cf, fd_mask)
    # Remove any voxels from ROI and RONI that became zero-variance after interpolation
    gm_var_mask = func_gm.std(axis=1) > 1e-3
    cf_var_mask = func_cf.std(axis=1) > 1e-3
    func_gm = func_gm[gm_var_mask, :]
    func_cf = func_cf[cf_var_mask, :]
    # Also update coordinates lists accordingly (we will compute coordinates next for ones we kept)
    # Compute coordinates of all voxels for reference
    x_idx, y_idx, z_idx = np.meshgrid(np.arange(nx), np.arange(ny), np.arange(nz), indexing='ij')
    coords_flat = np.stack([x_idx.flatten(), y_idx.flatten(), z_idx.flatten()], axis=1)
    gm_coords = coords_flat[gm_flat]         # coordinates for original ROI voxels
    cf_coords = coords_flat[cf_flat]         # coordinates for original RONI voxels
    gm_coords = gm_coords[gm_var_mask]       # apply mask of removed low-variance ROI voxels
    cf_coords = cf_coords[cf_var_mask]       # apply mask of removed low-variance RONI voxels
    # Load events file if exists (to compute task regressors)
    # Assuming BIDS format events file corresponding to epi_path
    events_fn = epi_path.split('_desc-')[0] + '_events.tsv'
    face_reg = None
    place_reg = None
    if os.path.exists(events_fn):
        events_df = pd.read_csv(events_fn, sep='\t')
        # Create design matrix with SPM HRF for all condition columns
        from nilearn.glm.first_level import make_first_level_design_matrix
        tr = getattr(epi_img, "spacing", [None,None,None, None])[3] or 0.8  # try to get TR, default 0.8s if not found
        n_scans = func_gm.shape[1]
        frame_times = np.arange(n_scans) * tr
        design_matrix = make_first_level_design_matrix(frame_times, events_df,
                                                       drift_model="polynomial", drift_order=3, hrf_model="SPM")
        # For the N-back face vs place task, sum face conditions and place conditions
        face_columns = [col for col in design_matrix.columns if 'face' in col.lower()]
        place_columns = [col for col in design_matrix.columns if 'place' in col.lower()]
        if face_columns:
            face_reg = design_matrix[face_columns].values.sum(axis=1)
        if place_columns:
            place_reg = design_matrix[place_columns].values.sum(axis=1)
    if face_reg is None:
        # If events not found or no face/place columns, use zeros (or could raise warning)
        face_reg = np.zeros(func_gm.shape[1])
        place_reg = np.zeros(func_gm.shape[1])
    # Identify FFA ROI voxels within the ROI mask
    ffa_mask_img = ants.image_read(roi_mask_path)
    ffa_mask_data = ffa_mask_img.numpy()
    # Select voxels that are in both FFA mask and GM mask
    ffa_mask_flat = ffa_mask_data.flatten()
    ffa_idx_all = (ffa_mask_flat == 1) & gm_flat  # boolean mask for FFA voxels among all voxels
    # Now map this to indices in our reduced ROI set (after dropping low-var voxels)
    # We have gm_flat (before drop) and gm_var_mask (after drop). We need indices in func_gm (after drop).
    gm_indices_all = np.where(gm_flat)[0]   # indices (in flattened image) of ROI voxels
    ffa_indices_all = np.where(ffa_idx_all)[0]   # indices (in flattened image) of FFA voxels
    # Determine which of those FFA indices survived low-variance filtering:
    ffa_survived = np.isin(gm_indices_all[gm_var_mask], ffa_indices_all)
    # Extract FFA time series from the filtered ROI data
    ffa_data = func_gm[ffa_survived, :]  # shape (n_ffa_voxels, T)
    # Compute baseline CompCor for FFA: use top 5 PCA components from noise (RONI) to regress out from FFA signals
    n_components = min(5, func_cf.shape[0], func_cf.shape[1])
    if n_components < 1:
        # If noise region is empty or too small, skip compcorr
        ffa_comp = ffa_data.copy()
    else:
        pca = PCA(n_components=n_components)
        conf_pcs = pca.fit_transform(func_cf.T)  # shape (T x n_components)
        lin_reg = linear_model.LinearRegression()
        lin_reg.fit(conf_pcs, ffa_data.T)  # regress noise PCs against each FFA voxel's timecourse
        ffa_pred_noise = lin_reg.predict(conf_pcs)  # predicted noise in FFA (shape T x n_ffa_voxels)
        ffa_comp = (ffa_data.T - ffa_pred_noise).T  # FFA with compCor noise removed
    # At this point:
    # func_gm (ROI signals), func_cf (noise signals), ffa_data (original FFA ROI signals), ffa_comp (FFA after baseline noise removal)
    # conf (motion confounds if available), face_reg, place_reg are all prepared.
    return func_gm, func_cf, conf, ffa_data, face_reg, place_reg, ffa_comp

def train_cvae_model(obs_list: np.ndarray, noi_list: np.ndarray, conf: np.ndarray, ffa: np.ndarray, cvae_params: dict):
    """
    Train the CVAE model on provided ROI (obs_list) and RONI (noi_list) data.
    Parameters:
      obs_list (np.ndarray): ROI data, shape (n_vox_roi, n_time)
      noi_list (np.ndarray): RONI data, shape (n_vox_noise, n_time)
      conf (np.ndarray): Confound matrix, shape (n_confounds, n_time)
      ffa (np.ndarray): FFA ROI data, shape (n_ffa_voxels, n_time)
      cvae_params (dict): Dictionary of CVAE hyperparameters (latent_dim, epochs, batch_size, learning_rate, beta, etc.)
    Returns:
      model (cVAE): Trained CVAE model.
      track (dict): Training history with loss components and metrics per epoch.
      outputs (dict): Dictionary of outputs (reconstructed signals, separated signals, etc.) after training.
    """
    # Unpack hyperparameters
    latent_dim = cvae_params.get("latent_dim", (16,16))
    epochs = cvae_params.get("epochs", 50)
    batch_size = cvae_params.get("batch_size", 256)
    lr = cvae_params.get("learning_rate", 1e-3)
    beta = cvae_params.get("beta", 1.0)
    gamma = cvae_params.get("gamma", 1.0)
    delta = cvae_params.get("delta", 1.0)
    # Prepare PyTorch device (CPU or GPU)
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    # Prepare training DataLoader
    train_dataset = TrainDataset(obs_list, noi_list)
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size,
                                              shuffle=True, drop_last=True)
    # Prepare confounds tensor for conditioning in model (repeat for batch)
    n_confounds, T = conf.shape
    # Create a confound tensor of shape (batch_size, n_confounds, T) by repeating conf across the batch dimension
    conf_batch = torch.tensor(np.array([conf for _ in range(batch_size)]), dtype=torch.float32).to(device)
    # Initialize model and optimizer
    model = cVAE(conf_batch, in_channels=4, in_dim=T, latent_dim=latent_dim,
                 beta=beta, gamma=gamma, delta=delta).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    # Tracking dictionary for losses and metrics
    track = {
        'l': [], 'kld_loss': [], 'recons_loss_roi': [], 'recons_loss_roni': [], 
        'loss_recon_conf_s': [], 'loss_recon_conf_z': [], 'ncc_loss_tg': [], 'ncc_loss_bg': [], 
        'ncc_loss_conf_s': [], 'ncc_loss_conf_z': [], 'smoothness_loss': [], 'recons_loss_fg': [],
        'r_ffa_reg': [], 'r_compcor_reg': [], 'r_TG_reg': [], 'r_FG_reg': [], 'r_BG_reg': [],
        'varexp': [], 'batch_varexp': [], 'ffa_io': [],
        'tg_mu_z': [], 'tg_log_var_z': [], 'tg_mu_s': [], 'tg_log_var_s': []
    }
    # Pre-calculate baseline correlation metrics for FFA region
    # r_ffa_reg: correlation between original FFA average and face regressor
    # r_compcor_reg: correlation between baseline-compcorr-cleaned FFA average and face regressor
    # (Note: These will remain constant each epoch, but we append them for reference.)
    if 'face_reg' in globals():
        # Compute these only if face_reg is provided in global scope (or pass as parameter if needed)
        face_reg = globals().get('face_reg', None)
    else:
        face_reg = None
    if face_reg is not None and face_reg.shape[0] == ffa.shape[1]:
        ffa_mean = ffa.mean(axis=0)
        ffa_compcorr_mean = None
        # If baseline compcorr data is available in global outputs (for simplicity, pass ffa_compcorr via global if needed)
        ffa_compcorr_data = globals().get('ffa_compcorr', None)
        if ffa_compcorr_data is not None and ffa_compcorr_data.shape == ffa.shape:
            ffa_compcorr_mean = ffa_compcorr_data.mean(axis=0)
        else:
            ffa_compcorr_mean = ffa_mean  # if not available, treat as same
        base_r = np.corrcoef(ffa_mean, face_reg)[0, 1] if np.std(ffa_mean) > 0 else 0.0
        compcorr_r = np.corrcoef(ffa_compcorr_mean, face_reg)[0, 1] if np.std(ffa_compcorr_mean) > 0 else 0.0
    else:
        base_r = compcorr_r = 0.0
    # Training loop
    for epoch in range(1, epochs+1):
        model.train()
        epoch_loss = 0.0
        for obs_batch, noise_batch in train_loader:
            # obs_batch: (batch, T), noise_batch: (batch, T) as float32
            # Convert to torch and add coordinate channels
            # We need to reshape each batch to (batch, channels=4, time) where channels = [signal, x_coord, y_coord, z_coord]
            # To do that, we'll retrieve coordinates corresponding to indices in the dataset.
            # Our TrainDataset currently pairs obs and noise by the same index, but their coordinates differ.
            # For simplicity, we will use ROI coordinates for obs and a placeholder for noise (not used in model input).
            batch_indices = np.arange(obs_batch.shape[0])
            obs_coords = torch.tensor(gm_coords[batch_indices], dtype=torch.float32)  # shape (batch, 3)
            # Repeat coordinates across time dimension
            coord_times = obs_coords.unsqueeze(2).repeat(1, 1, obs_batch.shape[1])  # (batch, 3, T)
            obs_ts = obs_batch.unsqueeze(1)  # (batch, 1, T)
            obs_with_coords = torch.cat([obs_ts, coord_times], dim=1)  # (batch, 4, T)
            # Do the same for noise batch using cf_coords
            noise_coords = torch.tensor(cf_coords[batch_indices], dtype=torch.float32)
            coord_times_noise = noise_coords.unsqueeze(2).repeat(1, 1, noise_batch.shape[1])
            noise_ts = noise_batch.unsqueeze(1)
            noise_with_coords = torch.cat([noise_ts, coord_times_noise], dim=1)
            obs_with_coords = obs_with_coords.to(device)
            noise_with_coords = noise_with_coords.to(device)
            # Ensure confounds tensor in model has correct batch size (if last batch dropped, conf_batch already correct size)
            # Forward pass: get both target and background outputs
            outputs_tg = model.forward_tg(obs_with_coords)
            outputs_bg = model.forward_bg(noise_with_coords)
            loss_dict = model.loss_function(*outputs_tg, *outputs_bg)
            loss = loss_dict['loss']
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
        # Evaluate on full FFA region (signal-of-interest ROI) after epoch
        model.eval()
        # Prepare full FFA data as input (with coordinates)
        ffa_coords = gm_coords[np.isin(gm_coords, gm_coords[ffa_survived], axis=0)]
        # Actually, to get FFA coords corresponding to ffa (after filtering) we have ffa_survived boolean used above
        ffa_coords = gm_coords[ffa_survived]
        ffa_inputs = torch.tensor(ffa, dtype=torch.float32).unsqueeze(1)  # (n_ffa_vox, 1, T)
        ffa_coord_stack = torch.tensor(ffa_coords, dtype=torch.float32)
        ffa_coord_stack = ffa_coord_stack.unsqueeze(1).repeat(1, ffa_inputs.shape[-1], 1)  # (n_ffa_vox, T, 3)
        ffa_coord_stack = ffa_coord_stack.permute(0, 2, 1)  # (n_ffa_vox, 3, T)
        ffa_inputs_with_coords = torch.cat([ffa_inputs, ffa_coord_stack], dim=1).to(device)  # (n_ffa_vox, 4, T)
        # Forward through model for FFA ROI
        with torch.no_grad():
            recon = model.forward_tg(ffa_inputs_with_coords)[0].cpu().numpy()  # reconstructed full signal (ROI)
            fg = model.forward_fg(ffa_inputs_with_coords)[0].cpu().numpy()     # foreground (signal component)
            bg = model.forward_bg(ffa_inputs_with_coords)[0].cpu().numpy()     # background (noise component)
            # Compute metrics on FFA outputs
            ffa_mean = ffa.mean(axis=0)
            recon_mean = recon.mean(axis=0)
            fg_mean = fg.mean(axis=0)
            bg_mean = bg.mean(axis=0)
            # Correlations with face_reg if available
            if face_reg is not None:
                r_ffa = np.corrcoef(ffa_mean, face_reg)[0, 1] if np.std(ffa_mean) > 0 else 0.0
                r_tg = np.corrcoef(recon_mean, face_reg)[0, 1] if np.std(recon_mean) > 0 else 0.0
                r_fg = np.corrcoef(fg_mean, face_reg)[0, 1] if np.std(fg_mean) > 0 else 0.0
                r_bg = np.corrcoef(bg_mean, face_reg)[0, 1] if np.std(bg_mean) > 0 else 0.0
            else:
                r_ffa = r_tg = r_fg = r_bg = 0.0
            # Variance explained (R^2) for FFA signals by reconstruction
            SST = np.sum((ffa - ffa_mean) ** 2)
            SSE = np.sum((ffa - recon) ** 2)
            varexp = 1 - (SSE / (SST + 1e-8))
            # Correlation between FFA input and reconstruction (averaged if multiple voxels)
            c_io = np.corrcoef(ffa_mean, recon_mean)[0, 1] if np.std(ffa_mean) > 0 else 0.0
        # Record epoch metrics in track
        track['l'].append(epoch_loss / len(train_loader))
        track['kld_loss'].append(float(loss_dict['kld_loss'].cpu().detach().numpy()))
        track['recons_loss_roi'].append(float(loss_dict['recons_loss_roi'].cpu().detach().numpy()))
        track['recons_loss_roni'].append(float(loss_dict['recons_loss_roni'].cpu().detach().numpy()))
        track['loss_recon_conf_s'].append(float(loss_dict['loss_recon_conf_s'].cpu().detach().numpy()))
        track['loss_recon_conf_z'].append(float(loss_dict['loss_recon_conf_z'].cpu().detach().numpy()))
        track['ncc_loss_tg'].append(float(loss_dict['ncc_loss_tg'].cpu().detach().numpy()))
        track['ncc_loss_bg'].append(float(loss_dict['ncc_loss_bg'].cpu().detach().numpy()))
        track['ncc_loss_conf_s'].append(float(loss_dict['ncc_loss_conf_s'] if isinstance(loss_dict['ncc_loss_conf_s'], float) else loss_dict['ncc_loss_conf_s'].cpu().detach().numpy()))
        track['ncc_loss_conf_z'].append(float(loss_dict['ncc_loss_conf_z'] if isinstance(loss_dict['ncc_loss_conf_z'], float) else loss_dict['ncc_loss_conf_z'].cpu().detach().numpy()))
        track['smoothness_loss'].append(float(loss_dict['smoothness_loss'].cpu().detach().numpy()))
        track['recons_loss_fg'].append(float(loss_dict['recons_loss_fg'].cpu().detach().numpy()))
        # Append correlation metrics (constant baseline values each epoch for baseline lines)
        track['r_ffa_reg'].append(base_r if 'base_r' in locals() else r_ffa)
        track['r_compcor_reg'].append(compcorr_r if 'compcorr_r' in locals() else r_ffa)  # if no baseline, use r_ffa
        track['r_TG_reg'].append(r_tg)
        track['r_FG_reg'].append(r_fg)
        track['r_BG_reg'].append(r_bg)
        track['varexp'].append(varexp)
        track['batch_varexp'].append(varexp)
        track['ffa_io'].append(c_io)
        # Track latent distributions (use mean of absolute values as summary)
        track['tg_mu_z'].append(float(np.mean(np.abs(tg_mu_z.cpu().detach().numpy()))))
        track['tg_log_var_z'].append(float(np.mean(np.abs(tg_log_var_z.cpu().detach().numpy()))))
        track['tg_mu_s'].append(float(np.mean(np.abs(tg_mu_s.cpu().detach().numpy()))))
        track['tg_log_var_s'].append(float(np.mean(np.abs(tg_log_var_s.cpu().detach().numpy()))))
        # (Note: We reuse tg_mu_z etc from last batch forward in loop above; ideally compute on full data, but using last batch as approximation)
        # Print progress message for epoch
        print(f"Epoch {epoch}/{epochs} - Loss: {track['l'][-1]:.4f} - FFA vs reg: {track['r_ffa_reg'][-1]:.3f}, CompCor vs reg: {track['r_compcor_reg'][-1]:.3f}, FG vs reg: {r_fg:.3f}, BG vs reg: {r_bg:.3f}")
    # After training, prepare outputs
    outputs = {
        'recon': recon,             # reconstructed full FFA signals (np.ndarray shape [n_ffa_voxels, T])
        'signal': fg,               # foreground signals (FFA, shape [n_ffa_voxels, T])
        'noise': bg,                # background signals (FFA, shape [n_ffa_voxels, T])
        'ffa': ffa,                 # original FFA signals (input to model)
        'ffa_compcorr': globals().get('ffa_compcorr', ffa),  # baseline compcorr cleaned FFA (if computed)
        'face_reg': globals().get('face_reg', None),
        'place_reg': globals().get('place_reg', None),
        'confounds': conf
    }
    return model, track, outputs

def plot_dashboard(track: dict, outputs: dict, output_dir: str = None):
    """
    Plot a dashboard of training metrics and model outputs similar to the original script.
    Saves the figure if output_dir is provided.
    """
    # Setting up subplots grid (5 rows x 9 columns as in original)
    nrows, ncols = 5, 9
    fig, axes = plt.subplots(nrows, ncols, figsize=(5*ncols, 5*nrows))
    axes = axes.flatten()
    sp = 0
    # Plot training losses over epochs
    axes[sp].plot(track['l']); axes[sp].set_title(f"Total loss: {track['l'][-1]:.2f}"); sp += 1
    axes[sp].plot(track['batch_varexp']); axes[sp].set_title(f"batch_varexp: {track['batch_varexp'][-1]:.2f}"); sp += 1
    axes[sp].plot(track['varexp']); axes[sp].set_title(f"FFA varexp: {track['varexp'][-1]:.2f}"); sp += 1
    axes[sp].plot(track['ffa_io']); axes[sp].set_title(f"ffa_io (corr FFA vs recon): {track['ffa_io'][-1]:.2f}"); sp += 1
    axes[sp].plot(track['recons_loss_roi']); axes[sp].set_title(f"recons_loss_roi: {track['recons_loss_roi'][-1]:.2f}"); sp += 1
    axes[sp].plot(track['recons_loss_roni']); axes[sp].set_title(f"recons_loss_roni: {track['recons_loss_roni'][-1]:.2f}"); sp += 1
    axes[sp].plot(track['loss_recon_conf_s']); axes[sp].set_title(f"loss_recon_conf_s: {track['loss_recon_conf_s'][-1]:.2f}"); sp += 1
    axes[sp].plot(track['kld_loss']); axes[sp].set_title(f"kld_loss: {track['kld_loss'][-1]:.2f}"); sp += 1
    axes[sp].plot(track['loss_recon_conf_z']); axes[sp].set_title(f"loss_recon_conf_z: {track['loss_recon_conf_z'][-1]:.2f}"); sp += 1
    axes[sp].plot(track['ncc_loss_tg']); axes[sp].set_title(f"ncc_loss_tg: {track['ncc_loss_tg'][-1]:.2f}"); sp += 1
    axes[sp].plot(track['ncc_loss_bg']); axes[sp].set_title(f"ncc_loss_bg: {track['ncc_loss_bg'][-1]:.2f}"); sp += 1
    axes[sp].plot(track['recons_loss_fg']); axes[sp].set_title(f"recons_loss_fg: {track['recons_loss_fg'][-1]:.2f}"); sp += 1
    axes[sp].plot(track['ncc_loss_conf_s']); axes[sp].set_title(f"ncc_loss_conf_s: {track['ncc_loss_conf_s'][-1]:.2f}"); sp += 1
    axes[sp].plot(track['ncc_loss_conf_z']); axes[sp].set_title(f"ncc_loss_conf_z: {track['ncc_loss_conf_z'][-1]:.2f}"); sp += 1
    axes[sp].plot(track['smoothness_loss']); axes[sp].set_title(f"smoothness_loss: {track['smoothness_loss'][-1]:.2f}"); sp += 1
    # Plot an example voxel timecourse from a batch (last epoch, first voxel in batch vs recon FG/BG)
    # (For simplicity, we'll use the last epoch first voxel of last batch from track if available)
    if 'outputs' in globals():
        # If we had stored batch_in and batch_out for last iteration in training, we could use them.
        # For now, skip this specific plot due to lack of batch sample retention.
        batch_obs_example = outputs['ffa'][0] if outputs['ffa'].shape[0] > 0 else None
        batch_recon_example = outputs['recon'][0] if outputs['recon'].shape[0] > 0 else None
        if batch_obs_example is not None and batch_recon_example is not None:
            axes[sp].plot(batch_obs_example, 'b-')
            axes[sp].plot(batch_recon_example, 'g-')
            axes[sp].set_title("Batch example ROI vs recon")
        sp += 1
    else:
        sp += 1
    # Plot FFA region average timecourse vs model recon average
    axes[sp].plot(outputs['ffa'].mean(axis=0), label='FFA (orig)')
    axes[sp].plot(outputs['recon'].mean(axis=0), label='Recon')
    axes[sp].set_title("FFA AVG vs Recon AVG")
    axes[sp].legend(); sp += 1
    # Plot FFA average vs model foreground (signal) and face task regressor
    face_reg = outputs.get('face_reg', None)
    axes[sp].plot(outputs['ffa'].mean(axis=0), 'k-', label='FFA')
    axes[sp].plot(outputs['signal'].mean(axis=0), 'g-', label='Signal (FG)')
    if face_reg is not None:
        axes[sp].plot(face_reg, 'r--', label='Face_reg')
    axes[sp].set_title("FFA SIGNAL vs Face reg"); axes[sp].legend(); sp += 1
    # Plot FFA average vs model background (noise) and face regressor
    axes[sp].plot(outputs['ffa'].mean(axis=0), 'k-', label='FFA')
    axes[sp].plot(outputs['noise'].mean(axis=0), 'r-', label='Noise (BG)')
    if face_reg is not None:
        axes[sp].plot(face_reg, 'b--', label='Face_reg')
    axes[sp].set_title("FFA NOISE vs Face reg"); axes[sp].legend(); sp += 1
    # Plot example confound prediction (from z and s) vs actual confound (take one confound index, e.g., 2 for rot_z or similar if exists)
    # If confounds exist
    conf = outputs.get('confounds', None)
    if conf is not None and conf.shape[0] >= 1:
        conf_idx = min(conf.shape[0]-1, 2)  # pick one confound (e.g., 3rd if exists)
        # We can attempt to get conf_pred from model's last processed latents (if saved in track or accessible)
        # We did not store conf_pred explicitly; skipping detailed conf plot for simplicity.
        axes[sp].plot(conf[conf_idx, :], 'k-', label='Confound actual')
        axes[sp].set_title("Confound example"); sp += 1
    else:
        sp += 1
    # Plot track of correlation metrics over epochs
    axes[sp].plot(track['r_ffa_reg'], 'k-', label='FFA raw')
    axes[sp].plot(track['r_TG_reg'], 'b-', label='Recon (TG)')
    axes[sp].set_title(f"R TG-REG final: {track['r_TG_reg'][-1]:.2f}")
    axes[sp].legend(); sp += 1
    axes[sp].plot(track['r_ffa_reg'], 'k-', label='FFA raw')
    axes[sp].plot(track['r_compcor_reg'], 'b-', label='FFA CompCor')
    axes[sp].plot(track['r_FG_reg'], 'g-', label='FG')
    axes[sp].set_title(f"R FG-REG final: {track['r_FG_reg'][-1]:.2f}")
    axes[sp].legend(); sp += 1
    axes[sp].plot(track['r_ffa_reg'], 'k-', label='FFA raw')
    axes[sp].plot(track['r_BG_reg'], 'r-', label='BG')
    axes[sp].set_title(f"R BG-REG final: {track['r_BG_reg'][-1]:.2f}")
    axes[sp].legend(); sp += 1
    # Remove any unused subplots (in case we didn't fill all 45 slots)
    for j in range(sp, nrows*ncols):
        fig.delaxes(axes[j])
    fig.tight_layout()
    if output_dir:
        fig_path = os.path.join(output_dir, "dashboard.png")
        plt.savefig(fig_path)
        print(f"Dashboard plot saved to {fig_path}")
    else:
        plt.show()
    plt.close(fig)

def save_outputs(model: cVAE, track: dict, outputs: dict, output_dir: str):
    """
    Save model outputs and derivatives to files in the specified output directory.
    - Saves track and outputs dictionaries as pickled files.
    - Saves model state dictionary.
    - Saves a 4D NIfTI image of the denoised signal (foreground) across the brain ROI.
    """
    safe_mkdir(output_dir)
    # Save training history and outputs as pickles
    track_path = os.path.join(output_dir, "training_track.pkl")
    outputs_path = os.path.join(output_dir, "outputs.pkl")
    with open(track_path, 'wb') as f:
        import pickle
        pickle.dump(track, f)
    with open(outputs_path, 'wb') as f:
        import pickle
        pickle.dump(outputs, f)
    print(f"Saved training track to {track_path}")
    print(f"Saved outputs to {outputs_path}")
    # Save model weights
    model_path = os.path.join(output_dir, "cvae_model_state.pth")
    torch.save({'model_state_dict': model.state_dict()}, model_path)
    print(f"Saved model state to {model_path}")
    # Save the denoised signal as a NIfTI image (4D)
    try:
        import ants
        # We will reconstruct a 4D image for "signal" output across the whole brain ROI.
        # We have outputs['signal'] for FFA region only. For full ROI, we use model to generate signals for all ROI voxels.
        # Let's generate the full ROI "foreground signal" image.
        obs_arr = outputs.get('signal')  # currently only FFA region
        # Instead, generate for all ROI:
        if obs_arr is None or obs_arr.shape[0] != outputs['ffa'].shape[0]:
            # If outputs['signal'] is only FFA, regenerate using model for all ROI voxels (obs_list)
            obs_list = outputs['ffa'] if outputs.get('ffa') is not None else None
            if obs_list is None:
                obs_list = outputs['recon']  # fallback to recon if ffa not present
            if obs_list is None:
                print("No ROI data available to reconstruct full signal image.")
            else:
                # Need to incorporate coordinates for each ROI voxel
                full_obs = torch.tensor(obs_list, dtype=torch.float32).unsqueeze(1)  # (N_roi, 1, T)
                gm_coord_tensor = torch.tensor(gm_coords[:obs_list.shape[0]], dtype=torch.float32)  # coordinates for ROI voxels (assuming original gm_coords aligns with obs_list)
                coord_stack = gm_coord_tensor.unsqueeze(1).repeat(1, obs_list.shape[1], 1).permute(0, 2, 1)  # (N_roi, 3, T)
                full_input = torch.cat([full_obs, coord_stack], dim=1).to(next(model.parameters()).device)
                with torch.no_grad():
                    fg_full = model.forward_fg(full_input)[0].cpu().numpy()  # (N_roi, 4, T) but only first channel is signal
                fg_signal = fg_full[:, 0, :]  # shape (N_roi, T)
                # Create an array for full brain with shape (X*Y*Z, T)
                brain_signal_flat = np.zeros((gm_mask.size, fg_signal.shape[1]), dtype=np.float32)
                brain_signal_flat[gm_flat, :] = 0  # ensure shape
                # Place reconstructed signals into their ROI locations
                brain_signal_flat[gm_flat, :] = fg_signal
                # Reshape to original image 4D shape
                brain_signal_img_data = brain_signal_flat.T.reshape(epi_img.numpy().shape)
        else:
            # If outputs['signal'] covered ROI fully (which it likely doesn't, as it's just FFA subset), handle directly
            brain_signal_img_data = np.zeros_like(epi_img.numpy())
            brain_signal_img_data = brain_signal_img_data.astype(np.float32)
            brain_signal_img_data[gm_mask] = outputs['signal'].flatten()
        # Save using ANTs image
        signal_img = epi_img.new_image_like(brain_signal_img_data)
        signal_img_path = os.path.join(output_dir, "signal_denoised.nii.gz")
        signal_img.to_file(signal_img_path)
        print(f"Saved denoised signal image to {signal_img_path}")
    except Exception as e:
        print(f"Could not save NIfTI signal image: {e}")

In [None]:
# ===============================
# Main execution (if run as script)
# ===============================
if __name__ == "__main__":
    # Load and preprocess data
    obs_list, noi_list, conf, ffa, face_reg, place_reg, ffa_compcorr = load_data_and_preprocess(epi_path, anat_path, n_dummy)
    # For convenience, set global variables for face_reg, place_reg, ffa_compcorr to use in training metrics
    globals()['face_reg'] = face_reg
    globals()['place_reg'] = place_reg
    globals()['ffa_compcorr'] = ffa_compcorr
    # Train CVAE model
    model, track, outputs = train_cvae_model(obs_list, noi_list, conf, ffa, cvae_params)
    # Plot results dashboard and save it
    plot_dashboard(track, outputs, output_dir=ofn_root)
    # Save outputs and model
    save_outputs(model, track, outputs, output_dir=ofn_root)

In [3]:


epi_path = '../Data/020-fmriprepped/sub-NDARINV1H7JEJW1/ses-baselineYear1Arm1/func/sub-NDARINV1H7JEJW1_ses-baselineYear1Arm1_task-nback_run-02_space-MNI152NLin2009cAsym_res-2_desc-preproc_bold.nii.gz'
anat_path = 
n_dummy = 
