In [30]:
import torch
import plotly.express as px
import plotly.io as pio
pio.renderers.default='notebook'

def master_plot_function(
    volume: torch.Tensor, 
    gray_min: float | None = None, 
    gray_max: float | None = None,
    gray_mid: float | None = None,
):
    gray_min = volume.min().item() if gray_min is None else gray_min
    gray_max = volume.max().item() if gray_max is None else gray_max
    gray_mid = (gray_min + gray_max) / 2 if gray_mid is None else gray_mid

    max_idx = torch.nonzero(volume == volume.max())
    print("Location(s) of maximum value in result:", max_idx.tolist())
    z_max = max_idx[0, 0].item() if len(max_idx) > 0 else 0
    px.imshow(volume[z_max, :, :],
              origin="lower",
              color_continuous_scale="viridis",
              zmin=gray_min,
              zmax=gray_max,
              title=f"Slice {z_max} with threshold zmin={gray_min:.1f}",
    ).show()


from torch_affine_utils.transforms_3d import Rx, Ry, Rz, T
from torch_grid_utils import dft_center
from torch_transform_image import affine_transform_image_3d

In [35]:
image = torch.zeros((28, 28, 28), dtype=torch.float32)
image[14, 7, 14] = 1
image = image.float()

center_dot = torch.zeros((28, 28, 28), dtype=torch.float32)
center_dot[14, 14, 14] = 1
center_dot = center_dot.float()

master_plot_function(image + center_dot*0.1, gray_mid=0)

Location(s) of maximum value in result: [[14, 7, 14]]


In [133]:
center = dft_center(image.shape, fftshift=True, rfft=False)
rotate_zyx=[90, 0, 0]
shifts_zyx=[0, 0, 5]

# shift then rotate
m = (
    T(center) @
    Rz(rotate_zyx[0], zyx=True) @
    Ry(rotate_zyx[1], zyx=True) @
    Rx(rotate_zyx[2], zyx=True) @
    T(shifts_zyx) @
    T(-center)
)

# z = 14, y=7, x=14
print(m @ torch.tensor([14, 7, 14, 1]).float())
# we expect to end up with z=14, y=21, x=19

# linalg.inv
m = torch.linalg.inv(m)

master_plot_function(affine_transform_image_3d(image, m, interpolation='trilinear', zyx_matrices=True) + center_dot*0.1, gray_mid=0)

tensor([14., 19., 21.,  1.])
Location(s) of maximum value in result: [[14, 19, 21]]
