In [None]:
import argparse, pydicom
import torch
import numpy as np
import matplotlib.pyplot as plt

from pathlib import Path
from typing import Any, Dict, Iterable, Tuple, List
from omegaconf import DictConfig, OmegaConf

import tg43.dicom_helper as dhelp
import tg43.contour_helper as chelp
import tg43.dose_calculation as dosecal
import tg43.utils as utils
import tg43.visualization as vis

import src.dataloader as dataloader

In [None]:
def load_hyperparams(config_path: Path) -> Tuple[DictConfig, Dict]:
    """Load MCO-IPSA hyperparameters from YAML and resolve runtime types."""

    cfg = OmegaConf.load(config_path)
    cfg_dict = OmegaConf.to_container(cfg, resolve=True)  # type: ignore[arg-type]
    return cfg

def resolve_single(path_iterable: Iterable[Path], description: str) -> Path:
    """Return the first path from ``path_iterable`` or raise if empty."""

    candidates = sorted(path_iterable)
    if not candidates:
        raise FileNotFoundError(f"No files found for {description}.")
    return candidates[0]

def load_case_paths(root: Path, case_id: int) -> Dict[str, Path]:
    """Collect CT, RTDOSE, RTPLAN, and RTSTRUCT paths for a case."""

    case_dir = root / f"{case_id}"
    if not case_dir.exists():
        raise FileNotFoundError(f"Case directory not found: {case_dir}")

    return {
        "ct": resolve_single(case_dir.glob("CT*/*"), "CT series"),
        "dose": resolve_single(case_dir.glob("RTDOSE*/*/*.dcm"), "RTDOSE file"),
        "plan": resolve_single(case_dir.glob("RTPLAN*/*/*.dcm"), "RTPLAN file"),
        "struct": resolve_single(case_dir.glob("RTSTRUCT*/*/*.dcm"), "RTSTRUCT file"),
    }

In [None]:
cfg = load_hyperparams(Path("config.yaml"))
cfg

In [None]:
data_root = Path(cfg['run']['data-root'])
case_id = Path(cfg['run']['case'])
print(f"Data root: {data_root}")
print(f"Case ID: {case_id}")
paths = load_case_paths(data_root, case_id)
paths

In [None]:
ct_image, ct_array, _ = dhelp.load_ct_volume(paths["ct"])
ct_array.shape

In [None]:
import importlib
importlib.reload(dhelp)

rt_channels, aux_points = dhelp.load_rtplan_by_channel(paths['plan'], all_points=True)
rt_channels, aux_points

In [None]:
aux_points

In [None]:
def get_ref_point(points, label):
    for entry in points.get("dose_reference_points", []):
        if entry["description"].strip().lower() == label.lower():
            return entry
    raise ValueError(f"{label} point not found in RTPLAN.")

art_point = get_ref_point(aux_points, "Art")
alt_point = get_ref_point(aux_points, "Alt")
art_point

In [None]:
import importlib
importlib.reload(dhelp)

dhelp.extract_dwell_positions(ct_image, rt_channels, unique=True)

In [None]:
def fliter_redundant_positions(positions: np.ndarray) -> np.ndarray:
    """ Filter out redundant positions that are closer than the threshold distance. """
    
    seen = set()
    unique_arrays = []
    for position in positions:
        key = tuple(position.tolist())          # immutable surrogate
        if key in seen:
            continue
        seen.add(key)
        unique_arrays.append(position)
        
    return np.array(unique_arrays)

def reference_points(rt_channels, offset_ovoid=1.5, offset_tandem=2) -> Dict[str, np.ndarray]:
    """ default offsets:
    offset_ovoid = 1 + 0.5  # 1 cm (ovoid thickness) + 0.5 cm (margin from ovoid surface)
    offset_tandem = 2  # 2 cm margin from tandem source position
    """

    pos_ovoid_left = fliter_redundant_positions(rt_channels[0].positions_cm)
    pos_ovoid_right = fliter_redundant_positions(rt_channels[1].positions_cm)
    pos_tandem = fliter_redundant_positions(rt_channels[2].positions_cm)

    pos_a_left = np.array(alt_point["positions_cm"])
    pos_a_right = np.array(art_point["positions_cm"])

    def _unit_vector(vec: np.ndarray) -> np.ndarray:
        norm = np.linalg.norm(vec)
        if norm == 0:
            raise ValueError("Encountered a zero-length vector while normalizing")
        return vec / norm

    ovoid_center = (np.mean(pos_ovoid_right, axis=0) + np.mean(pos_ovoid_left, axis=0)) / 2
    ovoid_axis = np.mean(pos_ovoid_right, axis=0) - np.mean(pos_ovoid_left, axis=0)
    tandem_axis = pos_tandem[-1] - pos_tandem[0]

    ovoid_dir = _unit_vector(ovoid_axis)
    tandem_dir = _unit_vector(tandem_axis)

    dot_val = np.dot(ovoid_dir, tandem_dir)
    angle_deg = np.degrees(np.arccos(np.clip(dot_val, -1.0, 1.0)))
    print(f"Angle between ovoid cluster and tandem axes: {angle_deg:.2f} deg")
    if np.isclose(angle_deg, 90.0, atol=5.0):
        print("Approximately orthogonal within +/- 5 deg tolerance.")
    else:
        print("Not orthogonal within +/- 5 deg tolerance.")
        
    ref_ovoid_left = pos_ovoid_left - (ovoid_dir * offset_ovoid)
    ref_ovoid_right = pos_ovoid_right + (ovoid_dir * offset_ovoid)

    ref_tandem_left = pos_tandem - (ovoid_dir * offset_tandem)
    ref_tandem_right = pos_tandem + (ovoid_dir * offset_tandem)

    threshold_distance = 0.5  # 0.5 cm threshold along tandem direction
    ref_tandem_left = ref_tandem_left[np.linalg.norm(pos_tandem - ovoid_center, axis=1) >= threshold_distance]
    ref_tandem_right = ref_tandem_right[np.linalg.norm(pos_tandem - ovoid_center, axis=1) >= threshold_distance]

    def _filter_by_pointa(ref_tandem, pos_a):

        delta_cm = ref_tandem - pos_a
        tandem_axis = ref_tandem[-1] - ref_tandem[0]
        tandem_axis /= np.linalg.norm(tandem_axis)

        proj_mm = (delta_cm @ tandem_axis) * 10.0
        mask = proj_mm <= 5.0
        return ref_tandem[mask]

    ref_tandem_left = _filter_by_pointa(ref_tandem_left[1:], pos_a_left)
    ref_tandem_right = _filter_by_pointa(ref_tandem_right[1:], pos_a_right)

    return {
        "ref_ovoid_left": ref_ovoid_left[1:3], 
        "ref_ovoid_right": ref_ovoid_right[1:3], 
        "ref_tandem_left": ref_tandem_left,
        "ref_tandem_right": ref_tandem_right,
    }

