In [2]:
loader = get_bb2t1ce_dataloader(BB_DIR, T1CE_DIR, batch_size=1)
batch = next(iter(loader))
bb, t1ce = batch["bb"], batch["t1ce"]

print(f"[CHECK] BB shape: {bb.shape}")
print(f"[CHECK] T1CE shape: {t1ce.shape}")


[INFO] Found 100 BB–T1CE pairs.
[CHECK] BB shape: torch.Size([178, 1, 240, 240])
[CHECK] T1CE shape: torch.Size([178, 1, 240, 240])


In [4]:
# ✅ Dataset test
loader = get_bb2t1ce_dataloader(BB_DIR, T1CE_DIR, batch_size=1, as_2d=False)
batch = next(iter(loader))
bb, t1ce = batch["bb"], batch["t1ce"]
print(f"✅ Loaded batch: BB {bb.shape}, T1CE {t1ce.shape}")

# ✅ 3D → 2D flatten
def collapse_z_to_batch(tensor):
    """(B, C, Z, H, W) → (B*Z, C, H, W)"""
    if tensor.ndim == 5:
        B, C, Z, H, W = tensor.shape
        tensor = tensor.permute(0, 2, 1, 3, 4).reshape(B * Z, C, H, W)
    return tensor

bb = collapse_z_to_batch(bb)
t1ce = collapse_z_to_batch(t1ce)
print(f"✅ Flattened: BB {bb.shape}, T1CE {t1ce.shape}")


[INFO] Found 100 BB–T1CE pairs.


✅ Loaded batch: BB torch.Size([1, 1, 178, 240, 240]), T1CE torch.Size([1, 1, 178, 240, 240])
✅ Flattened: BB torch.Size([178, 1, 240, 240]), T1CE torch.Size([178, 1, 240, 240])


In [5]:
from datasets.transforms import bb2t1ce_transform

# ✅ Transform 적용 (2D 전용)
bb_t, t1ce_t = bb2t1ce_transform(bb, t1ce)
print(f"✅ Transform applied: BB {bb_t.shape}, T1CE {t1ce_t.shape}")


IndexError: too many indices for tensor of dimension 4

In [6]:
# ================================================================
# test.py
# ASAN_02_BB_T1CE - 전체 모듈 통합 테스트 (MONAI 기반)
# ================================================================

import os
import sys
import torch

# ------------------------------------------------------------
# 1️⃣ 프로젝트 루트 경로 설정
# ------------------------------------------------------------
ROOT_DIR = "/workspace/nas100/forGPU2/Kimjihoo/ASAN_02_BB_T1CE"
os.chdir(ROOT_DIR)
sys.path.append(ROOT_DIR)
print(f"현재 경로: {os.getcwd()}")

# ------------------------------------------------------------
# 2️⃣ 모듈 임포트
# ------------------------------------------------------------
from datasets.bb2t1ce_dataset import get_bb2t1ce_dataloader
from models.diffusion_unet_custom import DiffusionUNetCustom as ConditionalUNet
from runners.trainer import Trainer

# ------------------------------------------------------------
# 3️⃣ 기본 설정
# ------------------------------------------------------------
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"✅ Using device: {DEVICE}")

CONFIG = {
    "data": {
        "bb_dir": "./data/0_RAW_DATA/meta-bb",
        "t1ce_dir": "./data/0_RAW_DATA/meta-t1ce",
        "batch_size": 1,
        "num_workers": 2,
    },
    "diffusion": {
        "timesteps": 10,               # 빠른 테스트용
        "beta_start": 1e-4,
        "beta_end": 0.02,
    },
    "optimizer": {
        "lr": 1e-4,
        "weight_decay": 0.0,
    },
    "training": {
        "n_epochs": 2,                 # 빠른 테스트
        "loss_type": "l2",
        "ckpt_dir": "./checkpoints_test",
    },
}

# ------------------------------------------------------------
# 4️⃣ Dataset / DataLoader
# ------------------------------------------------------------
train_loader = get_bb2t1ce_dataloader(
    bb_dir=CONFIG["data"]["bb_dir"],
    t1ce_dir=CONFIG["data"]["t1ce_dir"],
    batch_size=CONFIG["data"]["batch_size"],
    num_workers=CONFIG["data"]["num_workers"],
    as_2d=True  # 3D→2D 슬라이스 변환 (Diffusion 학습용)
)

sample_batch = next(iter(train_loader))
print(f"✅ Sample batch loaded: BB {sample_batch['bb'].shape}, T1CE {sample_batch['t1ce'].shape}")

# ------------------------------------------------------------
# 5️⃣ 모델 정의
# ------------------------------------------------------------
# Conditional UNet은 x_t, timestep t, 조건 BB를 입력으로 받음
model = ConditionalUNet(
    in_channels=1,
    cond_channels=1,
    base_channels=64,
    channel_mults=[1, 2, 4],
    num_res_blocks=2,
    dropout=0.1
).to(DEVICE)

# ------------------------------------------------------------
# 6️⃣ Trainer 초기화
# ------------------------------------------------------------
trainer = Trainer(
    model=model,
    dataloader=train_loader,
    config=CONFIG,
    device=DEVICE
)

# ------------------------------------------------------------
# 7️⃣ 학습 실행
# ------------------------------------------------------------
if __name__ == "__main__":
    print("🚀 Starting BB→T1CE diffusion model training...")
    trainer.train()
    print("✅ Training completed successfully.")


현재 경로: /workspace/nas100/forGPU2/Kimjihoo/ASAN_02_BB_T1CE
✅ Using device: cuda




RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/monai/transforms/transform.py", line 150, in apply_transform
    return _apply_transform(transform, data, unpack_items, lazy, overrides, log_stats)
  File "/opt/conda/lib/python3.10/site-packages/monai/transforms/transform.py", line 98, in _apply_transform
    return transform(data, lazy=lazy) if isinstance(transform, LazyTrait) else transform(data)
  File "/opt/conda/lib/python3.10/site-packages/monai/transforms/spatial/dictionary.py", line 868, in __call__
    d[key] = self.resizer(
  File "/opt/conda/lib/python3.10/site-packages/monai/transforms/spatial/array.py", line 858, in __call__
    raise ValueError(
ValueError: len(spatial_size) must be greater or equal to img spatial dimensions, got spatial_size=2 img=3.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/_utils/worker.py", line 349, in _worker_loop
    data = fetcher.fetch(index)  # type: ignore[possibly-undefined]
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 52, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 52, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/opt/conda/lib/python3.10/site-packages/monai/data/dataset.py", line 109, in __getitem__
    return self._transform(index)
  File "/opt/conda/lib/python3.10/site-packages/monai/data/dataset.py", line 95, in _transform
    return self.transform(data_i)
  File "/opt/conda/lib/python3.10/site-packages/monai/transforms/compose.py", line 346, in __call__
    result = execute_compose(
  File "/opt/conda/lib/python3.10/site-packages/monai/transforms/compose.py", line 116, in execute_compose
    data = apply_transform(
  File "/opt/conda/lib/python3.10/site-packages/monai/transforms/transform.py", line 180, in apply_transform
    raise RuntimeError(f"applying transform {transform}") from e
RuntimeError: applying transform <monai.transforms.spatial.dictionary.Resized object at 0x7fa0e0e810f0>
