In [1]:
import os
import numpy as np
import torch
from PIL import Image
import cv2
from ultralytics import YOLO
from matplotlib import pyplot as plt
import json
import pickle
import logging
from tqdm import tqdm
import time
import gc
import shutil

from sam2.build_sam import build_sam2_video_predictor
from ultralytics import YOLO

In [24]:
# !sudo mount -t drvfs E: /mnt/g
# mogrify -format jpg *.png && rm *.png

ERROR! Session/line number was not unique in database. History logging moved to new session 63


In [None]:
dataset_path = "data/raw"
is_manual_annotation = False
perform_tta = True

num_pos_points_per_tool = 5
num_neg_points_per_tool = 5
override_annotations = True

save_img = True

annotation_dir = "data/annotations/" + "manual" if is_manual_annotation else "data/annotations/" + "auto"

masks_dir = "data/masks"
results_dir = "data/results"
log_dir =  "logs"
val_path = dataset_path + "/SegSTRONGC_val/val"
test_path = dataset_path + "/SegSTRONGC_test/test"
train_path = dataset_path + "/SegSTRONGC_train"
models = ["sam2.1_hiera_base_plus", "yolo11x-seg"]
model = models[0]

val_dirs = [val_path + "/" + domain for domain in os.listdir(val_path)]
# test_dirs = [test_path + "/" + domain for domain in os.listdir(test_path)]
# train_dirs = [train_path + "/" + domain for domain in os.listdir(train_path)]


sub_dirs = [path + "/" + sub_dir for path in val_dirs for sub_dir in os.listdir(path)]
# test_sub_dirs = [path + "/" + sub_dir for path in test_dirs for sub_dir in os.listdir(path)]
# train_sub_dirs = [path + "/" + sub_dir for path in train_dirs for sub_dir in os.listdir(path)]

# ground truth is 'ground_truth'
domains = ['bg_change', 'blood', 'low_brightness', 'regular', 'smoke']
test_domains = ['bg_change', 'blood', 'low_brightness', 'regular', 'smoke']
train_domains = ['regular']

checkpoints = {
    "sam2.1_hiera_base_plus": "checkpoints/sam2.1_hiera_base_plus.pt",
    "yolo11x-seg": "checkpoints/yolo11x-seg.pt"
}

model_cfgs = {
    "sam2.1_hiera_base_plus": "configs/sam2.1/sam2.1_hiera_b+.yaml",
    "yolo11x-seg": None
}

num_images_per_domain = 300
checkpoint = checkpoints[model]
model_cfg = model_cfgs[model]

logging.basicConfig(filename=log_dir + f'/{model}.log', level=logging.INFO, format='%(asctime)s - %(message)s', filemode='w')
logger = logging.getLogger()

In [5]:
def get_image_paths(path, domain, is_left, num_images=300):
    stereo_dir = "left" if is_left else "right"
    image_paths = []

    for i in range(num_images):
        image_paths.append(path + "/" + domain + "/" + stereo_dir + "/" + str(i) + ".png")

    return image_paths

get_image_paths(val_path, "bg_change", True, num_images=300)

