# Environment, all in one

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os

import matplotlib.pyplot as plt
import numpy as np
import SimpleITK as sitk
from armscan_env.clustering import TissueClusters
from armscan_env.envs.rewards import anatomy_based_rwd
from armscan_env.slicing import slice_volume
from armscan_env.util.visualizations import show_clusters
from IPython.core.display import HTML

We can now put everything together in a single environment. We will use the `slice_volume` function to create a 2D slice of the 3D volume, and then we will use the `find_DBSCAN_clusters` function to find the clusters of pixels that correspond to the different tissues. Finally, we will use the `anatomy_based_rwd` function to calculate the reward based on the anatomy of the arm.

In [None]:
# Load the image data
path_to_labels = os.path.join("../..", "data", "labels", "00001_labels.nii")
volume = sitk.ReadImage(path_to_labels)
img_array = sitk.GetArrayFromImage(volume)

In [None]:
from celluloid import Camera

t = [160, 155, 150, 148, 146, 142, 140, 140, 115, 120, 125, 125, 130, 130, 135, 138, 140, 140, 140]
z = [0, -5, 0, 0, 5, 15, 19.3, -10, 0, 0, 0, 5, -8, 8, 0, -10, -10, 10, 19.3]
o = volume.GetOrigin()


# Sample functions for demonstration
def linear_function(x: np.ndarray, m: float, b: float) -> np.ndarray:
    return m * x + b


# Create a figure and a gridspec with two rows and two columns
fig = plt.figure(constrained_layout=True, figsize=(8, 6))
gs = fig.add_gridspec(2, 2)
camera = Camera(fig)

# Add subplots
ax1 = fig.add_subplot(gs[:, 0])
ax2 = fig.add_subplot(gs[0, 1])
ax3 = fig.add_subplot(gs[1, 1])

for i in range(len(t)):
    # Subplot 1: Image with dashed line
    ax1.imshow(img_array[40, :, :])
    x_dash = np.arange(img_array.shape[2])
    b = volume.TransformPhysicalPointToIndex([o[0], o[1] + t[i], o[2]])[1]
    y_dash = linear_function(x_dash, np.tan(np.deg2rad(z[i])), b)
    ax1.set_title(f"Section {0}")
    line = ax1.plot(x_dash, y_dash, linestyle="--", color="red")[0]
    ax1.set_title("Slice cut")

    # ACTION
    sliced_volume = slice_volume(volume=volume, z_rotation=z[i], x_rotation=0.0, y_trans=t[i])
    sliced_img = sitk.GetArrayFromImage(sliced_volume)[:, 0, :]
    ax2.imshow(sliced_img, origin="lower", aspect=6)
    ax2.set_title(f"Slice {i}")

    # OBSERVATION
    clusters = TissueClusters.from_labelmap_slice(sliced_img)
    ax3 = show_clusters(clusters, sliced_img, ax3)
    ax3.set_title(f"Clusters {i}")

    # REWARD
    loss = anatomy_based_rwd(clusters)
    ax3.text(0, 0, f"Loss: {loss:.2f}", fontsize=12, color="red")

    camera.snap()
    plt.close()

In [None]:
animation = camera.animate()
HTML(animation.to_jshtml())

Rotations are defined in degrees, and translations are defined in millimeters. In order for the agent to take meaningful actions, we need to define the action space by bounds. Rotation bounds are set to 180 degrees, since a greater angle can be achieved by rotating in the opposite direction. Translation bounds are set to stay within the image bounds.
The physical dimension of the volume is expressed in mm. It is calculated by taking the difference between the physical coordinates of the first and last voxel in the volume.

In [None]:
origin = volume.GetOrigin()
spacing = volume.GetSpacing()
size = volume.GetSize()
end = volume.TransformIndexToPhysicalPoint(size)
print(f"{origin=},\n {spacing=},\n {end=}")
dim = np.subtract(end, origin)
physical_size = size * np.array(spacing)
index_dim = dim / spacing
print(f"{dim=} == {physical_size},\n {index_dim=} == {size=}")

In [None]:
from armscan_env.envs.labelmaps_navigation import (
    LabelmapClusteringBasedReward,
    LabelmapEnv,
)

reward_metric = LabelmapClusteringBasedReward()
volume_dict = {"1": volume}
volume_size = volume.GetSize()

env = LabelmapEnv(
    name2volume={"1": volume},
    slice_shape=(volume_size[2], volume_size[0]),
    reward_metric=reward_metric,
    termination_criterion=None,
    max_episode_len=100,
    angle_bounds=(90.0, 45.0),
    translation_bounds=(0.0, physical_size[1]),
    render_mode="animation",
)
observation, info = env.reset(reset_translation_bounds=False)
print(f"{env.get_translation_bounds()=}")
print(f"{observation.shape=},\n {info=}")
for _ in range(10):
    action = env.action_space.sample()
    print(f"{action=}")
    observation, reward, terminated, truncated, info = env.step(action)
    print(f"{observation.shape=},\n {info=},\n {reward=},\n {terminated=},\n {truncated=} ")
    env.render()

    if terminated or truncated:
        observation, info = env.reset(reset_translation_bounds=False)
        print(f"{observation.shape=}, {info=}")
animation = env.get_cur_animation()
env.close()

In [None]:
HTML(animation.to_jshtml())