# Mitral Valve (MV) segmentation in echocardiography videos using DINOv3

This notebook adapts the DINOv3 segmentation-tracking propagation approach to the MV segmentation task you described.

What it does:
- Loads training videos (train.pkl) and test videos (test.pkl / test list).
- Uses the 3 annotated frames per training video as context to propagate dense masks to every frame in the same video using DINOv3 patch features and a non-parametric propagation (top-k softmax weighting inside a local neighborhood).
- Demonstrates how to apply the bounding box, convert grayscale to RGB, handle resizing to the DINO patch grid, upsample predictions, map back to original resolution and save masks or pseudo-labels.
- (Optional) Provides a skeleton to train a segmentation model (UNet) on the pseudo-labels.

Run sections sequentially. Replace paths to train.pkl/test.pkl with your dataset paths when needed.

In [None]:
# Basic imports and configuration
import os
import math
import pickle
import shutil
from pathlib import Path
import numpy as np
from PIL import Image
import torch
import torch.nn.functional as F
from torch import nn
import torchvision.transforms as TVT
import torchvision.transforms.functional as TVTF
import matplotlib.pyplot as plt
import time

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Device:', DEVICE)

## Settings: models & hyperparameters
You can change these to fit your GPU and data. Start with smaller SHORT_SIDE if memory is tight.

In [None]:
# Configurable parameters
DINOV3_LOCATION = os.getenv('DINOV3_LOCATION', 'facebookresearch/dinov3')  # or local repo path
MODEL_NAME = 'dinov3_vitl16'  # try 'dinov3_vits16' if GPU constrained
SHORT_SIDE = 640  # try 640/800/960
TOPK = 10
TEMPERATURE = 0.2
NEIGHBORHOOD_PATCH_RADIUS = 16  # measured in patches; controls search radius
USE_BBOX_CROP = True  # apply provided bbox to reduce area
ASSUME_BINARY = True  # assume labels are background / mitral valve (0/1)
SAVE_PSEUDOLABELS_DIR = 'pseudo_labels'
os.makedirs(SAVE_PSEUDOLABELS_DIR, exist_ok=True)
print('Config: SHORT_SIDE=', SHORT_SIDE, 'TOPK=', TOPK, 'NEIGHBORHOOD_RADIUS=', NEIGHBORHOOD_PATCH_RADIUS)

## Load DINOv3 model
This uses torch.hub. If you have a local repo checkout set DINOV3_LOCATION env var or change the variable above.

In [None]:
!pip install torchmetrics

In [None]:
!wget https://dinov3.llamameta.net/dinov3_vitl16/dinov3_vitl16_pretrain_lvd1689m-8aa4cbdd.pth?Policy=eyJTdGF0ZW1lbnQiOlt7InVuaXF1ZV9oYXNoIjoia3Blb3kxMGhocDJ0NjduaHIxMm8wdXg5IiwiUmVzb3VyY2UiOiJodHRwczpcL1wvZGlub3YzLmxsYW1hbWV0YS5uZXRcLyoiLCJDb25kaXRpb24iOnsiRGF0ZUxlc3NUaGFuIjp7IkFXUzpFcG9jaFRpbWUiOjE3NjUwMTkzNDJ9fX1dfQ__&Signature=Jp9sdIhUsBzvm8MgUhYgTBE9T0Uo%7Ex21c3ZMOZzO8lWcD1NxNrG7%7Excg2tCU-O-jOyTyFoKetX0XksD3%7EKGwAS5bMRYCQYx-ifp7ahUttS0zWa3gY2UbyEAP6NZIHvQilgYi2ZRe4ypNIVCcFlkuNsEdSHCZxsdxYCIf8gyTvRmtWYG9w9dEHYNTQQXv8ybjTfMsxnItTnQXHW47Q0uCO177JfyjUUlUKPTORvFwtr8rKjUZODWElLC-31kMuGjvjyKZD2bn%7E0Okm1q7-kSYqnItGnqJB0uG5WrCSZTfgfvfJDyU9EGkrZXAZ99qEOmFhzr8PPXXn0Dz1QSpiehyug__&Key-Pair-Id=K15QRJLYKIFSLZ&Download-Request-ID=1498136187968196

