# XNAT + xnatpy Workflow Template

This notebook demonstrates best practices for working with an XNAT instance using `xnatpy`:

1. Uploading data to XNAT
2. Exploring and understanding the XNAT data hierarchy
3. Filtering available data based on DICOM metadata
4. Downloading filtered data for local processing

The notebook assumes you have:

- A reachable XNAT server (URL)
- Credentials with appropriate project permissions
- `xnat` (xnatpy) and optionally `pydicom` installed in your environment


## 0. Session setup: connect to XNAT with xnatpy

We create a reusable helper to open a connection using either:

- Explicit arguments, or
- Environment variables:
  - `XNAT_HOST`
  - `XNAT_USER`
  - `XNAT_PASS`


In [None]:
import os
from pathlib import Path
from typing import Optional

import xnat  # xnatpy


def get_xnat_session(
    host: Optional[str] = None,
    user: Optional[str] = None,
    password: Optional[str] = None,
):
    """Create and return an XNAT session.

    Uses environment variables if explicit args are not supplied.
    """
    host = host or os.environ.get("XNAT_HOST")
    user = user or os.environ.get("XNAT_USER")
    password = password or os.environ.get("XNAT_PASS")

    if host is None or user is None or password is None:
        raise ValueError(
            "XNAT_HOST, XNAT_USER, XNAT_PASS must be set in the environment or "
            "passed explicitly to get_xnat_session()."
        )

    session = xnat.connect(host, user=user, password=password)
    return session

# Example usage (uncomment when ready):
# with get_xnat_session() as session:
#     print(session.projects.keys())


## 1. Uploading data to XNAT with xnatpy

The most robust way to upload image sessions via script is to:

1. Create a ZIP archive containing a valid XNAT image session layout, or
2. Use a ZIP of raw DICOM files that the XNAT import service can interpret.

We use the `import_` service from xnatpy, which wraps XNAT's Image Session Import Service.


In [None]:
# 1.1 Configure basic upload parameters

XNAT_PROJECT_ID = "MY_PROJECT_ID"   # XNAT project ID (not label)
XNAT_SUBJECT_LABEL = "subject_001"  # existing or new subject label
LOCAL_ZIP_PATH = Path("data") / "subject_001_session.zip"  # local zip of DICOMs

print("Local upload archive:", LOCAL_ZIP_PATH)

# 1.2 Upload ZIP to XNAT using the import service
#
# This will:
# - Upload the archive
# - Let XNAT parse the image session
# - Attach it under the given project/subject
#
# You can also target the prearchive first (safer) by passing destination='/prearchive'.

# with get_xnat_session() as session:
#     if not LOCAL_ZIP_PATH.exists():
#         raise FileNotFoundError(f"ZIP not found: {LOCAL_ZIP_PATH}")
#
#     print("Uploading to XNAT...")
#     imported = session.services.import_(
#         str(LOCAL_ZIP_PATH),
#         project=XNAT_PROJECT_ID,
#         subject=XNAT_SUBJECT_LABEL,
#         # destination='/prearchive',  # uncomment to use the prearchive
#     )
#     print("Import complete:", imported)

print("Upload code is ready. Configure paths, uncomment, and run when you are ready to upload.")


## 2. Organize and explore the XNAT data structure

XNAT organizes imaging data as:

- Project → Subjects → Experiments/Sessions → Scans → Resources → Files

This section shows how to:

- Navigate through projects, subjects, experiments, scans
- Summarize structure into a small table for inspection


In [None]:
import pandas as pd

def summarize_project_structure(session: xnat.XNATSession, project_id: str) -> pd.DataFrame:
    """Summarize subjects, experiments, and scans for a project into a DataFrame."""
    if project_id not in session.projects:
        raise KeyError(f"Project '{project_id}' not found. Available: {list(session.projects.keys())}")

    project = session.projects[project_id]

    records = []
    for subj in project.subjects.values():
        for exp in subj.experiments.values():
            for scan in exp.scans.values():
                records.append(
                    {
                        "project_id": project_id,
                        "subject_label": subj.label,
                        "experiment_label": exp.label,
                        "scan_id": scan.id,
                        "scan_type": getattr(scan, "type", None),
                        "series_description": getattr(scan, "series_description", None),
                    }
                )

    df = pd.DataFrame.from_records(records)
    return df

