In [43]:
import argparse, pydicom
import torch
import numpy as np
import src.dicom_helper as dhelp

from pathlib import Path
from typing import Any, Dict, Iterable, Tuple, List
from omegaconf import DictConfig, OmegaConf

## DICOMs

In [62]:
def load_hyperparams(config_path: Path) -> Tuple[DictConfig, Dict]:
    """Load MCO-IPSA hyperparameters from YAML and resolve runtime types."""

    cfg = OmegaConf.load(config_path)
    cfg_dict = OmegaConf.to_container(cfg, resolve=True)  # type: ignore[arg-type]
    hyperparams = cfg_dict["hyperparams"]
    hyperparams["anisotropy_table"] = Path(hyperparams["anisotropy_table"])
    hyperparams["radial_table"] = Path(hyperparams["radial_table"])
    return cfg, hyperparams

def resolve_single(path_iterable: Iterable[Path], description: str) -> Path:
    """Return the first path from ``path_iterable`` or raise if empty."""

    candidates = sorted(path_iterable)
    if not candidates:
        raise FileNotFoundError(f"No files found for {description}.")
    return candidates[0]

def load_case_paths(root: Path, case_id: int) -> Dict[str, Path]:
    """Collect CT, RTDOSE, RTPLAN, and RTSTRUCT paths for a case."""

    case_dir = root / f"{case_id}"
    if not case_dir.exists():
        raise FileNotFoundError(f"Case directory not found: {case_dir}")

    return {
        "ct": resolve_single(case_dir.glob("CT*/*"), "CT series"),
        "dose": resolve_single(case_dir.glob("RTDOSE*/*/*.dcm"), "RTDOSE file"),
        "plan": resolve_single(case_dir.glob("RTPLAN*/*/*.dcm"), "RTPLAN file"),
        "struct": resolve_single(case_dir.glob("RTSTRUCT*/*/*.dcm"), "RTSTRUCT file"),
    }

def extract_dwells(rtplan_path: str) -> List[Dict]:
    """Extract dwell positions and channel directions from RTPLAN.
    Returns list of dicts: {positions: [N,3] mm (world LPS), direction: [3] (unit, if inferable)}
    """
    ds = pydicom.dcmread(rtplan_path)
    dwells = []
    # IEC: Brachy Application Setup Sequence (300A,00A0) -> Channel Sequence (300A,0280)
    if hasattr(ds, "BrachyApplicationSetupSequence"):
        setups = ds.BrachyApplicationSetupSequence
    elif hasattr(ds, "ApplicationSetupSequence"):
        setups = ds.ApplicationSetupSequence
    
    for setup in setups:
        setup_number = int(getattr(setup, "ApplicationSetupNumber", 0))
        channel_seq = getattr(setup, "ChannelSequence", [])
        for ch in channel_seq:
            cps = getattr(ch, 'BrachyControlPointSequence', None) or getattr(ch, 'ControlPointSequence', None)
            if cps is None:
                continue
            pts = []
            for cp in cps:
                # Preferred: Control Point 3D Position (300A,02D4)
                if hasattr(cp, 'ControlPoint3DPosition'):
                    p = np.array(cp.ControlPoint3DPosition, dtype=float)
                    pts.append(p)
                # Some vendors store in (300A,012C) Source Applicator Position or similarâ€”add fallbacks as needed.
                elif hasattr(cp, 'SourceApplicatorPosition') and hasattr(cp, 'TableTopLateralPosition'):
                    # This branch is vendor-specific; leave as TODO
                    pass
            if len(pts) >= 2:
                pts = np.stack(pts, axis=0)  # [N,3]
                # Infer channel direction from first-to-last control point
                ch_dir = pts[-1] - pts[0]
                n = np.linalg.norm(ch_dir) + 1e-8
                ch_dir = ch_dir / n
                dwells.append(dict(positions=pts, direction=ch_dir))
            elif len(pts) == 1:
                dwells.append(dict(positions=np.stack(pts), direction=np.array([0.0,0.0,1.0])))
    if not dwells:
        raise ValueError('No dwell control points found; check vendor tags or add fallbacks.')
    return dwells

