뇌의 위치를 일반화하고, 일반화된 뇌에 해마 mask를 적용하여 해마 영역을 추출

---




/drive/MyDrive/  
-hd_bet/ : 두개골 제거가 진행된 데이터셋  
-template/ : 공간 정렬 및 마스킹에 필요한 템플레이트 (https://git.fmrib.ox.ac.uk/fsl/data_standard)  
-normalize_mri/ : 최종 공간 정렬 및 마스킹이 진행된 데이터




In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
pip install antspyx nibabel numpy scipy

Collecting antspyx
  Downloading antspyx-0.6.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (7.1 kB)
Collecting scipy
  Downloading scipy-1.15.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (61 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m62.0/62.0 kB[0m [31m5.6 MB/s[0m eta [36m0:00:00[0m
Downloading antspyx-0.6.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (22.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m22.4/22.4 MB[0m [31m84.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading scipy-1.15.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (37.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m37.3/37.3 MB[0m [31m70.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: scipy, antspyx
  Attempting uninstall: scipy
    Found existing installation: scipy 1.16.3
    Uninstalling scipy-1.16.3:
      Successfully uninstalled scipy-1.16.3


In [3]:
import os
import re
import json
import logging
from pathlib import Path
from datetime import datetime
from concurrent.futures import ProcessPoolExecutor
from multiprocessing import get_context

import numpy as np
import nibabel as nib
import scipy.ndimage as ndi

# =============================================================================
# 경로/파라미터
# =============================================================================
ROOT = Path("/content/drive/MyDrive")
HD_BET_DIR   = ROOT / "hd_bet"
TEMPLATE_DIR = ROOT / "template"
OUT_ROOT     = ROOT / "registered_brains" # 변경: 모든 결과가 저장될 단일 디렉토리

TEMPLATE_IMG    = TEMPLATE_DIR / "MNI152_T1_1mm_brain.nii.gz"
# HIPPO_MASK_MNI  = TEMPLATE_DIR / "MNI152_T1_1mm_Hipp_mask_dil8.nii.gz" # 제거: 더 이상 사용되지 않음

REG_MODE = "down2mm_syn"   # 'affine_only' | 'quick_syn' | 'down2mm_syn'
THRESH_MASK = 0.5
MARGIN_VOX  = 10
ROI_TARGET  = (64, 64, 64) # Used for final brain resampling dimensions
CLOSE_ITER  = 1

os.environ["ITK_GLOBAL_DEFAULT_NUMBER_OF_THREADS"] = "1"
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["OPENBLAS_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["NUMEXPR_NUM_THREADS"] = "1"

N_JOBS = 10      # Colab High-RAM 권장
CHUNK  = 200

LOG_PATH  = OUT_ROOT / "batch_log.txt"

# =============================================================================
# 로깅/유틸
# =============================================================================
def setup_logger():
    OUT_ROOT.mkdir(parents=True, exist_ok=True)
    root = logging.getLogger()
    for h in list(root.handlers):
        root.removeHandler(h)
    root.setLevel(logging.INFO)
    fh = logging.FileHandler(LOG_PATH)
    fh.setLevel(logging.INFO)
    fh.setFormatter(logging.Formatter("%(asctime)s [%(levelname)s] [%(processName)s] %(message)s"))
    root.addHandler(fh)

def ensure_dir(p: Path):
    p.mkdir(parents=True, exist_ok=True)

def save_np_as_nii_like(vol_np: np.ndarray, like_path: Path, out_path: Path):
    like = nib.load(str(like_path))
    img  = nib.Nifti1Image(vol_np.astype(np.float32), like.affine, like.header)
    nib.save(img, str(out_path))

def crop_bbox_with_margin(vol: np.ndarray, mask: np.ndarray, margin=10):
    idx = np.argwhere(mask > 0)
    if idx.size == 0:
        raise ValueError("마스크가 비어있습니다(>0 voxels 없음). 정합/임계값 확인 필요.")
    zmin, ymin, xmin = idx.min(0); zmax, ymax, xmax = idx.max(0)
    zmin = max(zmin - margin, 0); ymin = max(ymin - margin, 0); xmin = max(xmin - margin, 0)
    zmax = min(zmax + margin, vol.shape[0] - 1)
    ymax = min(ymax + margin, vol.shape[1] - 1)
    xmax = min(xmax + margin, vol.shape[2] - 1)
    vol_roi = vol[zmin:zmax+1, ymin:ymax+1, xmin:xmax+1]
    msk_roi = mask[zmin:zmax+1, ymin:ymax+1, xmin:xmax+1]
    return vol_roi, msk_roi

def resize3d_linear(vol: np.ndarray, out_shape):
    zoom = [o/i for o, i in zip(out_shape, vol.shape)]
    return ndi.zoom(vol, zoom, order=1)

def zscore(vol: np.ndarray):
    mu = float(vol.mean()); sd = float(vol.std()) + 1e-8
    return (vol - mu) / sd

def is_mask_file(path: Path):
    return bool(re.search(r"(?:^|[_\-])mask(?:\.|_|$)", path.stem, re.IGNORECASE))

def subject_id_from_name(path: Path):
    stem = path.stem
    m = re.match(r"^([A-Za-z]?\d+)", stem)
    return m.group(1) if m else stem

# =============================================================================
# 부모 프로세스: 스레드 제한 + ants 선행 임포트
# =============================================================================
os.environ.setdefault("ITK_GLOBAL_DEFAULT_NUMBER_OF_THREADS", "2")
os.environ.setdefault("OMP_NUM_THREADS", "2")
os.environ.setdefault("OPENBLAS_NUM_THREADS", "2")
os.environ.setdefault("MKL_NUM_THREADS", "2")
os.environ.setdefault("NUMEXPR_NUM_THREADS", "2")

import ants  # antspyx

# =============================================================================
# 빠른 정합
# =============================================================================
def _register_fast(sub_img, tpl_img):
    reg_kwargs = dict(
        type_of_transform="SyN",
        reg_iterations=(40, 20, 0),
        shrink_factors=(8, 4, 2),
        smoothing_sigmas=(3, 2, 1),
        aff_metric="mattes",
        verbose=False,
        random_seed=0,
    )

    if REG_MODE == "affine_only":
        return ants.registration(
            fixed=tpl_img, moving=sub_img,
            type_of_transform="Affine",
            reg_iterations=(40, 20, 0),
            aff_metric="mattes",
            verbose=False, random_seed=0
        )

    elif REG_MODE == "quick_syn":
        return ants.registration(fixed=tpl_img, moving=sub_img, **reg_kwargs)

    elif REG_MODE == "down2mm_syn":
        # interp_type=0 : nearest / 1 : linear 등
        tpl_ds = ants.resample_image(tpl_img, (2.0, 2.0, 2.0), use_voxels=False, interp_type=0)
        sub_ds = ants.resample_image(sub_img, (2.0, 2.0, 2.0), use_voxels=False, interp_type=0)
        return ants.registration(fixed=tpl_ds, moving=sub_ds, **reg_kwargs)

    return ants.registration(fixed=tpl_img, moving=sub_img, type_of_transform="SyN")

# =============================================================================
# 워커
# =============================================================================
def process_one_subject(sub_path_str: str):
    setup_logger()
    p = Path(sub_path_str)
    try:
        sub_id  = subject_id_from_name(p)
        # 변경: outdir 및 regdir 생성 로직 제거, OUT_ROOT에 직접 저장
        ensure_dir(OUT_ROOT) # OUT_ROOT 디렉토리가 존재하는지 확인/생성

        logging.info(f"[{sub_id}] 시작: {p.name}")
        sub_img = ants.image_read(str(p))
        tpl_img = ants.image_read(str(TEMPLATE_IMG))

        # 1) 정합
        reg = _register_fast(sub_img, tpl_img)
        warped_img = reg["warpedmovout"]

        # 2) 128x128x128으로 리샘플링하여 저장
        resampled_brain = ants.resample_image(
            warped_img,
            ROI_TARGET,  # Target dimensions (128, 128, 128)
            use_voxels=True,
            interp_type=1  # Linear interpolation
        )
        # 변경: 파일을 OUT_ROOT에 {subject_id}.nii.gz 형식으로 직접 저장
        resampled_brain_path = OUT_ROOT / f"{sub_id}.nii.gz"
        ants.image_write(resampled_brain, str(resampled_brain_path))

        logging.info(f"[{sub_id}] 완료")
        return (sub_path_str, True, None)
    except BaseException as e:
        logging.exception(f"[실패] {p.name}: {e}")
        return (sub_path_str, False, f"{type(e).__name__}: {e}")

# =============================================================================
# 배치 실행 (콘솔 진행률 + 시각)
# =============================================================================
def run_in_chunks(paths_str):
    total = len(paths_str)
    done = 0
    ok_total, fail_total = 0, 0

    try:
        ctx = get_context("fork")
    except ValueError:
        ctx = get_context("spawn")

    for i in range(0, total, CHUNK):
        batch = paths_str[i:i+CHUNK]
        with ProcessPoolExecutor(max_workers=N_JOBS, mp_context=ctx) as ex:
            for path_str, ok, _ in ex.map(process_one_subject, batch, chunksize=1):
                done += 1
                if ok: ok_total += 1
                else:  fail_total += 1
                now = datetime.now().strftime("%H:%M:%S")
                print(f"{done}/{total}  [{now}]", flush=True)
    return ok_total, fail_total

# =============================================================================
# 메인
# =============================================================================
def main():
    setup_logger()

    if not TEMPLATE_IMG.exists():
        print(f"템플릿 없음: {TEMPLATE_IMG}"); return
    # if not HIPPO_MASK_MNI.exists(): # Removed
    #     print(f"해마 마스크 없음: {HIPPO_MASK_MNI}"); return # Removed

    nii_list = sorted(
        p for p in HD_BET_DIR.glob("**/*.nii*")
        if p.is_file() and not is_mask_file(p)
    )
    if not nii_list:
        print(f"입력 NIfTI를 찾지 못함: {HD_BET_DIR}"); return

    paths_str = [str(p) for p in nii_list]
    print(f"총 {len(paths_str)}개, 워커 {N_JOBS}, 모드 {REG_MODE}")

    success, fail = run_in_chunks(paths_str)
    print(f"완료: 성공 {success}, 실패 {fail}")

    with open(OUT_ROOT / "batch_summary.json", "w", encoding="utf-8") as f:
        json.dump({"total": len(paths_str), "success": success, "fail": fail}, f, ensure_ascii=False, indent=2)

if __name__ == "__main__":
    main()

총 1167개, 워커 10, 모드 down2mm_syn
1/1167  [11:36:57]
2/1167  [11:36:57]
3/1167  [11:36:57]
4/1167  [11:36:57]
5/1167  [11:36:57]
6/1167  [11:36:57]
7/1167  [11:36:57]
8/1167  [11:36:57]
9/1167  [11:36:57]
10/1167  [11:36:57]
11/1167  [11:37:22]
12/1167  [11:37:22]
13/1167  [11:37:23]
14/1167  [11:37:24]
15/1167  [11:37:24]
16/1167  [11:37:24]
17/1167  [11:37:24]
18/1167  [11:37:26]
19/1167  [11:37:26]
20/1167  [11:37:26]
21/1167  [11:37:26]
22/1167  [11:37:43]
23/1167  [11:37:45]
24/1167  [11:37:46]
25/1167  [11:37:46]
26/1167  [11:37:48]
27/1167  [11:37:49]
28/1167  [11:37:49]
29/1167  [11:37:49]
30/1167  [11:37:49]
31/1167  [11:37:49]
32/1167  [11:38:06]
33/1167  [11:38:06]
34/1167  [11:38:08]
35/1167  [11:38:08]
36/1167  [11:38:08]
37/1167  [11:38:08]
38/1167  [11:38:12]
39/1167  [11:38:12]
40/1167  [11:38:12]
41/1167  [11:38:14]
42/1167  [11:38:32]
43/1167  [11:38:32]
44/1167  [11:38:32]
45/1167  [11:38:34]
46/1167  [11:38:34]
47/1167  [11:38:34]
48/1167  [11:38:34]
49/1167  [11:38:37