In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import SimpleITK as sitk
from armscan_env.config import get_config
from armscan_env.envs.labelmaps_navigation import (
    LabelmapClusteringBasedReward,
    LabelmapEnv,
    LabelmapEnvTerminationCriterion,
)
from armscan_env.envs.observations import LabelmapSliceAsChannelsObservation
from armscan_env.wrapper import LinearSweepWrapper
from IPython.core.display import HTML

config = get_config()

# The scanning sub-problem in fewer dimensions

In [None]:
volume_1 = sitk.ReadImage(config.get_labels_path(1))
volume_2 = sitk.ReadImage(config.get_labels_path(2))
img_array_1 = sitk.GetArrayFromImage(volume_1)
img_array_2 = sitk.GetArrayFromImage(volume_2)

In [None]:
volume_size = volume_1.GetSize()

env = LabelmapEnv(
    name2volume={"1": volume_1},
    observation=LabelmapSliceAsChannelsObservation(
        slice_shape=(volume_size[0], volume_size[2]),
        action_shape=(4,),
    ),
    slice_shape=(volume_size[0], volume_size[2]),
    reward_metric=LabelmapClusteringBasedReward(),
    termination_criterion=LabelmapEnvTerminationCriterion(),
    max_episode_len=10,
    rotation_bounds=(30.0, 10.0),
    translation_bounds=(0.0, None),
    render_mode="animation",
)

In [None]:
y_slice_rewards = []

env.reset()
for y_action in np.linspace(0, env.translation_bounds[1], 500):
    cur_y_action = env.get_optimal_action()
    cur_y_action.translation = (cur_y_action.translation[0], y_action)
    observation, reward, terminated, truncated, info = env.step(cur_y_action)
    y_slice_rewards.append(reward)
    env.render()

    if terminated or truncated:
        observation, info = env.reset(reset_render=True)
animation = env.get_cur_animation()
env.close()

In [None]:
HTML(animation.to_jshtml())

In [None]:
plt.plot(np.linspace(0, env.translation_bounds[1], 500), y_slice_rewards)
plt.xlabel("Y translation")
plt.ylabel("Reward")
plt.show()

In [None]:
volume_size = volume_1.GetSize()

env = LabelmapEnv(
    name2volume={"1": volume_1, "2": volume_2},
    observation=LabelmapSliceAsChannelsObservation(
        slice_shape=(volume_size[0], volume_size[2]),
        action_shape=(4,),
    ),
    slice_shape=(volume_size[0], volume_size[2]),
    reward_metric=LabelmapClusteringBasedReward(),
    termination_criterion=LabelmapEnvTerminationCriterion(),
    max_episode_len=10,
    rotation_bounds=(30.0, 10.0),
    translation_bounds=(0.0, None),
    render_mode="animation",
)

In [None]:
env = LinearSweepWrapper(env)

In [None]:
env.reset()
for _ in range(50):
    action = env.action_space.sample()
    observation, reward, terminated, truncated, info = env.step(action)
    env.render()

    if terminated or truncated:
        observation, info = env.reset(reset_render=False)
animation = env.get_cur_animation()
env.close()

In [None]:
HTML(animation.to_jshtml())