In [None]:
pos_a_left = np.array(alt_point["positions_cm"])
pos_a_right = np.array(art_point["positions_cm"])

In [None]:
refs = reference_points(rt_channels)
ref_ovoid_left = refs['ref_ovoid_left']
ref_ovoid_right = refs['ref_ovoid_right']
ref_tandem_left = refs['ref_tandem_left']
ref_tandem_right = refs['ref_tandem_right']

fontsize = 12
fig = plt.figure(figsize=(8, 8))
ax = fig.add_subplot(111, projection='3d')

for channel in rt_channels:
    positions = np.array(channel.positions_cm)
    # ax.scatter(Coronal, Sagittal, Axial)
    ax.scatter(positions[:, 0], positions[:, 1], positions[:, 2], s=10, label=f'Channel {channel.channel_number}')
ax.scatter(ref_ovoid_left[:, 0], ref_ovoid_left[:, 1], ref_ovoid_left[:, 2], s=10, marker='x', color='black', label='Reference Points')
ax.scatter(ref_ovoid_right[:, 0], ref_ovoid_right[:, 1], ref_ovoid_right[:, 2], s=10, marker='x', color='black')
ax.scatter(ref_tandem_left[:, 0], ref_tandem_left[:, 1], ref_tandem_left[:, 2], s=10, marker='x', color='black')
ax.scatter(ref_tandem_right[:, 0], ref_tandem_right[:, 1], ref_tandem_right[:, 2], s=10, marker='x', color='black')

ax.scatter(pos_a_left[0], pos_a_left[1], pos_a_left[2], s=15, marker='^', color="#e745ed", label="Point A")
ax.scatter(pos_a_right[0], pos_a_right[1], pos_a_right[2], s=15, marker='^', color="#e745ed")

ax.set_xlabel('Coronal (cm)', fontsize=fontsize)
ax.set_ylabel('Sagittal (cm)', fontsize=fontsize)
ax.set_zlabel('Axial (cm)', fontsize=fontsize)

ax.view_init(elev=90, azim=-75, roll=0)
plt.legend(bbox_to_anchor=(1.1, 0.5), loc='center left', fontsize=fontsize)
plt.tight_layout()
plt.show()

In [None]:
refs = reference_points(rt_channels)
ref_ovoid_left = refs['ref_ovoid_left']
ref_ovoid_right = refs['ref_ovoid_right']
ref_tandem_left = refs['ref_tandem_left']
ref_tandem_right = refs['ref_tandem_right']

fontsize = 12
fig = plt.figure(figsize=(8, 8))
ax = fig.add_subplot(111, projection='3d')

for channel in rt_channels:
    positions = np.array(channel.positions_cm)
    # ax.scatter(Coronal, Sagittal, Axial)
    ax.scatter(positions[:, 0], positions[:, 1], positions[:, 2], s=10, label=f'Channel {channel.channel_number}')
ax.scatter(ref_ovoid_left[:, 0], ref_ovoid_left[:, 1], ref_ovoid_left[:, 2], s=10, marker='x', color='black', label='Reference Points')
ax.scatter(ref_ovoid_right[:, 0], ref_ovoid_right[:, 1], ref_ovoid_right[:, 2], s=10, marker='x', color='black')
ax.scatter(ref_tandem_left[:, 0], ref_tandem_left[:, 1], ref_tandem_left[:, 2], s=10, marker='x', color='black')
ax.scatter(ref_tandem_right[:, 0], ref_tandem_right[:, 1], ref_tandem_right[:, 2], s=10, marker='x', color='black')

ax.scatter(pos_a_left[0], pos_a_left[1], pos_a_left[2], s=15, marker='^', color="#e745ed", label="Point A")
ax.scatter(pos_a_right[0], pos_a_right[1], pos_a_right[2], s=15, marker='^', color="#e745ed")

ax.set_xlabel('Coronal (cm)', fontsize=fontsize)
ax.set_ylabel('Sagittal (cm)', fontsize=fontsize)
ax.set_zlabel('Axial (cm)', fontsize=fontsize)

ax.view_init(elev=0, azim=-75, roll=0)
plt.legend(bbox_to_anchor=(1.1, 0.5), loc='center left', fontsize=fontsize)
plt.tight_layout()
plt.show()