In [None]:
import os
import re
import numpy as np
import pandas as pd
import nibabel as nib
from glob import glob
from scipy.ndimage import binary_closing, binary_fill_holes
from skimage.transform import resize
from PIL import Image

# ─── CONFIG ───────────────────────────────────────────────────────────
BASE_IN  = "/Volumes/Samsung_PSSD_T7_Shield/Athena PreProccessd"
SITES    = ["KKI","NeuroIMAGE","NYU","OHSU","Peking_1","Peking_2","WashU"]
CSV_MAP  = {s: os.path.join(BASE_IN, s, f"{s}_phenotypic.csv") for s in SITES}

BASE_OUT = os.path.join(BASE_IN, "Slices")
N_TIME   = 50     # number of timepoints to extract (middle frames)
N_SLICE  = 50     # number of slices per axis
OUT_SIZE = (224, 224)
EXPECTED_SLICES = N_TIME * N_SLICE  # 2500 slices per subject

# ─── HELPERS ───────────────────────────────────────────────────────────
def pick_latest_run(files):
    """Select the nifti file with the highest session/run numbers"""
    def key(f):
        m = re.search(r"session_(\d+)_rest_(\d+)", f)
        return (int(m.group(1)), int(m.group(2))) if m else (-1, -1)
    return sorted(files, key=key)[-1] if files else None

def make_brain_mask(vol4d):
    """Create 3D brain mask from 4D fMRI data"""
    mean_vol = vol4d.mean(axis=3)
    m0 = mean_vol > 0
    m1 = binary_closing(m0, iterations=2)
    return binary_fill_holes(m1)

def process_slice(slice_data, out_size):
    """Process individual slice with masking after resizing"""
    # Resize and create mask
    resized = resize(slice_data, out_size, preserve_range=True)
    mask = resize((slice_data > 0).astype(float), out_size, order=0, preserve_range=True) > 0.5
    
    # Apply mask to set background to zero
    resized[~mask] = 0
    return resized

# ─── MAIN PROCESSING ──────────────────────────────────────────────────
for site in SITES:
    print(f"Processing site: {site}")
    csv_path = CSV_MAP[site]
    
    # Handle missing CSV files
    if not os.path.exists(csv_path):
        print(f"⚠ CSV not found: {csv_path}")
        continue
        
    df = pd.read_csv(csv_path)
    subj_col = df.columns[0]  # First column contains subject IDs

    for subj in df[subj_col].astype(str):
        subj_path = os.path.join(BASE_IN, site, subj)
        output_dir = os.path.join(BASE_OUT, site, subj)
        
        # Check if subject already has complete set of slices
        if os.path.exists(output_dir):
            existing_files = glob(os.path.join(output_dir, "*.png"))
            if len(existing_files) >= EXPECTED_SLICES:
                print(f"✅ Skipping {site}/{subj} (already has {len(existing_files)} slices)")
                continue
        
        if not os.path.isdir(subj_path):
            print(f"⚠ Subject dir missing: {subj_path}")
            continue

        # Find and select the best run
        runs = glob(os.path.join(subj_path, "*.nii.gz"))
        if not runs:
            print(f"⚠ No NIFTI files for {site}/{subj}")
            continue
            
        chosen_file = pick_latest_run(runs)
        print(f"Processing: {site}/{subj} with {os.path.basename(chosen_file)}")

        # Load and prepare data
        img = nib.load(chosen_file)
        data = img.get_fdata()
        X, Y, Z, T = data.shape

        # Create and apply brain mask
        mask_3d = make_brain_mask(data)
        data = data * mask_3d[..., None]  # Apply mask to all timepoints

        # Subject-level normalization
        brain_values = data[mask_3d].ravel()
        if len(brain_values) > 0:
            min_val, max_val = np.min(brain_values), np.max(brain_values)
            if max_val > min_val:
                data = (data - min_val) / (max_val - min_val)
            else:
                data = data.copy()  # Avoid modifying original if no scaling
        else:
            print(f"⚠ Empty brain mask for {site}/{subj}")
            continue

        # Select middle time frames
        if T < N_TIME:
            print(f"⚠ Insufficient timepoints ({T} < {N_TIME}) for {site}/{subj}")
            continue
            
        start_t = (T - N_TIME) // 2
        time_points = range(start_t, start_t + N_TIME)

        # Determine slice indices based on brain coverage
        xs = np.where(mask_3d.any(axis=(1,2)))[0]
        ys = np.where(mask_3d.any(axis=(0,2)))[0]
        zs = np.where(mask_3d.any(axis=(0,1)))[0]
        
        if len(xs) == 0 or len(ys) == 0 or len(zs) == 0:
            print(f"⚠ Empty axis in brain mask for {site}/{subj}")
            continue
            
        x_slices = np.linspace(xs.min(), xs.max(), N_SLICE).astype(int)
        y_slices = np.linspace(ys.min(), ys.max(), N_SLICE).astype(int)
        z_slices = np.linspace(zs.min(), zs.max(), N_SLICE).astype(int)

        # Prepare output directory
        os.makedirs(output_dir, exist_ok=True)

        # Process each timepoint and slice
        for t_idx, t in enumerate(time_points):
            for s_idx in range(N_SLICE):
                # Extract orthogonal slices
                xy_slice = data[:, :, z_slices[s_idx], t]  # Axial (XY)
                xz_slice = data[:, y_slices[s_idx], :, t]  # Coronal (XZ)
                yz_slice = data[x_slices[s_idx], :, :, t]  # Sagittal (YZ)

                # Process each slice with masking
                r = process_slice(xy_slice, OUT_SIZE)  # Red channel
                g = process_slice(xz_slice, OUT_SIZE)  # Green channel
                b = process_slice(yz_slice, OUT_SIZE)  # Blue channel

                # Combine into RGB and convert to 8-bit
                rgb = np.stack([r, g, b], axis=-1)
                rgb = (rgb * 255).clip(0, 255).astype(np.uint8)

                # Save as PNG
                fname = f"tp{t_idx:02d}_sl{s_idx:02d}.png"
                Image.fromarray(rgb).save(os.path.join(output_dir, fname))

        # Verify slice count after processing
        created_files = glob(os.path.join(output_dir, "*.png"))
        print(f"✅ Created {len(created_files)} slices for {site}/{subj}")

print("=== PROCESSING COMPLETE ===")

Processing site: KKI
✅ Skipping KKI/1018959 (already has 2500 slices)
✅ Skipping KKI/1019436 (already has 2500 slices)
✅ Skipping KKI/1043241 (already has 2500 slices)
✅ Skipping KKI/1266183 (already has 2500 slices)
✅ Skipping KKI/1535233 (already has 2500 slices)
✅ Skipping KKI/1541812 (already has 2500 slices)
✅ Skipping KKI/1577042 (already has 2500 slices)
✅ Skipping KKI/1594156 (already has 2500 slices)
✅ Skipping KKI/1623716 (already has 2500 slices)
✅ Skipping KKI/1638334 (already has 2500 slices)
✅ Skipping KKI/1652369 (already has 2500 slices)
✅ Skipping KKI/1686265 (already has 2500 slices)
✅ Skipping KKI/1692275 (already has 2500 slices)
✅ Skipping KKI/1735881 (already has 2500 slices)
✅ Skipping KKI/1779922 (already has 2500 slices)
✅ Skipping KKI/1842819 (already has 2500 slices)
✅ Skipping KKI/1846346 (already has 2500 slices)
✅ Skipping KKI/1873761 (already has 2500 slices)
✅ Skipping KKI/1962503 (already has 2500 slices)
✅ Skipping KKI/1988015 (already has 2500 slices)