# Neuro-TabPFN v0.3 - Pipeline (Colab-ready)

Objetivo: usar **solo** el pipeline del repo `high-dimensional/individualized_prescriptive_inference` con sus embeddings (PCA/NMF **y VAE propio del repo**), manteniendo la resolución nativa MNI **91×109×91** y usando **todas las 4119 máscaras**.

Flujo:

**Resumen del flujo:**
1. Setup de entorno y rutas
2. Instalación de dependencias + clonación repo
3. Descarga/lectura de 4119 máscaras (lesiones) 
4. Representación (genera embeddings AE/VAE/NMF/PCA en k-folds)
5. Deficit modelling: `deficit_modelling.py`
6. Prescriptive simulations: `prescription.py` con `--use_vae True`
7. TabICL Two-Stage + Do-Loss training + evaluación
8. Guardado de salidas

## 1. Setup de entorno y rutas

In [1]:
import os
# Si usas Drive, pon USE_DRIVE=True; de lo contrario change "H:/My Drive" por tu directorio local
USE_DRIVE = False
DRIVE_ROOT = "/content/drive/MyDrive" if USE_DRIVE else "H:/My Drive"

ROOT = os.path.join(DRIVE_ROOT)
ROOT_DIR = os.path.join(ROOT, "Debbuging Neuro")
DATA_DIR = os.path.join(ROOT_DIR, "Data")
RESULTS_DIR = os.path.join(ROOT_DIR, "Results")
REPO_DIR = os.path.join(ROOT, "individualized_prescriptive_inference")
REPO_RESULTS = os.path.join(REPO_DIR, "results")

os.makedirs(DATA_DIR, exist_ok=True)
os.makedirs(RESULTS_DIR, exist_ok=True)
os.makedirs(REPO_RESULTS, exist_ok=True)

os.chdir(ROOT)
    
print(f"Working dir: {ROOT_DIR}")

Working dir: H:/My Drive\Debbuging Neuro


In [3]:
%%capture
!pip install -r requirements.txt
#!pip install -q torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu 

In [5]:
# IMPORTS
import glob
import zipfile
from urllib.request import urlretrieve
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import nibabel as nib
from scipy.ndimage import zoom
from datasets import load_dataset
from tqdm import tqdm
from nilearn import datasets, image
import urllib
import os
import shutil
import gzip
import sys
import json, random, warnings, textwrap, subprocess
from typing import Tuple
import subprocess
import sys
import textwrap
import subprocess

try:
    import torch_directml
    dml = torch_directml.device()
except ImportError:
    dml = None
    
plt.style.use('seaborn-v0_8-whitegrid')
warnings.filterwarnings('ignore')

In [6]:
# =============================================================================
# CONFIG 
# =============================================================================
# Si usas mac descomentar el primer device y comentar el segundo device

class Config:
    SEED = 42
    
    # Device detection
    # DEVICE = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
    try:
        import torch_directml
        DEVICE = torch_directml.device()
    except ImportError:
        DEVICE = torch.device(
            "cuda" if torch.cuda.is_available() else
            "mps" if torch.backends.mps.is_available() else
            "cpu"
        )
    # Trial
    SAMPLES = 10
    
    # Pipeline params (usados por representation.py)
    TARGET_SHAPE = (91, 109, 91)
    LATENT_DIM = 50
    K_FOLDS = 10
    
    # TabICL params (solo si usas TabICL después)
    D_MODEL = 128
    N_HEAD = 4
    N_LAYERS_COL = 2
    N_LAYERS_ROW = 4
    DIM_FEEDFORWARD = 512
    DROPOUT = 0.1
    LR_TABICL = 5e-4
    SYN_BATCH = 64
    SYN_SEQ = 48
    DO_STEPS = 400
    
    # SCM params para simulación causal
    SCM_EFFECT = 5.0
    SCM_WEIGHT_Z = 0.5
    SCM_WEIGHT_T = 2.0

cfg = Config()
random.seed(cfg.SEED)
np.random.seed(cfg.SEED)
torch.manual_seed(cfg.SEED)

