In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import SimpleITK as sitk
from armscan_env.config import get_config
from armscan_env.envs.base import EnvRollout
from armscan_env.envs.labelmaps_navigation import (
    LabelmapClusteringBasedReward,
    LabelmapEnv,
    LabelmapEnvTerminationCriterion,
)
from armscan_env.envs.observations import (
    ActionRewardObservation,
)
from armscan_env.wrapper import ArmscanEnvFactory
from tqdm import tqdm

from tianshou.highlevel.env import EnvMode

config = get_config()

# The scanning sub-problem in fewer dimensions

In [None]:
def walk_through_env(
    env: LabelmapEnv,
    n_steps: int = 10,
    reset: bool = True,
    show_pbar: bool = True,
    render_title: str = "Labelmap slice",
) -> EnvRollout:
    env_rollout = EnvRollout()

    if reset:
        obs, info = env.reset()
        env.render(title=render_title)

        # add initial state to the rollout
        reward = env.compute_cur_reward()
        terminated = env.should_terminate()
        truncated = env.should_truncate()
        env_rollout.append_reset(
            obs,
            info,
            reward=reward,
            terminated=terminated,
            truncated=truncated,
        )

    env_is_1d = env.action_space.shape == (1,)

    y_lower_bound = -1 if env_is_1d else env.translation_bounds[0]
    y_upper_bound = 1 if env_is_1d else env.translation_bounds[1]

    y_actions = np.linspace(y_lower_bound, y_upper_bound, n_steps)
    if show_pbar:
        y_actions = tqdm(y_actions, desc="Step:")

    print(f"Walking through y-axis from {y_lower_bound} to {y_upper_bound} in {n_steps} steps")
    for y_action in y_actions:
        if not env_is_1d:
            cur_y_action = env.get_optimal_action()
            cur_y_action.translation = (cur_y_action.translation[0], y_action)
            cur_y_action = cur_y_action.to_normalized_array(
                rotation_bounds=env.rotation_bounds,
                translation_bounds=env.translation_bounds,
            )
        else:
            # projected environment
            cur_y_action = np.array([y_action])
        obs, reward, terminated, truncated, info = env.step(cur_y_action)

        env_rollout.append_step(cur_y_action, obs, reward, terminated, truncated, info)
        env.render(title=render_title)
    return env_rollout


def plot_rollout_rewards(env_rollout: EnvRollout, show: bool = True) -> None:
    plt.plot(env_rollout.rewards)

    steps_where_terminated = np.where(env_rollout.terminated)[0]
    # mark the steps where the environment was terminated with a red transparent rectangle
    # and add a legend that red means terminated
    for step in steps_where_terminated:
        plt.axvspan(step - 0.5, step + 0.5, color="red", alpha=0.5)

    plt.xlabel("Step")
    plt.ylabel("Reward")

    plt.legend(["Reward", "Terminated"])

    if show:
        plt.show()

In [None]:
volume_1 = sitk.ReadImage(config.get_labels_path(1))
volume_2 = sitk.ReadImage(config.get_labels_path(2))
img_array_1 = sitk.GetArrayFromImage(volume_1)
img_array_2 = sitk.GetArrayFromImage(volume_2)

In [None]:
volume_size = volume_1.GetSize()

env = ArmscanEnvFactory(
    name2volume={"1": volume_1},
    observation=ActionRewardObservation(action_shape=(4,)).to_array_observation(),
    slice_shape=(volume_size[0], volume_size[2]),
    reward_metric=LabelmapClusteringBasedReward(),
    termination_criterion=LabelmapEnvTerminationCriterion(),
    max_episode_len=10,
    rotation_bounds=(30.0, 10.0),
    translation_bounds=(0.0, None),
    render_mode="animation",
    n_stack=2,
).create_env(EnvMode.WATCH)

In [None]:
env_rollout = walk_through_env(env, 10)

plot_rollout_rewards(env_rollout)

In [None]:
env.get_cur_animation_as_html()

In [None]:
volume_size = volume_1.GetSize()

projected_env = ArmscanEnvFactory(
    name2volume={"1": volume_1},
    observation=ActionRewardObservation(action_shape=(1,)).to_array_observation(),
    slice_shape=(volume_size[0], volume_size[2]),
    reward_metric=LabelmapClusteringBasedReward(),
    termination_criterion=LabelmapEnvTerminationCriterion(),
    max_episode_len=10,
    rotation_bounds=(30.0, 10.0),
    translation_bounds=(0.0, None),
    render_mode="animation",
    n_stack=2,
    project_actions_to="y",
    apply_volume_transformation=True,
).create_env(EnvMode.WATCH)

In [None]:
projected_env_rollout = walk_through_env(
    projected_env,
    10,
    render_title="Projected labelmap slice",
)
plot_rollout_rewards(projected_env_rollout)

In [None]:
print(
    "Observed 'rewards': \n",
    [round(obs[1][-1], 4) for obs in projected_env_rollout.observations],
)
print("Env rewards: \n", [round(r, 4) for r in projected_env_rollout.rewards])

In [None]:
projected_env.get_cur_animation_as_html()