In [3]:
cfg, hyperparams = load_hyperparams(Path("config.yaml"))
cfg, hyperparams

({'hyperparams': {'anisotropy_table': './ESTRO/nucletron_mhdr-v2_F.xlsx', 'radial_table': './ESTRO/nucletron_mhdr-v2_gL.xlsx'}, 'run': {'data-root': '/mnt/d/PRV/GYN_Geneva_wo_Needles', 'case': 'Case6'}},
 {'anisotropy_table': PosixPath('ESTRO/nucletron_mhdr-v2_F.xlsx'),
  'radial_table': PosixPath('ESTRO/nucletron_mhdr-v2_gL.xlsx')})

In [4]:
data_root = Path(cfg['run']['data-root'])
case_id = Path(cfg['run']['case'])
print(f"Data root: {data_root}")
print(f"Case ID: {case_id}")
paths = load_case_paths(data_root, case_id)
paths

Data root: /mnt/d/PRV/GYN_Geneva_wo_Needles
Case ID: Case6


{'ct': PosixPath('/mnt/d/PRV/GYN_Geneva_wo_Needles/Case6/CT_0_20240112/series_3_1.3.6.1.4.1.2452.6.1088874395.1296847118.3536313259.1264601617'),
 'dose': PosixPath('/mnt/d/PRV/GYN_Geneva_wo_Needles/Case6/RTDOSE_0_20240112/series_1_1.3.6.1.4.1.2452.6.2542691449.1298947939.3638676387.3704138358/ima_unknown_uid_1.3.6.1.4.1.2452.6.1444471091.1123319702.4073412502.1826252067.dcm'),
 'plan': PosixPath('/mnt/d/PRV/GYN_Geneva_wo_Needles/Case6/RTPLAN_0_20240112/series_1_1.3.6.1.4.1.2452.6.1356188019.1134003628.1522809736.1180775174/ima_empty_uid_1.3.6.1.4.1.2452.6.3162210145.1093853173.1113596587.2422958571.dcm'),
 'struct': PosixPath('/mnt/d/PRV/GYN_Geneva_wo_Needles/Case6/RTSTRUCT_0_20240112/series_1_1.3.6.1.4.1.2452.6.3775943779.1152257382.1983088276.1409971453/ima_empty_uid_1.3.6.1.4.1.2452.6.3039571248.1233624635.2885846168.1640366842.dcm')}

In [5]:
ct_image, ct_array, ct_metadata = dhelp.load_ct_volume(paths['ct'])
ct_array.shape

(160, 512, 512)

In [6]:
dose_image, dose_array, dose_metadata = dhelp.load_rtdose_volume(paths['dose'])
dose_array.shape, dose_metadata['direction']

((241, 261, 251), (1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0))

In [52]:
rt_channels = dhelp.load_rtplan_by_channel(paths['plan'])
dwells = dhelp.extract_dwell_positions(ct_image, rt_channels)
rt_channels, dwells