In [None]:
from google.colab import drive
drive.mount('/content/drive')


In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

In [None]:
#!git clone https://github.com/facebookresearch/dinov3.git
REPO_DIR = "/content/dinov3"

In [None]:
!wget -c "https://dinov3.llamameta.net/dinov3_vitb16/dinov3_vitb16_pretrain_lvd1689m-73cec8be.pth?Policy=eyJTdGF0ZW1lbnQiOlt7InVuaXF1ZV9oYXNoIjoia3Blb3kxMGhocDJ0NjduaHIxMm8wdXg5IiwiUmVzb3VyY2UiOiJodHRwczpcL1wvZGlub3YzLmxsYW1hbWV0YS5uZXRcLyoiLCJDb25kaXRpb24iOnsiRGF0ZUxlc3NUaGFuIjp7IkFXUzpFcG9jaFRpbWUiOjE3NjUwMTkzNDJ9fX1dfQ__&Signature=Jp9sdIhUsBzvm8MgUhYgTBE9T0Uo%7Ex21c3ZMOZzO8lWcD1NxNrG7%7Excg2tCU-O-jOyTyFoKetX0XksD3%7EKGwAS5bMRYCQYx-ifp7ahUttS0zWa3gY2UbyEAP6NZIHvQilgYi2ZRe4ypNIVCcFlkuNsEdSHCZxsdxYCIf8gyTvRmtWYG9w9dEHYNTQQXv8ybjTfMsxnItTnQXHW47Q0uCO177JfyjUUlUKPTORvFwtr8rKjUZODWElLC-31kMuGjvjyKZD2bn%7E0Okm1q7-kSYqnItGnqJB0uG5WrCSZTfgfvfJDyU9EGkrZXAZ99qEOmFhzr8PPXXn0Dz1QSpiehyug__&Key-Pair-Id=K15QRJLYKIFSLZ&Download-Request-ID=1498136187968196" \
      -O "/content/dinov3_vitb16_pretrain_lvd1689m-73cec8be.pth"

In [None]:
!wget -c "https://dinov3.llamameta.net/dinov3_vitl16/dinov3_vitl16_pretrain_lvd1689m-8aa4cbdd.pth?Policy=eyJTdGF0ZW1lbnQiOlt7InVuaXF1ZV9oYXNoIjoia3Blb3kxMGhocDJ0NjduaHIxMm8wdXg5IiwiUmVzb3VyY2UiOiJodHRwczpcL1wvZGlub3YzLmxsYW1hbWV0YS5uZXRcLyoiLCJDb25kaXRpb24iOnsiRGF0ZUxlc3NUaGFuIjp7IkFXUzpFcG9jaFRpbWUiOjE3NjUwMTkzNDJ9fX1dfQ__&Signature=Jp9sdIhUsBzvm8MgUhYgTBE9T0Uo%7Ex21c3ZMOZzO8lWcD1NxNrG7%7Excg2tCU-O-jOyTyFoKetX0XksD3%7EKGwAS5bMRYCQYx-ifp7ahUttS0zWa3gY2UbyEAP6NZIHvQilgYi2ZRe4ypNIVCcFlkuNsEdSHCZxsdxYCIf8gyTvRmtWYG9w9dEHYNTQQXv8ybjTfMsxnItTnQXHW47Q0uCO177JfyjUUlUKPTORvFwtr8rKjUZODWElLC-31kMuGjvjyKZD2bn%7E0Okm1q7-kSYqnItGnqJB0uG5WrCSZTfgfvfJDyU9EGkrZXAZ99qEOmFhzr8PPXXn0Dz1QSpiehyug__&Key-Pair-Id=K15QRJLYKIFSLZ&Download-Request-ID=1498136187968196" \
      -O "/content/dinov3_vitl16_pretrain_lvd1689m-8aa4cbdd.pth"


