# Create synethetic dataset

(To showcase file structure)

In [None]:
# Example: Creating a synthetic fMRI dataset for Brain-Semantoks model

# This cell demonstrates how to create a properly formatted HDF5 dataset 
# that can be used with the Brain-Semantoks model for pretraining or 
# downstream tasks (linear probe/finetuning).

# Dataset requirements:
# - TR = 2.0 seconds (0.5 Hz sampling rate)
# - ROI ordering: 7-network ordering for Schaefer400 parcellation (THIS IS DIFFERENT FROM SOME OTHER MODELS! The ROIs are the same for 7n and 17n, but the ordering is different.)
# See: https://github.com/ThomasYeoLab/CBIG/tree/v0.14.3-Update_Yeo2011_Schaefer2018_labelname/stable_projects/brain_parcellation/Schaefer2018_LocalGlobal/Parcellations/MNI/
# - Z-score normalization per ROI per subject
# - Bandpass filtering: 0.01-0.1 Hz (similar filters like 0.009-0.08 are fine too. Some small tests indicate unfiltered data works almost as well.)
# - HDF5 file structure matching the expected format

import numpy as np
import h5py
from scipy import signal

def bandpass_filter(data, lowcut=0.01, highcut=0.1, fs=0.5, order=5):
    nyquist = 0.5 * fs
    low = lowcut / nyquist
    high = highcut / nyquist
    b, a = signal.butter(order, [low, high], btype='band')
    
    filtered_data = np.zeros_like(data)
    for i in range(data.shape[0]):  # for each subject
        for j in range(data.shape[1]):  # for each ROI
            filtered_data[i, j, :] = signal.filtfilt(b, a, data[i, j, :])
    
    return filtered_data

def zscore_per_roi(data):
    zscored = np.zeros_like(data)
    for i in range(data.shape[0]):  # for each subject
        for j in range(data.shape[1]):  # for each ROI
            roi_ts = data[i, j, :]
            zscored[i, j, :] = (roi_ts - roi_ts.mean()) / (roi_ts.std() + 1e-8)
    
    return zscored

# Create synthetic dataset
n_subjects = 20
n_timepoints = 180  # TR=2.0, so 180 TRs = 6 minutes
n_rois_schaefer = 400
n_rois_tian = 50
n_rois_buckner = 7

# Generate random timeseries data N(0,1) for each atlas
print("Generating synthetic timeseries...")
schaefer_data = np.random.randn(n_subjects, n_rois_schaefer, n_timepoints)
tian_data = np.random.randn(n_subjects, n_rois_tian, n_timepoints)
buckner_data = np.random.randn(n_subjects, n_rois_buckner, n_timepoints)

# Apply bandpass filter (0.01-0.1 Hz, TR=2.0 -> fs=0.5 Hz)
print("Applying bandpass filter (0.01-0.1 Hz)...")
schaefer_data = bandpass_filter(schaefer_data, lowcut=0.01, highcut=0.1, fs=0.5)
tian_data = bandpass_filter(tian_data, lowcut=0.01, highcut=0.1, fs=0.5)
buckner_data = bandpass_filter(buckner_data, lowcut=0.01, highcut=0.1, fs=0.5)

# Z-score per ROI per subject
print("Z-scoring per ROI per subject...")
schaefer_data = zscore_per_roi(schaefer_data)
tian_data = zscore_per_roi(tian_data)
buckner_data = zscore_per_roi(buckner_data)

# Create subject IDs
subject_ids = np.array([f'sub-SYNTHETIC{i:05d}'.encode('utf-8') for i in range(n_subjects)])

# Create labels for downstream tasks
# Age: 5 classes (binned)
age_5c = np.random.randint(0, 5, size=n_subjects)

# Sex: binary
sex_bi = np.random.randint(0, 2, size=n_subjects)

# Continuous age
age = np.random.uniform(18, 80, size=n_subjects)

print("\nDataset shapes:")
print(f"  Schaefer400: {schaefer_data.shape}")
print(f"  Tian subcortical: {tian_data.shape}")
print(f"  Buckner7: {buckner_data.shape}")
print(f"  Subject IDs: {subject_ids.shape}")
print(f"  Labels (age_5c): {age_5c.shape}")
print(f"  Labels (sex_bi): {sex_bi.shape}")

# Save to HDF5 file
output_path = "/home/sagi11/code/Brain-Semantoks/synthetic_dataset_example.h5"
print(f"\nSaving to: {output_path}")

with h5py.File(output_path, 'w') as f:
    # Create timeseries group
    ts_group = f.create_group('timeseries')
    
    # Store each atlas separately (following 7n ordering for Schaefer400); fp16 is safe given z-scoring
    ts_group.create_dataset('schaefer400', data=schaefer_data.astype(np.float16), compression='gzip')
    ts_group.create_dataset('tian3', data=tian_data.astype(np.float16), compression='gzip')
    ts_group.create_dataset('buckner7', data=buckner_data.astype(np.float16), compression='gzip')
    
    # Store subject identifiers
    f.create_dataset('long_subject_id', data=subject_ids)
    
    # Store labels for downstream tasks
    f.create_dataset('age_5c', data=age_5c)
    f.create_dataset('sex_bi', data=sex_bi)
    f.create_dataset('age', data=age)
    

