In [None]:
import torch
print(torch.cuda.is_available())
print(torch.cuda.get_device_name(0))

In [None]:
!apt-get update && apt-get install -y wget

In [None]:
import os

video_folder = "/EchoNet-Dynamic/Videos"
csv_path = "/EchoNet-Dynamic/VolumeTracings.csv"

# Create the folder that will hold outputs (not the CSV file!)
output_root = "/EchoNet-Dynamic/dataset_frames"
os.makedirs(output_root, exist_ok=True)

print("Video folder exists:", os.path.exists(video_folder))
print("CSV file exists:", os.path.isfile(csv_path))


In [None]:
!mkdir -p ./checkpoints/
!curl https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_tiny.pt --output checkpoints.pt #change to large when need be

In [None]:
import os
# if using Apple MPS, fall back to CPU for unsupported ops
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
import numpy as np
import torch
import matplotlib.pyplot as plt
from PIL import Image

In [None]:
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")
print(f"using device: {device}")

In [None]:
if device.type == "cuda":
    # use bfloat16 for the entire notebook
    torch.autocast("cuda", dtype=torch.bfloat16).__enter__()
    # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
    if torch.cuda.get_device_properties(0).major >= 8:
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True
elif device.type == "mps":
    print(
        "\nSupport for MPS devices is preliminary. SAM 2 is trained with CUDA and might "
        "give numerically different outputs and sometimes degraded performance on MPS. "
        "See e.g. https://github.com/pytorch/pytorch/issues/84936 for a discussion."
    )

In [None]:
import os
import urllib.request

# Always stay inside your home folder
os.makedirs("checkpoints", exist_ok=True)
os.makedirs("videos", exist_ok=True)

# Download model checkpoint
model_url = "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_tiny.pt"
model_path = "checkpoints/sam2.1_hiera_tiny.pt"
if not os.path.exists(model_path):
    urllib.request.urlretrieve(model_url, model_path)
    print("Downloaded model checkpoint")



In [None]:
from sam2.build_sam import build_sam2_video_predictor

sam2_checkpoint = "/MedSAM2/checkpoints/MedSAM2_latest.pt"
model_cfg = "configs/sam2.1_hiera_t512.yaml"

#predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device=device)
predictor = build_sam2_video_predictor(
    config_file=model_cfg,
    ckpt_path= sam2_checkpoint,
    apply_postprocessing=True,
    # hydra_overrides_extra=hydra_overrides_extra,
    vos_optimized=  True,
)

In [None]:
import cv2
import os
from PIL import Image
import matplotlib.pyplot as plt

video_folder = "/EchoNet-Dynamic/Videos"      # input: where .avi videos are
output_root = "EchoNet-Dynamic/dataset_frames"       # output: where frames will go
os.makedirs(output_root, exist_ok=True)

# list of .avi files
avi_files = [f for f in os.listdir(video_folder) if f.endswith(".avi")]
avi_files.sort()  

# Process only the first video
video_file = avi_files[0]
video_path = os.path.join(video_folder, video_file)
video_name = os.path.splitext(video_file)[0]


video_dir = os.path.join(output_root, video_name)  # == where .jpg frames are saved
os.makedirs(video_dir, exist_ok=True)

# Extract frames
cap = cv2.VideoCapture(video_path)
frame_idx = 0

while True:
    ret, frame = cap.read()
    if not ret:
        break

    frame_path = os.path.join(video_dir, f"{frame_idx}.jpg")  # use `video_dir` here
    cv2.imwrite(frame_path, frame)
    frame_idx += 1

cap.release()
print(f"Extracted {frame_idx} frames from '{video_file}' into '{video_dir}'")

# List and sort frame files from the video_dir
frame_names = [
    p for p in os.listdir(video_dir)
    if os.path.splitext(p)[-1].lower() in [".jpg", ".jpeg"]
]
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))

# Show the first frame (just like the original code)
frame_idx = 0
plt.figure(figsize=(9, 6))
plt.title(f"frame {frame_idx}")
plt.imshow(Image.open(os.path.join(video_dir, frame_names[frame_idx])))
plt.axis('off')
plt.show()


In [None]:
inference_state = predictor.init_state(video_path=video_dir)
predictor.reset_state(inference_state)

In [None]:
"""EchoNet-Dynamic Dataset.""" #Code from this cell obtained from https://github.com/echonet/dynamic/blob/master/echonet/datasets/echo.py
import os
import collections
import pandas

import numpy as np
import skimage.draw
import torchvision
import echonet