In [None]:
!ls -lh /content/dinov3_vitl16_pretrain_lvd1689m-8aa4cbdd.pth


In [None]:
!head -c 200 /content/dinov3_vitl16_pretrain_lvd1689m-8aa4cbdd.pth


In [None]:
!unzip -t /content/dinov3_vitl16_pretrain_lvd1689m-*.pth

In [None]:
WEIGHTS_PATH = "/content/drive/MyDrive/dinov3_vitl16_pretrain_lvd1689m-8aa4cbdd.pth"

In [None]:
!git clone https://github.com/facebookresearch/dinov3.git

In [None]:
import time

print('Loading DINOv3 model (this may download weights)...')
try:
    # Correctly load the model from the local repository and local weights
    model = torch.hub.load(
        REPO_DIR, # Path to the local dinov3 repository
        "dinov3_vitl16", # The model entry point (e.g., 'dinov3_vitl16')
        source='local', # Indicate that the source is a local repository
        weights=WEIGHTS_PATH # Pass the path to your locally downloaded weights using 'weights' (plural)
    )
except Exception as e:
    if "HTTP Error 403: Forbidden" in str(e) or "GatedRepoError" in str(e):
        print("Error: Failed to download DINOv3 model weights due to a 403 Forbidden error or Gated Repo access.")
        print("This usually means the server denied access to the download link or the model is gated.")
        print("Please ensure you have access to the model, or consider manually downloading the model weights.")
        print(f"The expected download URL was: https://dl.fbaipublicfiles.com/dinov3/{MODEL_NAME}/{MODEL_NAME}_pretrain_lvd1689m-8aa4cbdd.pth")
    else:
        print(f"An unexpected error occurred during model loading: {e}")
    raise # Re-raise the exception to stop execution if model loading is critical

model.to(DEVICE).eval()
patch_size = model.patch_size
embed_dim = model.embed_dim
print('Model:', MODEL_NAME, 'patch_size=', patch_size, 'embed_dim=', embed_dim)


## Helper transforms and small utilities
ResizeToMultiple ensures both sides are multiples of patch_size and the short side is around SHORT_SIDE.

In [None]:
class ResizeToMultiple(nn.Module):
    def __init__(self, short_side: int, multiple: int):
        super().__init__()
        self.short_side = short_side
        self.multiple = multiple

    def _round_up(self, side: float) -> int:
        return math.ceil(side / self.multiple) * self.multiple

    def forward(self, img):
        old_width, old_height = TVTF.get_image_size(img)
        if old_width > old_height:
            new_height = self._round_up(self.short_side)
            new_width = self._round_up(old_width * new_height / old_height)
        else:
            new_width = self._round_up(self.short_side)
            new_height = self._round_up(old_height * new_width / old_width)
        return TVTF.resize(img, [new_height, new_width], interpolation=TVT.InterpolationMode.BICUBIC)