print(f"Device: {cfg.DEVICE}")
print(f"Latent dim: {cfg.LATENT_DIM}, K-folds: {cfg.K_FOLDS}")

Device: privateuseone:0
Latent dim: 50, K-folds: 10


## 2. Instalación de dependencias + clonación repo

In [None]:
%%capture
# CLONAR REPO + REQS
%cd "$ROOT"
if not os.path.exists(REPO_DIR):
    !git clone https://github.com/high-dimensional/individualized_prescriptive_inference.git "$REPO_DIR"
%cd $REPO_DIR

## 3. Descarga de 4119 máscaras (lesiones) + Otros dataset 

In [None]:
# DESCARGA Y PREPROCESADO DE MÁSCARAS
import os, zipfile, glob, shutil
from urllib.request import urlretrieve

LESIONS_ZIP = os.path.join(REPO_DIR, "lesions.zip")
URL = "https://github.com/high-dimensional/individualized_prescriptive_inference/raw/main/lesions.zip"

# Descargar
if not os.path.exists(LESIONS_ZIP):
    try:
        urlretrieve(URL, LESIONS_ZIP)
    except Exception:
        import subprocess
        subprocess.run(["wget", "-q", URL, "-O", LESIONS_ZIP], check=True)
        
# Descomprimir directamente en REPO_DIR
if not glob.glob(os.path.join(REPO_DIR, "*.nii*")):
    try:
        with zipfile.ZipFile(LESIONS_ZIP, "r") as zip_ref:
            zip_ref.extractall(REPO_DIR)
    except Exception:
        import subprocess
        subprocess.run(["unzip", "-q", LESIONS_ZIP, "-d", REPO_DIR], check=True)

# Verificar
LESIONES_PATH = os.path.join(REPO_DIR, "lesions")
DISCO_PATH = os.path.join(REPO_DIR, "disconnectomes")

lesion_files = sorted(glob.glob(os.path.join(LESIONES_PATH, "*.nii*")))

print("=" * 60)
print("Total máscaras:", len(lesion_files))

# Recorrer los primeros archivos para verificar su forma (shape)
for f in lesion_files[:3]: 
    img = nib.load(f)
    print(f"Archivo: {os.path.basename(f)}")
    print(f" - Shape: {img.shape}")
    print(f" - Affine (orientación):\n{img.affine}")
    print("-" * 30)

In [11]:
# Atlas Preprocessing for Stroke Lesion Analysis
# References: Liu et al. (2023). Digital 3D brain MRI arterial territories atlas. Scientific Data, 10(1), 1-12. https://doi.org/10.1038/s41597-022-01923-0

ATLAS_DIR = os.path.join(REPO_DIR, "atlases")
VASC_ATLAS_DIR = os.path.join(ATLAS_DIR, "vasc_atlas")
NITRC_ATLAS_DIR = os.path.join(DATA_DIR, "NITRC")

COMPRESSED_FILES = [
    "functional_parcellation_2mm.nii.gz",
    "icv_mask_2mm.nii.gz"
]

# =============================================================================
# Utility Functions
# =============================================================================

def load_and_resample_to_lesion_space(atlas_path: str, reference_path: str) -> Tuple[nib.Nifti1Image, np.ndarray, np.ndarray]:
    """
    Load NIfTI atlas and resample to match lesion mask space exactly.
    Handles 4D inputs by taking first volume.
    """
    atlas_img = nib.load(atlas_path)
    atlas_data = atlas_img.get_fdata()
    
    # Handle 4D: take first volume
    if atlas_data.ndim == 4:
        atlas_data = atlas_data[:, :, :, 0]
        atlas_img = nib.Nifti1Image(atlas_data, atlas_img.affine)
    
    # Load reference (lesion or ICV mask) for target space
    ref_img = nib.load(reference_path)
    
    # Resample atlas to reference space
    resampled = image.resample_to_img(
        atlas_img, 
        ref_img, 
        interpolation="nearest"
    )
    
    return resampled, resampled.get_fdata().astype(np.int16), resampled.affine


