# Environment, all in one

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os

import matplotlib.pyplot as plt
import numpy as np
import SimpleITK as sitk
from armscan_env.clustering import TissueClusters
from armscan_env.envs.rewards import anatomy_based_rwd
from armscan_env.slicing import slice_volume
from armscan_env.util.visualizations import show_clusters

We can now put everything together in a single environment. We will use the `slice_volume` function to create a 2D slice of the 3D volume, and then we will use the `find_DBSCAN_clusters` function to find the clusters of pixels that correspond to the different tissues. Finally, we will use the `anatomy_based_rwd` function to calculate the reward based on the anatomy of the arm.

In [3]:
# Load the image data
path_to_labels = os.path.join("../..", "data", "labels", "00001_labels.nii")
volume = sitk.ReadImage(path_to_labels)
img_array = sitk.GetArrayFromImage(volume)

In [4]:
from celluloid import Camera

t = [160, 155, 150, 148, 146, 142, 140, 140, 115, 120, 125, 125, 130, 130, 135, 138, 140, 140, 140]
z = [0, -5, 0, 0, 5, 15, 19.3, -10, 0, 0, 0, 5, -8, 8, 0, -10, -10, 10, 19.3]
o = volume.GetOrigin()


# Sample functions for demonstration
def linear_function(x: np.ndarray, m: float, b: float) -> np.ndarray:
    return m * x + b


# Create a figure and a gridspec with two rows and two columns
fig = plt.figure(constrained_layout=True, figsize=(8, 6))
gs = fig.add_gridspec(2, 2)
camera = Camera(fig)

# Add subplots
ax1 = fig.add_subplot(gs[:, 0])
ax2 = fig.add_subplot(gs[0, 1])
ax3 = fig.add_subplot(gs[1, 1])

for i in range(len(t)):
    # Subplot 1: Image with dashed line
    ax1.imshow(img_array[40, :, :])
    x_dash = np.arange(img_array.shape[2])
    b = volume.TransformPhysicalPointToIndex([o[0], o[1] + t[i], o[2]])[1]
    y_dash = linear_function(x_dash, np.tan(np.deg2rad(z[i])), b)
    ax1.set_title(f"Section {0}")
    line = ax1.plot(x_dash, y_dash, linestyle="--", color="red")[0]
    ax1.set_title("Slice cut")

    # ACTION
    sliced_volume = slice_volume(volume=volume, z_rotation=z[i], x_rotation=0.0, y_trans=t[i])
    sliced_img = sitk.GetArrayFromImage(sliced_volume)[:, 0, :]
    ax2.imshow(sliced_img, origin="lower", aspect=6)
    ax2.set_title(f"Slice {i}")

    # OBSERVATION
    clusters = TissueClusters.from_labelmap_slice(sliced_img)
    ax3 = show_clusters(clusters, sliced_img, ax3)
    # ax3.set_title(f'Clusters {i}')

    # REWARD
    loss = anatomy_based_rwd(clusters)
    plt.text(0, 0, f"Loss: {loss:.2f}", fontsize=12, color="red")

    camera.snap()
    plt.close()

In [5]:
animation = camera.animate()
plt.rcParams["animation.html"] = "jshtml"
animation  # noqa: B018

In [6]:
from armscan_env.envs.labelmaps_navigation import (
    LabelmapClusteringBasedReward,
    LabelmapEnv,
)

reward_metric = LabelmapClusteringBasedReward()
volume_dict = {"1": volume}
volume_size = volume.GetSize()

env = LabelmapEnv(
    name2volume={"1": volume},
    slice_shape=(volume_size[2], volume_size[0]),
    reward_metric=reward_metric,
    termination_criterion=None,
    max_episode_len=100,
    angle_bounds=(90.0, 45.0),
    translation_bounds=(10.0, 10.0),
)
observation, info = env.reset()
print(f"{observation.shape=}\n, {info=}")
for _ in range(10):
    action = env.action_space.sample()
    print(f"{action=}")
    observation, reward, terminated, truncated, info = env.step(action)
    print(f"{observation.shape=},\n {info=},\n {reward=},\n {terminated=},\n {truncated=} ")

    if terminated or truncated:
        observation, info = env.reset()
        print(f"{observation.shape=}, {info=}")
env.close()

observation.shape=(61, 606)
, info={}
action=array([-0.07910593,  0.17333129,  0.8294553 ,  0.28365666], dtype=float32)
observation.shape=(61, 606),
 info={},
 reward=0.9333333333333332,
 terminated=False,
 truncated=False 
action=array([0.28816175, 0.40485442, 0.66519904, 0.7015066 ], dtype=float32)
observation.shape=(61, 606),
 info={},
 reward=0.9333333333333332,
 terminated=False,
 truncated=False 
action=array([-0.14100526, -0.96696305,  0.60765755, -0.22683713], dtype=float32)
observation.shape=(61, 606),
 info={},
 reward=0.9333333333333332,
 terminated=False,
 truncated=False 
action=array([-0.76914096, -0.85776454, -0.38026178,  0.98030406], dtype=float32)
observation.shape=(61, 606),
 info={},
 reward=0.9333333333333332,
 terminated=False,
 truncated=False 
action=array([-0.03853408, -0.6840921 ,  0.02260824, -0.5352145 ], dtype=float32)
observation.shape=(61, 606),
 info={},
 reward=0.9333333333333332,
 terminated=False,
 truncated=False 
action=array([ 0.21744339, -0.757844