transform = TVT.Compose([
    ResizeToMultiple(short_side=SHORT_SIDE, multiple=patch_size),
    TVT.ToTensor(),
    TVT.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

@torch.no_grad()
def forward_feats(img_tensor: torch.Tensor) -> torch.Tensor:
    """img_tensor: [3, H, W] normalized, on DEVICE -> returns [h, w, D]"""
    feats = model.get_intermediate_layers(img_tensor.unsqueeze(0), n=1, reshape=True)[0]  # [1, D, h, w]
    feats = feats.movedim(-3, -1)  # [1, h, w, D]
    feats = F.normalize(feats, dim=-1, p=2)
    return feats.squeeze(0)

def make_neighborhood_mask(h: int, w: int, size: float, shape: str = 'circle') -> torch.Tensor:
    ij = torch.stack(
        torch.meshgrid(
            torch.arange(h, dtype=torch.float32, device=DEVICE),
            torch.arange(w, dtype=torch.float32, device=DEVICE),
            indexing='ij',
        ),
        dim=-1,
    )
    if shape == 'circle':
        ord = 2
    elif shape == 'square':
        ord = torch.inf
    else:
        raise ValueError(f'Invalid shape={shape}')
    norm = torch.linalg.vector_norm(
        ij[:, :, None, None, :] - ij[None, None, :, :, :],  # [h", w", h, w, 2]
        ord=ord,
        dim=-1,
    )
    mask = norm <= size
    return mask

@torch.no_grad()
def propagate(current_features: torch.Tensor,
              context_features: torch.Tensor,
              context_probs: torch.Tensor,
              neighborhood_mask: torch.Tensor,
              topk: int,
              temperature: float) -> torch.Tensor:
    """Propagate context_probs -> returns [h2, w2, M]"""
    t, h, w, M = context_probs.shape
    # similarity
    dot = torch.einsum('ijd, tuvd -> ijtuv', current_features, context_features)  # [h2, w2, t, h, w]
    dot = torch.where(neighborhood_mask[:, :, None, :, :], dot, -torch.inf)
    dotflat = dot.flatten(2, -1).flatten(0, 1)  # [h2*w2, t*h*w]
    # safe topk: if t*h*w < topk reduce
    k = min(topk, dotflat.shape[1])
    if k <= 0:
        raise RuntimeError('Empty context for propagation')
    kth = torch.topk(dotflat, k=k, dim=1).values[:, -1:]
    dotflat = torch.where(dotflat >= kth, dotflat, -torch.inf)
    weights = torch.softmax(dotflat / temperature, dim=1)
    context_flat = context_probs.flatten(0, 2)  # [t*h*w, M]
    pred = weights @ context_flat
    pred = pred / (pred.sum(dim=1, keepdim=True) + 1e-12)
    return pred.unflatten(0, (current_features.shape[0], current_features.shape[1]))  # [h2, w2, M]


## Load dataset (train.pkl) and inspect one sample
Adjust path to your train.pkl file. The notebook assumes train.pkl is a list of dicts as described.

In [None]:
# Path to data - change as needed
import gzip # Added import for gzip
TRAIN_PKL = '/content/drive/MyDrive/train.pkl'
if not os.path.exists(TRAIN_PKL):
    print('train.pkl not found in notebook directory. Please place train.pkl or change TRAIN_PKL path.')
else:
    try:
        with gzip.open(TRAIN_PKL, 'rb') as f: # Changed to gzip.open
            train_data = pickle.load(f)
        print('Loaded train.pkl, number of videos:', len(train_data))
        # show structure of first entry
        sample = train_data[0]
        print('Keys in sample:', list(sample.keys()))
        print('Video shape (H,W,T):', sample['video'].shape)
        print('Frames with labels:', sample.get('frames'))
        print("Dataset source:", sample.get('dataset'))
    except Exception as e:
        print(f"Error loading train.pkl: {e}")
        print("It might not be a gzipped pickle file, or the file is corrupted.")


## Core processing: create dense masks for a single video
Function process_video_dict runs the pipeline for a single video dictionary from train.pkl.
- Applies bbox crop (if present and enabled)
- Converts grayscale->RGB by stacking
- Computes features for all frames once
- Uses the 3 annotated frames as context (one-hot) and propagates to all frames
- Upsamples to cropped resolution and reinserts into full frame
- Returns list of masks (H, W) per frame

In [None]:
def process_video_dict(vd: dict,
                       topk=TOPK,
                       temp=TEMPERATURE,
                       radius=NEIGHBORHOOD_PATCH_RADIUS,
                       use_bbox=USE_BBOX_CROP):
    """Process one video dict from train.pkl; returns list of HxW masks (uint8) and confidence maps.
    vd keys: 'name', 'video' (H,W,T), 'box' (H,W) bool, 'label' (H,W,T) or labels and 'frames'
    """
    video = vd['video']  # shape H,W,T, dtype uint8
    H, W, T = video.shape
    # bbox crop
    if use_bbox and ('box' in vd) and (vd['box'] is not None):
        box = vd['box']
        ys, xs = np.where(box)
        y0, y1 = int(ys.min()), int(ys.max()) + 1
        x0, x1 = int(xs.min()), int(xs.max()) + 1
    else:
        y0, y1, x0, x1 = 0, H, 0, W
    cropped = video[y0:y1, x0:x1, :]
    # build PIL frames RGB from grayscale
    pil_frames = []
    for t in range(T):
        im = Image.fromarray(cropped[..., t]).convert('L')
        arr = np.array(im)
        rgb = np.stack([arr, arr, arr], axis=-1).astype(np.uint8)
        pil_frames.append(Image.fromarray(rgb))
    # compute features for all frames (may be slow)
    feats_list = []
    for t in range(T):
        inp = transform(pil_frames[t]).to(DEVICE)
        feats = forward_feats(inp)  # [h, w, D]
        feats_list.append(feats)
    feats = torch.stack(feats_list, dim=0)  # [T, h, w, D]
    h, w = feats.shape[1], feats.shape[2]
    # build context from annotated frames
    frames_annot = vd.get('frames', [])
    if len(frames_annot) == 0:
        raise RuntimeError('No annotated frames found in video dict')
    context_idx = frames_annot
    context_features = feats[context_idx]  # [t, h, w, D]
    # build context_probs from label masks
    labels = vd.get('label', None)
    if labels is None:
        raise RuntimeError('No label key in training datum')
    # labels expected shape: H, W, maybe T or list: handle both
    context_probs_list = []
    M = 2 if ASSUME_BINARY else int(labels.max() + 1)
    for idx in context_idx:
        lab_full = labels[..., idx]
        lab_crop = lab_full[y0:y1, x0:x1].astype(np.int64)
        lab_t = torch.from_numpy(lab_crop).to(DEVICE)
        lab_grid = F.interpolate(lab_t[None, None].float(), (h, w), mode='nearest-exact')[0, 0].long()
        onehot = F.one_hot(lab_grid, num_classes=M).float()  # [h, w, M]
        context_probs_list.append(onehot)
    context_probs = torch.stack(context_probs_list, dim=0)  # [t, h, w, M]
    # neighborhood mask (same grid dims)
    neighborhood_mask = make_neighborhood_mask(h, w, radius, shape='circle')
    # propagate to all frames
    preds_per_frame = []
    confidences = []
    for t in range(T):
        cur_feats = feats[t]
        pred_grid = propagate(cur_feats, context_features, context_probs, neighborhood_mask, topk, temp)
        # pred_grid [h, w, M]
        # record confidence (max prob per patch)
        conf_patch = pred_grid.max(dim=-1).values.cpu().numpy()  # [h,w]
        conf = torch.from_numpy(conf_patch)[None, None].float()  # [1,1,h,w]
        conf_up = F.interpolate(conf, size=(y1-y0, x1-x0), mode='bilinear', align_corners=False)[0,0].numpy()
        confidences.append(conf_up)
        # upsample to crop resolution
        prob = pred_grid.permute(2, 0, 1).unsqueeze(0)  # [1, M, h, w]
        up = F.interpolate(prob, size=(y1 - y0, x1 - x0), mode='bilinear', align_corners=False)[0]  # [M, Hcrop, Wcrop]
        mask_crop = up.argmax(0).cpu().numpy().astype(np.uint8)
        # reinsert into full frame
        full_mask = np.zeros((H, W), dtype=np.uint8)
        full_mask[y0:y1, x0:x1] = mask_crop
        preds_per_frame.append(full_mask)
    # return per-frame masks and confidences (per-patch upsampled confidence can also be saved)
    return preds_per_frame, confidences


In [None]:
import torch
import torch.nn.functional as F
import numpy as np
from PIL import Image

@torch.no_grad()
def dino_mask_feature_for_frame(vd, t, mask_full,
                                use_bbox=USE_BBOX_CROP) -> np.ndarray:
    """
    vd: one entry from train_data (with 'video' and optional 'box')
    t: frame index
    mask_full: [H, W] uint8 or bool mask (full-frame prediction)
    returns: 1D numpy vector of size embed_dim (mean DINO feature over mask)
    """
    video = vd["video"]  # [H, W, T], uint8
    H, W, T = video.shape
    assert 0 <= t < T

    # --- same crop as in process_video_dict ---
    if use_bbox and ("box" in vd) and (vd["box"] is not None):
        box = vd["box"]  # [H, W] bool
        ys, xs = np.where(box)
        y0, y1 = int(ys.min()), int(ys.max()) + 1
        x0, x1 = int(xs.min()), int(xs.max()) + 1
    else:
        y0, y1, x0, x1 = 0, H, 0, W

    frame_crop = video[y0:y1, x0:x1, t]       # [Hc, Wc]
    mask_crop  = mask_full[y0:y1, x0:x1]      # [Hc, Wc]

    # --- build RGB PIL (like process_video_dict) ---
    im = Image.fromarray(frame_crop).convert("L")
    arr = np.array(im)
    rgb = np.stack([arr, arr, arr], axis=-1).astype(np.uint8)
    pil = Image.fromarray(rgb)

    # --- DINO features: [h, w, D] ---
    inp = transform(pil).to(DEVICE)          # [3, H', W']
    feats = forward_feats(inp)               # [h, w, D]
    h, w, D = feats.shape

    # --- map mask to patch grid ---
    mask_t = torch.from_numpy(mask_crop.astype(np.float32))[None, None]  # [1,1,Hc,Wc]
    mask_small = F.interpolate(mask_t, size=(h, w), mode="nearest")[0, 0]  # [h, w]
    mask_bool = mask_small > 0.5

    # if mask is empty after downsampling, fall back to global average
    if mask_bool.sum() == 0:
        vec = feats.mean(dim=(0, 1))
    else:
        vec = feats[mask_bool].mean(dim=0)    # [D]

    return vec.cpu().numpy()


## Demo on a single training video (if train.pkl present)
This cell runs the pipeline for the first training video and visualizes results for a few frames.

In [None]:
import time
import numpy as np
from PIL import Image
from IPython.display import Image as IPyImage, display

def make_overlay_gif(vd, preds, gif_path, fps=20, alpha=0.4):
    """
    vd: dict from train_data (with 'video' key: H x W x T)
    preds: list/array of masks for each frame (len = T)
    gif_path: where to save the GIF
    """
    video = vd['video']  # H, W, T
    H, W, T = video.shape

    frames = []
    for t in range(T):
        img = video[..., t]          # H x W
        mask = preds[t]              # H x W

        # ensure uint8 grayscale
        img = img.astype(np.uint8)

        # binary mask
        mask_bin = (mask > 0).astype(np.uint8)

        # grayscale -> RGB
        img_rgb = np.stack([img, img, img], axis=-1).astype(np.uint8)

        # red overlay where mask == 1
        overlay = img_rgb.copy()
        overlay[mask_bin == 1] = [255, 0, 0]  # RGB red

        out_rgb = (img_rgb * (1 - alpha) + overlay * alpha).astype(np.uint8)
        frames.append(Image.fromarray(out_rgb))

    # save GIF
    frames[0].save(
        gif_path,
        save_all=True,
        append_images=frames[1:],
        duration=int(1000 / fps),  # ms per frame
        loop=0,
    )

    return gif_path

In [None]:
if 'train_data' in globals():
    n_samples = min(5, len(train_data))
    print(f"Processing {n_samples} videos from train_data")

    for idx in range(n_samples):
        vd = train_data[idx]
        name = vd.get('name', f'sample_{idx}')

        print(f"\n[{idx+1}/{n_samples}] Processing video '{name}'")
        start = time.time()
        preds, confs = process_video_dict(vd)
        print('  Done in', time.time() - start, 's, produced', len(preds), 'frames')

        gif_path = f"{name}_overlay.gif"
        make_overlay_gif(vd, preds, gif_path, fps=20, alpha=0.4)
        print("  Saved GIF to", gif_path)

        # show the GIF inline
        display(IPyImage(filename=gif_path))

else:
    print('No train.pkl loaded; skip demo.')

## Save pseudo-labels for all training videos
We save per-video per-frame masks as compressed numpy .npz files. You can then use these to train a segmentation model.

In [None]:
if 'train_data' in globals():
    out_dir = Path(SAVE_PSEUDOLABELS_DIR)
    out_dir.mkdir(exist_ok=True)
    for i, vd in enumerate(train_data):
        name = vd.get('name', f'video_{i}')
        print(f'[{i+1}/{len(train_data)}] processing', name)
        try:
            preds, confs = process_video_dict(vd)
            # save masks as uint8 per-frame stack
            arr = np.stack(preds, axis=-1).astype(np.uint8)  # H,W,T
            # confidences as float32 stack
            conf_arr = np.stack(confs, axis=-1).astype(np.float32)
            np.savez_compressed(out_dir / f'{name}.npz', masks=arr, conf=conf_arr)
        except Exception as e:
            print('Error on', name, e)
    print('Saved pseudo-labels to', out_dir)
else:
    print('train.pkl not loaded. Skipping saving pseudo-labels.')


In [None]:
!cp -r /content/pseudo_labels /content/drive/MyDrive/

In [None]:
import os
import numpy as np
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.mixture import GaussianMixture

pseudo_dir = SAVE_PSEUDOLABELS_DIR  # e.g. "data/pseudo_labels"
all_feats = []

# map name -> vd so we can recover the original video for each npz file
name_to_vd = {}
for vd in train_data:
    name = vd.get("name", None)
    if name is not None:
        name_to_vd[name] = vd

for fname in os.listdir(pseudo_dir):
    if not fname.endswith(".npz"):
        continue

    name = fname.split(".")[0]
    if name not in name_to_vd:
        print(f"Warning: no vd for {name}, skipping.")
        continue

    vd = name_to_vd[name]
    npz = np.load(os.path.join(pseudo_dir, fname))
    masks = npz["masks"]  # [H, W, T]
    confs = npz["conf"]   # [H, W, T]
    H, W, T = masks.shape

    for t in range(T):
        # optional: only use reasonably confident frames for training the GMM
        if confs[..., t].mean() < 0.2:
            continue

        vec = dino_mask_feature_for_frame(vd, t, masks[..., t])
        all_feats.append(vec)

all_feats = np.stack(all_feats, axis=0)  # [N_frames, D]
print("DINO mask features shape:", all_feats.shape)


In [None]:
# --- standardize ---
scaler = StandardScaler()
X_scaled = scaler.fit_transform(all_feats)

# --- PCA to reduce dimension (tune n_components) ---
pca = PCA(n_components=32, random_state=0)
X_pca = pca.fit_transform(X_scaled)

# --- GMM on PCA space ---
gmm = GaussianMixture(
    n_components=3,          # tune (2–5 usually fine)
    covariance_type="full",
    random_state=0,
)
gmm.fit(X_pca)

# log-likelihood for training points
logp = gmm.score_samples(X_pca)
print("logp range:", logp.min(), logp.max())

# choose an anomaly threshold, e.g. keep 95% most likely
threshold = np.percentile(logp, 5)
print("anomaly threshold:", threshold)


In [None]:
@torch.no_grad()
def dino_gmm_anomaly_score(vd, t, mask_full):
    vec = dino_mask_feature_for_frame(vd, t, mask_full)
    Xs = scaler.transform(vec[None, :])
    Xp = pca.transform(Xs)
    logp = gmm.score_samples(Xp)[0]
    return logp

def is_outlier(vd, t, mask_full):
    logp = dino_gmm_anomaly_score(vd, t, mask_full)
    return logp < threshold


In [None]:
self.items = []
for vd in train_data:
    name = vd.get("name", None)
    npz_path = os.path.join(SAVE_PSEUDOLABELS_DIR, f"{name}.npz")
    if not os.path.exists(npz_path):
        continue
    data = np.load(npz_path)
    masks = data["masks"]  # [H, W, T]
    confs = data["conf"]

    H, W, T = masks.shape
    for t in range(T):
        # optional: basic filters first
        if confs[..., t].mean() < 0.2:
            continue

        # GMM filter
        logp = dino_gmm_anomaly_score(vd, t, masks[..., t])
        if logp < threshold:
            continue  # skip anomalous frame

        self.items.append((vd, t))


## (Optional) Training a segmentation network on pseudo-labels
Below is a compact skeleton for training a UNet-like model. It's intentionally short — replace with your preferred architecture and training loop. Use strong data augmentation and validation on expert-labeled subset.

In [None]:
class SmallUNet(nn.Module):
    # very small UNet skeleton for demonstration; replace with proven arch
    def __init__(self, in_ch=3, out_ch=2):
        super().__init__()
        self.enc1 = nn.Sequential(nn.Conv2d(in_ch, 32, 3, padding=1), nn.ReLU(), nn.Conv2d(32,32,3,padding=1), nn.ReLU())
        self.pool = nn.MaxPool2d(2)
        self.enc2 = nn.Sequential(nn.Conv2d(32,64,3,padding=1), nn.ReLU(), nn.Conv2d(64,64,3,padding=1), nn.ReLU())
        self.up = nn.ConvTranspose2d(64,32,2,stride=2)
        self.dec1 = nn.Sequential(nn.Conv2d(64,32,3,padding=1), nn.ReLU(), nn.Conv2d(32,32,3,padding=1), nn.ReLU())
        self.head = nn.Conv2d(32,out_ch,1)
    def forward(self,x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool(e1))
        u = self.up(e2)
        d = self.dec1(torch.cat([u,e1], dim=1))
        return self.head(d)

# NOTE: training loop is left minimal. For a real experiment, implement dataset loader, augmentation,
# optimizer, scheduler, proper validation and checkpointing. We intentionally do not run training here.

print('UNet skeleton ready (not training in notebook by default).')


## Notes, tips and next steps
- If DINOv3 features do not match well on echocardiography, try preprocessing: CLAHE, histogram equalization, contrast adjustments, or a small learned colorization network before DINOv3.
- If memory is an issue, reduce SHORT_SIDE or use a smaller DINOv3 model (dinov3_vits16).
- To improve robustness, use all 3 annotated frames as context (done above), and optionally perform forward/backward propagation and average results.
- When creating pseudo-labels for training, filter by confidence per-frame or per-pixel, and prefer expert-labeled videos or give them higher weight.
- For test inference, either: (A) run propagation with the provided 3 context frames (if test has context), or (B) use a trained segmentation model for speed and consistency.

If you want, I can now:
- 1) adapt this notebook to run distributed or tiled propagation for very large crops to save memory, or
- 2) implement the full training loop + data loader that consumes the generated pseudo-labels and trains a UNet with augmentation, or
- 3) add detailed visualization utilities (top-k patch matches, per-patch similarity heatmaps) to debug propagation failures.

Which would you like next?