def lateralize_atlas(data: np.ndarray, n_labels: int = 4) -> np.ndarray:
    """Split atlas by hemisphere: left=[1,n], right=[n+1,2n]."""
    mid = data.shape[0] // 2
    out = np.zeros_like(data)
    out[:mid] = data[:mid]
    out[mid:] = np.where(data[mid:] > 0, data[mid:] + n_labels, 0)
    return out


def classify_circulation(data: np.ndarray) -> np.ndarray:
    """Map territories to anterior (1) vs posterior (2) circulation."""
    out = np.zeros_like(data)
    out[np.isin(data, [1, 2])] = 1  # ACA + MCA
    out[np.isin(data, [3, 4])] = 2  # PCA + VB
    return out

# =============================================================================
# Step 1: Sync Atlas Directory from Repository
# =============================================================================

source_atlas = os.path.join(REPO_DIR, "atlases")

if os.path.exists(source_atlas) and not os.path.exists(ATLAS_DIR):
    shutil.copytree(source_atlas, ATLAS_DIR)
    print(f"Synced atlas directory to {ATLAS_DIR}")
elif not os.path.exists(source_atlas):
    os.makedirs(ATLAS_DIR, exist_ok=True)
    print(f"Created empty atlas directory at {ATLAS_DIR}")
else:
    print(f"Atlas directory exists at {ATLAS_DIR}")

# =============================================================================
# Step 2: Decompress Required Files
# =============================================================================

for gz_name in COMPRESSED_FILES:
    gz_path = os.path.join(ATLAS_DIR, gz_name)
    nii_path = os.path.join(ATLAS_DIR, gz_name.replace(".gz", ""))
    
    if os.path.exists(gz_path) and not os.path.exists(nii_path):
        with gzip.open(gz_path, 'rb') as f_in, open(nii_path, 'wb') as f_out:
            shutil.copyfileobj(f_in, f_out)
        print(f"Decompressed: {gz_name}")
    elif os.path.exists(nii_path):
        print(f"Already exists: {os.path.basename(nii_path)}")
    else:
        print(f"WARNING: Missing source file {gz_name}")

# =============================================================================
# Step 3: Build Vascular Territory Atlases (resampled to lesion space)
# =============================================================================

os.makedirs(VASC_ATLAS_DIR, exist_ok=True)

# Reference image for target space (ICV mask has correct shape)
REFERENCE_PATH = os.path.join(ATLAS_DIR, "icv_mask_2mm.nii")

if not os.path.exists(REFERENCE_PATH):
    print(f"ERROR: Reference image not found: {REFERENCE_PATH}")
else:
    # 3.1 Full parcellation (30 territories)
    full_atlas_path = os.path.join(NITRC_ATLAS_DIR, "ArterialAtlas.nii")
    if os.path.exists(full_atlas_path):
        full_img, _, _ = load_and_resample_to_lesion_space(full_atlas_path, REFERENCE_PATH)
        nib.save(full_img, os.path.join(VASC_ATLAS_DIR, "all_territories.nii"))
        print(f"all_territories.nii: {full_img.shape}")
    else:
        print(f"WARNING: {full_atlas_path} not found")

    # 3.2 Major territories (ACA, MCA, PCA, VB)
    level2_path = os.path.join(NITRC_ATLAS_DIR, "ArterialAtlas_level2.nii")
    if os.path.exists(level2_path):
        level2_img, level2_data, affine = load_and_resample_to_lesion_space(level2_path, REFERENCE_PATH)
        nib.save(level2_img, os.path.join(VASC_ATLAS_DIR, "major_arterial_territory.nii"))
        print(f"major_arterial_territory.nii: {level2_img.shape}")
        
        # 3.3 Lateralized version (L/R hemisphere split)
        lat_data = lateralize_atlas(level2_data, n_labels=4)
        nib.save(nib.Nifti1Image(lat_data, affine), os.path.join(VASC_ATLAS_DIR, "major_arterial_territory_lat.nii"))
        print(f"major_arterial_territory_lat.nii: labels {np.unique(lat_data)}")
        
        # 3.4 Anterior vs posterior circulation
        circ_data = classify_circulation(level2_data)
        nib.save(nib.Nifti1Image(circ_data, affine), os.path.join(VASC_ATLAS_DIR, "major_territories.nii"))
        print(f"major_territories.nii: anterior={np.sum(circ_data==1)}, posterior={np.sum(circ_data==2)} voxels")
    else:
        print(f"WARNING: {level2_path} not found")

