In [1]:
# composable-mapping demo: multi-resolution dose + fine contour, 2D
# Deforms with translation / rotation / scaling / generic affine

import math
import torch
import numpy as np
import matplotlib.pyplot as plt

import composable_mapping as cm
# API we’ll use: CoordinateSystem, Affine/affine, SamplableVolume, DataFormat, etc.
# See docs index listing these symbols.  # (Docs: honkamj.github.io/composable-mapping) 

# --- composable-mapping compat helpers ---
def cs_grid(cs):
    """Return CoordinateSystem grid as a MappableTensor (handles property/method)."""
    g = getattr(cs, "grid")
    return g() if callable(g) else g

def cs_shape(cs):
    """Return (H, W[, ...]) regardless of property/method API."""
    s = getattr(cs, "spatial_shape")
    return s() if callable(s) else s

def cs_spacing(cs):
    """Return spacing tensor regardless of CPU/GPU variant & property/method API."""
    f = getattr(cs, "grid_spacing", None) or getattr(cs, "grid_spacing_cpu", None)
    if f is None:
        raise AttributeError("CoordinateSystem has no grid_spacing[ _cpu ] accessor")
    return f() if callable(f) else f


# ------------ helpers: coordinate systems & synthetic data ------------

def make_coordinate_systems(H=128, W=160, img_spacing_mm=1.0, dose_spacing_mm=0.5):
    """
    Two grids with identical physical FOV:
      - image grid: (H, W), spacing = img_spacing_mm
      - dose grid:  (Hd, Wd), spacing = dose_spacing_mm
    """
    # Base image CS: voxel grid scaled to world mm (origin 0, identity orientation)
    img_cs = cm.CoordinateSystem.voxel((H, W)).multiply_world(
        torch.tensor([img_spacing_mm, img_spacing_mm], dtype=torch.float32)
    )

    # Choose dose shape so that physical sizes match
    Hd = int(round(H * img_spacing_mm / dose_spacing_mm))
    Wd = int(round(W * img_spacing_mm / dose_spacing_mm))

    dose_cs = cm.CoordinateSystem.voxel((Hd, Wd)).multiply_world(
        torch.tensor([dose_spacing_mm, dose_spacing_mm], dtype=torch.float32)
    )

    # (Optional) sanity check: same FOV in mm
    img_fov = torch.tensor([H * img_spacing_mm, W * img_spacing_mm])
    dose_fov = torch.tensor([Hd * dose_spacing_mm, Wd * dose_spacing_mm])
    if not torch.allclose(img_fov, dose_fov, atol=1e-3):
        print(f"[warn] FOV mismatch: img {img_fov.tolist()} mm vs dose {dose_fov.tolist()} mm")

    return img_cs, dose_cs


def gaussian2d(grid_yx, mu_yx, sigma_yx):
    y, x = grid_yx[..., 0], grid_yx[..., 1]
    my, mx = mu_yx
    sy, sx = sigma_yx
    return torch.exp(-0.5 * (((y - my)/sy)**2 + ((x - mx)/sx)**2))

def make_synthetic_dose(dose_cs):
    """
    Create a smooth toy 'dose' on the dose grid (same physical FOV as image),
    using world-mm coordinates from the coordinate system.
    """
    # World grid as a MappableTensor
    gyx_mt = cs_grid(dose_cs)                          # <-- NO parentheses
    gyx = gyx_mt.generate_values()                      # materialize to Tensor

    # Expect shape [*, 2, H, W]; split to (y, x) in mm
    if gyx.shape[-3] != 2:
        raise RuntimeError(f"Expected coord channels=2 at dim -3, got shape {gyx.shape}")
    y = gyx[..., 0, :, :]
    x = gyx[..., 1, :, :]

    # A couple of Gaussians summed
    H, W = y.shape[-2:]
    cy, cx = (cs_spacing(dose_cs)[0] * H) / 2.0, (cs_spacing(dose_cs)[1] * W) / 2.0
    g1 = torch.exp(-((y - cy) ** 2 + (x - cx) ** 2) / (2.0 * 15.0 ** 2))
    g2 = 0.6 * torch.exp(-((y - (0.35 * cy * 2)) ** 2 + (x - (0.65 * cx * 2)) ** 2) / (2.0 * 10.0 ** 2))
    dose = (g1 + g2).unsqueeze(0).unsqueeze(0)         # [1,1,H,W]
    return dose