['data/raw/SegSTRONGC_val/val/bg_change/left/0.png',
 'data/raw/SegSTRONGC_val/val/bg_change/left/1.png',
 'data/raw/SegSTRONGC_val/val/bg_change/left/2.png',
 'data/raw/SegSTRONGC_val/val/bg_change/left/3.png',
 'data/raw/SegSTRONGC_val/val/bg_change/left/4.png',
 'data/raw/SegSTRONGC_val/val/bg_change/left/5.png',
 'data/raw/SegSTRONGC_val/val/bg_change/left/6.png',
 'data/raw/SegSTRONGC_val/val/bg_change/left/7.png',
 'data/raw/SegSTRONGC_val/val/bg_change/left/8.png',
 'data/raw/SegSTRONGC_val/val/bg_change/left/9.png',
 'data/raw/SegSTRONGC_val/val/bg_change/left/10.png',
 'data/raw/SegSTRONGC_val/val/bg_change/left/11.png',
 'data/raw/SegSTRONGC_val/val/bg_change/left/12.png',
 'data/raw/SegSTRONGC_val/val/bg_change/left/13.png',
 'data/raw/SegSTRONGC_val/val/bg_change/left/14.png',
 'data/raw/SegSTRONGC_val/val/bg_change/left/15.png',
 'data/raw/SegSTRONGC_val/val/bg_change/left/16.png',
 'data/raw/SegSTRONGC_val/val/bg_change/left/17.png',
 'data/raw/SegSTRONGC_val/val/bg_chang

In [27]:
def calculate_iou(TP, FP, FN):
    return TP / (TP + FP + FN)

def calculate_dsc(TP, FP, FN):
    return 2 * TP / (2 * TP + FP + FN)

def calculate_miou(pred_masks, gt_masks):
    ious = []
    for i in range(len(pred_masks)):
        TP = np.logical_and(pred_masks[i], gt_masks[i])
        FP = np.logical_and(pred_masks[i], np.logical_not(gt_masks[i]))
        FN = np.logical_and(np.logical_not(pred_masks[i]), gt_masks[i])

        iou = calculate_iou(np.sum(TP), np.sum(FP), np.sum(FN))
        ious.append(iou)
    
    return np.mean(ious)

def calculate_mdsc(pred_masks, gt_masks):
    dscs = []
    for i in range(len(pred_masks)):
        TP = np.logical_and(pred_masks[i], gt_masks[i])
        FP = np.logical_and(pred_masks[i], np.logical_not(gt_masks[i]))
        FN = np.logical_and(np.logical_not(pred_masks[i]), gt_masks[i])

        dsc = calculate_dsc(np.sum(TP), np.sum(FP), np.sum(FN))
        dscs.append(dsc)
    
    return np.mean(dscs)

In [None]:
def manual_annotate(frame_path):
    annotations = {0: []}
    current_tool = 0
    is_positive = True

    window_name = "Manual Annotation of Frame -" + str(frame_path)
    cv2.namedWindow(window_name)

    def handle_mouse_click(event, x, y, flags, params):
        if event == cv2.EVENT_LBUTTONDOWN:
            annotations[current_tool].append({
                "x": x,
                "y": y,
                "label": 1 if is_positive else 0
            })

            if is_positive:
                cv2.circle(frame, (x, y), 10, (0, 255, 0), -1)
            else:
                cv2.circle(frame, (x, y), 10, (0, 0, 255), -1)

        cv2.imshow(window_name, frame)

    cv2.setMouseCallback(window_name, handle_mouse_click)
    frame = cv2.imread(frame_path)
    original_frame = frame.copy()

    while True:
        frame = original_frame.copy()
        for tool in annotations:
            for annotation in annotations[tool]:
                if annotation["label"] == 1:
                    cv2.circle(frame, (annotation["x"], annotation["y"]), 10, (0, 255, 0), -1)
                else:
                    cv2.circle(frame, (annotation["x"], annotation["y"]), 10, (0, 0, 255), -1)
                
        display_text = f"Tool: {current_tool}, Mode: {'Positive' if is_positive else 'Negative'}"
        cv2.putText(frame, display_text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2, cv2.LINE_AA)
        cv2.imshow(window_name, frame)

        key = cv2.waitKey(1) & 0xFF

        if key == ord("n"):
            current_tool += 1
            is_positive = True
            annotations[current_tool] = []
        elif key == ord("p"):
            is_positive = not is_positive
        elif key == ord("c"):
            annotations[current_tool] = []
        elif key == ord("s"):
            break

    cv2.destroyWindow(window_name)
    return annotations

def auto_annotate(frame_path):
    ground_truth_mask_path = None
    annotations = {}
    # annotations[current_tool].append({
    #             "x": x,
    #             "y": y,
    #             "label": 1 if is_positive else 0
    #         })

    for domain in domains:
        if domain in frame_path:
            ground_truth_mask_path = frame_path.replace(domain, "ground_truth")
            break
    
    if ground_truth_mask_path is None:
        raise ValueError("Ground truth path not found.")
    else:
        ground_truth_mask = cv2.imread(ground_truth_mask_path, cv2.IMREAD_GRAYSCALE)
        ground_truth_mask = (ground_truth_mask > 0).astype(np.bool_)
        _, labels = cv2.connectedComponents(ground_truth_mask.astype(np.uint8))

        # get unique labels, and get count for each label, then sort by count in descending order
        unique_labels, counts = np.unique(labels, return_counts=True)
        sorted_labels = unique_labels[np.argsort(-counts)]

        background_label = sorted_labels[0]
        first_tool_label = sorted_labels[1]

        if len(sorted_labels) > 2:
            second_tool_label = sorted_labels[2]
        else:
            second_tool_label = -1

        #for each tool, get the centroid, and add num_auto_points - 1 sampled random points
        annotations[0] = []
        annotations[1] = []

        for label, obj_id in [(first_tool_label, 0), (second_tool_label, 1)]:

            if label == -1:
                continue
            
            mask = labels == label
            mask = mask.astype(np.uint8)
            contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
            contour = contours[0]
            M = cv2.moments(contour)
            cx = int(M['m10'] / M['m00'])
            cy = int(M['m01'] / M['m00'])

            annotations[obj_id].append({
                "x": cx,
                "y": cy,
                "label": 1
            })

            label_indices = np.where(labels == label)
            random_indices = np.random.choice(len(label_indices[0]), num_pos_points_per_tool - 1, replace=False)
            for i in random_indices:
                x = int(label_indices[1][i])
                y = int(label_indices[0][i])
                if ground_truth_mask[y, x]:
                    annotations[obj_id].append({
                        "x": x,
                        "y": y,
                        "label": 1
                    })

            label_indices = np.where(labels == background_label)
            # print("HIII ", label_indices)
            random_indices = np.random.choice(len(label_indices[0]), num_neg_points_per_tool, replace=False)
            for i in random_indices:
                x = int(label_indices[1][i])
                y = int(label_indices[0][i])
                if not ground_truth_mask[y, x]:
                    annotations[obj_id].append({
                        "x": x,
                        "y": y,
                        "label": 0
                    })


    # place these points on the frame and display once to verify
    # frame = cv2.imread(frame_path)
    # for tool in annotations:
    #     for annotation in annotations[tool]:
    #         cv2.circle(frame, (annotation["x"], annotation["y"]), 10, (0, 255, 0), -1)

    # window_name = "Auto Annotation of Frame -" + str(frame_path)
    # cv2.namedWindow(window_name)
    # cv2.imshow(window_name, frame)
    # cv2.waitKey(0)
    # cv2.destroyWindow(window_name)

    # plt.imshow(frame)
    # plt.show()

    return annotations

def manual_annotate_first_frames(sub_dirs, domains, split):
    print(f"Total number of subdirs for {split}: {len(sub_dirs)}")
    for sub_dir in sub_dirs:
        for domain in domains:
            left_video_frames_path = sub_dir + "/" + domain + "/left"
            right_video_frames_path = sub_dir + "/" + domain + "/right"

            first_left_frame = left_video_frames_path + "/0.png"
            first_right_frame = right_video_frames_path + "/0.png"

            last_left_frame = left_video_frames_path + "/299.png"
            last_right_frame = right_video_frames_path + "/299.png"

            if is_manual_annotation:
                left_annotations = manual_annotate(first_left_frame)
                right_annotations = manual_annotate(first_right_frame)

                left_reverse_annotations = manual_annotate(last_left_frame)
                right_reverse_annotations = manual_annotate(last_right_frame)
            else:
                left_annotations = auto_annotate(first_left_frame)
                right_annotations = auto_annotate(first_right_frame)

                left_reverse_annotations = auto_annotate(last_left_frame)
                right_reverse_annotations = auto_annotate(last_right_frame)

            annotation_file = annotation_dir + f"/{split}.json"
            if os.path.exists(annotation_file):
                with open(annotation_file, "r") as f:
                    all_annotations = json.load(f)
            else:
                all_annotations = {}

            # print(left_annotations)
            # print(right_annotations)
            all_annotations[sub_dir + "/" + domain + "/left"] = left_annotations
            all_annotations[sub_dir + "/" + domain + "/right"] = right_annotations

            # print(all_annotations)

            with open(annotation_file, "w") as f:
                json.dump(all_annotations, f)

            print(f"Domain annotated: {first_left_frame}")

            annotation_file = annotation_dir + f"/{split}_reverse.json"
            if os.path.exists(annotation_file):
                with open(annotation_file, "r") as f:
                    all_annotations = json.load(f)
            else:
                all_annotations = {}

            all_annotations[sub_dir + "/" + domain + "/left"] = left_reverse_annotations
            all_annotations[sub_dir + "/" + domain + "/right"] = right_reverse_annotations

            with open(annotation_file, "w") as f:
                json.dump(all_annotations, f)

            print(f"Reverse Domain annotated: {last_left_frame}")
        print(f"Subdir annotated: {sub_dir}")

if override_annotations:
    manual_annotate_first_frames(sub_dirs, domains, "val")
    print("Split annotated: val")
    # manual_annotate_first_frames(test_sub_dirs, test_domains, "test")
    # print("Split annotated: test")
    # manual_annotate_first_frames(train_sub_dirs, train_domains, "train")
    # print("Split annotated: train")

else:
    print("Annotations already exist. Skipping annotation process.")

KeyboardInterrupt: 

KeyboardInterrupt: 

In [None]:
# select the device for computation
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")
print(f"using device: {device}")

if device.type == "cuda":
    # use bfloat16 for the entire notebook
    torch.autocast("cuda", dtype=torch.bfloat16).__enter__()
    # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
    if torch.cuda.get_device_properties(0).major >= 8:
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True
elif device.type == "mps":
    print(
        "\nSupport for MPS devices is preliminary. SAM 2 is trained with CUDA and might "
        "give numerically different outputs and sometimes degraded performance on MPS. "
        "See e.g. https://github.com/pytorch/pytorch/issues/84936 for a discussion."
    )

using device: cuda


In [None]:
def run_inference(inference_model, frames_path, split, is_reverse, forward_pass_path=None):
    mask_storage_data = {}
    predicted_masks = []
    
    if inference_model == "sam2.1_hiera_base_plus":
        print(f"Loading annotations for split: {split}")
        try:
            if not is_reverse:
                with open(annotation_dir + f"/{split}.json", "r") as f:
                    annotations = json.load(f)
                print(f"Successfully loaded annotations for {len(annotations)} items")
            else:
                with open(annotation_dir + f"/{split}_reverse.json", "r") as f:
                    annotations = json.load(f)
                print(f"Successfully loaded reverse annotations for {len(annotations)} items")
        except Exception as e:
            print(f"Error loading annotations: {e}")
            logger.error(f"Error loading annotations for split {split}: {e}")
            return

        if annotations is None:
            print("No annotations found for split", split)
            logger.warning(f"No annotations found for split {split}")
            return None, None
        
        start = time.time()
        print(f"Initializing SAM for video...")
        sam2_predictor = build_sam2_video_predictor(model_cfg, checkpoint, device=device)

        inference_state = sam2_predictor.init_state(
            video_path = frames_path,
        )
        end = time.time()
        print(f"Initialization took {end - start:.2f} seconds.")
        logger.info(f"SAM initialization for {forward_pass_path} took {end - start:.2f} seconds.")

        current_annotations = annotations[forward_pass_path]
        print(f"Found {len(current_annotations)} objects with annotations.")
        logger.info(f"Processing {len(current_annotations)} objects with annotations for {forward_pass_path}")
        
        n_points = 0
        for object in tqdm(current_annotations, desc=f"Processing annotations for objects"):
            n_points += len(current_annotations[object])
            object_points = []
            object_labels = []
            for annotation in current_annotations[object]:
                object_points.append([int(annotation["x"]), int(annotation["y"])])
                object_labels.append(annotation["label"])

            points = np.array(object_points, np.float32)
            labels = np.array(object_labels, np.int32)

            if points.shape[0] == 0:
                continue
            
            _, out_obj_ids, out_mask_logits = sam2_predictor.add_new_points_or_box(
                inference_state = inference_state,
                frame_idx = 0,
                obj_id = int(object),
                points = points,
                labels = labels,
            )
        print(f"Added {n_points} annotation points across all objects.")
        logger.info(f"Added {n_points} annotation points across all objects for {forward_pass_path}")

        print("Starting mask propagation...")
        start = time.time()
        video_segments = {}

        n_frames = 0
        for out_frame_idx, out_obj_ids, out_mask_logits in sam2_predictor.propagate_in_video(inference_state):
            n_frames += 1
            video_segments[out_frame_idx] = {
                out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
                for i, out_obj_id in enumerate(out_obj_ids)
            }

        end = time.time()
        prop_time = end - start

        print(f"Mask propagation completed in {prop_time:.2f} seconds.")
        logger.info(f"Mask propagation for {forward_pass_path} took {prop_time:.2f} seconds.")

        print("Processing predicted masks...")
        for frame_idx, obj_dict in tqdm(video_segments.items(), desc="Processing video frames"):
            # it should have shape (1080, 1920)
            # mask_storage_data[frame_idx] = []
            overall_mask = np.zeros((1080, 1920), dtype=bool)

            for obj_id, mask_array in obj_dict.items():
                # mask_storage_data[frame_idx].append({
                #     obj_id: mask_array
                # })
                
                overall_mask = np.logical_or(overall_mask, mask_array.squeeze())

            predicted_masks.append(overall_mask)
            mask_storage_data[frame_idx] = overall_mask

        sam2_predictor.reset_state(inference_state)
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            print("Empied CUDA cache.")
        print("SAM state reset.")
    elif inference_model == "yolo11x-seg":
        start = time.time()
        print(f"Initializing Yolo for video...")
        yolo_model = YOLO(checkpoint)
        end = time.time()
        print(f"Initialization took {end - start:.2f} seconds.")
        logger.info(f"Yolo initialization for {frames_path} took {end - start:.2f} seconds.")

        print("Starting mask propagation...")
        start = time.time()

        total_images = len(os.listdir(frames_path))
        print(f"Total images in video: {total_images}")
        logger.info(f"Total images in video: {total_images}")

        video_result = []

        for i in tqdm(range(total_images), desc="Propagating video frames"):
            frame_path = frames_path + f"/{i}.png"
            # video_result.append(yolo_model(frame_path, max_det = 2))
            video_result.append(yolo_model(frame_path))

        end = time.time()
        prop_time = end - start

        print(f"Mask propagation completed in {prop_time:.2f} seconds.")
        logger.info(f"Mask propagation for {frames_path} took {prop_time:.2f} seconds.")

        print("Processing predicted masks...")

        for frame_idx, frame_result in tqdm(enumerate(video_result), desc="Processing video frames"):
            # it should have shape (1080, 1920)
            # mask_storage_data[frame_idx] = []
            overall_mask = np.zeros((1080, 1920), dtype=bool)

            for result in frame_result:
                if result.masks is None:
                    # mask_storage_data[frame_idx].append({
                    #     0: overall_mask
                    # })
                    continue

                for mask_id, mask in enumerate(result.masks.data):
                    mask_np = mask.cpu().numpy()
                    reshaped_mask = cv2.resize(mask_np, (1920, 1080), interpolation=cv2.INTER_NEAREST)

                    # mask_storage_data[frame_idx].append({
                    #     mask_id: reshaped_mask
                    # })

                    overall_mask = np.logical_or(overall_mask, reshaped_mask)

            predicted_masks.append(overall_mask)
            mask_storage_data[frame_idx] = overall_mask
    else:
        raise ValueError("Invalid model name.")
    
    return mask_storage_data, predicted_masks

In [None]:
import pandas as pd
import os
import datetime

def process_video(frames_path, sub_dir, domain, split, is_left):
    print("="*50)
    print(f"Processing video: {frames_path}")
    print(f"Domain: {domain}, Split: {split}, Camera: {'left' if is_left else 'right'}")
    logger.info(f"Processing video: {frames_path} (Domain: {domain}, Split: {split}, Camera: {'left' if is_left else 'right'})")
    stereo_dir = "left" if is_left else "right"
    ground_truth_masks_path = sub_dir + "/ground_truth/" + stereo_dir

    overall_start = time.time()
    mask_storage_data, predicted_masks = run_inference(model, frames_path, split, False, frames_path)

    if perform_tta:
        temp_video_frames_path = "data/temp"
        if not os.path.exists(temp_video_frames_path):
            os.makedirs(temp_video_frames_path)

        total_files = len(os.listdir(frames_path))
        for filename in os.listdir(frames_path):
            if filename.endswith(".png"):
                original_index = int(filename.split('.')[0])
                new_index = total_files - 1 - original_index
                new_filename = f"{new_index}.png"
                shutil.copy(os.path.join(frames_path, filename), os.path.join(temp_video_frames_path, new_filename))

        reverse_mask_storage_data, predicted_reverse_masks = run_inference(model, temp_video_frames_path, split, True, frames_path)
        predicted_reverse_masks = predicted_reverse_masks[::-1]

        if save_img:
            save_dir = f"data/results/{model}/visualizations"
            os.makedirs(save_dir, exist_ok=True)
            
            for i in range(len(predicted_masks)):
                # Get frame path and load original image
                frame_path = os.path.join(frames_path, f"{i}.png")
                original_img = cv2.imread(frame_path)
                
                # Load ground truth mask
                gt_mask = cv2.imread(os.path.join(ground_truth_masks_path, f"{i}.png"), cv2.IMREAD_GRAYSCALE)
                gt_mask = (gt_mask > 0).astype(np.uint8) * 255
                
                # Convert predicted masks to uint8
                forward_mask = predicted_masks[i].astype(np.uint8) * 255
                reverse_mask = predicted_reverse_masks[i].astype(np.uint8) * 255
                
                # Create visualization grid
                h, w = original_img.shape[:2]
                grid = np.zeros((h*2, w*2, 3), dtype=np.uint8)
                
                # Place images in grid
                grid[:h, :w] = original_img  # Original
                grid[:h, w:] = cv2.cvtColor(gt_mask, cv2.COLOR_GRAY2BGR)  # Ground truth
                grid[h:, :w] = cv2.cvtColor(forward_mask, cv2.COLOR_GRAY2BGR)  # Forward mask
                grid[h:, w:] = cv2.cvtColor(reverse_mask, cv2.COLOR_GRAY2BGR)  # Reverse mask
                
                # Add labels
                labels = ['Original', 'Ground Truth', 'Forward Mask', 'Reverse Mask']
                positions = [(10, 30), (w+10, 30), (10, h+30), (w+10, h+30)]
                
                for label, pos in zip(labels, positions):
                    cv2.putText(grid, label, pos, cv2.FONT_HERSHEY_SIMPLEX, 1, (255,255,255), 2)
                
                # Save the grid
                save_path = os.path.join(save_dir, f"frame_{i:04d}.png")
                cv2.imwrite(save_path, grid)
                
                print(f"Saved visualization images to {save_dir}")

        save_img = False

        for i in range(len(predicted_masks)):
            predicted_masks[i] = np.logical_or(predicted_masks[i], predicted_reverse_masks[i])

        shutil.rmtree(temp_video_frames_path)
        print(f"Applied TTA to {len(predicted_masks)} masks.")
    
    print(f"Generated {len(predicted_masks)} masks for evaluation.")
    logger.info(f"Generated {len(predicted_masks)} masks for {frames_path}")

    # Save the masks
    masks_split_dir = masks_dir + f"/{model}" + f"/{split}"
    if not os.path.exists(masks_split_dir):
        os.makedirs(masks_split_dir)

    masks_file = masks_split_dir + f"/{frames_path.replace('/', '-')}.pkl"

    # data = {}
    # data[frames_path] = mask_storage_data

    # with open(masks_file, "wb") as f:
    #     pickle.dump(data, f)
    # print(f"Masks saved to {masks_file}")

    print("Loading ground truth masks for evaluation...")
    ground_truth_masks = []
    for i in range(len(predicted_masks)):
        ground_truth_mask = cv2.imread(ground_truth_masks_path + "/" + str(i) + ".png", cv2.IMREAD_GRAYSCALE)
        ground_truth_mask = (ground_truth_mask > 0).astype(np.bool_)
        ground_truth_masks.append(ground_truth_mask)
    print(f"Loaded {len(ground_truth_masks)} ground truth masks.")

    print("Calculating evaluation metrics...")
    start = time.time()
    miou = calculate_miou(predicted_masks, ground_truth_masks)
    mdsc = calculate_mdsc(predicted_masks, ground_truth_masks)
    end = time.time()
    eval_time = end - start
    print(f"Time taken for metrics calculation: {eval_time:.2f} seconds.")
    logger.info(f"Time taken for metrics calculation: {eval_time:.2f} seconds.")

    print(f"Mean IoU for {sub_dir}/{domain}/{stereo_dir}: {miou:.4f}")
    print(f"Mean DSC for {sub_dir}/{domain}/{stereo_dir}: {mdsc:.4f}")

    logger.info(f"Mean IoU for {sub_dir}/{domain}/{stereo_dir}: {miou:.4f}")
    logger.info(f"Mean DSC for {sub_dir}/{domain}/{stereo_dir}: {mdsc:.4f}")

    results_file = results_dir + f"/{model}" + f"/{split}.json"
    if os.path.exists(results_file):
        print(f"Loading existing results file: {results_file}")
        with open(results_file, "r") as f:
            all_results = json.load(f)
    else:
        print(f"Creating new results file: {results_file}")
        all_results = {}
    
    all_results[frames_path] = {
        "miou": miou,
        "mdsc": mdsc
    }

    with open(results_file, "w") as f:
        json.dump(all_results, f)
    print(f"Results saved to {results_file}")

    overall_end = time.time()
    total_time = overall_end - overall_start
    print(f"Processing video took {total_time:.2f} seconds.")
    logger.info(f"Results for {sub_dir}/{domain}/{stereo_dir} saved.")
    logger.info(f"Processing video took {total_time:.2f} seconds.")
    print("="*50)

    return miou, mdsc

def process_split(sub_dirs, domains, split):
    print("="*80)
    print(f"Running inference for split: {split}")
    logger.info(f"Using Model: {model}")
    logger.info(f"Annotation mode: {'manual' if is_manual_annotation else 'auto'}")
    logger.info(f"Performing tta: {'yes' if perform_tta else 'no'}")
    if not is_manual_annotation:
        logger.info(f"Number of positive point annotations per tool: {num_pos_points_per_tool}")
        logger.info(f"Number of negative point annotations per tool: {num_neg_points_per_tool}")

    print("="*80)
    logger.info(f"----------------Running inference for split {split}-------------")
    overall_start = time.time()
    
    print(f"Processing {len(sub_dirs)} sub-directories and {len(domains)} domains")
    logger.info(f"Processing {len(sub_dirs)} sub-directories and {len(domains)} domains for split {split}")
        
    sub_dir_results = {}
    for sub_dir in tqdm(sub_dirs, desc=f"Processing sub-directories"):
        print("\n" + "-"*60)
        print(f"Processing sub-directory: {sub_dir}")
        logger.info(f"Processing sub-directory: {sub_dir}")
        domain_results = {}
        for domain in tqdm(domains, desc=f"Processing domains"):
            print(f"\nProcessing domain: {domain}")
            logger.info(f"Processing domain: {domain} in {sub_dir}")
            left_video_frames_path = sub_dir + "/" + domain + "/left"
            right_video_frames_path = sub_dir + "/" + domain + "/right"

            print(f"Processing left camera video...")
            left_miou, left_msdc = process_video(left_video_frames_path, sub_dir, domain, split, True)
            
            print(f"Processing right camera video...")
            right_miou, right_msdc = process_video(right_video_frames_path, sub_dir, domain, split, False)

            overall_miou = (left_miou + right_miou) / 2
            overall_msdc = (left_msdc + right_msdc) / 2

            print(f"\nResults for {sub_dir}/{domain}:")
            print(f"  Left: IoU={left_miou:.4f}, DSC={left_msdc:.4f}")
            print(f"  Right: IoU={right_miou:.4f}, DSC={right_msdc:.4f}")
            print(f"  Overall: IoU={overall_miou:.4f}, DSC={overall_msdc:.4f}")
            
            logger.info(f"Results for {sub_dir}/{domain}: Left IoU={left_miou:.4f}, Right IoU={right_miou:.4f}, Overall IoU={overall_miou:.4f}")

            domain_results[domain] = {
                "left_miou": left_miou,
                "left_msdc": left_msdc,
                "right_miou": right_miou,
                "right_msdc": right_msdc,
                "overall_miou": overall_miou,
                "overall_msdc": overall_msdc
            }

        sub_dir_results[sub_dir] = domain_results

    print("\n" + "="*60)
    print(f"SUMMARY RESULTS FOR SPLIT: {split}")
    print("="*60)
    logger.info(f"SUMMARY RESULTS FOR SPLIT: {split}")
    logger.info(f"Using Model: {model}")
    logger.info(f"Annotation mode: {'manual' if is_manual_annotation else 'auto'}")
    logger.info(f"Performing tta: {'yes' if perform_tta else 'no'}")
    if not is_manual_annotation:
        logger.info(f"Number of positive point annotations per tool: {num_pos_points_per_tool}")
        logger.info(f"Number of negative point annotations per tool: {num_neg_points_per_tool}")

    # Domain-wise results
    print("\nDomain-wise Results:")
    logger.info("Domain-wise Results:")
    domain_results_data = {}
    for domain in domains:
        left_mious = [sub_dir_results[sub_dir][domain]["left_miou"] for sub_dir in sub_dirs]
        right_mious = [sub_dir_results[sub_dir][domain]["right_miou"] for sub_dir in sub_dirs]
        overall_mious = [sub_dir_results[sub_dir][domain]["overall_miou"] for sub_dir in sub_dirs]

        left_msdcs = [sub_dir_results[sub_dir][domain]["left_msdc"] for sub_dir in sub_dirs]
        right_msdcs = [sub_dir_results[sub_dir][domain]["right_msdc"] for sub_dir in sub_dirs]
        overall_msdcs = [sub_dir_results[sub_dir][domain]["overall_msdc"] for sub_dir in sub_dirs]

        print(f"\nDomain: {domain}")
        print(f"  Left Frame IoU: {np.mean(left_mious):.4f}")
        print(f"  Right Frame IoU: {np.mean(right_mious):.4f}")
        print(f"  Overall IoU: {np.mean(overall_mious):.4f}")
        print(f"  Left Frame DSC: {np.mean(left_msdcs):.4f}")
        print(f"  Right Frame DSC: {np.mean(right_msdcs):.4f}")
        print(f"  Overall DSC: {np.mean(overall_msdcs):.4f}")
        
        logger.info(f"Domain {domain} - Left IoU: {np.mean(left_mious):.4f}, Right IoU: {np.mean(right_mious):.4f}, Overall IoU: {np.mean(overall_mious):.4f}")
        logger.info(f"Domain {domain} - Left DSC: {np.mean(left_msdcs):.4f}, Right DSC: {np.mean(right_msdcs):.4f}, Overall DSC: {np.mean(overall_msdcs):.4f}")
        
        domain_results_data[domain] = {
            "left_miou": np.mean(left_mious),
            "right_miou": np.mean(right_mious),
            "overall_miou": np.mean(overall_mious),
            "left_mdsc": np.mean(left_msdcs),
            "right_mdsc": np.mean(right_msdcs),
            "overall_mdsc": np.mean(overall_msdcs)
        }

    # Overall results across all domains and sub-dirs
    left_mious = [np.mean([sub_dir_results[sub_dir][domain]["left_miou"] for domain in domains]) for sub_dir in sub_dirs]
    right_mious = [np.mean([sub_dir_results[sub_dir][domain]["right_miou"] for domain in domains]) for sub_dir in sub_dirs]
    overall_mious = [np.mean([sub_dir_results[sub_dir][domain]["overall_miou"] for domain in domains]) for sub_dir in sub_dirs]

    left_msdcs = [np.mean([sub_dir_results[sub_dir][domain]["left_msdc"] for domain in domains]) for sub_dir in sub_dirs]
    right_msdcs = [np.mean([sub_dir_results[sub_dir][domain]["right_msdc"] for domain in domains]) for sub_dir in sub_dirs]
    overall_msdcs = [np.mean([sub_dir_results[sub_dir][domain]["overall_msdc"] for domain in domains]) for sub_dir in sub_dirs]

    print("\n" + "-"*60)
    print("FINAL RESULTS ACROSS ALL DOMAINS AND SUB-DIRECTORIES:")
    print(f"  Left Frame IoU: {np.mean(left_mious):.4f}")
    print(f"  Right Frame IoU: {np.mean(right_mious):.4f}")
    print(f"  Overall IoU: {np.mean(overall_mious):.4f}")
    print(f"  Left Frame DSC: {np.mean(left_msdcs):.4f}")
    print(f"  Right Frame DSC: {np.mean(right_msdcs):.4f}")
    print(f"  Overall DSC: {np.mean(overall_msdcs):.4f}")

    logger.info("FINAL RESULTS ACROSS ALL DOMAINS AND SUB-DIRECTORIES:")
    logger.info(f"Left Frame IoU: {np.mean(left_mious):.4f}")
    logger.info(f"Right Frame IoU: {np.mean(right_mious):.4f}")
    logger.info(f"Overall IoU: {np.mean(overall_mious):.4f}")
    logger.info(f"Left Frame DSC: {np.mean(left_msdcs):.4f}")
    logger.info(f"Right Frame DSC: {np.mean(right_msdcs):.4f}")
    logger.info(f"Overall DSC: {np.mean(overall_msdcs):.4f}")

    overall_end = time.time()
    total_time = overall_end - overall_start
    print(f"\nTotal time taken for split {split}: {total_time:.2f} seconds ({total_time/60:.2f} minutes)")
    logger.info(f"Total time taken for split {split}: {total_time:.2f} seconds ({total_time/60:.2f} minutes)")
    print("="*80)
    
    # Save results to CSV
    
    csv_file = f'data/results/results.csv'
    timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    
    data = {
        'timestamp': timestamp,
        'model': model,
        'split': split,
        'annotation_mode': 'manual' if is_manual_annotation else 'auto',
        'num_pos_points': num_pos_points_per_tool if not is_manual_annotation else 'N/A',
        'num_neg_points': num_neg_points_per_tool if not is_manual_annotation else 'N/A',
        'tta': 'yes' if perform_tta else 'no',
        'overall_left_miou': np.mean(left_mious),
        'overall_right_miou': np.mean(right_mious),
        'overall_miou': np.mean(overall_mious),
        'overall_left_mdsc': np.mean(left_msdcs),
        'overall_right_mdsc': np.mean(right_msdcs),
        'overall_mdsc': np.mean(overall_msdcs),
        'total_time_seconds': total_time,
        'total_time_minutes': total_time/60
    }
    
    # Add domain-specific results
    for domain in domains:
        data[f'{domain}_left_miou'] = domain_results_data[domain]['left_miou']
        data[f'{domain}_right_miou'] = domain_results_data[domain]['right_miou']
        data[f'{domain}_overall_miou'] = domain_results_data[domain]['overall_miou']
        data[f'{domain}_left_mdsc'] = domain_results_data[domain]['left_mdsc']
        data[f'{domain}_right_mdsc'] = domain_results_data[domain]['right_mdsc']
        data[f'{domain}_overall_mdsc'] = domain_results_data[domain]['overall_mdsc']
    
    # Convert to DataFrame for a single row
    df_new = pd.DataFrame([data])
    
    # Check if file exists and append, or create new
    if os.path.exists(csv_file):
        df_existing = pd.read_csv(csv_file)
        df_combined = pd.concat([df_existing, df_new], ignore_index=True)
        df_combined.to_csv(csv_file, index=False)
    else:
        # Create directory if it doesn't exist
        os.makedirs(os.path.dirname(csv_file), exist_ok=True)
        df_new.to_csv(csv_file, index=False)
    
    print(f"Results saved to CSV file: {csv_file}")
    logger.info(f"Results saved to CSV file: {csv_file}")


process_split(sub_dirs, domains, "val")
# process_split(test_sub_dirs, test_domains, "test")
# process_split(train_sub_dirs, train_domains, "train")

Running inference for split: val
Processing 3 sub-directories and 5 domains


Processing sub-directories:   0%|          | 0/3 [00:00<?, ?it/s]


------------------------------------------------------------
Processing sub-directory: data/raw/SegSTRONGC_val/val/1/2





Processing domain: bg_change
Processing left camera video...
Processing video: data/raw/SegSTRONGC_val/val/1/2/bg_change/left
Domain: bg_change, Split: val, Camera: left
Loading annotations for split: val
Successfully loaded annotations for 30 items
Initializing SAM for video...



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
frame loading (JPEG): 100%|██████████| 300/300 [00:30<00:00,  9.71it/s]


Initialization took 58.92 seconds.
Found 2 objects with annotations.



[A
Processing annotations for objects: 100%|██████████| 2/2 [00:02<00:00,  1.15s/it]


Added 20 annotation points across all objects.
Starting mask propagation...



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
propagate in video:  15%|█▌        | 45/300 [06:47<38:30,  9.06s/it]
Processing domains:   0%|          | 0/5 [07:48<?, ?it/s]
Processing sub-directories:   0%|          | 0/3 [07:48<?, ?it/s]


KeyboardInterrupt: 

In [None]:
# domains = ['bg_change', 'blood', 'low_brightness', 'regular', 'smoke']
# annotations = None
# with open('data/annotations/auto/val.json', 'r') as f:
#     annotations = json.load(f)

# path = "data/masks/sam2.1_hiera_base_plus/val/data-raw-SegSTRONGC_val-val-1-2-bg_change-right.pkl"
# with open(path, 'rb') as f:
#     mass = pickle.load(f)
#     for video_path, video_data in mass.items():
#         # print(video_path, video_data)
#         for frame_id, frame_data in video_data.items():
#             overall_mask = np.zeros((1080, 1920), dtype=bool)
#             for data in frame_data:
#                 for object_id, mask in data.items():
#                     overall_mask = np.logical_or(overall_mask, mask[0])

#             ground_truth_masks_path = video_path
#             for domain in domains:
#                 if domain in video_path:
#                     ground_truth_masks_path = video_path.replace(domain, 'ground_truth')
#                     break
#             ground_truth_masks_path = ground_truth_masks_path + "/" + str(frame_id) + ".jpg"
#             ground_truth_mask = cv2.imread(ground_truth_masks_path, cv2.IMREAD_GRAYSCALE)
#             ground_truth_mask = (ground_truth_mask > 0).astype(np.bool_)

#             original_image = cv2.imread(video_path + "/" + str(frame_id) + ".jpg")
#             original_image = cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB)

#             # frame_annotations = annotations[video_path.replace("/left", "").replace("/right", "")]
#             frame_annotations = annotations[video_path]

#             # place dots in the original image for each annotation
#             for object_id, object_annotations in frame_annotations.items():
#                 # print(object_annotations)
#                 for annotation in object_annotations:
#                     x = annotation['x']
#                     y = annotation['y']
#                     label = annotation['label']
#                     if object_id == "0":
#                         original_image = cv2.circle(original_image, (x, y), 10, (0, 255, 0), -1)
#                     else:
#                         original_image = cv2.circle(original_image, (x, y), 10, (255, 0, 0), -1)

#             #show the masks and the original image
#             fig, axs = plt.subplots(1, 3, figsize=(30, 15))
#             axs[0].imshow(original_image)
#             axs[0].set_title("Original Image")
#             axs[1].imshow(overall_mask)
#             axs[1].set_title("Overall Mask")
#             axs[2].imshow(ground_truth_mask)
#             axs[2].set_title("Ground Truth Mask")
#             plt.show()
#             break