print("\nAtlas preprocessing complete.")

# =============================================================================
# Step 4: Verify all shapes match
# =============================================================================

print(f"\n{'FILE':<45} | {'SHAPE':<15} | {'STATE'}")
print("-" * 75)

archivos_a_verificar = [
    os.path.join(VASC_ATLAS_DIR, "all_territories.nii"),
    os.path.join(VASC_ATLAS_DIR, "major_arterial_territory.nii"),
    os.path.join(VASC_ATLAS_DIR, "major_arterial_territory_lat.nii"),
    os.path.join(VASC_ATLAS_DIR, "major_territories.nii"),
    os.path.join(ATLAS_DIR, "functional_parcellation_2mm.nii"),
    os.path.join(ATLAS_DIR, "icv_mask_2mm.nii")
]

for path in archivos_a_verificar:
    nombre = os.path.basename(path)
    if os.path.exists(path):
        img = nib.load(path)
        shape = img.shape
        es_compatible = "OK" if shape == cfg.TARGET_SHAPE else "DISCREPANCIA"
        print(f"{nombre:<45} | {str(shape):<15} | {es_compatible}")
    else:
        print(f"{nombre:<45} | {'NO ENCONTRADO':<15} | REVISAR")

Atlas directory exists at H:/My Drive\individualized_prescriptive_inference\atlases
Already exists: functional_parcellation_2mm.nii
Already exists: icv_mask_2mm.nii
all_territories.nii: (91, 109, 91)
major_arterial_territory.nii: (91, 109, 91)
major_arterial_territory_lat.nii: labels [ 0  1  3  5  7  9 11 13]
major_territories.nii: anterior=50671, posterior=116476 voxels

Atlas preprocessing complete.

FILE                                          | SHAPE           | STATE
---------------------------------------------------------------------------
all_territories.nii                           | (91, 109, 91)   | OK
major_arterial_territory.nii                  | (91, 109, 91)   | OK
major_arterial_territory_lat.nii              | (91, 109, 91)   | OK
major_territories.nii                         | (91, 109, 91)   | OK
functional_parcellation_2mm.nii               | (91, 109, 91)   | OK
icv_mask_2mm.nii                              | (91, 109, 91)   | OK


## 4. Representación (genera embeddings NMF/PCA en k-folds)

In [None]:
"""
Representation Learning Module Execution
========================================
Executes dimensionality reduction pipeline (PCA, NMF, VAE) on lesion masks.

This step generates latent embeddings for downstream causal inference.
See Giles et al. (2025), Section 5.2 for methodology details.
"""
%cd {REPO_DIR}

# =============================================================================
# Configuration
# =============================================================================

REPRESENTATION_SCRIPT = os.path.join(REPO_DIR, "software", "representation.py")
OUTPUT_DIR = os.path.join(REPO_RESULTS, "representations")

os.makedirs(OUTPUT_DIR, exist_ok=True)

# =============================================================================
# Build Command 
# =============================================================================
cmd_args = [
    sys.executable,
    REPRESENTATION_SCRIPT,
    "--lesionpath", LESIONES_PATH,
    "--discopath", DISCO_PATH,  
    "--savepath", OUTPUT_DIR,
    "--kfolds", "10",
    "--latent_components", "50",
    "--batch_size", "32",
    "--min_epoch", "16",
    "--max_epoch", "32",
    "--early_stopping_epochs", "4",
    "--n_samples", "100",
    "--run_vae", "True",   
    "--run_ae", "False",   
    "--run_nmf", "False",
    "--run_pca", "False",
    "--verbose", "True"
]

