# EPINR Diffusion MRI Data Collation/Listing

In [None]:
!cd /home/tas6hh/Projects/mr-inr/notebooks/epinr/dmri/

In [None]:
%autoreload 2
# Imports
from pathlib import Path

# Computation & ML libraries.
import numpy as np
import pandas as pd

from data_utils import (
    PE_DIR_ALIASES,
    EMPTY_RUN_TOKEN,
    EMPTY_SESSION_TOKEN,
    dataset_table_cols,
)

print("Columns for dataset tables:\n", dataset_table_cols)

In [None]:
def parse_acqparams(f: Path) -> tuple[dict]:
    """Parse FSL-style acqparams.txt file into a dictionary mapping"""
    acqparams = list()
    with open(f, "r") as f_in:
        lines = f_in.readlines()

    for line in lines:
        if line.strip() == "" or line.strip().startswith("#"):
            continue
        c = line.strip().split()
        pe_dir = tuple(np.asarray(c[:3], dtype=int))
        total_readout_time_sec = float(c[3])
        acqparams.append(
            {
                "pe_dir": PE_DIR_ALIASES[pe_dir],
                "total_readout_time_sec": total_readout_time_sec,
            }
        )
    return tuple(acqparams)

In [None]:
dataset_root_dirs = {
    "mica_mics": Path("~/mnt/magpie/outputs/mica_mics/derivatives/epinr_fmap_learning")
    .expanduser()
    .resolve(),
    "vcu_ms": Path("~/mnt/magpie/outputs/vcu_ms_epinr/derivatives/epinr_fmap_learning")
    .expanduser()
    .resolve(),
}

In [None]:
# MICA-MICS dataset
dataset_name = "mica_mics"
data_table_f = Path(".") / "data_tables" / "mica_mics_dmri_dataset_table.csv"
data_table_f.parent.mkdir(parents=True, exist_ok=True)
rows = {k: list() for k in dataset_table_cols}

subj_dirs = sorted(dataset_root_dirs[dataset_name].glob("sub-*"))
for subj_dir in subj_dirs:
    subj_id = subj_dir.name
    for ses_dir in subj_dir.glob("ses-*"):
        ses_id = ses_dir.name
        run_id = EMPTY_RUN_TOKEN
        for dwi_dir in ses_dir.glob("dwi-*"):
            dwi_idx = int(dwi_dir.name.replace("dwi-", ""))
            acqparams = parse_acqparams(dwi_dir / "acqparams.txt")[0]

            rows["dataset_name"].append(dataset_name)
            rows["subj_id"].append(subj_id)
            rows["session_id"].append(ses_id)
            rows["run_id"].append(run_id)
            rows["dwi_idx"].append(dwi_idx)
            rows["dwi"].append(dwi_dir / "b0.nii.gz")
            rows["dwi_mask"].append(dwi_dir / "b0_mask.nii.gz")
            rows["pe_dir"].append(acqparams["pe_dir"])
            rows["total_readout_time_s"].append(acqparams["total_readout_time_sec"])
            rows["t1w_reg_dwi"].append(dwi_dir / "t1w.nii.gz")
            rows["t1w_mask"].append(dwi_dir / "t1w_mask.nii.gz")
            rows["mni2t1w_warp"].append(dwi_dir / "mni2t1w_ants_composite_warp.h5")
            rows["topup_displacement_hz"].append(
                dwi_dir / "topup_suscept_field_hz.nii.gz"
            )
            rows["topup_corrected_dwi"].append(
                dwi_dir / "applytopup_lin-jac_corrected_b0.nii.gz"
            )

            for k in (
                "fs_label",
                "t1w_wm_mask",
                "t1w_gm_mask",
                "t1w_csf_mask",
                "t1w2acpc_affine",
            ):
                rows[k].append("NULL")

            # Check for a susceptibility atlas is warped to this subject.
            if (
                dwi_dir / "suscept-atlas-warped_suscept-field_dir-ap_mm.nii.gz"
            ).exists():
                rows["suscept_atlas_mm_dir_ap"].append(
                    dwi_dir / "suscept-atlas-warped_suscept-field_dir-ap_mm.nii.gz"
                )
            else:
                rows["suscept_atlas_mm_dir_ap"].append("NULL")

            for k in rows.keys():
                if isinstance(rows[k][-1], Path):
                    assert rows[k][-1].exists()
                    rows[k][-1] = str(
                        rows[k][-1].relative_to(dataset_root_dirs[dataset_name])
                    )
                assert len(rows[k]) == len(rows["dataset_name"])
