In [None]:
!nvidia-smi

In [None]:
from huggingface_hub import login
login(new_session=False)

In [None]:
!pip install decord
!pip install git+https://github.com/facebookresearch/sam3.git

In [None]:
import tensorflow_datasets as tfds
import numpy as np
from PIL import Image
from IPython import display
import IPython.display as ipd
import os
from tqdm import tqdm
import matplotlib
import matplotlib.pyplot as plt
import cv2
import spacy

from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
import torch
from PIL import Image

from sam3.model_builder import build_sam3_image_model
from sam3.model.sam3_image_processor import Sam3Processor

import glob
import os

from sam3.visualization_utils import (
    load_frame,
    prepare_masks_for_visualization,
    visualize_formatted_frame_output,
)

import sam3

In [None]:
def as_gif(images, path="temp.gif"):
  # Render the images as the gif (15Hz control frequency):
  images[0].save(path, save_all=True, append_images=images[1:], duration=int(1000/15), loop=0)
  gif_bytes = open(path,"rb").read()
  return gif_bytes

In [None]:
#OPTIONAL: Save the images

OUTPUT_DIR = "/content/droid_frames_exterior_2_left"
os.makedirs(OUTPUT_DIR, exist_ok=True)

ds = tfds.load(
    "droid_100",
    data_dir="gs://gresearch/robotics",
    split="train"
)

frame_idx = 0

for episode in ds.shuffle(10, seed=6).take(1):
    for step in tqdm(episode["steps"]):
        frame = step["observation"]["exterior_image_1_left"].numpy()

        img = Image.fromarray(frame)
        img.save(
            os.path.join(OUTPUT_DIR, f"{frame_idx:05d}.jpg"),
            quality=95,
            subsampling=0
        )

        frame_idx += 1

print(f"Saved {frame_idx} frames to {OUTPUT_DIR}")


**Use Spacy for objects extraction from instruction**

In [None]:
nlp = spacy.load("en_core_web_sm")

In [None]:
def extract_objects(instruction):
    doc = nlp(instruction.lower())
    objects = []

    for chunk in doc.noun_chunks:

        token_text = chunk.text.strip()

        if token_text in ["a", "the", "it", "they", "them", "this", "that"]:
            continue

        if token_text not in objects:
            objects.append(token_text)

    objects = [obj for obj in objects if obj.isalpha() or " " in obj]

    return objects

**Sam3 only Method (Prompt: text)**

In [None]:
# Load the model
model = build_sam3_image_model()
processor = Sam3Processor(model)

In [None]:
ds = tfds.load("droid_100", data_dir="gs://gresearch/robotics", split="train")
SEED = 3
# 6

for episode in ds.shuffle(10, seed=SEED).take(1):
  # print(episode)
  steps = episode["steps"]
  first_step = next(iter(steps))
  instruction = first_step["language_instruction"].numpy().decode("utf-8")

original_images = []
for episode in ds.shuffle(10, seed=SEED).take(1):
  for i, step in enumerate(episode["steps"]):
    original_images.append(
      Image.fromarray(
        np.concatenate((
              step["observation"]["exterior_image_1_left"].numpy(),
              step["observation"]["exterior_image_2_left"].numpy(),
              step["observation"]["wrist_image_left"].numpy(),
        ), axis=1)
      )
    )

text_labels = extract_objects(instruction)
print("Instruction: ", instruction)
print("Objects: ", text_labels)
display.Image(as_gif(original_images))

In [None]:
for episode in ds.shuffle(10, seed=SEED).take(1):
  # print(episode)
  steps = episode["steps"]
  first_step = next(iter(steps))
  instruction = first_step["language_instruction_2"].numpy().decode("utf-8")
  print(instruction)

