In [None]:
#run on hyak
import os
import json
import pandas as pd
import re
import numpy as np

# --- Configuration ---
project_root = '/gscratch/scrubbed/fanglab/xiaoqian/IFOCUS'
definitions_dir = '/mmfs1/home/xxqian/'
data_dir = os.path.join(project_root, 'sourcedata/nii')
definitions_csv_path = os.path.join(definitions_dir, 'image03_definitions.csv')
output_csv_path = os.path.join(definitions_dir, 'NDA_image03_extracted.csv')


# --- Helper Functions ---

def get_nested_val(data, key, default=None):
    """Safely retrieves values from top-level or global.const."""
    if not isinstance(data, dict): return default
    if key in data:
        return data[key]
    if "global" in data and "const" in data["global"] and key in data["global"]["const"]:
        return data["global"]["const"][key]
    return default

def check_transformation(filename, json_data):
    """Returns 'Yes' if spatial normalization detected, else 'No'."""
    if "space-" in filename: return "Yes"
    if isinstance(json_data, dict):
        if "TemplateSpace" in json_data or "SourceFile" in json_data: return "Yes"
    return "No"

def get_experiment_id(filename):
    """Extracts experiment_id based on task name."""
    fname = filename.lower()
    if "task-rest" in fname: return 2820
    if "task-selfother" in fname: return 2821
    return None

def calculate_image_orientation(bids_data):
    """
    Calculates dominant plane (Axial, Coronal, Sagittal) from DICOM vectors.
    """
    iop = get_nested_val(bids_data, "ImageOrientationPatientDICOM")
    if not iop or len(iop) != 6: return None
        
    try:
        row_cosine = np.array(iop[0:3])
        col_cosine = np.array(iop[3:6])
        slice_normal = np.cross(row_cosine, col_cosine)
        
        dominant_axis = np.argmax(np.abs(slice_normal))
        
        if dominant_axis == 0: return "Sagittal"
        if dominant_axis == 1: return "Coronal"
        if dominant_axis == 2: return "Axial"
    except:
        return None
    return None

def map_bids_to_nda_scan_type(bids_data, filename):
    """
    Maps BIDS metadata to the STRICT NDA scan_type list.
    Ref: IFOCUS protocol PDF.
    """
    fname = filename.lower()
    protocol = get_nested_val(bids_data, "ProtocolName", "").upper()
    
    # 1. Functional
    if "bold" in fname or "func" in fname or "fmri" in fname: 
        return "fMRI"
        
    # 2. Field Maps (Distortion Correction)
    if "fmap" in fname or "fieldmap" in fname or "epi" in fname: 
        return "Field Map"
        
    # 3. Structural T1
    if "t1w" in fname:
        if "MEMPR" in protocol or "VNAV" in protocol or "MPRAGE" in protocol:
             return "MR structural (MPRAGE)"
        if "FLASH" in protocol: 
            return "MR structural (FLASH)"
        if "MP2RAGE" in protocol:
            return "MR structural (MP2RAGE)"
        return "MR structural (T1)"
        
    # 4. Structural T2 / FLAIR
    if "flair" in fname: 
        return "MR: FLAIR"
    if "t2w" in fname:
        if "TSE" in protocol: return "MR structural (TSE)"
        return "MR structural (T2)"
    if "pd" in fname: 
        return "MR structural (PD)"

    # 5. Diffusion
    if "dwi" in fname or "diff" in fname: 
        return "MR diffusion" 
    
    # 6. Localizers / Scouts (Should be filtered, but mapped just in case)
    if "scout" in fname or "localizer" in fname or "setter" in fname:
        return "Localizer scan"

    # 7. Other Specific Modalities
    if "asl" in fname: return "ASL"
    if "pet" in fname: return "PET"
    if "mre" in fname: return "Magnetic Resonance Elastography (MRE)"
    if "mrs" in fname: return "Magnetic Resonance Spectroscopy(MRS)"
    
    return get_nested_val(bids_data, "ProtocolName")

def load_definitions(csv_path):
    if not os.path.exists(csv_path):
        print(f"Error: Definitions file not found at {csv_path}")
        return [], {}
    df = pd.read_csv(csv_path)
    required_fields = df['ElementName'].tolist()
    alias_map = {}
    if 'Aliases' in df.columns:
        for index, row in df.iterrows():
            if pd.notna(row['Aliases']):
                aliases = [a.strip() for a in str(row['Aliases']).split(',')]
                alias_map[row['ElementName']] = aliases
    print(f"Loaded {len(required_fields)} fields from definitions CSV.")
    return required_fields, alias_map

