# SEGMENTATION : Sam 2

# Import & setup

In [None]:
import os
# if using Apple MPS, fall back to CPU for unsupported ops
#os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
os.environ['KMP_DUPLICATE_LIB_OK']='True'
import numpy as np
import torch
import matplotlib.pyplot as plt
from PIL import Image
# from codecarbon import EmissionsTracker
import time

print(torch.cuda.is_available())
print(torch.cuda.current_device())
print(torch.cuda.get_device_name(0))


# select the device for computation
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")
print(f"using device: {device}")

if device.type == "cuda":
    # use bfloat16 for the entire notebook
    torch.autocast("cuda", dtype=torch.bfloat16).__enter__()
    # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
    if torch.cuda.get_device_properties(0).major >= 8:
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True
elif device.type == "mps":
    print(
        "\nSupport for MPS devices is preliminary. SAM 2 is trained with CUDA and might "
        "give numerically different outputs and sometimes degraded performance on MPS. "
        "See e.g. https://github.com/pytorch/pytorch/issues/84936 for a discussion."
    )


### Initialisation de Sam

In [None]:
from sam2.build_sam import build_sam2_video_predictor

sam2_checkpoint = "../checkpoints/sam2.1_hiera_large.pt"
model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"

predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device=device)

### Fonctions d'affichages des points, masques et image

In [None]:
def show_mask(mask, ax, obj_id=None, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        cmap = plt.get_cmap("tab10")
        cmap_idx = 0 if obj_id is None else obj_id
        color = np.array([*cmap(cmap_idx)[:3], 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)


# place the point the plot
def show_points(coords, labels, ax, marker_size=200):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)

# show the bounding box
def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))

### Step 1 : Récupération du dossier contenant le dataset ( images sous format .jpeg )

In [None]:
# `video_dir` a directory of JPEG frames with filenames like `<frame_index>.jpg`
video_dir = "../ai4industry/dataset_light/dataset/part000/part28"
save_mask = "partie28"
vis_frame_stride = 10

frame_names = [
    p for p in os.listdir(video_dir)
    if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
]

frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
print(frame_names)



### Affichage de la premiere frame du dataset

In [None]:
# take a look the first video frame
frame_idx = 0
plt.figure(figsize=(9, 6))
plt.title(f"frame {frame_idx}")
plt.imshow(Image.open(os.path.join(video_dir, frame_names[frame_idx])))

### Step 2 : Lancement de Sam sur le dataset

In [None]:
tracker = EmissionsTracker()

tracker.start()
start_time = time.time()

inference_state = predictor.init_state(video_path=video_dir)
predictor.reset_state(inference_state)

emissions = tracker.stop()
end_time = time.time()

# Calculer le temps écoulé
elapsed_time = end_time - start_time

# Afficher les résultats
print(f"Les émissions de CO₂ générées sont estimées à {emissions:.6f} kg")
print(f"Le temps d'exécution est de {elapsed_time:.2f} secondes")

### Affichage de l'image d'origine avec le masque obtenu et le point d'ancrage

In [None]:
ann_frame_idx = 0  # the frame index we interact with
ann_obj_id = 1  # give a unique id to each object we interact with (it can be any integers)

# Let's add a positive click at (x, y) = (210, 350) to get started
points = np.array([[550, 200]], dtype=np.float32)
# for labels, `1` means positive click and `0` means negative click
labels = np.array([1], np.int32)
_, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
    inference_state=inference_state,
    frame_idx=ann_frame_idx,
    obj_id=ann_obj_id,
    points=points,
    labels=labels,
)

# show the results on the current (interacted) frame
plt.figure(figsize=(9, 6))
plt.title(f"frame {ann_frame_idx}")
plt.imshow(Image.open(os.path.join(video_dir, frame_names[ann_frame_idx])))
show_points(points, labels, plt.gca())
show_mask((out_mask_logits[0] > 0.0).cpu().numpy(), plt.gca(), obj_id=out_obj_ids[0])

#### Step 3: Propagation du prompt

Pour obtenir le masklet tout au long de la vidéo, nous propageons les invites en utilisant l'API `propagate_in_video`.

In [None]:
tracker.start()
start_time = time.time()

# run propagation throughout the video and collect the results in a dict
video_segments = {}  # video_segments contains the per-frame segmentation results
for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):
    video_segments[out_frame_idx] = {
        out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
        for i, out_obj_id in enumerate(out_obj_ids)
    }

# render the segmentation results every few frames (see vis_frame_stride In14)
plt.close("all")
for out_frame_idx in range(0, len(frame_names), vis_frame_stride):
    plt.figure(figsize=(6, 4))
    plt.title(f"frame {out_frame_idx}")
    plt.imshow(Image.open(os.path.join(video_dir, frame_names[out_frame_idx])))
    for out_obj_id, out_mask in video_segments[out_frame_idx].items():
        show_mask(out_mask, plt.gca(), obj_id=out_obj_id)

emissions = tracker.stop()
end_time = time.time()

# Calculer le temps écoulé
elapsed_time = end_time - start_time

# Afficher les résultats
print(f"Les émissions de CO₂ générées sont estimées à {emissions:.6f} kg")
print(f"Le temps d'exécution est de {elapsed_time:.2f} secondes")

### Step 4 : Export des données

In [None]:
for out_frame_idx in range(0, len(frame_names), vis_frame_stride):
    plt.figure(figsize=(6, 4))
    plt.title(f"frame {out_frame_idx}")
    for out_obj_id, out_mask in video_segments[out_frame_idx].items():
        show_mask(out_mask, plt.gca(), obj_id=out_obj_id)

In [None]:
value = video_segments[0][1].shape

road_points=[[]]
for i in range(video_segments[0][1].shape[1]):
    for j in range(video_segments[0][1].shape[2]):
        if video_segments[0][1][0,i,j]:
            road_points[0].append([i,j])


## Fonction de sauvegarde du masque en un tableau numpy

In [None]:

mask_arrays = np.zeros((50, 480,848 ), dtype=bool)

for key, value in video_segments.items():
    mask_arrays[key - 1] = np.copy(value[1])
    
np.save(save_mask,mask_arrays)