# Verify the saved file
with h5py.File(output_path, 'r') as f:
    print(f"Top-level keys: {list(f.keys())}")
    print(f"Timeseries keys: {list(f['timeseries'].keys())}")
    print(f"\nShapes:")
    print(f"  schaefer400: {f['timeseries/schaefer400'].shape}")
    print(f"  tian3: {f['timeseries/tian3'].shape}")
    print(f"  buckner7: {f['timeseries/buckner7'].shape}")
    print(f"  long_subject_id: {f['long_subject_id'].shape}")
    print(f"\nFirst 5 subject IDs: {f['long_subject_id'][:5]}")

# To use this dataset in config YAML files:

# For pretraining (specify in data.datasets):
#   - name: synthetic
#     data_path: /path/to/synthetic_dataset_example.h5
#     raw_signal_length: 180
#     train_subject_ids_path: /path/to/train_subject_ids.npy  # optional

# For downstream tasks (linear_probe.probe_datasets or finetune.probe_datasets):
#   - name: synthetic_age
#     data_path: /path/to/synthetic_dataset_example.h5
#     label_names: ['age_5c']
#     n_class: [5]
#     raw_signal_length: 180
#     probe_train_subject_ids_path: /path/to/subject_ids.npy  # optional
#     split_ratio: [0.7, 0.15, 0.15]  # train, val, test
#     stratify: ['age_5c']

#   - name: synthetic_sex  
#     data_path: /path/to/synthetic_dataset_example.h5
#     label_names: ['sex_bi']
#     n_class: [2]
#     raw_signal_length: 180
#     split_ratio: [0.7, 0.15, 0.15]
#     stratify: ['sex_bi']

# The model expects concatenated timeseries in this order:
#   [schaefer400, tian3, buckner7] -> total 457 ROIs
  
# With target_signal_length=100, the model processes 100-timepoint crops.

# Use model to obtain data representations

In [None]:
%load_ext autoreload
%autoreload 2
import sys
import os
import torch
import yaml
import numpy as np
import h5py

from simdino import SimDINOModel