def extract_metadata_from_json(json_path, record, required_fields, alias_map):
    """Parses BIDS JSON and fills the record."""
    try:
        with open(json_path, 'r') as f:
            bids_data = json.load(f)
    except Exception as e:
        print(f"  [!] Error reading JSON {os.path.basename(json_path)}: {e}")
        return record

    bids_map = {
        "mri_repetition_time_pd": "RepetitionTime",
        "mri_echo_time_pd":       "EchoTime",
        "flip_angle":             "FlipAngle",
        "magnetic_field_strength": "MagneticFieldStrength",
        "image_slice_thickness":   "SliceThickness",
        "scanner_manufacturer_pd": "Manufacturer",
        "scanner_type_pd":         "ManufacturersModelName",
        "scanner_software_versions_pd": "SoftwareVersions",
        "deviceserialnumber":      "DeviceSerialNumber",
        "patient_position":        "PatientPosition",
        "slice_timing":            "SliceTiming",
        "pixel_bandwidth":         "PixelBandwidth"
    }

    # Identify Scan Category
    fname = record['image_file'].lower()
    is_functional = False
    is_structural = False
    is_fieldmap = False
    
    if "bold" in fname or "func" in fname or "fmri" in fname:
        is_functional = True
    elif "t1w" in fname or "t2w" in fname or "anat" in fname or "flair" in fname:
        is_structural = True
    elif "fmap" in fname or "fieldmap" in fname or "epi" in fname:
        is_fieldmap = True

    # --- Pre-calculate Dimension ---
    dims = None
    if is_functional or is_fieldmap:
        dims = 4
    elif is_structural:
        dims = 3
    else:
        shape = get_nested_val(bids_data, "dcmmeta_shape")
        if shape: dims = len(shape)

    for field in required_fields:
        if field in record and record[field] is not None:
            continue
            
        val = None
        
        # --- DIMENSION LOGIC ---
        if field == "image_num_dimensions":
            val = dims

        # --- EXTENT 4 TYPE LOGIC ---
        elif field == "extent4_type":
            if dims is not None and dims > 3:
                val = "time"
            else:
                val = None

        # --- UNIT LOGIC ---
        elif field in ["image_unit1", "image_unit2", "image_unit3"]:
            val = "Millimeters"
        elif field == "image_unit4":
            if dims is not None and dims > 3:
                val = "Seconds"
            else:
                val = None

        # --- EXTENT LOGIC ---
        elif field == "image_extent1":
            if is_structural:
                val = 176 
            else:
                shape = get_nested_val(bids_data, "dcmmeta_shape")
                acq_mat = get_nested_val(bids_data, "AcquisitionMatrix")
                if shape and len(shape) > 0: val = shape[0]
                elif acq_mat and len(acq_mat) > 0: val = acq_mat[0]
            
        elif field == "image_extent2":
            if is_structural:
                val = 256 
            else:
                shape = get_nested_val(bids_data, "dcmmeta_shape")
                acq_mat = get_nested_val(bids_data, "AcquisitionMatrix")
                if shape and len(shape) > 1: val = shape[1]
                elif acq_mat and len(acq_mat) > 1: val = acq_mat[1]
            
        elif field == "image_extent3":
            if is_structural:
                val = 256 
            else:
                shape = get_nested_val(bids_data, "dcmmeta_shape")
                if shape and len(shape) > 2: val = shape[2]
            
        elif field == "image_extent4":
            if dims is not None and dims > 3:
                shape = get_nested_val(bids_data, "dcmmeta_shape")
                if shape and len(shape) > 3: 
                    val = shape[3]
            else:
                val = None 

        # --- RESOLUTION LOGIC ---
        elif field == "image_resolution1": # Voxel X
            spacing = get_nested_val(bids_data, "PixelSpacing")
            if spacing and len(spacing) > 0: 
                val = spacing[0]
            else:
                if is_structural: val = 1.0
                elif is_functional: val = 3.0
                elif is_fieldmap: val = 3.0

        elif field == "image_resolution2": # Voxel Y
            spacing = get_nested_val(bids_data, "PixelSpacing")
            if spacing and len(spacing) > 1: 
                val = spacing[1]
            else:
                if is_structural: val = 1.0
                elif is_functional: val = 3.0
                elif is_fieldmap: val = 3.0
                
        elif field == "image_resolution3": # Voxel Z
            val = get_nested_val(bids_data, "SliceThickness")
            if val is None:
                if is_structural: val = 1.0
                elif is_functional: val = 3.0
                elif is_fieldmap: val = 3.0

        elif field == "image_resolution4": 
            if dims is not None and dims > 3:
                val = 1  
            else:
                val = None 

        # --- SPECIAL HANDLING ---
        elif field == "image_orientation": 
            val = calculate_image_orientation(bids_data)
            
        elif field == "mri_field_of_view_pd":
            if is_functional or is_fieldmap: 
                val = "222.0x222.0x156.0 mm" 
            elif is_structural: 
                val = "176.0x256.0x256.0 mm"
                
        elif field == "receive_coil": val = "HeadNeck_64"
        elif field == "photomet_interpret": val = "MONOCHROME2"
        elif field == "image_file_format": val = "NIFTI"
        elif field == "experiment_id": val = get_experiment_id(record['image_file'])
        elif field == "transformation_performed": val = check_transformation(record['image_file'], bids_data)
        elif field == "scan_type": val = map_bids_to_nda_scan_type(bids_data, record['image_file'])

        elif field in bids_map:
            val = get_nested_val(bids_data, bids_map[field])

        elif field == "acquisition_matrix":
             mat = get_nested_val(bids_data, "AcquisitionMatrix")
             if not mat:
                 pe = get_nested_val(bids_data, "AcquisitionMatrixPE")
                 if pe: mat = f"{pe}x{pe}" 
             val = mat

        # Aliases & Exact Match
        if val is None and field in alias_map:
            for alias in alias_map[field]:
                val = get_nested_val(bids_data, alias)
                if val is not None: break
        if val is None:
            val = get_nested_val(bids_data, field)

        if val is not None:
            record[field] = val
            
    return record