data_table = pd.DataFrame(rows)
data_table.to_csv(data_table_f, index=False)
print(f"Wrote dataset table to: {data_table_f}")

In [None]:
# VCU-MS adult dataset
dataset_name = "vcu_ms"
data_table_f = Path(".") / "data_tables" / "vcu_ms_dmri_dataset_table.csv"
data_table_f.parent.mkdir(parents=True, exist_ok=True)
rows = {k: list() for k in dataset_table_cols}

subj_dirs = sorted(
    list(dataset_root_dirs[dataset_name].glob("P_*"))
    + list(dataset_root_dirs[dataset_name].glob("HC_*"))
)
for subj_dir in subj_dirs:
    subj_id = subj_dir.name
    ses_id = EMPTY_SESSION_TOKEN
    run_id = EMPTY_RUN_TOKEN
    for dwi_dir in subj_dir.glob("dwi-*"):
        dwi_idx = int(dwi_dir.name.replace("dwi-", ""))
        mni2t1w_warp = dwi_dir / "mni2t1w_ants_composite_warp.h5"
        if not mni2t1w_warp.exists():
            print(f"Skipping {dwi_dir} since mni2t1w warp not found at: {mni2t1w_warp}")
            continue
        acqparams = parse_acqparams(dwi_dir / "acqparams.txt")[0]

        rows["dataset_name"].append(dataset_name)
        rows["subj_id"].append(subj_id)
        rows["session_id"].append(ses_id)
        rows["run_id"].append(run_id)
        rows["dwi_idx"].append(dwi_idx)
        rows["dwi"].append(dwi_dir / "b0.nii.gz")
        rows["dwi_mask"].append(dwi_dir / "b0_mask.nii.gz")
        rows["pe_dir"].append(acqparams["pe_dir"])
        rows["total_readout_time_s"].append(acqparams["total_readout_time_sec"])
        #!TESTING register to t2w rather than t1w
        rows["t1w_reg_dwi"].append(dwi_dir / "t1w.nii.gz")
        #!
        # rows["t1w_reg_dwi"].append(dwi_dir / "t2w.nii.gz")
        #!
        rows["t1w_mask"].append(dwi_dir / "t1w_mask.nii.gz")
        rows["mni2t1w_warp"].append(dwi_dir / "mni2t1w_ants_composite_warp.h5")
        rows["topup_displacement_hz"].append(dwi_dir / "topup_suscept_field_hz.nii.gz")
        rows["topup_corrected_dwi"].append(
            dwi_dir / "applytopup_lin-jac_corrected_b0.nii.gz"
        )

        for k in (
            "fs_label",
            "t1w_wm_mask",
            "t1w_gm_mask",
            "t1w_csf_mask",
            "t1w2acpc_affine",
        ):
            rows[k].append("NULL")

        # Check for a susceptibility atlas is warped to this subject.
        if (dwi_dir / "suscept-atlas-warped_suscept-field_dir-ap_mm.nii.gz").exists():
            rows["suscept_atlas_mm_dir_ap"].append(
                dwi_dir / "suscept-atlas-warped_suscept-field_dir-ap_mm.nii.gz"
            )
        else:
            rows["suscept_atlas_mm_dir_ap"].append("NULL")

        for k in rows.keys():
            if isinstance(rows[k][-1], Path):
                assert rows[k][-1].exists()
                rows[k][-1] = str(
                    rows[k][-1].relative_to(dataset_root_dirs[dataset_name])
                )
            assert len(rows[k]) == len(rows["dataset_name"])
data_table = pd.DataFrame(rows)
data_table.to_csv(data_table_f, index=False)
print(f"Wrote dataset table to: {data_table_f}")

In [None]:
# Merge dataset tables
datset_tables = [
    pd.read_csv(Path(".") / "data_tables" / "mica_mics_dmri_dataset_table.csv"),
    pd.read_csv(Path(".") / "data_tables" / "vcu_ms_dmri_dataset_table.csv"),
]
merged_dataset_table = pd.concat(datset_tables, ignore_index=True)
merged_data_table_f = Path(".") / "data_tables" / "merged_dmri_dataset_table.csv"
merged_dataset_table.to_csv(merged_data_table_f, index=False)
print(f"Wrote merged dataset table to: {merged_data_table_f}")