In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import SimpleITK as sitk
from armscan_env import config
from armscan_env.clustering import TissueClusters
from armscan_env.envs.rewards import anatomy_based_rwd
from armscan_env.envs.state_action import ManipulatorAction
from armscan_env.util.visualizations import show_clusters
from armscan_env.volumes.slicing import (
    create_transformed_volume,
    get_volume_slice,
)

config = config.get_config()

In [None]:
volume = sitk.ReadImage(config.get_labels_path(1))
volume_img = sitk.GetArrayFromImage(volume)

x_size, y_size, z_size = (
    sz * sp for sz, sp in zip(volume.GetSize(), volume.GetSpacing(), strict=True)
)
extent_xy = (0, x_size, y_size, 0)

plt.imshow(volume_img[40, :, :], extent=extent_xy)
action = ManipulatorAction(rotation=(19, 0), translation=(0, 140))

o = volume.GetOrigin()
x_dash = np.arange(x_size)
b = action.translation[1]
y_dash = x_dash * np.tan(np.deg2rad(action.rotation[0])) + b
plt.plot(x_dash, y_dash, linestyle="--", color="red")

plt.show()

In [None]:
sliced_volume = get_volume_slice(
    action=action,
    volume=volume,
    slice_shape=(volume.GetSize()[0], volume.GetSize()[2]),
)
sliced_img = sitk.GetArrayFromImage(sliced_volume)
print(f"Slice value range: {np.min(sliced_img)} - {np.max(sliced_img)}")

extent_xz = (0, x_size, z_size, 0)
plt.imshow(sliced_img, extent=extent_xz)
plt.show()

In [None]:
transform = sitk.Euler3DTransform()
transform.SetRotation(0, 0, np.deg2rad(19))
transform.SetTranslation((0, 10, 0))
transform.SetCenter(volume.GetOrigin())
resampled = sitk.Resample(volume, transform, sitk.sitkNearestNeighbor, 0.0, volume.GetPixelID())
plt.imshow(sitk.GetArrayFromImage(resampled)[40, :, :], extent=extent_xy)

In [None]:
transformation_action = ManipulatorAction(rotation=(19, 0), translation=(0, 140))
relative_action = ManipulatorAction(rotation=(0, 0), translation=(0, 0))

In [None]:
volume_rotation = np.deg2rad(transformation_action.rotation)
volume_translation = transformation_action.translation

volume_transform = sitk.Euler3DTransform()
volume_transform.SetRotation(volume_rotation[1], 0, volume_rotation[0])
volume_transform.SetTranslation((*volume_translation, 0))

inverse_volume_transform = volume_transform.GetInverse()
inverse_volume_transform_matrix = np.eye(4)
inverse_volume_transform_matrix[:3, :3] = np.array(inverse_volume_transform.GetMatrix()).reshape(
    3,
    3,
)
inverse_volume_transform_matrix[:3, 3] = inverse_volume_transform.GetTranslation()

action_rotation = np.deg2rad(relative_action.rotation)
action_translation = relative_action.translation
action_transform = sitk.Euler3DTransform()
action_transform.SetRotation(action_rotation[1], 0, action_rotation[0])
action_transform.SetTranslation((*action_translation, 0))

In [None]:
composite = sitk.CompositeTransform(3)
composite.AddTransform(inverse_volume_transform)
composite.AddTransform(action_transform)

In [None]:
volume_transformation = ManipulatorAction(rotation=(19, 0), translation=(-9.74, -4.31))
transformed_volume = create_transformed_volume(volume, volume_transformation)
transformed_action = transformed_volume.transform_action(action)

In [None]:
print(f"{action=}\n{transformed_action=}\n")

In [None]:
transformed_img = sitk.GetArrayFromImage(transformed_volume)

plt.imshow(transformed_img[40, :, :], extent=extent_xy)

ot = transformed_volume.GetOrigin()
x_dash = np.arange(x_size)
b = transformed_action.translation[1]
y_dash = x_dash * np.tan(np.deg2rad(transformed_action.rotation[0])) + b
plt.plot(x_dash, y_dash, linestyle="--", color="red")

plt.show()

In [None]:
sliced_transformed_volume = get_volume_slice(
    action=transformed_action,
    volume=transformed_volume,
    slice_shape=(volume.GetSize()[0], volume.GetSize()[2]),
)
sliced_transformed_img = sitk.GetArrayFromImage(sliced_transformed_volume)
print(f"Slice value range: {np.min(sliced_transformed_img)} - {np.max(sliced_transformed_img)}")

plt.imshow(sliced_transformed_img, extent=extent_xz)
plt.show()

In [None]:
cluster = TissueClusters.from_labelmap_slice(sliced_transformed_img.T)
show_clusters(cluster, sliced_transformed_img.T)
reward = anatomy_based_rwd(cluster)
print(f"Reward: {reward}")
plt.show()

In [None]:
volume_2 = sitk.ReadImage(config.get_labels_path(2))
volume_2_img = sitk.GetArrayFromImage(volume_2)
x_size_2, y_size_2, z_size_2 = (
    sz * sp for sz, sp in zip(volume_2.GetSize(), volume_2.GetSpacing(), strict=True)
)
extent_xy_2 = (0, x_size_2, y_size_2, 0)

spacing = volume_2.GetSpacing()
plt.imshow(volume_2_img[51, :, :], extent=extent_xy_2)
action_2 = ManipulatorAction(rotation=(5, 0), translation=(0, 112))

o = volume_2.GetOrigin()
x_dash = np.arange(x_size_2)
b = action_2.translation[1]
y_dash = x_dash * np.tan(np.deg2rad(action_2.rotation[0])) + b
plt.plot(x_dash, y_dash, linestyle="--", color="red")

plt.show()

In [None]:
sliced_volume_2 = get_volume_slice(
    action=action_2,
    volume=volume_2,
    slice_shape=(volume_2.GetSize()[0], volume_2.GetSize()[2]),
)
sliced_img_2 = sitk.GetArrayFromImage(sliced_volume_2)

cluster = TissueClusters.from_labelmap_slice(sliced_img_2.T)
show_clusters(cluster, sliced_img_2.T, aspect=spacing[2] / spacing[0])

plt.show()

In [None]:
volume_transformation_2 = ManipulatorAction(rotation=(10, 0), translation=(-9.74, -4.31))
transformed_volume_2 = create_transformed_volume(volume_2, volume_transformation_2)
transformed_action_2 = transformed_volume_2.transform_action(action_2)

In [None]:
transformed_img_2 = sitk.GetArrayFromImage(transformed_volume_2)

plt.imshow(transformed_img_2[51, :, :], extent=extent_xy_2)

x_dash = np.arange(x_size_2)
b = transformed_action_2.translation[1]
y_dash = x_dash * np.tan(np.deg2rad(transformed_action_2.rotation[0])) + b
plt.plot(x_dash, y_dash, linestyle="--", color="red")

plt.show()

In [None]:
sliced_transformed_volume_2 = get_volume_slice(
    action=transformed_action_2,
    volume=transformed_volume_2,
    slice_shape=(volume_2.GetSize()[0], volume_2.GetSize()[2]),
)
sliced_transformed_img_2 = sitk.GetArrayFromImage(sliced_transformed_volume_2)

cluster = TissueClusters.from_labelmap_slice(sliced_transformed_img_2.T)
show_clusters(cluster, sliced_transformed_img_2.T, aspect=spacing[2] / spacing[0])
reward = anatomy_based_rwd(cluster)
print(f"Reward: {reward}")

plt.show()