class Echo(torchvision.datasets.VisionDataset):
    """EchoNet-Dynamic Dataset.

    Args:
        root (string): Root directory of dataset (defaults to `echonet.config.DATA_DIR`)
        split (string): One of {``train'', ``val'', ``test'', ``all'', or ``external_test''}
        target_type (string or list, optional): Type of target to use,
            ``Filename'', ``EF'', ``EDV'', ``ESV'', ``LargeIndex'',
            ``SmallIndex'', ``LargeFrame'', ``SmallFrame'', ``LargeTrace'',
            or ``SmallTrace''
            Can also be a list to output a tuple with all specified target types.
            The targets represent:
                ``Filename'' (string): filename of video
                ``EF'' (float): ejection fraction
                ``EDV'' (float): end-diastolic volume
                ``ESV'' (float): end-systolic volume
                ``LargeIndex'' (int): index of large (diastolic) frame in video
                ``SmallIndex'' (int): index of small (systolic) frame in video
                ``LargeFrame'' (np.array shape=(3, height, width)): normalized large (diastolic) frame
                ``SmallFrame'' (np.array shape=(3, height, width)): normalized small (systolic) frame
                ``LargeTrace'' (np.array shape=(height, width)): left ventricle large (diastolic) segmentation
                    value of 0 indicates pixel is outside left ventricle
                             1 indicates pixel is inside left ventricle
                ``SmallTrace'' (np.array shape=(height, width)): left ventricle small (systolic) segmentation
                    value of 0 indicates pixel is outside left ventricle
                             1 indicates pixel is inside left ventricle
            Defaults to ``EF''.
        mean (int, float, or np.array shape=(3,), optional): means for all (if scalar) or each (if np.array) channel.
            Used for normalizing the video. Defaults to 0 (video is not shifted).
        std (int, float, or np.array shape=(3,), optional): standard deviation for all (if scalar) or each (if np.array) channel.
            Used for normalizing the video. Defaults to 0 (video is not scaled).
        length (int or None, optional): Number of frames to clip from video. If ``None'', longest possible clip is returned.
            Defaults to 16.
        period (int, optional): Sampling period for taking a clip from the video (i.e. every ``period''-th frame is taken)
            Defaults to 2.
        max_length (int or None, optional): Maximum number of frames to clip from video (main use is for shortening excessively
            long videos when ``length'' is set to None). If ``None'', shortening is not applied to any video.
            Defaults to 250.
        clips (int, optional): Number of clips to sample. Main use is for test-time augmentation with random clips.
            Defaults to 1.
        pad (int or None, optional): Number of pixels to pad all frames on each side (used as augmentation).
            and a window of the original size is taken. If ``None'', no padding occurs.
            Defaults to ``None''.
        noise (float or None, optional): Fraction of pixels to black out as simulated noise. If ``None'', no simulated noise is added.
            Defaults to ``None''.
        target_transform (callable, optional): A function/transform that takes in the target and transforms it.
        external_test_location (string): Path to videos to use for external testing.
    """

    def __init__(self, root=None,
                 split="train", target_type="EF",
                 mean=0., std=1.,
                 length=16, period=2,
                 max_length=250,
                 clips=1,
                 pad=None,
                 noise=None,
                 target_transform=None,
                 external_test_location=None):
        if root is None:
            root = "/shared_data/p_vidalr/iraj/EchoNet-Dynamic"

        super().__init__(root, target_transform=target_transform)

        self.split = split.upper()
        if not isinstance(target_type, list):
            target_type = [target_type]
        self.target_type = target_type
        self.mean = mean
        self.std = std
        self.length = length
        self.max_length = max_length
        self.period = period
        self.clips = clips
        self.pad = pad
        self.noise = noise
        self.target_transform = target_transform
        self.external_test_location = external_test_location

        self.fnames, self.outcome = [], []

        if self.split == "EXTERNAL_TEST":
            self.fnames = sorted(os.listdir(self.external_test_location))
        else:
            # Load video-level labels
            with open(os.path.join(self.root, "FileList.csv")) as f:
                data = pandas.read_csv(f)
            data["Split"].map(lambda x: x.upper())

            if self.split != "ALL":
                data = data[data["Split"] == self.split]

            self.header = data.columns.tolist()
            self.fnames = data["FileName"].tolist()
            self.fnames = [fn + ".avi" for fn in self.fnames if os.path.splitext(fn)[1] == ""]  # Assume avi if no suffix
            self.outcome = data.values.tolist()

            # Check that files are present
            missing = set(self.fnames) - set(os.listdir(os.path.join(self.root, "Videos")))
            if len(missing) != 0:
                print("{} videos could not be found in {}:".format(len(missing), os.path.join(self.root, "Videos")))
                for f in sorted(missing):
                    print("\t", f)
                raise FileNotFoundError(os.path.join(self.root, "Videos", sorted(missing)[0]))

            # Load tracings
            self.frames = collections.defaultdict(list)
            self.trace = collections.defaultdict(_defaultdict_of_lists)

            with open(os.path.join(self.root, "VolumeTracings.csv")) as f:
                header = f.readline().strip().split(",")
                assert header == ["FileName", "X1", "Y1", "X2", "Y2", "Frame"]

                for line in f:
                    filename, x1, y1, x2, y2, frame = line.strip().split(',')
                    x1 = float(x1)
                    y1 = float(y1)
                    x2 = float(x2)
                    y2 = float(y2)
                    frame = int(frame)
                    if frame not in self.trace[filename]:
                        self.frames[filename].append(frame)
                    self.trace[filename][frame].append((x1, y1, x2, y2))
            for filename in self.frames:
                for frame in self.frames[filename]:
                    self.trace[filename][frame] = np.array(self.trace[filename][frame])

            # A small number of videos are missing traces; remove these videos
            keep = [len(self.frames[f]) >= 2 for f in self.fnames]
            self.fnames = [f for (f, k) in zip(self.fnames, keep) if k]
            self.outcome = [f for (f, k) in zip(self.outcome, keep) if k]

    def __getitem__(self, index):
        # Find filename of video
        if self.split == "EXTERNAL_TEST":
            video = os.path.join(self.external_test_location, self.fnames[index])
        elif self.split == "CLINICAL_TEST":
            video = os.path.join(self.root, "ProcessedStrainStudyA4c", self.fnames[index])
        else:
            video = os.path.join(self.root, "Videos", self.fnames[index])

        # Load video into np.array
        video = echonet.utils.loadvideo(video).astype(np.float32)

        # Add simulated noise (black out random pixels)
        # 0 represents black at this point (video has not been normalized yet)
        if self.noise is not None:
            n = video.shape[1] * video.shape[2] * video.shape[3]
            ind = np.random.choice(n, round(self.noise * n), replace=False)
            f = ind % video.shape[1]
            ind //= video.shape[1]
            i = ind % video.shape[2]
            ind //= video.shape[2]
            j = ind
            video[:, f, i, j] = 0

        # Apply normalization
        if isinstance(self.mean, (float, int)):
            video -= self.mean
        else:
            video -= self.mean.reshape(3, 1, 1, 1)

        if isinstance(self.std, (float, int)):
            video /= self.std
        else:
            video /= self.std.reshape(3, 1, 1, 1)

        # Set number of frames
        c, f, h, w = video.shape
        if self.length is None:
            # Take as many frames as possible
            length = f // self.period
        else:
            # Take specified number of frames
            length = self.length

        if self.max_length is not None:
            # Shorten videos to max_length
            length = min(length, self.max_length)

        if f < length * self.period:
            # Pad video with frames filled with zeros if too short
            # 0 represents the mean color (dark grey), since this is after normalization
            video = np.concatenate((video, np.zeros((c, length * self.period - f, h, w), video.dtype)), axis=1)
            c, f, h, w = video.shape  # pylint: disable=E0633

        if self.clips == "all":
            # Take all possible clips of desired length
            start = np.arange(f - (length - 1) * self.period)
        else:
            # Take random clips from video
            start = np.random.choice(f - (length - 1) * self.period, self.clips)

        # Gather targets
        target = []
        for t in self.target_type:
            key = self.fnames[index]
            if t == "Filename":
                target.append(self.fnames[index])
            elif t == "LargeIndex":
                # Traces are sorted by cross-sectional area
                # Largest (diastolic) frame is last
                target.append(np.int(self.frames[key][-1]))
            elif t == "SmallIndex":
                # Largest (diastolic) frame is first
                target.append(np.int(self.frames[key][0]))
            elif t == "LargeFrame":
                target.append(video[:, self.frames[key][-1], :, :])
            elif t == "SmallFrame":
                target.append(video[:, self.frames[key][0], :, :])
            elif t in ["LargeTrace", "SmallTrace"]:
                if t == "LargeTrace":
                    t = self.trace[key][self.frames[key][-1]]
                else:
                    t = self.trace[key][self.frames[key][0]]
                x1, y1, x2, y2 = t[:, 0], t[:, 1], t[:, 2], t[:, 3]
                x = np.concatenate((x1[1:], np.flip(x2[1:])))
                y = np.concatenate((y1[1:], np.flip(y2[1:])))

                r, c = skimage.draw.polygon(np.rint(y).astype(int), np.rint(x).astype(int), (video.shape[2], video.shape[3]))


                mask = np.zeros((video.shape[2], video.shape[3]), np.float32)
                mask[r, c] = 1
                target.append(mask)
            else:
                if self.split == "CLINICAL_TEST" or self.split == "EXTERNAL_TEST":
                    target.append(np.float32(0))
                else:
                    target.append(np.float32(self.outcome[index][self.header.index(t)]))

        if target != []:
            target = tuple(target) if len(target) > 1 else target[0]
            if self.target_transform is not None:
                target = self.target_transform(target)

        # Select clips from video
        video = tuple(video[:, s + self.period * np.arange(length), :, :] for s in start)
        if self.clips == 1:
            video = video[0]
        else:
            video = np.stack(video)

        if self.pad is not None:
            # Add padding of zeros (mean color of videos)
            # Crop of original size is taken out
            # (Used as augmentation)
            c, l, h, w = video.shape
            temp = np.zeros((c, l, h + 2 * self.pad, w + 2 * self.pad), dtype=video.dtype)
            temp[:, :, self.pad:-self.pad, self.pad:-self.pad] = video  # pylint: disable=E1130
            i, j = np.random.randint(0, 2 * self.pad, 2)
            video = temp[:, :, i:(i + h), j:(j + w)]

        return video, target

    def __len__(self):
        return len(self.fnames)

    def extra_repr(self) -> str:
        """Additional information to add at end of __repr__."""
        lines = ["Target type: {target_type}", "Split: {split}"]
        return '\n'.join(lines).format(**self.__dict__)


