In [3]:

import composable_mapping as cm
print(cm.__version__)

AttributeError: module 'composable_mapping' has no attribute '__version__'

In [1]:
# composable-mapping demo: multi-resolution dose + fine contour, 2D
# Deforms with translation / rotation / scaling / generic affine
# Robust to composable_mapping API differences

import math
import torch
import numpy as np
import matplotlib.pyplot as plt

import composable_mapping as cm
# We'll use: CoordinateSystem, SamplableVolume, (Affine/affine), Linear/Nearest interpolators, etc.

# --- composable-mapping compat helpers ---
def cs_grid(cs):
    """Return CoordinateSystem grid as a MappableTensor (handles property/method)."""
    g = getattr(cs, "grid")
    return g() if callable(g) else g

def cs_shape(cs):
    """Return (H, W[, ...]) regardless of property/method API."""
    s = getattr(cs, "spatial_shape")
    return tuple(s() if callable(s) else s)

def cs_spacing(cs):
    """Return spacing tensor regardless of CPU/GPU variant & property/method API."""
    f = getattr(cs, "grid_spacing", None) or getattr(cs, "grid_spacing_cpu", None)
    if f is None:
        raise AttributeError("CoordinateSystem has no grid_spacing[ _cpu ] accessor")
    val = f() if callable(f) else f
    return torch.as_tensor(val, dtype=torch.float32)

def make_affine(M: torch.Tensor):
    """
    Build a world->world affine mapping robustly across composable_mapping versions.
    Tries cm.affine(matrix) first, then class-based APIs.
    """
    if hasattr(cm, "affine") and callable(cm.affine):
        return cm.affine(M)
    if hasattr(cm, "Affine"):
        A = cm.Affine
        if hasattr(A, "from_matrix") and callable(A.from_matrix):
            # Newer API: single-arg from_matrix(matrix)
            return A.from_matrix(M)
        # Fallback: some versions may accept a direct constructor
        try:
            return A(M)
        except TypeError as e:
            raise RuntimeError(f"Could not construct Affine with available APIs: {e}")
    raise RuntimeError("Composable Mapping: no affine constructor found (missing affine/Affine).")

def channels_first_df_or_none():
    """Return DataFormat.CHANNELS_FIRST if available, else None."""
    if hasattr(cm, "DataFormat"):
        DF = cm.DataFormat
        if hasattr(DF, "CHANNELS_FIRST"):
            return DF.CHANNELS_FIRST
    return None


# ------------ helpers: coordinate systems & synthetic data ------------

def make_coordinate_systems(H=128, W=160, img_spacing_mm=1.0, dose_spacing_mm=0.5):
    """
    Two grids with identical physical FOV:
      - image grid: (H, W), spacing = img_spacing_mm
      - dose grid:  (Hd, Wd), spacing = dose_spacing_mm
    """
    img_cs = cm.CoordinateSystem.voxel((H, W)).multiply_world(
        torch.tensor([img_spacing_mm, img_spacing_mm], dtype=torch.float32)
    )

    Hd = int(round(H * img_spacing_mm / dose_spacing_mm))
    Wd = int(round(W * img_spacing_mm / dose_spacing_mm))

    dose_cs = cm.CoordinateSystem.voxel((Hd, Wd)).multiply_world(
        torch.tensor([dose_spacing_mm, dose_spacing_mm], dtype=torch.float32)
    )

    # (Optional) sanity check: same FOV in mm
    img_fov = torch.tensor([H * img_spacing_mm, W * img_spacing_mm])
    dose_fov = torch.tensor([Hd * dose_spacing_mm, Wd * dose_spacing_mm])
    if not torch.allclose(img_fov, dose_fov, atol=1e-3):
        print(f"[warn] FOV mismatch: img {img_fov.tolist()} mm vs dose {dose_fov.tolist()} mm")

    return img_cs, dose_cs


def make_synthetic_dose(dose_cs):
    """
    Create a smooth toy 'dose' on the dose grid (same physical FOV as image),
    using world-mm coordinates from the coordinate system.

    IMPORTANT: return shape [C, H, W] (no batch dim), C=1.
    """
    gyx_mt = cs_grid(dose_cs)
    # Materialize to Tensor (handle wrapper types)
    if hasattr(gyx_mt, "generate_values"):
        gyx = gyx_mt.generate_values()
    elif isinstance(gyx_mt, torch.Tensor):
        gyx = gyx_mt
    else:
        raise TypeError("Unexpected grid type returned by cs_grid(dose_cs)")

    # Expect shape [..., 2, H, W]; split to (y, x) in mm
    if gyx.shape[-3] != 2:
        raise RuntimeError(f"Expected coord channels=2 at dim -3, got shape {tuple(gyx.shape)}")
    y = gyx[..., 0, :, :]
    x = gyx[..., 1, :, :]

    H, W = y.shape[-2:]
    sy, sx = cs_spacing(dose_cs)
    cy, cx = (sy * H) / 2.0, (sx * W) / 2.0

    g1 = torch.exp(-((y - cy) ** 2 + (x - cx) ** 2) / (2.0 * 15.0 ** 2))
    g2 = 0.6 * torch.exp(-((y - (0.35 * cy * 2)) ** 2 + (x - (0.65 * cx * 2)) ** 2) / (2.0 * 10.0 ** 2))
    dose = (g1 + g2).unsqueeze(0)  # [1, H, W]  (C=1, NO batch dim)
    return dose