In [None]:
def main_pipeline(ori_images, instr):

  masked_images = []

  for image in ori_images:
    inference_state = processor.set_image(image)
    # Prompt the model with text

    def overlay_masks(image, masks, colors, transparency):
        """
        masks: (N, H, W)
        colors: list of (R, G, B)
        alpha: 0â€“255
        """
        image = image.convert("RGBA")

        if masks.ndim == 4:
            masks = masks.squeeze(0)

        masks = masks.cpu().numpy().astype(np.uint8)

        for mask, color in zip(masks, colors):
            mask_img = Image.fromarray(mask * 255, mode="L")
            overlay = Image.new("RGBA", image.size, color + (0,))
            overlay.putalpha(mask_img.point(lambda v: transparency if v > 0 else 0))
            image = Image.alpha_composite(image, overlay)

        return image


    def best_mask_from_prompt(processor, inference_state, prompt):
        with torch.inference_mode():
            output = processor.set_text_prompt(
                state=inference_state,
                prompt=prompt,
            )

        masks, boxes, scores = output["masks"], output["boxes"], output["scores"]

        if len(masks) == 0:
            # print(f"No masks found for prompt '{prompt}'")
            return None, None

        else:

          best_idx = torch.argmax(scores).item()

          # normalize to (1, H, W)
          if masks.ndim == 4:
              masks = masks.squeeze(0)
          best_mask = masks[best_idx]

          if best_mask.ndim == 2:
              best_mask = best_mask.unsqueeze(0)

          return best_mask, scores[best_idx].item()

    prompts = extract_objects(instruction)
    results = []
    all_masks = []
    all_colors = []
    all_scores = []

    DEFAULT_COLORS = [
      (160, 32, 240),  # purple
      (255, 0, 0),     # red
      (255, 255, 0),   # yellow
      (0, 0, 255),     # blue
    ]

    for i, prompt in enumerate(prompts):
        mask, score = best_mask_from_prompt(
            processor,
            inference_state,
            prompt,
        )

        if mask is None:
            continue

        all_masks.append(mask)
        all_scores.append(score)

        all_colors.append(DEFAULT_COLORS[i % len(DEFAULT_COLORS)])

        # print(f"[{prompt}] score = {score:.3f}")



    if len(all_masks) == 0:
        masked_images.append(image.convert("RGBA"))
    else:
        all_masks = torch.cat(all_masks, dim=0)
        masked_images.append(
            overlay_masks(
                image,
                all_masks,
                colors=all_colors,
                transparency=80,
            )
        )

  return masked_images


text_labels = extract_objects(instruction)
print("Instruction: ", instruction)
print("Objects: ", text_labels)


OBSERVATIONS = [
    "exterior_image_1_left",
    "exterior_image_2_left",
    "wrist_image_left",
]

for episode in ds.shuffle(10, seed=SEED).take(1):

    camera_frames = {obs: [] for obs in OBSERVATIONS}

    for step in episode["steps"]:
        for obs in OBSERVATIONS:
            img_np = step["observation"][obs].numpy()
            camera_frames[obs].append(Image.fromarray(img_np))

    masked_camera_frames = {
        obs: main_pipeline(camera_frames[obs], instruction)
        for obs in OBSERVATIONS
    }

    final_frames = []
    T = len(masked_camera_frames[OBSERVATIONS[0]])

    for t in range(T):
        concat = np.concatenate(
            [
                np.array(masked_camera_frames[obs][t].convert("RGB"))
                for obs in OBSERVATIONS
            ],
            axis=1
        )
        final_frames.append(Image.fromarray(concat))

display.Image(as_gif(final_frames))

In [None]:
output_path = "/content/Videos/Masked_Video_2.mp4"
fps = 15  # adjust if needed
video = final_frames
# Convert first frame to get size

first = video[0]
if not isinstance(first, np.ndarray):
    first = np.array(first)

h, w = first.shape[:2]

fourcc = cv2.VideoWriter_fourcc(*"mp4v")
writer = cv2.VideoWriter(output_path, fourcc, fps, (w, h))

for img in video:
    if not isinstance(img, np.ndarray):
        img = np.array(img)

    # Ensure 3 channels
    if img.ndim == 2:
        img = np.stack([img]*3, axis=-1)

    # OpenCV expects BGR
    img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
    writer.write(img)

writer.release()

print(f"Saved video to {output_path}")


**Potentially use GroundingDino for more accurate masking (creat boundingbox first then read from the box)**

In [None]:
!git clone https://github.com/IDEA-Research/GroundingDINO.git
%cd GroundingDINO


In [None]:

model_id = "IDEA-Research/grounding-dino-tiny"

processor = AutoProcessor.from_pretrained(model_id)
dino_model = AutoModelForZeroShotObjectDetection.from_pretrained(
    model_id
).to("cuda")

In [None]:
pil_img = Image.fromarray(img)

In [None]:

text_labels = extract_objects(instruction)

inputs = processor(images=pil_img, text=text_labels, return_tensors="pt").to(dino_model.device)

with torch.no_grad():
    outputs = dino_model(**inputs)

In [None]:
results = processor.post_process_grounded_object_detection(
    outputs,
    inputs.input_ids,
    threshold=0.4,
    text_threshold=0.3,
    target_sizes=[pil_img.size[::-1]]
)

result = results[0]
for box, score, labels in zip(result["boxes"], result["scores"], result["labels"]):
    box = [round(x, 2) for x in box.tolist()]
    print(f"Detected {labels} with confidence {round(score.item(), 3)} at location {box}")

In [None]:
draw_img = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR)

