# Environment, all in one

In [None]:
import os

import matplotlib.pyplot as plt
import numpy as np
import SimpleITK as sitk
from armscan_env.clustering import find_DBSCAN_clusters
from armscan_env.envs.rewards import anatomy_based_rwd
from armscan_env.slicing import slice_volume
from armscan_env.util.visualizations import show_clusters

We can now put everything together in a single environment. We will use the `slice_volume` function to create a 2D slice of the 3D volume, and then we will use the `find_DBSCAN_clusters` function to find the clusters of pixels that correspond to the different tissues. Finally, we will use the `anatomy_based_rwd` function to calculate the reward based on the anatomy of the arm.

In [None]:
# Load the image data
path_to_labels = os.path.join("../..", "data", "labels", "00001_labels.nii")
volume = sitk.ReadImage(path_to_labels)
img_array = sitk.GetArrayFromImage(volume)

In [None]:
tissues = {"bones": 1, "tendons": 2, "ulnar": 3}

In [None]:
from celluloid import Camera

t = [160, 155, 150, 148, 146, 142, 140, 140, 115, 120, 125, 125, 130, 130, 135, 138, 140, 140, 140]
z = [0, -5, 0, 0, 5, 15, 19.3, -10, 0, 0, 0, 5, -8, 8, 0, -10, -10, 10, 19.3]
o = volume.GetOrigin()


# Sample functions for demonstration
def linear_function(x: np.ndarray, m: float, b: float) -> np.ndarray:
    return m * x + b


# Create a figure and a gridspec with two rows and two columns
fig = plt.figure(constrained_layout=True, figsize=(8, 6))
gs = fig.add_gridspec(2, 2)
camera = Camera(fig)

# Add subplots
ax1 = fig.add_subplot(gs[:, 0])
ax2 = fig.add_subplot(gs[0, 1])
ax3 = fig.add_subplot(gs[1, 1])

for i in range(len(t)):
    # Subplot 1: Image with dashed line
    ax1.imshow(img_array[40, :, :])
    x_dash = np.arange(img_array.shape[2])
    b = volume.TransformPhysicalPointToIndex([o[0], o[1] + t[i], o[2]])[1]
    y_dash = linear_function(x_dash, np.tan(np.deg2rad(z[i])), b)
    ax1.set_title(f"Section {0}")
    line = ax1.plot(x_dash, y_dash, linestyle="--", color="red")[0]
    ax1.set_title("Slice cut")

    # ACTION
    sliced_volume = slice_volume(
        z_rotation=z[i],
        x_rotation=0.0,
        translation=np.array([0, t[i], 0]),
        volume=volume,
    )
    sliced_img = sitk.GetArrayFromImage(sliced_volume)[:, 0, :]
    ax2.imshow(sliced_img, origin="lower", aspect=6)
    ax2.set_title(f"Slice {i}")

    # OBSERVATION
    print(i)
    clusters = {
        "bones": find_DBSCAN_clusters(tissues["bones"], sliced_img, eps=4.1, min_samples=46),
        "tendons": find_DBSCAN_clusters(tissues["tendons"], sliced_img, eps=4.1, min_samples=46),
        "ulnar": find_DBSCAN_clusters(tissues["ulnar"], sliced_img, eps=2.5, min_samples=18),
    }
    ax3 = show_clusters(clusters, sliced_img, ax3)
    # ax3.set_title(f'Clusters {i}')

    # REWARD
    loss = anatomy_based_rwd(clusters)
    plt.text(0, 0, f"Loss: {loss:.2f}", fontsize=12, color="red")

    camera.snap()
    plt.close()

In [None]:
animation = camera.animate()
plt.rcParams["animation.html"] = "jshtml"
animation  # noqa: B018