def make_circle_contour(center_yx_mm, radius_mm, step_mm=0.1):
    """
    Return Nx2 world coordinates (in mm) for a circle with arc-length spacing ~step_mm.
    center_yx_mm: (y, x) in mm
    """
    circumference = 2 * math.pi * radius_mm
    n = max(16, int(round(circumference / step_mm)))
    thetas = torch.linspace(0.0, 2.0 * math.pi, n + 1, dtype=torch.float32)[:-1]
    y = center_yx_mm[0] + radius_mm * torch.sin(thetas)
    x = center_yx_mm[1] + radius_mm * torch.cos(thetas)
    return torch.stack([y, x], dim=-1)  # [N, 2] (world coords in mm)

# ------------ helpers: building mappings (world->world) ------------

def affine_about_point(R_2x2, t_2, pivot_yx):
    """
    Build a 3x3 homogeneous transform acting in world (mm) that rotates/scales about pivot.
    x' = T * x_hom, where x=[y,x,1]^T
    R_2x2 acts on [y,x]; t_2 is extra translation (world mm).
    """
    T = torch.eye(3, dtype=torch.float32)
    T[:2, :2] = R_2x2
    T[:2, 2] = t_2
    # Conjugate to rotate/scale about pivot (translate to origin -> R -> translate back)
    P = torch.eye(3, dtype=torch.float32)
    P[:2, 2] = -pivot_yx
    Pinv = torch.eye(3, dtype=torch.float32)
    Pinv[:2, 2] = pivot_yx
    return Pinv @ T @ P  # 3x3

def mapping_translation(ty_mm, tx_mm, cs):
    M = torch.eye(3, dtype=torch.float32)
    M[0, 2] = ty_mm
    M[1, 2] = tx_mm
    return make_affine(M)  # world->world

def mapping_rotation(angle_deg, cs):
    theta = math.radians(angle_deg)
    R = torch.tensor([[ math.cos(theta), -math.sin(theta)],
                      [ math.sin(theta),  math.cos(theta)]], dtype=torch.float32)
    H, W = cs_shape(cs)
    sy, sx = cs_spacing(cs)
    pivot = torch.tensor([H*sy/2.0, W*sx/2.0], dtype=torch.float32)
    M = affine_about_point(R, torch.tensor([0.0, 0.0]), pivot)
    return make_affine(M)

def mapping_scaling(sy_scale, sx_scale, cs):
    S = torch.tensor([[sy_scale, 0.0], [0.0, sx_scale]], dtype=torch.float32)
    H, W = cs_shape(cs)
    sy, sx = cs_spacing(cs)
    pivot = torch.tensor([H*sy/2.0, W*sx/2.0], dtype=torch.float32)
    M = affine_about_point(S, torch.tensor([0.0, 0.0]), pivot)
    return make_affine(M)

def mapping_generic_affine(A_2x2, t_2, cs):
    H, W = cs_shape(cs)
    sy, sx = cs_spacing(cs)
    pivot = torch.tensor([H*sy/2.0, W*sx/2.0], dtype=torch.float32)
    M = affine_about_point(A_2x2, t_2, pivot)
    return make_affine(M)

# ------------ helpers: sampling & plotting ------------

def warp_dose_samplable(dose_tensor, dose_cs, mapping, mode="linear"):
    """
    Keep output on the *original* dose grid (no resolution change).
    We 'pull' the source dose with the coordinate mapping acting in world mm.

    dose_tensor MUST be [C, H, W] (no batch dim), C=1.
    """
    data_format = channels_first_df_or_none()
    kwargs = {}
    if data_format is not None:
        kwargs["data_format"] = data_format

    vol = cm.SamplableVolume.from_tensor(dose_tensor, dose_cs, **kwargs)
    if mode == "nearest":
        if hasattr(cm, "NearestInterpolator"):
            vol = vol.modify_sampler(cm.NearestInterpolator())
    elif mode == "linear":
        if hasattr(cm, "LinearInterpolator"):
            vol = vol.modify_sampler(cm.LinearInterpolator())

    warped_mt = vol(mapping).sample_to(dose_cs)
    # Materialize to a real tensor robustly (expect [C, H, W])
    if isinstance(warped_mt, torch.Tensor):
        return warped_mt
    if hasattr(warped_mt, "generate_values"):
        return warped_mt.generate_values()
    if hasattr(warped_mt, "tensor"):
        return warped_mt.tensor
    raise TypeError("Unexpected return type from sample_to(); cannot materialize Tensor.")