def visualize_detections(image, results):
    img = image.copy()

    for box, score, label in zip(results["boxes"], results["scores"], results["labels"]):
        x1, y1, x2, y2 = box.tolist()

        cv2.rectangle(
            img,
            (int(x1), int(y1)),
            (int(x2), int(y2)),
            (0, 255, 0),
            2
        )

        text = f"{label} {score:.2f}"
        cv2.putText(
            img,
            text,
            (int(x1), int(y1) - 5),
            cv2.FONT_HERSHEY_SIMPLEX,
            0.35,
            (0, 255, 0),
            1
        )

    return img

vis = visualize_detections(draw_img, result)

plt.figure(figsize=(8, 8))
plt.imshow(cv2.cvtColor(vis, cv2.COLOR_BGR2RGB))
plt.axis("off")

In [None]:
# font size for axes titles
plt.rcParams["axes.titlesize"] = 12
plt.rcParams["figure.titlesize"] = 12


def propagate_in_video(predictor, session_id):
    # we will just propagate from frame 0 to the end of the video
    outputs_per_frame = {}
    for response in predictor.handle_stream_request(
        request=dict(
            type="propagate_in_video",
            session_id=session_id,
        )
    ):
        outputs_per_frame[response["frame_index"]] = response["outputs"]

    return outputs_per_frame


def abs_to_rel_coords(coords, IMG_WIDTH, IMG_HEIGHT, coord_type="point"):
    """Convert absolute coordinates to relative coordinates (0-1 range)

    Args:
        coords: List of coordinates
        coord_type: 'point' for [x, y] or 'box' for [x, y, w, h]
    """
    if coord_type == "point":
        return [[x / IMG_WIDTH, y / IMG_HEIGHT] for x, y in coords]
    elif coord_type == "box":
        return [
            [x / IMG_WIDTH, y / IMG_HEIGHT, w / IMG_WIDTH, h / IMG_HEIGHT]
            for x, y, w, h in coords
        ]
    else:
        raise ValueError(f"Unknown coord_type: {coord_type}")

In [None]:


gpus_to_use = range(torch.cuda.device_count())

predictor = build_sam3_video_predictor(
    gpus_to_use=gpus_to_use
)


In [None]:
print(os.path.exists("/content/droid_frames_exterior_2_left"))


In [None]:
response = predictor.handle_request(
    request=dict(
        type="start_session",
        resource_path="/content/droid_frames_exterior_2_left",
    )
)

session_id = response["session_id"]
print("Session ID:", session_id)


In [None]:
video_path = "/content/droid_frames_exterior_2_left"

if isinstance(video_path, str) and video_path.endswith(".mp4"):
    cap = cv2.VideoCapture(video_path)
    video_frames_for_vis = []
    while True:
        ret, frame = cap.read()
        if not ret:
            break
        video_frames_for_vis.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
    cap.release()
else:
    video_frames_for_vis = glob.glob(os.path.join(video_path, "*.jpg"))
    try:
        # integer sort instead of string sort (so that e.g. "2.jpg" is before "11.jpg")
        video_frames_for_vis.sort(
            key=lambda p: int(os.path.splitext(os.path.basename(p))[0])
        )
    except ValueError:
        # fallback to lexicographic sort if the format is not "<frame_index>.jpg"
        print(
            f'frame names are not in "<frame_index>.jpg" format: {video_frames_for_vis[:5]=}, '
            f"falling back to lexicographic sort."
        )
        video_frames_for_vis.sort()


In [None]:
def dino_to_sam3_box(box, img_w, img_h):
    x1, y1, x2, y2 = box
    return [
        x1 / img_w,
        y1 / img_h,
        (x2 - x1) / img_w,
        (y2 - y1) / img_h,
    ]

In [None]:
print(result["boxes"])

In [None]:
_ = predictor.handle_request(
    request=dict(
        type="reset_session",
        session_id=session_id,
    )
)


In [None]:
IMG_WIDTH, IMG_HEIGHT = pil_img.size

text_prompts=["pen"]
objects_ids = [0]

frame_idx = 0

for obj_id, text_prompt in enumerate(text_prompts):
    response = predictor.handle_request(
        request=dict(
            type="add_prompt",
            session_id=session_id,
            frame_index=frame_idx,
            text=text_prompt,
            obj_id=obj_id,
        )
    )

    print(f"Added text prompt '{text_prompt}' as object {obj_id}")

In [None]:
outputs_per_frame = propagate_in_video(predictor, session_id)
outputs_per_frame = prepare_masks_for_visualization(outputs_per_frame)

vis_frame_stride = 60
plt.close("all")

for frame_idx in range(0, len(outputs_per_frame), vis_frame_stride):
    visualize_formatted_frame_output(
        frame_idx,
        video_frames_for_vis,
        outputs_list=[outputs_per_frame],
        titles=["SAM 3 Multi-Object Tracking"],
        figsize=(6, 4),
    )
