# Background removal

In this notebook we show how you can use the depth maps to remove most of the background from the images. 
This is important because all data was recorded with the same background, but the models have to be able to generalize to different backgrounds.
The background at the ICRA 2024 will also be different than during the data collection.

In [None]:
import os
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt

from cloth_tools.dataset.format import load_competition_observation

data_dir = Path("data")
dataset_dir = data_dir / "cloth_competition_references_0001"

In [None]:
os.path.exists(dataset_dir)

In [None]:
observation_dirs = [dataset_dir / ref_dir for ref_dir in sorted(os.listdir(dataset_dir))]
len(observation_dirs)

In [None]:
reference_images = {}

for observation_dir in observation_dirs:
    cloth_id = str(observation_dir).split("_")[-1]
    observation = load_competition_observation(observation_dir)
    reference_images[cloth_id] = observation.image_left

In [None]:
import matplotlib.pyplot as plt

def display_reference_images(reference_images):
    """Displays reference images in a grid layout.

    Args:
        reference_images (dict): A dictionary where keys are cloth IDs and
                                 values are the corresponding images.
    """

    n_images = len(reference_images)
    cols = 3  # Adjust the number of columns as desired
    rows = (n_images // cols) + (n_images % cols > 0)  # Calculate rows for the grid

    fig, axes = plt.subplots(nrows=rows, ncols=cols, figsize=(cols * 4, rows * 3))  # Create subplots

    cloth_ids = list(reference_images.keys())
    images = list(reference_images.values())

    for i, ax in enumerate(axes.flat):
        if i < n_images:
            image = images[i]
            cloth_id = cloth_ids[i]
            ax.imshow(image)
            ax.set_title(f"Cloth ID: {cloth_id}")
            ax.axis("off")  # Hide unused subplots
        # else:

    plt.tight_layout()
    plt.show()

# Example usage with your dictionary:
display_reference_images(reference_images) 