# Example usage (uncomment when connected):
# with get_xnat_session() as session:
#     df_structure = summarize_project_structure(session, XNAT_PROJECT_ID)
# df_structure.head()


In [None]:
# 2.1 Example: display basic stats for the project

# with get_xnat_session() as session:
#     df_structure = summarize_project_structure(session, XNAT_PROJECT_ID)
#
# display(df_structure.head())
# print("\nNumber of subjects:", df_structure["subject_label"].nunique())
# print("Number of experiments:", df_structure["experiment_label"].nunique())
# print("Number of scans:", len(df_structure))


## 3. Filter data based on DICOM metadata

XNAT exposes some DICOM-derived fields directly on scans (e.g. series description, modality),
and xnatpy also provides helpers to access DICOM headers without downloading all pixel data:

- `scan.dicom_dump()` – uses XNAT's DICOM dump service (headers truncated to 64 chars)
- `scan.read_dicom(read_pixel_data=False)` – uses `pydicom` to read header only from one file

Typical workflow:

1. Iterate scans in a project.
2. Inspect selected DICOM header tags (e.g. SeriesDescription, BodyPartExamined, Modality).
3. Build a filtered list of scans that match your criteria (e.g. only T1-weighted MR, only breast, etc.).


In [None]:
import json

try:
    import pydicom  # noqa: F401
except ImportError:
    pydicom = None
    print("pydicom is not installed; scan.read_dicom() will not be available.")

def collect_scan_metadata(
    session: xnat.XNATSession,
    project_id: str,
    max_scans: int = 100,
):
    """Collect basic metadata + selected DICOM header fields for scans in a project.

    This uses scan.read_dicom(read_pixel_data=False) where available, which pulls a header
    from a single file in the series (sufficient for many series-level filters).
    """
    if project_id not in session.projects:
        raise KeyError(f"Project '{project_id}' not found.")

    project = session.projects[project_id]
    rows = []
    count = 0

    for subj in project.subjects.values():
        for exp in subj.experiments.values():
            for scan in exp.scans.values():
                row = {
                    "project_id": project_id,
                    "subject_label": subj.label,
                    "experiment_label": exp.label,
                    "scan_id": scan.id,
                    "scan_type": getattr(scan, "type", None),
                    "series_description": getattr(scan, "series_description", None),
                    "modality": getattr(scan, "modality", None),
                }

                # Optionally augment with DICOM header fields
                header_info = {}
                try:
                    # These helpers are provided by xnatpy
                    dicom_header = scan.read_dicom(read_pixel_data=False)  # type: ignore[attr-defined]
                    # Example tags (existence may vary by modality/site):
                    for tag, name in [
                        ((0x0008, 0x103E), "SeriesDescription"),
                        ((0x0018, 0x0020), "ScanningSequence"),
                        ((0x0018, 0x0050), "SliceThickness"),
                        ((0x0018, 0x0080), "RepetitionTime"),
                        ((0x0018, 0x0081), "EchoTime"),
                        ((0x0018, 0x1030), "ProtocolName"),
                        ((0x0018, 0x5100), "PatientPosition"),
                        ((0x0018, 0x0015), "BodyPartExamined"),
                    ]:
                        if dicom_header.get(tag) is not None:
                            header_info[name] = str(dicom_header.get(tag).value)
                except Exception as e:
                    header_info["dicom_error"] = str(e)

                row["dicom_header"] = json.dumps(header_info)
                rows.append(row)

                count += 1
                if count >= max_scans:
                    return pd.DataFrame(rows)

    return pd.DataFrame(rows)

# Example usage (uncomment when connected):
# with get_xnat_session() as session:
#     df_meta = collect_scan_metadata(session, XNAT_PROJECT_ID, max_scans=200)
# df_meta.head()


In [None]:
# 3.1 Example: filter scans based on DICOM metadata

import pandas as pd
import ast