class ModelInferenceWrapper:    
    def __init__(self, cfg, device='cuda'):
        """
        Initialize model from config file.
        
        Args:
            config_path: Path to YAML config file (with resume_checkpoint path)
            device: 'cuda' or 'cpu'
        """
        self.device = torch.device(device)
        self.config = cfg

        model_cfg = self.config['model']
        data_cfg = self.config['data']
        dino_cfg = self.config['dino']
        ssl_cfg = self.config['ssl']
        
        checkpoint_path = self.config['training'].get('resume_checkpoint')
        if not checkpoint_path or not os.path.exists(checkpoint_path):
            raise ValueError(f"Checkpoint not found at: {checkpoint_path}")
        print(f"Checkpoint path: {checkpoint_path}")
                
        # Build atlas names from explicit config
        atlas_names = []
        atlas_network_counts = []
        total_rois = 0
        for atlas_type in ['schaefer', 'tian', 'buckner']:
            atlas_name = data_cfg.get(f'{atlas_type}_atlas')
            if atlas_name is not None:
                atlas_names.append(atlas_name)
                atlas_network_counts.append(data_cfg[f'{atlas_type}_networks'])
                total_rois += data_cfg[f'{atlas_type}_rois']
        
        max_spatial = data_cfg.get('max_spatial', False)
        min_spatial = data_cfg.get('min_spatial', False)
    
        mlp_dim = int(model_cfg['embedding_dim'] * 4)
        heads = int(model_cfg['embedding_dim'] / 64)
        
        self.model = SimDINOModel(
            patch_size=data_cfg['patch_size'],
            do_masking=True,
            target_time_length=data_cfg['target_signal_length'],
            embedding_dim=model_cfg['embedding_dim'],
            depth=model_cfg['depth'],
            mlp_dim=mlp_dim,
            heads=heads,
            global_pooling=model_cfg['global_pooling'],
            layer_scale_init_value=model_cfg.get('layer_scale_init_value', None),
            network_data_path=data_cfg['network_map_path'],
            atlas_names=atlas_names,
            projection_hidden_dim=model_cfg['projection_hidden_dim'],
            projection_bottleneck_dim=model_cfg['projection_bottleneck_dim'],
            projection_nlayers=model_cfg['projection_nlayers'],
            base_teacher_momentum=dino_cfg['base_teacher_momentum'],
            coeff=dino_cfg['coeff'],
            mask_loss_weight=ssl_cfg['mask_loss_weight'],
            network_loss_weight=ssl_cfg.get('network_loss_weight', 0.0),
            backbone_type=model_cfg.get('backbone_type', 'cnn_tf'),
            # semantoks_config=model_cfg.get('semantoks_config', None),
            max_spatial=max_spatial,
            min_spatial=min_spatial,
            total_rois=total_rois,
            atlas_network_counts=atlas_network_counts,
            tokenizer_config=model_cfg.get('tokenizer', {}).get('config'),
            tokenizer_final_norm=model_cfg.get('tokenizer', {}).get('final_norm', 'layer')
        )
        
        # Load checkpoint weights
        print(f"Loading checkpoint from: {checkpoint_path}")
        checkpoint = torch.load(checkpoint_path, map_location=self.device)
        state_dict = checkpoint['model_state_dict']
        
        if list(state_dict.keys())[0].startswith('module.'):
            state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
        
        self.model.load_state_dict(state_dict)
        self.model.to(self.device)
        self.model.eval()
        
        print("Model loaded successfully")
        print(f"  Embedding dimension: {model_cfg['embedding_dim']}")
    
    def extract_embeddings(self, data, atlas_idx=0, feature_type='cls_avg', 
                          use_teacher=True, return_dict=False):
        """
        Extract embeddings from fMRI data.
        
        Args:
            data: Input tensor of shape (batch, channels, time)
                  e.g., (32, 400, 100) for Schaefer400 atlas
            atlas_idx: Which atlas to use (0 for first atlas in config)
            feature_type: Type of features to extract:
                - 'cls': CLS token only
                - 'avg': Average of all tokens (excluding CLS)
                - 'cls_avg': Concatenation of CLS and averaged tokens
            use_teacher: Use teacher encoder (True) or student encoder (False)
            return_dict: If True, return full output dict; if False, return features only
        
        Returns:
            embeddings: Tensor of shape (batch, feature_dim)
                - feature_dim = embedding_dim for 'cls' or 'avg'
                - feature_dim = embedding_dim * 2 for 'cls_avg'
            OR dict with keys ['global_cls', 'tokens'] if return_dict=True
        """
        if isinstance(data, np.ndarray):
            data = torch.from_numpy(data).float()
        
        data = data.to(self.device)
        if data.ndim == 2:
            data = data.unsqueeze(0) 
                
        self.model.eval()
        with torch.no_grad():
            
            encoder = self.model.teacher_encoder if use_teacher else self.model.student_encoder
            
            output = encoder(data, atlas=atlas_idx, mask=None)
            
            if return_dict:
                return output
            
            # Extract features based on type
            if feature_type == 'cls':
                features = output['global_cls']
            elif feature_type == 'avg':
                features = output['tokens'][:, 1:].mean(dim=1)
            elif feature_type == 'cls_avg':
                avg_tokens = output['tokens'][:, 1:].mean(dim=1)
                features = torch.cat([avg_tokens, output['global_cls']], dim=1)
            else:
                raise ValueError(f"Unknown feature_type: {feature_type}. "
                               f"Use 'cls', 'avg', or 'cls_avg'")
            
            return features
    
    def extract_embeddings_batch(self, data_loader, **kwargs):
        all_embeddings = []
        
        for batch in data_loader:
            if isinstance(batch, dict):
                data = batch['signal'][0][0]  
            else:
                data = batch
            
            embeddings = self.extract_embeddings(data, **kwargs)
            all_embeddings.append(embeddings.cpu())
        
        return torch.cat(all_embeddings, dim=0)
    
    def get_feature_dim(self, feature_type='cls_avg'):
        """Get the dimensionality of extracted features."""
        base_dim = self.config['model']['embedding_dim']
        if feature_type == 'cls_avg':
            return base_dim * 2
        else:
            return base_dim

Download the model from: https://huggingface.co/SamGijsen/Brain-Semantoks

Then, fill in the path to the config and checkpoint files below.

In [None]:
cfg_path = 'path/to/downloaded/config_used.yaml'
ckpt_path = 'path/to/downloaded/brainsemantoks_ckpt_epoch_100.pth'


cfg = yaml.safe_load(open(cfg_path, 'r'))
cfg["training"]["resume_checkpoint"] = ckpt_path
cfg["data"]["network_map_path"] = './network_mapping.npz'

model_wrapper = ModelInferenceWrapper(
    cfg=cfg,
    device='cuda' if torch.cuda.is_available() else 'cpu'
)

# Load and concatenate timeseries from synthetic dataset
with h5py.File("synthetic_dataset_example.h5", 'r') as f:
    data = np.concatenate([
        f['timeseries/schaefer400'][:],
        f['timeseries/tian3'][:],
        f['timeseries/buckner7'][:]
    ], axis=1)

print(f"Data shape: {data.shape}")  # (n_subjects, 457 ROIs, n_timepoints)

# Extract cls_avg embeddings (model expects 100 timepoints, but see the task-based fMRI application in our paper if you'd like to use the model for shorter sequence, blocks, or trials!)
embeddings = model_wrapper.extract_embeddings(
    data[:, :, :100],  # first 100 timepoints
    feature_type='cls_avg',
    use_teacher=True
)

print(f"Embeddings shape: {embeddings.shape}")  # (n_subjects, 1536)