In [None]:
import os
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt

from cloth_tools.dataset.format import load_competition_observation

data_dir = Path("data")
dataset_dir = data_dir / "cloth_competition_references_0001"

In [None]:
os.path.exists(dataset_dir)

In [None]:
observation_dirs = [dataset_dir / ref_dir for ref_dir in sorted(os.listdir(dataset_dir))]
len(observation_dirs)

In [None]:
index = 0
observation = load_competition_observation(observation_dirs[index])

In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(10, 5))
plt.imshow(observation.image_left)
plt.title("Left image")
plt.show()

In [None]:
from cloth_tools.dataset.format import CompetitionObservation
from airo_typing import Vector3DType
from airo_camera_toolkit.pinhole_operations.projection import project_points_to_image_plane
from airo_dataset_tools.data_parsers.camera_intrinsics import CameraIntrinsics
from airo_spatial_algebra import transform_points

# TODO change types
def get_bounding_box_between_grippers(
    arm_left_tcp_position: Vector3DType,
    arm_right_tcp_position: Vector3DType,
    intrinsics: np.ndarray,
    extrinsics: np.ndarray,
    y_padding: float = 0.1,
) -> tuple[float, float, float, float]:

    x_left, y_left, z_left = arm_left_tcp_position.squeeze()
    x_right, y_right, z_right = arm_right_tcp_position.squeeze()

    # Create the 3D rectangle for the bounding box
    y_padding = 0.1
    c1 = np.array([x_left, y_left + y_padding, z_left])
    c2 = np.array([x_right, y_right - y_padding, z_right])
    c3 = np.array([x_left, y_left + y_padding, 0.05])
    c4 = np.array([x_right, y_right - y_padding, 0.05])

    # Generate all corners
    corners_3d = np.array([c1, c2, c3, c4])

    X_C_W = np.linalg.inv(extrinsics)
    projected_corners = project_points_to_image_plane(transform_points(X_C_W, corners_3d), intrinsics).squeeze()

    # Get the 2D bounding box
    u_min = min(u for u, _ in projected_corners)
    v_min = min(v for _, v in projected_corners)
    u_max = max(u for u, _ in projected_corners)
    v_max = max(v for _, v in projected_corners)

    return u_min, v_min, u_max, v_max


def get_heuristic_cloth_bounding_box(observation: CompetitionObservation) -> tuple[float, float, float, float]:
    """Calculates an approximate 2D bounding box for the cloth region held by the robot arms.
    This function assume the case where the cloth is held both robots arms and stretched in front of the camera.

    Args:
        sample_dir: The path to the sample directory containing the necessary data files.
    Returns:
        A tuple of (u_min, v_min, u_max, v_max) representing the coordinates of the
        estimated bounding box within the image.
    """
    intrinsics = observation.camera_intrinsics
    extrinsics = observation.camera_pose_in_world
    arm_left_tcp_position = observation.arm_left_tcp_pose_in_world[:3, 3]
    arm_right_tcp_position = observation.arm_right_tcp_pose_in_world[:3, 3]
    return get_bounding_box_between_grippers(arm_left_tcp_position, arm_right_tcp_position, intrinsics, extrinsics)

In [None]:
u_min, v_min, u_max, v_max = get_heuristic_cloth_bounding_box(observation)

plt.figure(figsize=(10, 5))
plt.imshow(observation.image_left)
plt.title("Left image")
plt.gca().add_patch(plt.Rectangle((u_min, v_min), u_max - u_min, v_max - v_min, edgecolor="r", facecolor="none"))
plt.show()
 

In [None]:
sam_weights_dir = "/home/victor/cloth-competition/evaluation-service/weights"

!ls $sam_weights_dir

In [None]:
from segment_anything import SamPredictor, sam_model_registry

weights_name = "sam_vit_h_4b8939.pth"
device ="cuda"

sam_weights = os.path.join(sam_weights_dir, weights_name)
sam = sam_model_registry["vit_h"](checkpoint=sam_weights)
sam.to(device=device)

predictor = SamPredictor(sam)

In [None]:
predictor.set_image(observation.image_left)

In [None]:

input_label = []
input_point = []
input_box = np.array([u_min, v_min, u_max, v_max])

masks, _, _ = predictor.predict(
    point_coords=np.array(input_point) if len(input_point) > 0 else None,
    point_labels=np.array(input_label) if len(input_label) > 0 else None,
    box=input_box[None, :],
    multimask_output=False,
)

mask = masks[0]

In [None]:
plt.imshow(mask)
plt.gca().add_patch(plt.Rectangle((u_min, v_min), u_max - u_min, v_max - v_min, edgecolor="r", facecolor="none"))