([ChannelInfo(setup_number=0, channel_number=1, channel_id='Channel_1', total_time_s=33.1780900468977, final_cumulative_weight=0.73333310879436, positions_cm=[array([  0.5379861, -13.3453959,  47.7999833]), array([  0.5379861, -13.3453959,  47.7999833]), array([  0.5523497, -13.8179753,  47.6373107]), array([  0.5523497, -13.8179753,  47.6373107]), array([  0.5643543, -14.2816131,  47.4528961]), array([  0.5643543, -14.2816131,  47.4528961]), array([  0.5593393, -14.6576535,  47.1279034]), array([  0.5593393, -14.6576535,  47.1279034])], cumulative_weights=array([0.        , 0.18333328, 0.18333328, 0.36666655, 0.36666655,
         0.54999983, 0.54999983, 0.73333311]), relative_positions=array([ 0.,  0.,  5.,  5., 10., 10., 15., 15.]), strengths_U=array([40688.78106796, 40688.78106796, 40688.78106796, 40688.78106796,
         40688.78106796, 40688.78106796, 40688.78106796, 40688.78106796])),
  ChannelInfo(setup_number=0, channel_number=2, channel_id='Channel_2', total_time_s=33.17809004

In [63]:
extract_dwells(paths['plan'])

AttributeError: 'FileDataset' object has no attribute 'BrachyApplicationSetupSequence'

## 3DGS

In [57]:
def make_view_grids(dose_arr: np.ndarray, dose_metadata: Dict, plane: str, step: int = 1) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """Build world coordinates for a supervision plane from RTDOSE grid.
    plane in {'ax','co','sa'}. Returns (grid_xyz [H,W,3], H, W)
    step: sub-sampling stride for speed.
    """
    Z, Y, X = dose_arr.shape
    dx, dy, dz = dose_metadata['spacing']
    o = np.array(dose_metadata['origin'])
    rd = np.array(dose_metadata['direction'][0:3])
    cd = np.array(dose_metadata['direction'][3:6])
    sd = np.array(dose_metadata['direction'][6:9])
    
    # Build axes for each plane
    if plane == 'ax':
        # rows -> Y, cols -> X at a chosen slice (use mid-slice)
        zi = Z // 2
        rows = np.arange(0, Y, step)
        cols = np.arange(0, X, step)
        rr, cc = np.meshgrid(rows, cols, indexing='ij')
        # world = origin + cc*dx*cd + rr*dy*rd + offset_z*sd
        off = dose_metadata['offsets'][zi] if len(dose_metadata['offsets']) else zi*dz
        print(off)
        xyz = o + (cc*dx)[:, :, None]*cd + (rr*dy)[:, :, None]*rd + off*sd
    elif plane == 'co':
        yi = Y // 2
        zs = np.arange(0, Z, step)
        cols = np.arange(0, X, step)
        zz, cc = np.meshgrid(zs, cols, indexing='ij')
        off = (dose_metadata['offsets'][zz] if len(dose_metadata['offsets']) else (zz*dz))
        # Fix row index (yi)
        xyz = o + (cc*dx)[:, :, None]*cd + (yi*dy)*rd + (off[..., None])*sd
    elif plane == 'sa':
        xi = X // 2
        zs = np.arange(0, Z, step)
        rows = np.arange(0, Y, step)
        zz, rr = np.meshgrid(zs, rows, indexing='ij')
        off = (dose_metadata['offsets'][zz] if len(dose_metadata['offsets']) else (zz*dz))
        xyz = o + (xi*dx)*cd + (rr*dy)[:, :, None]*rd + (off[..., None])*sd
    else:
        raise ValueError('plane must be ax/co/sa')
    H, W = xyz.shape[:2]
    return xyz.reshape(-1, 3), H, W

def generate_gt_slices(dose_arr: np.ndarray):
    
    dose_ax = np.sum(ct_array, axis=0)
    dose_co = np.sum(ct_array, axis=1)
    dose_sa = np.sum(ct_array, axis=2)

    return dose_ax, dose_co, dose_sa


In [58]:
sup_stride = 1
xyz_ax, Hax, Wax = make_view_grids(dose_array, dose_metadata, 'ax', step=sup_stride)
xyz_co, Hco, Wco = make_view_grids(dose_array, dose_metadata, 'co', step=sup_stride)
xyz_sa, Hsa, Wsa = make_view_grids(dose_array, dose_metadata, 'sa', step=sup_stride)

120.0


In [59]:
xyz_ax.shape, Hax, Wax

((65511, 3), 261, 251)

In [60]:
dose_ax, dose_co, dose_sa = generate_gt_slices(dose_array)

In [61]:
num_gaussian_per_point = 1
sigma_mm = 2
jitter_mm = 1

mus, sig, wei, aniso = init_gaussians_from_dwells(
        dwells, per_cp=num_gaussian_per_point, sigma_mm=sigma_mm, jitter_mm=jitter_mm)

NameError: name 'ch' is not defined