# **Visualize a prediction of the model**

#### Matplotlib backend

In [None]:
# For interactive plots (needs ipympl and ipywidgets)
# %matplotlib widget

# For static plots
%matplotlib inline

#### Imports

In [None]:
# Standard libraries
import pathlib
import sys

# Add the src directory to the system path
# (to avoid having to install project as a package)
sys.path.append("../src/")

# Third-party libraries
import torch
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import numpy as np
from mpl_toolkits.axes_grid1 import make_axes_locatable

# Custom modules
from toolbox.modules.object_segmentation_prediction_module import (
    ObjectSegmentationPredictionModule,
    BatchInferenceData,
)
from toolbox.modules.probabilistic_segmentation_lookup import (
    ProbabilisticSegmentationLookup,
)
from toolbox.modules.probabilistic_segmentation_mlp import (
    ProbabilisticSegmentationMLP,
)

## Instanciate the model and load pre-trained weights

In [None]:
# Load the state_dict of the model from the checkpoint
runs_dir = pathlib.Path("../logs/train/runs")
checkpoint = None
# checkpoint = "2024-05-14_05-04-08*"

if checkpoint is None:
    # Get the last checkpoint
    logs_dir = sorted(runs_dir.iterdir())[-1]
else:
    # Get the checkpoint with the specified name
    logs_dir = runs_dir / checkpoint

print(f"Loading checkpoint from {logs_dir}")

train_module_state_dict = torch.load(logs_dir / "checkpoints/last.ckpt").get("state_dict")


# TODO: Reorganize
# probabilistic_segmentation_model = ProbabilisticSegmentationLookup(
#     use_histograms=False,
#     output_logits=False,
# )

probabilistic_segmentation_model = ProbabilisticSegmentationMLP(
    patch_size=5,
    output_logits=False,
)


# Instantiate the model used for prediction
prediction_module = ObjectSegmentationPredictionModule(
    probabilistic_segmentation_model=probabilistic_segmentation_model,
)

def match_state_dict(state_dict: dict, model: torch.nn.Module) -> dict:
    """Extract the state_dict of the model from an other state_dict by matching their
    keys.

    Args:
        state_dict (dict): The state_dict from which to extract the model's state_dict.
        model (torch.nn.Module): The model for which to extract the state_dict.

    Returns:
        dict: The state_dict of the model.
    """
    model_state_dict = model.state_dict()
    new_state_dict = {
        key: value
        for key, value in state_dict.items()
        if key in model_state_dict
    }
    
    model_state_dict.update(new_state_dict)
    
    return model_state_dict

# Get the state_dict of the model used for prediction from the pretrained model
prediction_module_state_dict = match_state_dict(
    train_module_state_dict,
    prediction_module,
)

# Load the state_dict into the model
prediction_module.load_state_dict(prediction_module_state_dict)

## Load a sample image and set a bounding box

In [None]:
image = plt.imread("images/gso_sample.png")

# Convert [0, 1] to [0, 255] and to uint8
image = (image * 255).astype(np.uint8)

fig, ax = plt.subplots()
ax.imshow(image)
ax.axis("off")

plt.title("Input image")
plt.tight_layout()
plt.show()

In [None]:
# Set the bounding box of the object to segment
# bbox = [265, 190, 420, 335]  # cat (cat_rbot.png)
# bbox = [275, 80, 410, 190]  # horse (cat_rbot.png)
# bbox = [0, 260, 170, 400]  # glue gun (cat_rbot.png)

bbox = [99, 247, 225, 397]  # turquoise bowl (gso_sample.png)
# bbox = [250, 120, 639, 479]  # epson box (gso_sample.png)

# Draw the bounding box on the image
fig, ax = plt.subplots()
ax.imshow(image)
rect = patches.Rectangle(
    (bbox[0], bbox[1]),
    bbox[2] - bbox[0],
    bbox[3] - bbox[1],
    linewidth=1,
    edgecolor="r",
    facecolor="none",
)
ax.add_patch(rect)
ax.axis("off")
plt.title("Input image with bounding box")
plt.tight_layout()
plt.show()

## Construct the input data for the model

In [None]:
input = BatchInferenceData(
    rgbs=torch.tensor(image).permute(2, 0, 1).unsqueeze(0),
    contour_points_list=[
            # First example of the batch
            [np.array(bbox).reshape(-1, 2),],
            # Second example of the batch...
        ],
)

## Model prediction

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# input = input.to(device)
prediction_module = prediction_module.to(device)

In [None]:
predicted_probabilistic_masks, sam_masks = prediction_module(input)

## Visualize the prediction

In [None]:
# MobileSAM mask
fig, ax = plt.subplots()
sam_mask = sam_masks.squeeze().cpu().numpy()
ax.imshow(sam_mask, cmap="gray")
ax.axis("off")
plt.tight_layout()
plt.show()

In [None]:
# Module mask
predicted_probabilistic_mask = predicted_probabilistic_masks.squeeze().cpu().numpy()

In [None]:
# Display the mask
fig, ax = plt.subplots()
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05)

# Mask
img = ax.imshow(predicted_probabilistic_mask, cmap="gray")
ax.axis("off")

# Colorbar
fig.colorbar(img, cmap="gray", cax=cax)

plt.tight_layout()
plt.show()