def deform_points_world(pts_yx, mapping):
    """Apply mapping to NÃ—2 world-mm coordinates."""
    out = mapping(pts_yx)
    # Some versions may return a wrapper; ensure Tensor
    if isinstance(out, torch.Tensor):
        return out
    if hasattr(out, "generate_values"):
        return out.generate_values()
    if hasattr(out, "tensor"):
        return out.tensor
    raise TypeError("Unexpected return type from mapping(points); cannot materialize Tensor.")

def plot_grid(ax, cs, mapping=None, every=16, alpha=0.8, lw=0.8):
    """Draw a (possibly) deformed lattice in world mm coordinates."""
    H, W = cs_shape(cs)
    sy, sx = cs_spacing(cs)
    ny = max(2, H // every)
    nx = max(2, W // every)
    ys = torch.linspace(0, H*sy, ny)
    xs = torch.linspace(0, W*sx, nx)

    # verticals
    for x in xs:
        y = torch.linspace(0, H*sy, 200)
        line = torch.stack([y, torch.full_like(y, x)], dim=-1)
        if mapping is not None:
            line = deform_points_world(line, mapping)
        ax.plot(line[:,1].numpy(), line[:,0].numpy(), linewidth=lw, alpha=alpha)
    # horizontals
    for y in ys:
        x = torch.linspace(0, W*sx, 200)
        line = torch.stack([torch.full_like(x, y), x], dim=-1)
        if mapping is not None:
            line = deform_points_world(line, mapping)
        ax.plot(line[:,1].numpy(), line[:,0].numpy(), linewidth=lw, alpha=alpha)

def show_one(mapping_name, mapping, img_cs, dose_cs, dose_tensor, contour_world):
    warped = warp_dose_samplable(dose_tensor, dose_cs, mapping, mode="linear")
    contour_def = deform_points_world(contour_world, mapping)

    Hd, Wd = cs_shape(dose_cs)
    sy, sx = cs_spacing(dose_cs)
    extent = [0, Wd*sx, Hd*sy, 0]

    fig, axs = plt.subplots(1, 3, figsize=(14, 4), constrained_layout=True)

    # original
    axs[0].imshow(dose_tensor[0].numpy(), extent=extent)  # dose_tensor: [1, H, W]
    axs[0].plot(contour_world[:,1].numpy(), contour_world[:,0].numpy(), lw=2)
    plot_grid(axs[0], img_cs, mapping=None, every=16, alpha=0.4)
    axs[0].set_title("Original dose/contour/grid")
    axs[0].set_aspect('equal')

    # deformed
    axs[1].imshow(warped[0].numpy(), extent=extent)       # warped: [1, H, W]
    axs[1].plot(contour_def[:,1].numpy(), contour_def[:,0].numpy(), lw=2)
    plot_grid(axs[1], img_cs, mapping=mapping, every=16, alpha=0.4)
    axs[1].set_title(f"Deformed by {mapping_name}")
    axs[1].set_aspect('equal')

    # deformation grid alone
    plot_grid(axs[2], img_cs, mapping=mapping, every=16, alpha=0.9)
    axs[2].set_xlim(0, Wd*sx)
    axs[2].set_ylim(Hd*sy, 0)
    axs[2].set_title("Deformation grid")
    axs[2].set_aspect('equal')

    plt.show()

# ------------ main demo ------------

if __name__ == "__main__":
    torch.set_default_dtype(torch.float32)

    H, W = 128, 160
    img_cs, dose_cs = make_coordinate_systems(H, W, img_spacing_mm=1.0, dose_spacing_mm=0.5)

    # synthetic dose (high-res grid, same physical FOV); shape [1, H_dose, W_dose]
    dose = make_synthetic_dose(dose_cs)

    # fine circle contour (0.1 mm arc step) in world mm
    center_world = torch.tensor([H*1.0/2.0, W*1.0/2.0])  # image spacing is 1 mm => center in mm
    contour = make_circle_contour(center_world, radius_mm=min(H, W)*0.25, step_mm=0.1)

    # mappings to test
    maps = {
        "translation (y:+8mm, x:-12mm)": mapping_translation(8.0, -12.0, img_cs),
        "rotation (15 deg)":              mapping_rotation(15.0, img_cs),
        "scaling (sy=0.9, sx=1.1)":       mapping_scaling(0.9, 1.1, img_cs),
        "generic affine":                 mapping_generic_affine(
                                              torch.tensor([[0.95, -0.05],
                                                            [0.08,  1.02]], dtype=torch.float32),
                                              torch.tensor([5.0, -7.0], dtype=torch.float32),
                                              img_cs),
    }

    for name, m in maps.items():
        show_one(name, m, img_cs, dose_cs, dose, contour)


AttributeError: 'Affine' object has no attribute 'transform'