def _defaultdict_of_lists():
    """Returns a defaultdict of lists.

    This is used to avoid issues with Windows (if this function is anonymous,
    the Echo dataset cannot be used in a dataloader).
    """

    return collections.defaultdict(list)

In [None]:
import os
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from echonet.datasets import Echo

#manual_mask means ground truth mask
def save_manual_masks(filename, dataset_root="/EchoNet-Dynamic",
                      output_dir="/EchoNet-Dynamic/manual_masks",
                      show_plot=False):
    """
    Save ED and ES binary masks for a given EchoNet-Dynamic video.

    Parameters
    ----------
    filename : str
        Name of the video file (e.g. "0XF0B7D7CD42C001E.avi")
    dataset_root : str
        Path to the EchoNet-Dynamic dataset root
    output_dir : str
        Directory where masks will be saved (one subfolder per video)
    show_plot : bool
        If True, displays ED/ES overlays
    """

    # --- Load dataset ---
    dataset = Echo(
        root=dataset_root,
        split="test",
        target_type=["LargeFrame", "LargeTrace", "SmallFrame", "SmallTrace"],
        length=None,
        clips=1
    )
   
    index = dataset.fnames.index(filename)
    video, (large_frame, large_mask, small_frame, small_mask) = dataset[index]

    # --- Get ED/ES frame indices ---
    frame_list = dataset.frames[filename]
    es_index = int(frame_list[0])   # ES
    ed_index = int(frame_list[-1])  # ED

    ed_mask = (large_mask > 0).astype(np.uint8)
    es_mask = (small_mask > 0).astype(np.uint8)

    video_out_dir = os.path.join(output_dir, filename)
    os.makedirs(video_out_dir, exist_ok=True)

    ed_path = os.path.join(video_out_dir, f"frame_{ed_index:03d}_manual_mask.png")
    es_path = os.path.join(video_out_dir, f"frame_{es_index:03d}_manual_mask.png")

    Image.fromarray(ed_mask * 255).save(ed_path)
    Image.fromarray(es_mask * 255).save(es_path)

    print(f"[Saved] ED → {ed_path}\n        ES → {es_path}")

    # --- Plot if requested ---
    if show_plot:
        def prep_frame(frame_chw):
            f = np.transpose(frame_chw, (1, 2, 0))
            f = (f - f.min()) / (f.max() - f.min() + 1e-8)
            return f

        plt.figure(figsize=(12, 6))
        plt.subplot(1, 2, 1)
        plt.imshow(prep_frame(large_frame), cmap="gray")
        plt.imshow(ed_mask, cmap="Reds", alpha=0.4)
        plt.title(f"ED overlay (frame {ed_index})"); plt.axis("off")

        plt.subplot(1, 2, 2)
        plt.imshow(prep_frame(small_frame), cmap="gray")
        plt.imshow(es_mask, cmap="Blues", alpha=0.4)
        plt.title(f"ES overlay (frame {es_index})"); plt.axis("off")

        plt.tight_layout(); plt.show()


# Forward Propagation

In [None]:
import os
import cv2
import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt

def dice_score(mask1, mask2):
    mask1 = (mask1 > 0).astype(np.bool_)
    mask2 = (mask2 > 0).astype(np.bool_)
    intersection = np.sum(mask1 & mask2)
    total = np.sum(mask1) + np.sum(mask2)
    return 1.0 if total == 0 else 2.0 * intersection / total