print(f"Executing: {' '.join(cmd_args)}")
print("-" * 60)

# =============================================================================
# Execute
# =============================================================================
try:
    result = subprocess.run(
        cmd_args,
        check=True,
        capture_output=False,
        text=True,
        env={**os.environ, "PYTHONUNBUFFERED": "1"}
    )
    print("Execution completed successfully.")
    if result.stdout:
        print(result.stdout)
        
except subprocess.CalledProcessError as e:
    print(f"\nProcess failed with exit code {e.returncode}")
    print("-" * 60)
    error_output = e.stderr or e.stdout
    if error_output:
        lines = error_output.strip().split('\n')
        traceback_start = next((i for i, l in enumerate(lines) if 'Traceback' in l), 0)
        print('\n'.join(lines[traceback_start:]))
    print("-" * 60)

H:\My Drive\individualized_prescriptive_inference
Executing: C:\ProgramData\anaconda3\envs\debugging\python.exe H:/My Drive\individualized_prescriptive_inference\software\representation.py --lesionpath H:/My Drive\individualized_prescriptive_inference\lesions --discopath H:/My Drive\individualized_prescriptive_inference\disconnectomes --savepath H:/My Drive\individualized_prescriptive_inference\results\representations --kfolds 10 --latent_components 50 --batch_size 32 --min_epoch 16 --max_epoch 32 --early_stopping_epochs 4 --n_samples 100 --run_vae True --run_ae False --run_nmf False --run_pca False --verbose True
------------------------------------------------------------


## 5. Deficit modelling: `deficit_modelling.py`

In [None]:
"""
Deficit Modelling Execution
===========================
Maps lesion representations to functional network disruptions.

Computes overlap between lesion masks and 16 NeuroQuery functional networks,
generating binary deficit labels for downstream prescriptive inference.

References:
    Giles et al. (2025), Section: Deficit Modelling
    Dockès et al. (2020), NeuroQuery functional parcellation
"""

%cd {REPO_DIR}

import subprocess
import sys
import os

# =============================================================================
# Configuration
# =============================================================================

DEFICIT_SCRIPT = os.path.join(REPO_DIR, "software", "deficit_modelling.py")
OUTPUT_DIR = os.path.join(REPO_RESULTS, "representations")

# Cross-validation and threshold parameters

ROI_THRESHOLD = 0.05  # Minimum overlap fraction to mark network as affected
thresh_str = str(ROI_THRESHOLD) if not isinstance(ROI_THRESHOLD, list) else str(ROI_THRESHOLD[0])


# =============================================================================
# Build Command
# =============================================================================

cmd_args = [
    sys.executable,
    DEFICIT_SCRIPT,
    "--path", OUTPUT_DIR,
    "--lesionpath", LESIONES_PATH,
    "--discopath", DISCO_PATH,  
    "--latent_list", "50",
    "--kfold_deficits", "10",
    "--roi_threshs", thresh_str,
    "--run_vae", "True",   
    "--run_ae", "False",   
    "--run_nmf", "False",
    "--run_pca", "False",
    #"--n_samples", "10", 
    "--verbose", "True" 
]


print(f"Executing deficit modelling with {len(LATENT_DIMS)} latent dimensions")
print(f"Parameters: K={K_FOLDS}, threshold={ROI_THRESHOLD}")

# =============================================================================
# Execute with Full Error Capture
# =============================================================================

try:
    result = subprocess.run(
        cmd_args,
        check=True,
        capture_output=False,
        text=True,
        env={**os.environ, "PYTHONUNBUFFERED": "1"}
    )
    print("Deficit modelling completed successfully.")
    if result.stdout:
        print(result.stdout)

except subprocess.CalledProcessError as e:
    error_output = e.stderr or e.stdout
    if error_output:
        lines = error_output.strip().split('\n')
        traceback_start = next((i for i, l in enumerate(lines) if 'Traceback' in l), 0)
        print('\n'.join(lines[traceback_start:]))
    else:
        print("No error output captured.")
    
    print("-" * 60)