def process_nifti_file(nifti_path, required_fields, alias_map):
    filename = os.path.basename(nifti_path)
    directory = os.path.dirname(nifti_path)
    fname_lower = filename.lower()
    
    # --- FILTERS ---
    
    # 0. Global Filter: Exclude ALL Scouts / Localizers / Setters
    # Ref: PDF "anat-scout_acq-aa" and "anat-setter_acq-MEMPRSetter"
    if "scout" in fname_lower or "localizer" in fname_lower or "setter" in fname_lower:
        return None 
    
    # 1. Structural Filter: Only process if it contains "desc-defaced"
    if "t1w" in fname_lower:
        if "desc-defaced" not in filename:
            return None # Skip raw T1w
            
        # Logic to find the ORIGINAL JSON for the defaced file
        json_filename = filename.replace("_desc-defaced", "").replace(".nii.gz", ".json")
        
    else:
        # Standard logic for Functional / other files
        if filename.endswith('.nii.gz'):
            json_filename = filename.replace(".nii.gz", ".json")
        elif filename.endswith('.nii'):
            json_filename = filename.replace(".nii", ".json")
        else:
            json_filename = None

    # --- INITIALIZE RECORD ---
    record = {field: None for field in required_fields}
    
    rel_path = os.path.relpath(nifti_path, start=data_dir)
    formatted_path = rel_path.replace(os.sep, '/')
    
    record['ID'] = None
    record['visit'] = None
    record['image_file'] = formatted_path 

    sub_match = re.search(r'sub-([a-zA-Z0-9]+)', filename)
    if sub_match:
        extracted_id = sub_match.group(1)
        record['ID'] = extracted_id
        if 'src_subject_id' in record: record['src_subject_id'] = extracted_id
        if 'subjectkey' in record: record['subjectkey'] = extracted_id

    ses_match = re.search(r'ses-([a-zA-Z0-9]+)', filename)
    if ses_match:
        record['visit'] = ses_match.group(1)

    # --- PROCESS JSON ---
    if json_filename:
        json_path = os.path.join(directory, json_filename)
        if os.path.exists(json_path):
            record = extract_metadata_from_json(json_path, record, required_fields, alias_map)
        else:
            print(f"Warning: JSON not found for {filename} at {json_path}")
            
    return record

# --- Execution ---

required_fields, alias_map = load_definitions(definitions_csv_path)

if not required_fields:
    print("Warning: No fields loaded.")
else:
    for col in ['ID', 'visit', 'image_file']:
        if col not in required_fields:
            required_fields.insert(0, col)

    print(f"Scanning data directory: {data_dir}")
    all_extracted_data = []
    file_count = 0

    for root, dirs, files in os.walk(data_dir):
        for file in files:
            if file.endswith('.nii') or file.endswith('.nii.gz'):
                
                nifti_path = os.path.join(root, file)
                record = process_nifti_file(nifti_path, required_fields, alias_map)
                
                if record is not None:
                    file_count += 1
                    all_extracted_data.append(record)

    if all_extracted_data:
        df_final = pd.DataFrame(all_extracted_data)
        
        # --- NEW HARDCODED COLUMNS ---
        df_final['image_description'] = 'fMRI'
        df_final['scan_object'] = 'Live'
        df_final['image_modality'] = 'MRI'
        
        # Priority columns for preview
        priority = ['ID', 'image_file', 'scan_type', 'image_resolution1', 'image_unit4']
        cols = [c for c in priority if c in df_final.columns] + [c for c in df_final.columns if c not in priority]
        df_final = df_final[cols]
        
        df_final.to_csv(output_csv_path, index=False)
        print(f"\nProcessing Complete.")
        print(f"  - Processed {file_count} valid NIfTI image files.")
        print(f"  - Data saved to: {output_csv_path}")
        print("\nPreview:")
        print(df_final[['image_file', 'scan_type', 'image_resolution1']].head(10))
    else:
        print(f"No valid .nii or .nii.gz files found (or all were filtered out) in {data_dir}")