def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))
def show_mask(mask, ax, obj_id=None, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        cmap = plt.get_cmap("tab10")
        cmap_idx = 0 if obj_id is None else obj_id
        color = np.array([*cmap(cmap_idx)[:3], 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)

def evaluate_medsam2_on_video(video_filename, predictor,
                               tracings_csv="/EchoNet-Dynamic/VolumeTracings.csv",
                               video_dir="/EchoNet-Dynamic/Videos",
                               frames_root="/EchoNet-Dynamic/dataset_frames", #You will need to do this on your own, extract the frames for every single video and save them in their respective folder
                               masks_root="/EchoNet-Dynamic/where you want to save the masks from this experiment", #you set the name of this yourself
                               gt_mask_root="/EchoNet-Dynamic/manual_masks", #this is where ground truth is saved
                               yolo_labels_dir="/runs/detect/predict/labels",  #running Sono_YOLO.ipynb code will give you this file
                               image_size=128):
    
    original_size = 112
    padded_size = image_size
    pad = (padded_size - original_size) // 2

    print(f"\n========== Processing Video: {video_filename} ==========")

    video_id = os.path.splitext(video_filename)[0]
    base_name = video_id
    j_name = video_id + '.avi'
    video_path = os.path.join(video_dir, video_filename)
    frames_dir = os.path.join(frames_root, j_name)
    masks_dir = os.path.join(masks_root, base_name)

    os.makedirs(frames_dir, exist_ok=True)
    os.makedirs(masks_dir, exist_ok=True)

    print("[INFO] Loading VolumeTracings.csv...")
    df = pd.read_csv(tracings_csv)
    df_video = df[df['FileName'] == video_filename]
    unique_frames = sorted(df_video['Frame'].unique())

    if len(unique_frames) < 2:
        print(f"[WARNING] Not enough annotated frames in {base_name}")
        return

    first_idx, second_idx = unique_frames[0], unique_frames[1]
    print(f"[INFO] First annotated frame: {first_idx}, Second annotated frame: {second_idx}")

    # =========== Unpad YOLO Prediction to Original Frame Size ============
    label_filename = f"{video_id}_frame{first_idx}.txt"
    label_path = os.path.join(yolo_labels_dir, label_filename)

    if not os.path.exists(label_path):
        print(f"[ERROR] YOLO label file not found: {label_path}")
        return

    with open(label_path, "r") as f:
        line = f.readline().strip()
        parts = line.split()
        if len(parts) != 5:
            print(f"[ERROR] Invalid YOLO label format in {label_path}")
            return
        _, x_center_norm, y_center_norm, width_norm, height_norm = map(float, parts)

    # Convert to 128x128 pixel coordinates
    x_center_128 = x_center_norm * padded_size
    y_center_128 = y_center_norm * padded_size
    width_128 = width_norm * padded_size
    height_128 = height_norm * padded_size

    # Shift center back by -pad to align with 112x112 original frame
    x_center_112 = x_center_128 - pad
    y_center_112 = y_center_128 - pad

    # Final box in original 112x112 space
    xmin = x_center_112 - width_128 / 2
    ymin = y_center_112 - height_128 / 2
    xmax = x_center_112 + width_128 / 2
    ymax = y_center_112 + height_128 / 2
    box = np.array([xmin + 3, ymin + 3, xmax - 3, ymax - 3], dtype=np.float32)

    print(f"[INFO] Unpadded YOLO bbox for MedSAM2: {box}")
    # =====================================================================

    # Step 1: Extract frames
    if len(os.listdir(frames_dir)) == 0:
        print(f"[INFO] Extracting frames from {video_path}...")
        cap = cv2.VideoCapture(video_path)
        idx = 0
        while True:
            ret, frame = cap.read()
            if not ret:
                break
            frame_path = os.path.join(frames_dir, f"{idx}.jpg")
            cv2.imwrite(frame_path, frame)
            idx += 1
        cap.release()
        print(f"[INFO] Extracted {idx} frames.")
    else:
        print(f"[INFO] Frames already present in {frames_dir}")

    # Step 2: Initialize and run MedSAM2
    print(f"[INFO] Initializing MedSAM2 predictor...")
    inference_state = predictor.init_state(video_path=frames_dir)
    predictor.reset_state(inference_state)

    print(f"[INFO] Running MedSAM2 at frame {first_idx}...")
    _, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
        inference_state=inference_state,
        frame_idx=first_idx,
        obj_id=1,
        box=box,
    )

    # Step 3: Propagate masks
    print("[INFO] Propagating masks through the video...")
    for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):
        for i, out_obj_id in enumerate(out_obj_ids):
            mask = (out_mask_logits[i] > 0).cpu().numpy().squeeze().astype(np.uint8) * 255
            mask_path = os.path.join(masks_dir, f"frame_{out_frame_idx}_sam_mask.png")
            Image.fromarray(mask).save(mask_path)
            if out_frame_idx % 30 == 0:
                print(f"[DEBUG] Saved MedSAM2 mask at frame {out_frame_idx}")

    # Step 4: Generate ground truth mask(s) for this video
    print(f"[INFO] Generating ground truth mask(s) for {video_filename}...")
    save_manual_masks(video_filename)
        
    
    # load ground truth mask for the second annotated frame
    gt_mask_path = os.path.join(gt_mask_root, video_filename, f"frame_{second_idx:03d}_manual_mask.png")
    if not os.path.exists(gt_mask_path):
        print(f"[ERROR] Ground truth mask not found at {gt_mask_path}")
        return
    gt_mask = np.array(Image.open(gt_mask_path).convert('L'))
    print(f"[INFO] Loaded ground truth mask from {gt_mask_path}")

    # Step 5: Load MedSAM2 mask
    pred_mask_path = os.path.join(masks_dir, f"frame_{second_idx}_sam_mask.png")
    if not os.path.exists(pred_mask_path):
        print(f"[ERROR] MedSAM2 propagated mask not found at {pred_mask_path}")
        return

    pred_mask = np.array(Image.open(pred_mask_path).convert('L'))
    print(f"[INFO] Loaded MedSAM2 mask from {pred_mask_path}")

    # Step 6: Compute Dice Score
    dice = dice_score(pred_mask, gt_mask)
    print(f"[RESULT] Dice Score at frame {second_idx}: {dice:.4f}")

    # Step 7: Visualize
    frame_img = np.array(Image.open(os.path.join(frames_dir, f"{second_idx:03d}.jpg")))
    fig, axs = plt.subplots(1, 2, figsize=(12, 5))
    axs[0].imshow(frame_img)
    axs[0].imshow(gt_mask, alpha=0.4, cmap="Reds")
    axs[0].set_title("Ground Truth")
    axs[0].axis("off")

    axs[1].imshow(frame_img)
    axs[1].imshow(pred_mask, alpha=0.4, cmap="Greens")
    axs[1].set_title("SonoYolo + MedSAM2 Prediction")
    axs[1].axis("off")

    plt.suptitle(f"Dice Score: {dice:.4f}", fontsize=14)
    plt.tight_layout()
    plt.show()

    print(f"[DONE] Finished processing {video_filename}")
    plt.figure(figsize=(9, 6))
    plt.title(f"{video_filename} - Frame {frame_idx}")
    plt.imshow(frame_img)
    show_box(box, plt.gca())
    show_mask((out_mask_logits[0] > 0).cpu().numpy(), plt.gca(), obj_id=out_obj_ids[0])
    plt.axis('off')
    plt.show()
    return dice


In [None]:
# Test Split Batch Evaluation

from echonet.datasets import Echo
import pandas as pd
import numpy as np
import traceback
import time