def filter_scans_by_criteria(df_meta: pd.DataFrame):
    """Example filter: only T1 MR brain scans (SeriesDescription contains 'T1' and BodyPartExamined == 'BRAIN')."""
    if df_meta.empty:
        return df_meta

    # Expand dicom_header column into a dict per row
    def parse_header(s):
        if isinstance(s, dict):
            return s
        if isinstance(s, str) and s:
            try:
                return json.loads(s)
            except Exception:
                try:
                    return ast.literal_eval(s)
                except Exception:
                    return {}
        return {}

    header_dicts = df_meta["dicom_header"].apply(parse_header)

    df_meta = df_meta.copy()
    df_meta["SeriesDescription_hdr"] = header_dicts.apply(lambda h: h.get("SeriesDescription"))
    df_meta["BodyPartExamined_hdr"] = header_dicts.apply(lambda h: h.get("BodyPartExamined"))
    df_meta["ProtocolName_hdr"] = header_dicts.apply(lambda h: h.get("ProtocolName"))

    # Example filter: modify to match your site/protocol
    mask = (
        df_meta["SeriesDescription_hdr"].str.contains("T1", case=False, na=False)
        & df_meta["BodyPartExamined_hdr"].str.contains("BRAIN", case=False, na=False)
        & df_meta["modality"].str.upper().eq("MR")
    )

    filtered = df_meta[mask].reset_index(drop=True)
    return filtered

# Example usage (uncomment when df_meta is populated):
# filtered_scans = filter_scans_by_criteria(df_meta)
# display(filtered_scans.head())
# print("Number of filtered scans:", len(filtered_scans))


## 4. Download filtered data locally

Once you have a filtered subset of scans (by DICOM metadata or any other criteria),
you typically want to download them for local processing.

xnatpy provides helper methods:

- `scan.download(path_to_zip)` – download a single scan as a ZIP of all DICOMs
- `subject.download_dir(target_dir)` – download entire subject tree
- `experiment.download_dir(target_dir)` – download entire session

Here we show how to:

- Loop over a filtered list of scan rows
- Resolve the corresponding scan objects
- Download each scan into a per-scan directory


In [None]:
def download_filtered_scans(
    session: xnat.XNATSession,
    filtered_scans_df: pd.DataFrame,
    base_download_dir: Path,
    project_id: str,
    as_zip: bool = False,
):
    """Download filtered scans to a local directory.

    Parameters
    ----------
    session : xnat.XNATSession
        Active XNAT session.
    filtered_scans_df : pd.DataFrame
        DataFrame produced by the filtering step. Expected columns:
        'subject_label', 'experiment_label', 'scan_id'.
    base_download_dir : Path
        Local directory where data will be saved.
    project_id : str
        XNAT project ID.
    as_zip : bool
        If True, use scan.download() (zip file). If False, use scan.download_dir().
    """
    base_download_dir.mkdir(parents=True, exist_ok=True)
    project = session.projects[project_id]

    for _, row in filtered_scans_df.iterrows():
        subj_label = row["subject_label"]
        exp_label = row["experiment_label"]
        scan_id = row["scan_id"]

        subject = project.subjects[subj_label]
        experiment = subject.experiments[exp_label]
        scan = experiment.scans[scan_id]

        scan_dir = base_download_dir / subj_label / exp_label / str(scan_id)
        scan_dir.mkdir(parents=True, exist_ok=True)

        if as_zip:
            zip_path = scan_dir / f"{scan_id}.zip"
            print(f"Downloading scan {scan_id} as ZIP → {zip_path}")
            scan.download(str(zip_path))
        else:
            print(f"Downloading scan {scan_id} into directory → {scan_dir}")
            scan.download_dir(str(scan_dir))

# Example usage (uncomment when connected and filtered_scans is defined):
# download_root = Path("downloads") / XNAT_PROJECT_ID
# with get_xnat_session() as session:
#     download_filtered_scans(session, filtered_scans, download_root, XNAT_PROJECT_ID, as_zip=False)


---

### Summary

This notebook provides a practical workflow for:

- Uploading DICOM data into XNAT using the import service
- Exploring the project/subject/experiment/scan hierarchy
- Building a metadata table that includes selected DICOM header fields
- Filtering scans based on DICOM metadata criteria
- Downloading only the filtered subset for local processing

You can now adapt the filtering logic to match specific clinical protocols
(e.g. BI-RADS screening MR, specific echo times, body parts, or sequences)
and wire this into your downstream ML training pipeline.