def make_circle_contour(center_yx_mm, radius_mm, step_mm=0.1):
    """
    Return Nx2 world coordinates (in mm) for a circle with arc-length spacing ~step_mm.
    center_yx_mm: (y, x) in mm
    """
    circumference = 2 * math.pi * radius_mm
    n = max(16, int(round(circumference / step_mm)))
    # torch.linspace on some versions doesn't support endpoint=. Emulate [0, 2π) by slicing off last sample.
    thetas = torch.linspace(0.0, 2.0 * math.pi, n + 1, dtype=torch.float32)[:-1]
    y = center_yx_mm[0] + radius_mm * torch.sin(thetas)
    x = center_yx_mm[1] + radius_mm * torch.cos(thetas)
    return torch.stack([y, x], dim=-1)  # [N, 2] (world coords in mm)

# ------------ helpers: building mappings (world->world) ------------

def affine_about_point(R_2x2, t_2, pivot_yx):
    """
    Build a 3x3 homogeneous transform acting in world (mm) that rotates/scales about pivot.
    x' = T * x_hom, where x=[y,x,1]^T
    R_2x2 acts on [y,x]; t_2 is extra translation (world mm).
    """
    T = torch.eye(3, dtype=torch.float32)
    T[:2, :2] = R_2x2
    T[:2, 2] = t_2
    # Conjugate to rotate/scale about pivot (translate to origin -> R -> translate back)
    P = torch.eye(3, dtype=torch.float32)
    P[:2, 2] = -pivot_yx
    Pinv = torch.eye(3, dtype=torch.float32)
    Pinv[:2, 2] = pivot_yx
    return Pinv @ T @ P  # 3x3

def mapping_translation(ty_mm, tx_mm, cs):
    M = torch.eye(3, dtype=torch.float32)
    M[0, 2] = ty_mm
    M[1, 2] = tx_mm
    return cm.Affine.from_matrix(M, cs)  # world->world

def mapping_rotation(angle_deg, cs):
    theta = math.radians(angle_deg)
    R = torch.tensor([[ math.cos(theta), -math.sin(theta)],
                      [ math.sin(theta),  math.cos(theta)]], dtype=torch.float32)
    # rotate about image center in world units
    H, W = cs.spatial_shape()
    spacing_y, spacing_x = cs.grid_spacing()
    pivot = torch.tensor([H*spacing_y/2.0, W*spacing_x/2.0], dtype=torch.float32)
    M = affine_about_point(R, torch.tensor([0.0, 0.0]), pivot)
    return cm.Affine.from_matrix(M, cs)

def mapping_scaling(sy, sx, cs):
    S = torch.tensor([[sy, 0.0], [0.0, sx]], dtype=torch.float32)
    H, W = cs.spatial_shape()
    spacing_y, spacing_x = cs.grid_spacing()
    pivot = torch.tensor([H*spacing_y/2.0, W*spacing_x/2.0], dtype=torch.float32)
    M = affine_about_point(S, torch.tensor([0.0, 0.0]), pivot)
    return cm.Affine.from_matrix(M, cs)

def mapping_generic_affine(A_2x2, t_2, cs):
    H, W = cs.spatial_shape()
    spacing_y, spacing_x = cs.grid_spacing()
    pivot = torch.tensor([H*spacing_y/2.0, W*spacing_x/2.0], dtype=torch.float32)
    M = affine_about_point(A_2x2, t_2, pivot)
    return cm.Affine.from_matrix(M, cs)

# ------------ helpers: sampling & plotting ------------

def warp_dose_samplable(dose_tensor, dose_cs, mapping, mode="linear"):
    """
    Keep output on the *original* dose grid (no resolution change).
    We 'pull' the source dose with the coordinate mapping acting in world mm.
    """
    # Wrap as a samplable volume to use the library’s sampling path
    vol = cm.SamplableVolume.from_tensor(dose_tensor, dose_cs)
    if mode == "nearest":
        vol = vol.modify_sampler(cm.NearestInterpolator())
    elif mode == "linear":
        vol = vol.modify_sampler(cm.LinearInterpolator())
    # Pull to the same grid: evaluate source at mapped coords
    warped = vol(mapping).sample_to(dose_cs)   # (docs list __call__/sample_to on SamplableVolume/ComposableMapping)
    return warped  # torch tensor [1,1,H,W]

def deform_points_world(pts_yx, mapping):
    """Apply mapping to N×2 world-mm coordinates."""
    # ComposableMapping.__call__ supports sampling given coordinates (docs list __call__)
    return mapping(pts_yx)

