In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import SimpleITK as sitk
from armscan_env import config
from armscan_env.clustering import TissueClusters
from armscan_env.envs.rewards import anatomy_based_rwd
from armscan_env.envs.state_action import ManipulatorAction
from armscan_env.util.visualizations import show_clusters
from armscan_env.volumes.loading import load_sitk_volumes
from armscan_env.volumes.volumes import TransformedVolume

config = config.get_config()
volumes = load_sitk_volumes(normalize=True)

In [None]:
volume = volumes[4]
volume_img = sitk.GetArrayFromImage(volume)

x_size, y_size, z_size = (
    sz * sp for sz, sp in zip(volume.GetSize(), volume.GetSpacing(), strict=True)
)
extent_xy = (0, x_size, y_size, 0)

plt.imshow(volume_img[40, :, :], extent=extent_xy)

o = volume.GetOrigin()
x_dash = np.arange(x_size)
b = volume.optimal_action.translation[1]
y_dash = x_dash * np.tan(np.deg2rad(volume.optimal_action.rotation[0])) + b
plt.plot(x_dash, y_dash, linestyle="--", color="red")

plt.show()

In [None]:
sliced_volume = volume.get_volume_slice(
    action=volume.optimal_action,
    slice_shape=(volume.GetSize()[0], volume.GetSize()[2]),
)
sliced_img = sitk.GetArrayFromImage(sliced_volume)
print(f"Slice value range: {np.min(sliced_img)} - {np.max(sliced_img)}")

extent_xz = (0, x_size, z_size, 0)
cluster = TissueClusters.from_labelmap_slice(sliced_img.T)
show_clusters(cluster, sliced_img.T)
reward = anatomy_based_rwd(cluster)
print(f"Reward: {reward}")
plt.show()

In [None]:
volume_transformation = ManipulatorAction(
    rotation=(-7.213170270886784, 0.0),
    translation=(-7.31243280019082, 9.172539411055304),
)
transformed_volume = TransformedVolume.create_transformed_volume(volume, volume_transformation)
transformed_action = transformed_volume.optimal_action
print(f"{volume.optimal_action=}\n{transformed_volume.optimal_action=}\n")

In [None]:
transformed_img = sitk.GetArrayFromImage(transformed_volume)

plt.imshow(transformed_img[40, :, :], extent=extent_xy)

ot = transformed_volume.GetOrigin()
x_dash = np.arange(x_size)
b = transformed_action.translation[1]
y_dash = x_dash * np.tan(np.deg2rad(transformed_action.rotation[0])) + b
plt.plot(x_dash, y_dash, linestyle="--", color="red")

plt.show()

In [None]:
sliced_transformed_volume = transformed_volume.get_volume_slice(
    action=transformed_action,
    slice_shape=(volume.GetSize()[0], volume.GetSize()[2]),
)
sliced_transformed_img = sitk.GetArrayFromImage(sliced_transformed_volume)
print(f"Slice value range: {np.min(sliced_transformed_img)} - {np.max(sliced_transformed_img)}")

cluster = TissueClusters.from_labelmap_slice(sliced_transformed_img.T)
show_clusters(cluster, sliced_transformed_img.T)
reward = anatomy_based_rwd(cluster)
print(f"Reward: {reward}")
plt.show()