# Segmentation des bboxes objet/sujet de SpatialSense+ avec SAM pour retourner des images binaires

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2

In [None]:
!pip install git+https://github.com/facebookresearch/segment-anything.git

Collecting git+https://github.com/facebookresearch/segment-anything.git
  Cloning https://github.com/facebookresearch/segment-anything.git to /tmp/pip-req-build-r77_6sjf
  Running command git clone --filter=blob:none --quiet https://github.com/facebookresearch/segment-anything.git /tmp/pip-req-build-r77_6sjf
  Resolved https://github.com/facebookresearch/segment-anything.git to commit dca509fe793f601edb92606367a655c15ac00fdf
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: segment_anything
  Building wheel for segment_anything (setup.py) ... [?25l[?25hdone
  Created wheel for segment_anything: filename=segment_anything-1.0-py3-none-any.whl size=36592 sha256=d59b72a2357a272823179eaeb528ef828bee63c3da16605ce65d3716b533acbf
  Stored in directory: /tmp/pip-ephem-wheel-cache-cvd4wjya/wheels/10/cf/59/9ccb2f0a1bcc81d4fbd0e501680b5d088d690c6cfbc02dc99d
Successfully built segment_anything
Installing collected packages: segment_anything
Successfully 

In [None]:
!pip install opencv-python pycocotools matplotlib onnxruntime onnx tqdm

Collecting onnxruntime
  Downloading onnxruntime-1.20.1-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (4.5 kB)
Collecting onnx
  Downloading onnx-1.17.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (16 kB)
Collecting coloredlogs (from onnxruntime)
  Downloading coloredlogs-15.0.1-py2.py3-none-any.whl.metadata (12 kB)
Collecting humanfriendly>=9.1 (from coloredlogs->onnxruntime)
  Downloading humanfriendly-10.0-py2.py3-none-any.whl.metadata (9.2 kB)
Downloading onnxruntime-1.20.1-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (13.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.3/13.3 MB[0m [31m97.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading onnx-1.17.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (16.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m16.0/16.0 MB[0m [31m93.7 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading coloredlogs-15.0.1-py2.py3-none-any.whl (46 k

### Chargement du model SAM vit_h

In [None]:
import sys
sys.path.append("..")
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor

sam_checkpoint = "/content/drive/MyDrive/Colab Notebooks/Mod_sys_vis/sam_vit_h_4b8939.pth"
model_type = "vit_h"

device = "cuda"

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)

mask_generator = SamAutomaticMaskGenerator(sam)

  state_dict = torch.load(f)


### Fonctions pour manipuler le dataset SpatialSense+ et les bboxes

In [None]:
def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)

def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)

def show_box(box, ax, color='green', label=None):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    rect = plt.Rectangle((x0, y0), w, h, edgecolor=color, facecolor=(0, 0, 0, 0), lw=2)
    ax.add_patch(rect)
    if label:
        ax.text(x0, y0 - 5, label, color=color, fontsize=10, weight='bold', va='bottom')

def load_image_from_url(url):
    try:
        response = requests.get(url, timeout=10)
        if response.status_code == 200:
            image_array = np.frombuffer(response.content, np.uint8)
            image = cv2.imdecode(image_array, cv2.IMREAD_COLOR)
            if image is not None:
                return cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            else:
                print(f"Erreur : Impossible de charger l'image depuis l'URL {url}")
                return None
        else:
            print(f"Erreur : Requête HTTP échouée pour {url}, Code {response.status_code}")
            return None
    except Exception as e:
        print(f"Erreur : Exception levée pour l'URL {url}. Détails : {e}")
        return None


### Début de la segmentation de toutes les images de SpatialSense+

In [None]:
import json

annotations_path = "/content/drive/MyDrive/Colab Notebooks/Mod_sys_vis/annots_spatialsenseplus.json"

with open(annotations_path, "r") as f:
    data = json.load(f)

In [None]:
import requests
from io import BytesIO
import os

predictor = SamPredictor(sam)

folder = "/content/drive/MyDrive/Colab Notebooks/Mod_sys_vis/spatialsense/images/images"

output_folder = "/content/drive/MyDrive/Colab Notebooks/Mod_sys_vis/masks"
os.makedirs(output_folder, exist_ok=True)


In [None]:
import csv

relations_file = os.path.join(output_folder, "relations.csv")

with open(relations_file, mode="w", newline="") as file:
    writer = csv.writer(file)
    writer.writerow(["Image", "Subject", "Relation", "Object", "Annotation_Index", "Label"])


In [None]:
from tqdm import tqdm
import time

# Relations pour RLM
valid_relations = ["above", "under", "to the left of", "to the right of"]

total_images = len(data["sample_annots"])

with tqdm(total=total_images, desc="Processing Images", unit="image") as pbar:
    for i, sample in enumerate(data["sample_annots"]):
        start_time = time.time()

        url = sample["url"]
        #print(f"Processing image {i+1}: {url}")

        filename = os.path.basename(url)

        if 'flickr' in url:
            image_path = os.path.join(folder, "flickr", filename)
        else:
            image_path = os.path.join(folder, "nyu", filename)

        image = cv2.imread(image_path)
        if image is None:
            #print(f"Error: Image {filename} not found at {image_path}")
            pbar.update(1)
            continue
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        predictor.set_image(image)

        # Parcourir les annotations
        for j, annotation in enumerate(sample["annotations"]):
            subject = annotation["subject"]
            object_ = annotation["object"]
            predicate = annotation["predicate"]
            label = annotation["label"]

            if predicate not in valid_relations:
                #print(f"Skipping relation '{predicate}' for image {filename}")
                continue

            # Récupérer noms et bounding boxes
            subject_name = subject["name"]
            object_name = object_["name"]
            subject_bbox = [subject["bbox"][2], subject["bbox"][0], subject["bbox"][3], subject["bbox"][1]]
            object_bbox = [object_["bbox"][2], object_["bbox"][0], object_["bbox"][3], object_["bbox"][1]]

            # Écrire la relation dans le fichier CSV
            with open(relations_file, mode="a", newline="") as file:
                writer = csv.writer(file)
                writer.writerow([
                    filename,          # Nom de l'image
                    subject_name,      # Nom du sujet
                    predicate,         # Relation entre sujet et objet
                    object_name,       # Nom de l'objet
                    j + 1,             # Index de l'annotation dans l'image
                    label              # Label
                ])

            bboxes = [subject_bbox, object_bbox]
            names = [subject_name, object_name]

            # Extraire et sauvegarder les masques pour le sujet et l'objet
            for k, (box, name) in enumerate(zip(bboxes, names)):
                input_box = np.array([box[0], box[1], box[2], box[3]])  # [x_min, y_min, x_max, y_max]
                masks, _, _ = predictor.predict(
                    point_coords=None,
                    point_labels=None,
                    box=input_box[None, :],
                    multimask_output=False
                )

                # Créer un masque blanc (255 pour l'objet, 0 pour le fond)
                white_mask = (masks[0] * 255).astype(np.uint8)

                if k == 0:  # Sujet
                    subject_mask_path = os.path.join(
                        output_folder, f"{os.path.splitext(filename)[0]}_subject_mask_{j+1}.png"
                    )
                    cv2.imwrite(subject_mask_path, white_mask)
                    #print(f"Saved white mask for subject in {filename} at {subject_mask_path}")
                else:  # Objet
                    object_mask_path = os.path.join(
                        output_folder, f"{os.path.splitext(filename)[0]}_object_mask_{j+1}.png"
                    )
                    cv2.imwrite(object_mask_path, white_mask)
                    #print(f"Saved white mask for object in {filename} at {object_mask_path}")

        end_time = time.time()
        elapsed_time = end_time - start_time

        pbar.update(1)
        pbar.set_postfix({"Time per image (s)": f"{elapsed_time:.2f}"})

Processing Images: 100%|██████████| 4418/4418 [2:53:57<00:00,  2.36s/image, Time per image (s)=2.14]


In [None]:
import pandas as pd

df = pd.read_csv(relations_file)
print(df.head())

                        Image  Subject Relation    Object  Annotation_Index  \
0   5040395364_fd039b9687.jpg  dustbin    above      bike                 1   
1  13041248553_100dac28c8.jpg      man    above  notebook                 1   
2   9624046088_9bc73c3a43.jpg      man    above      sign                 1   
3   9624046088_9bc73c3a43.jpg      man    under      sign                 2   
4   9624046088_9bc73c3a43.jpg     sign    under       man                 3   

   Label  
0  False  
1  False  
2   True  
3  False  
4   True  


In [None]:
df

Unnamed: 0,Image,Subject,Relation,Object,Annotation_Index,Label
0,5040395364_fd039b9687.jpg,dustbin,above,bike,1,False
1,13041248553_100dac28c8.jpg,man,above,notebook,1,False
2,9624046088_9bc73c3a43.jpg,man,above,sign,1,True
3,9624046088_9bc73c3a43.jpg,man,under,sign,2,False
4,9624046088_9bc73c3a43.jpg,sign,under,man,3,True
...,...,...,...,...,...,...
3199,9717381717_004e647f0f.jpg,table,under,goat,1,True
3200,nyu_bedroom_0135_r-1316568991.488913-837150099...,tissues,under,map,1,True
3201,3297717414_ea36748def.jpg,bread,under,filling,1,False
3202,4349352076_f754130983.jpg,shoe,under,shoe,1,False