# 1) Load EchoNet test split to get the .avi filenames
dataset = Echo(
    root="/EchoNet-Dynamic",
    split="test",
    target_type="Filename",   # we only need filenames here
    length=None,
    clips=1
)

test_fnames = dataset.fnames   # list of 'XXXXXXXXXXXX.avi'

print(f"[INFO] Found {len(test_fnames)} test videos.")

# 2) Optional: suppress the per-video plt.show() to avoid popping windows / slowing down
import contextlib, matplotlib.pyplot as plt
@contextlib.contextmanager
def suppress_plots():
    _orig_show = plt.show
    plt.show = lambda *a, **k: None
    try:
        yield
    finally:
        plt.show = _orig_show

# 3) Run evaluation for each test video
results = []
start = time.time()

with suppress_plots():   # remove this context manager if you DO want the plots
    for i, fname in enumerate(test_fnames, 1):
        try:
            print(f"\n=== [{i}/{len(test_fnames)}] {fname} ===")
            dice = evaluate_medsam2_on_video(
                video_filename=fname,
                predictor=predictor,
                # keep your defaults, or override here if needed:
                tracings_csv="/EchoNet-Dynamic/VolumeTracings.csv",
                video_dir="/EchoNet-Dynamic/Videos",
                frames_root="/EchoNet-Dynamic/dataset_frames",
                masks_root="/EchoNet-Dynamic/yournameforthemasks", #you must modify this
                gt_mask_root="/EchoNet-Dynamic/manual_masks", #your ground truth masks
            )
            results.append({"filename": fname, "dice": float(dice) if dice is not None else np.nan})
        except Exception as e:
            print(f"[ERROR] {fname}: {e}")
            traceback.print_exc()
            results.append({"filename": fname, "dice": np.nan})

elapsed = time.time() - start
print(f"\n[SUMMARY] Finished {len(test_fnames)} videos in {elapsed/60:.1f} min.")

# 4) Save CSV + print mean Dice
df_res = pd.DataFrame(results).sort_values("filename")
csv_path = "/home/iraj/EchoNet-Dynamic/_____.csv" #edit the csv path
df_res.to_csv(csv_path, index=False)
print(f"[SUMMARY] Results saved to: {csv_path}")
print(f"[SUMMARY] Mean Dice (ignoring NaN): {df_res['dice'].mean(skipna=True):.4f}")


# Backward Propagation

In [None]:
import os
import cv2
import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt

def dice_score(mask1, mask2):
    mask1 = (mask1 > 0).astype(np.bool_)
    mask2 = (mask2 > 0).astype(np.bool_)
    intersection = np.sum(mask1 & mask2)
    total = np.sum(mask1) + np.sum(mask2)
    return 1.0 if total == 0 else 2.0 * intersection / total