def plot_grid(ax, cs, mapping=None, every=16, alpha=0.8, lw=0.8):
    """Draw a deformed lattice in world mm coordinates."""
    H, W = cs.spatial_shape()
    sy, sx = cs.grid_spacing()
    ny = max(2, H // every)
    nx = max(2, W // every)
    ys = torch.linspace(0, H*sy, ny)
    xs = torch.linspace(0, W*sx, nx)

    # verticals
    for x in xs:
        y = torch.linspace(0, H*sy, 200)
        line = torch.stack([y, torch.full_like(y, x)], dim=-1)
        if mapping is not None:
            line = mapping(line)
        ax.plot(line[:,1].numpy(), line[:,0].numpy(), linewidth=lw, alpha=alpha)
    # horizontals
    for y in ys:
        x = torch.linspace(0, W*sx, 200)
        line = torch.stack([torch.full_like(x, y), x], dim=-1)
        if mapping is not None:
            line = mapping(line)
        ax.plot(line[:,1].numpy(), line[:,0].numpy(), linewidth=lw, alpha=alpha)

def show_one(mapping_name, mapping, img_cs, dose_cs, dose_tensor, contour_world):
    warped = warp_dose_samplable(dose_tensor, dose_cs, mapping, mode="linear")
    contour_def = deform_points_world(contour_world, mapping)

    fig, axs = plt.subplots(1, 3, figsize=(14, 4), constrained_layout=True)
    # original
    axs[0].imshow(dose_tensor[0,0].numpy(), extent=[0, dose_cs.spatial_shape()[1]*dose_cs.grid_spacing()[1],
                                                     dose_cs.spatial_shape()[0]*dose_cs.grid_spacing()[0], 0])
    axs[0].plot(contour_world[:,1].numpy(), contour_world[:,0].numpy(), lw=2)
    plot_grid(axs[0], img_cs, mapping=None, every=16, alpha=0.4)
    axs[0].set_title("Original dose/contour/grid"); axs[0].set_aspect('equal')

    # deformed
    axs[1].imshow(warped[0,0].numpy(), extent=[0, dose_cs.spatial_shape()[1]*dose_cs.grid_spacing()[1],
                                               dose_cs.spatial_shape()[0]*dose_cs.grid_spacing()[0], 0])
    axs[1].plot(contour_def[:,1].numpy(), contour_def[:,0].numpy(), lw=2)
    plot_grid(axs[1], img_cs, mapping=mapping, every=16, alpha=0.4)
    axs[1].set_title(f"Deformed by {mapping_name}"); axs[1].set_aspect('equal')

    # deformation grid alone
    plot_grid(axs[2], img_cs, mapping=mapping, every=16, alpha=0.9)
    axs[2].set_xlim(0, dose_cs.spatial_shape()[1]*dose_cs.grid_spacing()[1])
    axs[2].set_ylim(dose_cs.spatial_shape()[0]*dose_cs.grid_spacing()[0], 0)
    axs[2].set_title("Deformation grid"); axs[2].set_aspect('equal')
    plt.show()

# ------------ main demo ------------

if __name__ == "__main__":
    torch.set_default_dtype(torch.float32)

    H, W = 128, 160
    img_cs, dose_cs = make_coordinate_systems(H, W, img_spacing_mm=1.0, dose_spacing_mm=0.5)

    # synthetic dose (high-res grid, same physical FOV)
    dose = make_synthetic_dose(dose_cs)  # [1,1,Hdose,Wdose]

    # fine circle contour (0.1 mm arc step) in world mm
    center_world = torch.tensor([H*1.0/2.0, W*1.0/2.0])  # since image spacing is 1 mm
    contour = make_circle_contour(center_world, radius_mm=min(H, W)*0.25, step_mm=0.1)

    # mappings to test
    maps = {
        "translation (y:+8mm, x:-12mm)": mapping_translation(8.0, -12.0, img_cs),
        "rotation (15 deg)":              mapping_rotation(15.0, img_cs),
        "scaling (sy=0.9, sx=1.1)":       mapping_scaling(0.9, 1.1, img_cs),
        "generic affine":                 mapping_generic_affine(
                                              torch.tensor([[0.95, -0.05],
                                                            [0.08,  1.02]]),
                                              torch.tensor([5.0, -7.0]), img_cs),
    }

    for name, m in maps.items():
        show_one(name, m, img_cs, dose_cs, dose, contour)


TypeError: Affine.from_matrix() takes 2 positional arguments but 3 were given