## 6. Prescriptive simulations: `prescription.py` con `--use_vae True`

In [None]:
"""
Prescriptive Inference Simulations
==================================
Executes virtual clinical trials to evaluate treatment effect estimation
under varying confounding scenarios.

This module implements the InterSynth evaluation framework, testing
prescriptive models across 22,528 discrete DGP configurations.

References:
    Giles et al. (2025), Nature Communications - Sections: Virtual Trials, Prescriptive Inference
    
Parameters:
    - biasdegree (gamma): Confounding strength [0, 1]
    - te (beta): Treatment effect magnitude
    - re (alpha): Spontaneous recovery rate
    - deficits: 16 functional networks from NeuroQuery
"""

import subprocess
import sys
import os

# =============================================================================
# Configuration
# =============================================================================

PRESCRIPTION_SCRIPT = os.path.join(REPO_DIR, "software", "prescription.py")
PRESCRIPTION_DIR = os.path.join(REPO_RESULTS, "prescription")
REPRESENTATIONS_DIR = os.path.join(REPO_RESULTS, "representations")

thresh_str = str(ROI_THRESHOLD) if not isinstance(ROI_THRESHOLD, list) else str(ROI_THRESHOLD[0])

# =============================================================================
# Build Command
# =============================================================================

cmd_args = [
    sys.executable,
    PRESCRIPTION_SCRIPT,
    "--savepath", PRESCRIPTION_DIR,     # Where to save final results
    "--loadpath", REPRESENTATIONS_DIR,  # Where to load data (Deficit Modelling Output)
    
    # Experiment Configuration
    "--k", *[str(k) for k in range(10)],
    "--gene_or_receptor", "genetics", "receptor",
    "--lesion_or_disconnectome", "lesion", "disco", 
    "--lesion_deficit_thresh", thresh_str, thresh_str,
    
    # Simulation Parameters (Reduced for debug)
    "--deficits", "10", "11",         
    "--biasdegree", "0", "0.5",     
    "--biastype", "observed",        
    "--te", "0.5",                  
    "--re", "0.25",                 
    
    # Models and Representations
    "--bottlenecks", "0", "50",     
    "--simpleatlases", "major_territories", 
    "--simpleatlas_argmaxs", "True",
    "--vols", "True",
    "--centroids", "False",
    "--ml_models", "xgb", 
    
    # Which embeddings to use (Matches your previous steps)
    "--use_vae", "True",    
    "--use_ae", "False",    
    "--use_nmf", "False",
    "--use_pca", "False"
]


n_configs = (
    10 * 2 * 2 * # Folds * Deficits * Subdivisions
    2 * 2 * 1 * 1 # Inputs * Bias * TE * RE
)
print(f" Ejecutando Prescription para los 10 folds...")
print(f"   Inputs: lesion + disco")
print(f"   Input Dir: {REPRESENTATIONS_DIR}")
print(f"   Output Dir: {PRESCRIPTION_DIR}")
print(f"   Thresholds: {thresh_str}, {thresh_str}")

# =============================================================================
# Execute with Full Error Capture
# =============================================================================

try:
    result = subprocess.run(
        cmd_args,
        check=True,
        capture_output=True,
        text=True,
        env={**os.environ, "PYTHONUNBUFFERED": "1"}
    )
    print("Prescriptive simulations completed successfully.")
    if result.stdout:
        print(result.stdout[-2000:])  # Last 2000 chars to avoid overflow

except subprocess.CalledProcessError as e:
    print(f"\nProcess failed with exit code {e.returncode}")
    print("-" * 60)
    
    error_output = e.stderr or e.stdout
    if error_output:
        lines = error_output.strip().split('\n')
        traceback_start = next((i for i, l in enumerate(lines) if 'Traceback' in l), 0)
        print('\n'.join(lines[traceback_start:]))
    else:
        print("No error output captured.")
    
    print("-" * 60)

## 7. TabICL Two-Stage + Do-Loss training + evaluación

## 8. Guardado de salidas