def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))
def show_mask(mask, ax, obj_id=None, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        cmap = plt.get_cmap("tab10")
        cmap_idx = 0 if obj_id is None else obj_id
        color = np.array([*cmap(cmap_idx)[:3], 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)

def evaluate_medsam2_on_video(video_filename, predictor,
                               tracings_csv="/EchoNet-Dynamic/VolumeTracings.csv",
                               video_dir="/EchoNet-Dynamic/Videos",
                               frames_root="/EchoNet-Dynamic/dataset_frames_new",
                               masks_root="/EchoNet-Dynamic/yolomodmedsam2/backward_masks",
                               gt_mask_root="/EchoNet-Dynamic/manual_masks",
                               yolo_labels_dir="/runs/detect/predict/labels",
                               image_size=128):
    
    original_size = 112
    padded_size = image_size
    pad = (padded_size - original_size) // 2

    print(f"\n========== Processing Video: {video_filename} ==========")

    video_id = os.path.splitext(video_filename)[0]
    base_name = video_id
    j_name = video_id + '.avi'
    video_path = os.path.join(video_dir, video_filename)
    frames_dir = os.path.join(frames_root, j_name)
    masks_dir = os.path.join(masks_root, base_name)

    os.makedirs(frames_dir, exist_ok=True)
    os.makedirs(masks_dir, exist_ok=True)

    print("[INFO] Loading VolumeTracings.csv...")
    df = pd.read_csv(tracings_csv)
    df_video = df[df['FileName'] == video_filename]
    unique_frames = sorted(df_video['Frame'].unique())

    if len(unique_frames) < 2:
        print(f"[WARNING] Not enough annotated frames in {base_name}")
        return

    first_idx, second_idx = unique_frames[0], unique_frames[1]
    print(f"[INFO] First annotated frame: {first_idx}, Second annotated frame: {second_idx}")

    # =========== Unpad YOLO Prediction to Original Frame Size ============
    label_filename = f"{video_id}_frame{second_idx}.txt"
    label_path = os.path.join(yolo_labels_dir, label_filename)

    if not os.path.exists(label_path):
        print(f"[ERROR] YOLO label file not found: {label_path}")
        return

    with open(label_path, "r") as f:
        line = f.readline().strip()
        parts = line.split()
        if len(parts) != 5:
            print(f"[ERROR] Invalid YOLO label format in {label_path}")
            return
        _, x_center_norm, y_center_norm, width_norm, height_norm = map(float, parts)

    # Convert to 128x128 pixel coordinates
    x_center_128 = x_center_norm * padded_size
    y_center_128 = y_center_norm * padded_size
    width_128 = width_norm * padded_size
    height_128 = height_norm * padded_size

    # Shift center back by -pad to align with 112x112 original frame
    x_center_112 = x_center_128 - pad
    y_center_112 = y_center_128 - pad

    # Final box in original 112x112 space
    xmin = x_center_112 - width_128 / 2
    ymin = y_center_112 - height_128 / 2
    xmax = x_center_112 + width_128 / 2
    ymax = y_center_112 + height_128 / 2
    box = np.array([xmin+3, ymin+3, xmax-3, ymax-3], dtype=np.float32)

    print(f"[INFO] Unpadded YOLO bbox for MedSAM2: {box}")
    # =====================================================================

    # Step 1: Extract frames
    if len(os.listdir(frames_dir)) == 0:
        print(f"[INFO] Extracting frames from {video_path}...")
        cap = cv2.VideoCapture(video_path)
        idx = 0
        while True:
            ret, frame = cap.read()
            if not ret:
                break
            frame_path = os.path.join(frames_dir, f"{idx}.jpg")
            cv2.imwrite(frame_path, frame)
            idx += 1
        cap.release()
        print(f"[INFO] Extracted {idx} frames.")
    else:
        print(f"[INFO] Frames already present in {frames_dir}")

    # Step 2: Initialize and run MedSAM2
    print(f"[INFO] Initializing MedSAM2 predictor...")
    inference_state = predictor.init_state(video_path=frames_dir)
    predictor.reset_state(inference_state)

    print(f"[INFO] Running MedSAM2 at frame {second_idx}...")
    _, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
        inference_state=inference_state,
        frame_idx=second_idx,
        obj_id=1,
        box=box,
    )

    # Step 3: Propagate masks
    print("[INFO] Propagating masks backward...")
    steps = (second_idx - first_idx + 1) if second_idx >= first_idx else None  # None = let it run to start
    for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(
            inference_state,
            start_frame_idx=second_idx,                 # start where we seeded
            max_frame_num_to_track=steps,               # just enough to reach first_idx
            reverse=True):                              #go backward
        for i, out_obj_id in enumerate(out_obj_ids):
            mask = (out_mask_logits[i] > 0).cpu().numpy().squeeze().astype(np.uint8) * 255
            mask_path = os.path.join(masks_dir, f"frame_{out_frame_idx}_sam_mask.png")
            Image.fromarray(mask).save(mask_path)
            if out_frame_idx % 30 == 0:
                print(f"[DEBUG] Saved MedSAM2 mask at frame {out_frame_idx}")

    # Step 4: Generate ground truth mask(s) for this video
    print(f"[INFO] Generating ground truth mask(s) for {video_filename}...")
    save_manual_masks(video_filename)
        
    
    # load ground truth mask for the second annotated frame ----
    gt_mask_path = os.path.join(gt_mask_root, video_filename, f"frame_{first_idx:03d}_manual_mask.png")
    if not os.path.exists(gt_mask_path):
        print(f"[ERROR] Ground truth mask not found at {gt_mask_path}")
        return
    gt_mask = np.array(Image.open(gt_mask_path).convert('L'))
    print(f"[INFO] Loaded ground truth mask from {gt_mask_path}")

    # Step 5: Load MedSAM2 mask
    pred_mask_path = os.path.join(masks_dir, f"frame_{first_idx}_sam_mask.png")
    if not os.path.exists(pred_mask_path):
        print(f"[ERROR] MedSAM2 propagated mask not found at {pred_mask_path}")
        return

    pred_mask = np.array(Image.open(pred_mask_path).convert('L'))
    print(f"[INFO] Loaded MedSAM2 mask from {pred_mask_path}")

    # Step 6: Compute Dice Score
    dice = dice_score(pred_mask, gt_mask)
    print(f"[RESULT] Dice Score at frame {first_idx}: {dice:.4f}")

    # Step 7: Visualize
    frame_img = np.array(Image.open(os.path.join(frames_dir, f"{first_idx:03d}.jpg")))
    fig, axs = plt.subplots(1, 2, figsize=(12, 5))
    axs[0].imshow(frame_img)
    axs[0].imshow(gt_mask, alpha=0.4, cmap="Reds")
    axs[0].set_title("Ground Truth")
    axs[0].axis("off")

    axs[1].imshow(frame_img)
    axs[1].imshow(pred_mask, alpha=0.4, cmap="Greens")
    axs[1].set_title("SonoYolo + MedSAM2 Prediction")
    axs[1].axis("off")

    plt.suptitle(f"{video_filename}  |  Backward Dice @ frame {first_idx}: {dice:.4f}", fontsize=14)
    plt.tight_layout()
    plt.show()

    print(f"[DONE] Finished processing {video_filename}")
    plt.figure(figsize=(9, 6))
    plt.title(f"{video_filename} - Frame {frame_idx}")
    plt.imshow(frame_img)
    show_box(box, plt.gca())
    show_mask((out_mask_logits[0] > 0).cpu().numpy(), plt.gca(), obj_id=out_obj_ids[0])
    plt.axis('off')
    plt.show()
    return dice


In [None]:
#Test Split Batch Evaluation
from echonet.datasets import Echo
import pandas as pd
import numpy as np
import traceback
import time

# 1) Load EchoNet test split to get the .avi filenames
dataset = Echo(
    root="/shared_data/p_vidalr/iraj/EchoNet-Dynamic",
    split="test",
    target_type="Filename",   # we only need filenames here
    length=None,
    clips=1
)

test_fnames = dataset.fnames   # list of 'XXXXXXXXXXXX.avi'

print(f"[INFO] Found {len(test_fnames)} test videos.")

# 2) Optional: suppress the per-video plt.show() to avoid popping windows / slowing down
import contextlib, matplotlib.pyplot as plt
@contextlib.contextmanager
def suppress_plots():
    _orig_show = plt.show
    plt.show = lambda *a, **k: None
    try:
        yield
    finally:
        plt.show = _orig_show

# 3) Run evaluation for each test video
results = []
start = time.time()

with suppress_plots():   
    for i, fname in enumerate(test_fnames, 1):
        try:
            print(f"\n=== [{i}/{len(test_fnames)}] {fname} ===")
            dice = evaluate_medsam2_on_video(
                video_filename=fname,
                predictor=predictor,
           
                tracings_csv="/EchoNet-Dynamic/VolumeTracings.csv",
                video_dir="/EchoNet-Dynamic/Videos",
                frames_root="/EchoNet-Dynamic/dataset_frames_new",
                masks_root="/EchoNet-Dynamic/yolomodmedsam2/backward_masks",
                gt_mask_root="/EchoNet-Dynamic/manual_masks",
            )
            results.append({"filename": fname, "dice": float(dice) if dice is not None else np.nan})
        except Exception as e:
            print(f"[ERROR] {fname}: {e}")
            traceback.print_exc()
            results.append({"filename": fname, "dice": np.nan})

elapsed = time.time() - start
print(f"\n[SUMMARY] Finished {len(test_fnames)} videos in {elapsed/60:.1f} min.")

# 4) Save CSV + print mean Dice
df_res = pd.DataFrame(results).sort_values("filename")
csv_path = "/EchoNet-Dynamic/yolomodmedsam2_backward_masks1.csv"
df_res.to_csv(csv_path, index=False)
print(f"[SUMMARY] Results saved to: {csv_path}")
print(f"[SUMMARY] Mean Dice (ignoring NaN): {df_res['dice'].mean(skipna=True):.4f}")


# Framewise segmentation

In [None]:
import os
import cv2
import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt

def dice_score(mask1, mask2):
    mask1 = (mask1 > 0).astype(np.bool_)
    mask2 = (mask2 > 0).astype(np.bool_)
    intersection = np.sum(mask1 & mask2)
    total = np.sum(mask1) + np.sum(mask2)
    return 1.0 if total == 0 else 2.0 * intersection / total

def show_mask(mask, ax, obj_id=None, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        cmap = plt.get_cmap("tab10")
        cmap_idx = 0 if obj_id is None else obj_id
        color = np.array([*cmap(cmap_idx)[:3], 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)

def evaluate_medsam2_on_video(video_filename, predictor,
                              tracings_csv="/shared_data/p_vidalr/iraj/EchoNet-Dynamic/VolumeTracings.csv",
                              video_dir="/shared_data/p_vidalr/iraj/EchoNet-Dynamic/Videos",
                              frames_root="/shared_data/p_vidalr/iraj/EchoNet-Dynamic/dataset_frames_new",
                              masks_root="/shared_data/p_vidalr/iraj/EchoNet-Dynamic/yolomodmedsam2/framewise",
                              gt_mask_root="/shared_data/p_vidalr/iraj/EchoNet-Dynamic/manual_masks",
                              yolo_labels_dir="/home/iraj/runs/detect/predict4/labels",
                              image_size=128):
    """
    Same structure as your original function, but:
      - Uses YOLO bbox *per annotated frame* (no propagation).
      - Converts YOLO (128x128 padded) coords back to 112x112 frame coords.
      - Saves per-frame masks and computes Dice vs GT for the two annotated frames.
      - Returns a single numeric average Dice.
    """

    print(f"\n========== Processing Video (frame-wise): {video_filename} ==========")

    base_name  = os.path.splitext(video_filename)[0]
    video_path = os.path.join(video_dir, video_filename)
    frames_dir = os.path.join(frames_root, video_filename)   # keep your layout
    masks_dir  = os.path.join(masks_root, base_name)

    os.makedirs(frames_dir, exist_ok=True)
    os.makedirs(masks_dir,  exist_ok=True)

    # --- helpers (kept inside, like your original bbox helper) ---
    def _yolo_label_path(video_id, frame_idx):
        # Expect: {video_id}_frame{frame_idx}.txt  (e.g., 0XABC..._frame154.txt)
        return os.path.join(yolo_labels_dir, f"{video_id}_frame{frame_idx}.txt")

    def _yolo_box_to_112_coords(label_path, padded_size=128, original_size=112):
        """Read first bbox from YOLO txt and convert from padded_size->original_size coords."""
        pad = (padded_size - original_size) // 2
        with open(label_path, "r") as f:
            line = f.readline().strip()
            parts = line.split()
            if len(parts) < 5:
                raise ValueError(f"Invalid YOLO label format in {label_path}")
            # YOLO format: cls cx cy w h (normalized)
            _, x_center_norm, y_center_norm, width_norm, height_norm = map(float, parts[:5])

        x_center_128 = x_center_norm * padded_size
        y_center_128 = y_center_norm * padded_size
        width_128    = width_norm    * padded_size
        height_128   = height_norm   * padded_size

        # shift center back by -pad to align with 112x112 content
        x_center_112 = x_center_128 - pad
        y_center_112 = y_center_128 - pad

        xmin = x_center_112 - width_128  / 2.0
        ymin = y_center_112 - height_128 / 2.0
        xmax = x_center_112 + width_128  / 2.0
        ymax = y_center_112 + height_128 / 2.0

        # clip bounds
        xmin = max(0, min(xmin, original_size - 1))
        ymin = max(0, min(ymin, original_size - 1))
        xmax = max(0, min(xmax, original_size - 1))
        ymax = max(0, min(ymax, original_size - 1))

        return np.array([xmin+3, ymin+3, xmax-3, ymax-3], dtype=np.float32)

    # 1) Load annotated frames (two) from VolumeTracings.csv
    print("[INFO] Loading VolumeTracings.csv...")
    df = pd.read_csv(tracings_csv)
    df_video = df[df['FileName'] == video_filename]
    unique_frames = sorted(df_video['Frame'].unique())

    if len(unique_frames) < 2:
        print(f"[WARNING] Not enough annotated frames in {base_name}")
        return np.nan  # keep return type numeric

    first_idx, second_idx = int(unique_frames[0]), int(unique_frames[1])
    print(f"[INFO] Annotated frames: [{first_idx}, {second_idx}]")

    # 2) Extract frames if needed
    if len([f for f in os.listdir(frames_dir) if f.endswith(".jpg")]) == 0:
        print(f"[INFO] Extracting frames from {video_path}...")
        cap = cv2.VideoCapture(video_path)
        idx = 0
        while True:
            ret, frame = cap.read()
            if not ret:
                break
            cv2.imwrite(os.path.join(frames_dir, f"{idx:03d}.jpg"), frame)
            idx += 1
        cap.release()
        print(f"[INFO] Extracted {idx} frames.")
    else:
        print(f"[INFO] Frames already present in {frames_dir}")

    # 3) Ensure GT masks exist (uses your helper if present)
    print(f"[INFO] Ensuring ground-truth masks exist for {video_filename}...")
    try:
        save_manual_masks(video_filename)  # your existing helper
    except NameError:
        print("[WARN] save_manual_masks() not found; assuming GT masks already exist.")

    # Single-frame inference (NO propagation/history), using YOLO bbox for THIS frame
    def run_single_frame_with_yolo(frame_idx, obj_id=1):
        label_path = _yolo_label_path(base_name, frame_idx)
        if not os.path.exists(label_path):
            print(f"[ERROR] YOLO label not found for frame {frame_idx}: {label_path}")
            return None, np.nan

        try:
            box = _yolo_box_to_112_coords(label_path, padded_size=image_size, original_size=112)
        except Exception as e:
            print(f"[ERROR] Failed to parse YOLO bbox for frame {frame_idx}: {e}")
            return None, np.nan

        print(f"[INFO] (Frame {frame_idx}) YOLO BBox for MedSAM2 (112x112): {box}")

        # Initialize/reset predictor state so nothing propagates
        inference_state = predictor.init_state(video_path=frames_dir)
        predictor.reset_state(inference_state)

        _, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
            inference_state=inference_state,
            frame_idx=int(frame_idx),
            obj_id=obj_id,
            box=box,
        )

        # pick logits for obj_id if multiple
        if hasattr(out_mask_logits, "__len__") and len(out_mask_logits) > 1:
            try:
                idx = list(out_obj_ids).index(obj_id)
                logits = out_mask_logits[idx]
            except Exception:
                logits = out_mask_logits[0]
        else:
            logits = out_mask_logits[0] if hasattr(out_mask_logits, "__getitem__") else out_mask_logits

        mask = (logits > 0).detach().cpu().numpy().squeeze().astype(np.uint8) * 255
        pred_mask_path = os.path.join(masks_dir, f"frame_{frame_idx}_sam_mask.png")
        Image.fromarray(mask).save(pred_mask_path)
        print(f"[SAVED] MedSAM2 mask → {pred_mask_path}")
        return pred_mask_path, mask

    # 4) Run on both annotated frames (independently)
    pred1_path, pm1 = run_single_frame_with_yolo(first_idx,  obj_id=1)
    pred2_path, pm2 = run_single_frame_with_yolo(second_idx, obj_id=1)

    # 5) Load GT and compute Dice per frame
    gt1_path = os.path.join(gt_mask_root, video_filename, f"frame_{first_idx:03d}_manual_mask.png")
    gt2_path = os.path.join(gt_mask_root, video_filename, f"frame_{second_idx:03d}_manual_mask.png")

    if not (os.path.exists(gt1_path) and os.path.exists(gt2_path)):
        print(f"[ERROR] Missing GT mask(s). Expected:\n  {gt1_path}\n  {gt2_path}")
        return np.nan

    gt1 = np.array(Image.open(gt1_path).convert('L'))
    gt2 = np.array(Image.open(gt2_path).convert('L'))

    # Handle any failure that returned NaN
    if pm1 is None or pm2 is None:
        dice_vals = []
        if pm1 is not None:
            dice_vals.append(dice_score(pm1, gt1))
        if pm2 is not None:
            dice_vals.append(dice_score(pm2, gt2))
        avg_dice = float(np.mean(dice_vals)) if len(dice_vals) else np.nan
        print(f"[RESULT] Partial/NaN dice due to missing predictions. Avg: {avg_dice}")
        return avg_dice

    dice1 = dice_score(pm1, gt1)
    dice2 = dice_score(pm2, gt2)

    print(f"[RESULT] Dice (frame {first_idx:03d}):  {dice1:.4f}")
    print(f"[RESULT] Dice (frame {second_idx:03d}): {dice2:.4f}")

    avg_dice = float((dice1 + dice2) / 2.0)
    print(f"[RESULT] Average Dice (two annotated frames): {avg_dice:.4f}")

    # 6) Visualization blocks (same structure as your original)
    try:
        frame_img1 = np.array(Image.open(os.path.join(frames_dir, f"{first_idx:03d}.jpg")))
        fig1, axs1 = plt.subplots(1, 2, figsize=(12, 5))
        axs1[0].imshow(frame_img1); axs1[0].imshow(gt1, alpha=0.5, cmap="Reds")
        axs1[0].set_title(f"GT (frame {first_idx:03d})"); axs1[0].axis("off")
        axs1[1].imshow(frame_img1); axs1[1].imshow(pm1, alpha=0.5, cmap="Greens")
        axs1[1].set_title("MedSAM2 Prediction"); axs1[1].axis("off")
        plt.suptitle(f"{video_filename} | Dice (frame {first_idx:03d}) = {dice1:.4f}", fontsize=14)
        plt.tight_layout(); plt.show()
    except Exception as e:
        print(f"[WARN] Visualization skipped for frame {first_idx}: {e}")

    try:
        frame_img2 = np.array(Image.open(os.path.join(frames_dir, f"{second_idx:03d}.jpg")))
        fig2, axs2 = plt.subplots(1, 2, figsize=(12, 5))
        axs2[0].imshow(frame_img2); axs2[0].imshow(gt2, alpha=0.5, cmap="Reds")
        axs2[0].set_title(f"GT (frame {second_idx:03d})"); axs2[0].axis("off")
        axs2[1].imshow(frame_img2); axs2[1].imshow(pm2, alpha=0.5, cmap="Greens")
        axs2[1].set_title("MedSAM2 Prediction"); axs2[1].axis("off")
        plt.suptitle(f"{video_filename} | Dice {first_idx:03d}: {dice1:.4f} | {second_idx:03d}: {dice2:.4f} | Avg: {avg_dice:.4f}", fontsize=14)
        plt.tight_layout(); plt.show()
    except Exception as e:
        print(f"[WARN] Visualization skipped: {e}")

    print(f"[DONE] Finished (frame-wise) {video_filename}")
    return avg_dice


In [None]:
# Test Split Batch Evaluation

from echonet.datasets import Echo
import pandas as pd
import numpy as np
import traceback
import time


# 1) Load EchoNet test split to get the .avi filenames
dataset = Echo(
    root="/shared_data/p_vidalr/iraj/EchoNet-Dynamic",
    split="test",
    target_type="Filename",   # we only need filenames here
    length=None,
    clips=1
)

test_fnames = dataset.fnames   # list of 'XXXXXXXXXXXX.avi'

print(f"[INFO] Found {len(test_fnames)} test videos.")

# 2) Optional: suppress the per-video plt.show() to avoid popping windows / slowing down
import contextlib, matplotlib.pyplot as plt
@contextlib.contextmanager
def suppress_plots():
    _orig_show = plt.show
    plt.show = lambda *a, **k: None
    try:
        yield
    finally:
        plt.show = _orig_show

# 3) Run evaluation for each test video
results = []
start = time.time()

with suppress_plots():   # remove this context manager if you DO want the plots
    for i, fname in enumerate(test_fnames, 1):
        try:
            print(f"\n=== [{i}/{len(test_fnames)}] {fname} ===")
            dice = evaluate_medsam2_on_video(
                video_filename=fname,
                predictor=predictor,
                # keep your defaults, or override here if needed:
                tracings_csv="/EchoNet-Dynamic/VolumeTracings.csv",
                video_dir="/EchoNet-Dynamic/Videos",
                frames_root="/EchoNet-Dynamic/dataset_frames_new",
                masks_root="/EchoNet-Dynamic/yolomodmedsam2/framewise",
                gt_mask_root="/EchoNet-Dynamic/manual_masks",
            )
            results.append({"filename": fname, "dice": float(dice) if dice is not None else np.nan})
        except Exception as e:
            print(f"[ERROR] {fname}: {e}")
            traceback.print_exc()
            results.append({"filename": fname, "dice": np.nan})

elapsed = time.time() - start
print(f"\n[SUMMARY] Finished {len(test_fnames)} videos in {elapsed/60:.1f} min.")

# 4) Save CSV + print mean Dice
df_res = pd.DataFrame(results).sort_values("filename")
csv_path = "/EchoNet-Dynamic/yolomodmedsam2_framewise.csv"
df_res.to_csv(csv_path, index=False)
print(f"[SUMMARY] Results saved to: {csv_path}")
print(f"[SUMMARY] Mean Dice (ignoring NaN): {df_res['dice'].mean(skipna=True):.4f}")
