In [None]:
# Load dataset
import os
import random
import sys

import numpy as np
import torch

# Add the project root to sys.path
# We assume the notebook is located in <project_root>/notebooks/
project_root = os.path.abspath(os.path.join(os.getcwd(), '..'))
if project_root not in sys.path:
    sys.path.append(project_root)

# Import the dataset classes
from no_time_to_train.dataset.coco_ref_dataset import (
    COCOMemoryFillCropDataset, COCORefOracleTestDataset)
from no_time_to_train.dataset.few_shot_sampling import sample_memory_dataset

# Set seeds for reproducibility
SEED = 42
SHOTS = 10
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

# --- 1. Load Test Dataset (Query Set) ---
print("--- Loading Test Dataset (Query Set) ---")
query_set_config = {
    "root": os.path.join(project_root, "data/olive_diseases/val2017"),
    "json_file": os.path.join(project_root, "data/olive_diseases/annotations/instances_val2017.json"),
    "image_size": 1024,
    "norm_img": False,
    "class_split": "olive_diseases",
    "with_query_points": False
}


# --- 2. Load Support Dataset (Memory Set) ---
print("\n--- Loading Support Dataset (Reference Set) ---")
support_pkl_path = os.path.join(project_root, f"work_dirs/olive_results/olive_{SHOTS}shot_seed{SEED}.pkl")
support_json_file = os.path.join(project_root, "data/olive_diseases/annotations/instances_train2017.json")

# Generate the few-shot split if it doesn't exist
if not os.path.exists(support_pkl_path):
    print(f"Generating few-shot split at {support_pkl_path}...")
    os.makedirs(os.path.dirname(support_pkl_path), exist_ok=True)
    sample_memory_dataset(
        json_file=support_json_file,
        out_path=support_pkl_path,
        memory_length=SHOTS,
        remove_bad=True,
        dataset="olive_diseases"
    )
else:
    print(f"Found existing few-shot split at {support_pkl_path}")

support_set_config = {
    "root": os.path.join(project_root, "data/olive_diseases/train2017"),
    "json_file": support_json_file,
    "memory_pkl": support_pkl_path,
    "class_split": "olive_diseases",
    "image_size": 1024,
    "memory_length": SHOTS,
    "context_ratio": 0.2,
    "norm_img": False
}

support_set = COCOMemoryFillCropDataset(**support_set_config)
query_set = COCORefOracleTestDataset(**query_set_config)

print(len(support_set))
print(len(query_set))

In [None]:
# Collect ALL 50 support images
all_support_images = []
all_support_metadata = []

# support_set has length 50 (5 classes * 10 shots)
for i in range(len(support_set)):
    item = support_set[i]
    
    # Each item has 'refs_by_cat' with ONE category key
    refs = item['refs_by_cat']
    cat_ind = list(refs.keys())[0] # The internal category index (0, 1, 2, 3, 4)
    
    # Get the image tensor (1, 3, H, W)
    # The dataset unsqueezes it to (1, 3, H, W)
    img_tensor = refs[cat_ind]['imgs'][0] 
    
    # Get the mask tensor (H, W) if needed for precise prompting
    mask_tensor = refs[cat_ind]['masks']

    # Convert to numpy (H, W, 3) for video construction
    img_np = img_tensor.permute(1, 2, 0).numpy()
    
    # Store
    all_support_images.append(img_np)
    
    # Store metadata
    # Convert internal index back to real COCO category ID
    real_cat_id = support_set.cat_inds_to_ids[cat_ind]
    
    all_support_metadata.append({
        "index": i,
        "cat_ind": cat_ind,
        "category_id": real_cat_id,
        "image_id": item['refs_by_cat'][cat_ind]['img_info'][0]['id']
    })

print(f"Collected {len(all_support_images)} support images.")
print(f"First image shape: {all_support_images[0].shape}")
print(f"Metadata example: {all_support_metadata[0]}")


In [None]:
# visualize a support set
import matplotlib.pyplot as plt
def visualize_support_set(images, metadata, num_images=10):
    plt.figure(figsize=(20, 4))
    for i in range(num_images):
        plt.subplot(1, num_images, i + 1)
        plt.imshow(images[i])
        plt.title(f"Cat ID: {metadata[i]['category_id']}\nImage ID: {metadata[i]['image_id']}")
        plt.axis('off')
    plt.tight_layout()
    plt.show()  
    
visualize_support_set(all_support_images, all_support_metadata, num_images=5)

In [None]:
import cv2
import numpy as np
import torch

# 1. Get a query image
query_idx = 0
query_item = query_set[query_idx]

# Check keys
if 'target_img' in query_item:
    query_img_tensor = query_item['target_img']
elif 'refs_by_cat' in query_item:
    q_refs = query_item['refs_by_cat']
    q_cat_ind = list(q_refs.keys())[0]
    query_img_tensor = q_refs[q_cat_ind]['imgs'][0] 
else:
    raise ValueError(f"Unknown item structure. Keys: {query_item.keys()}")

# Convert to numpy (C, H, W) -> (H, W, C)
if isinstance(query_img_tensor, torch.Tensor):
    query_img_np = query_img_tensor.permute(1, 2, 0).numpy()
else:
    query_img_np = query_img_tensor

# 2. Helper to process images for video (float -> uint8)
def to_uint8(img):
    img = np.clip(img, 0, 1)
    return (img * 255).astype(np.uint8)

# 3. Construct the frame sequence
video_frames = []
NUM_REPEATS = 1

# Process support images
support_frames = [to_uint8(img) for img in all_support_images]
query_frame = to_uint8(query_img_np)

# Append support frames repeated NUM_REPEATS times
for _ in range(NUM_REPEATS):
    video_frames.extend(support_frames)

# Append query frame at the end
video_frames.append(query_frame)

# 4. Save video using OpenCV
output_path = "support_query_sequence.mp4"
fps = 5

if len(video_frames) > 0:
    height, width, layers = video_frames[0].shape
    
    # Try different codecs if one fails, strictly mp4v is often safe
    fourcc = cv2.VideoWriter_fourcc(*'mp4v') 
    video = cv2.VideoWriter(output_path, fourcc, fps, (width, height))

    for frame in video_frames:
        # cv2 expects BGR
        frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
        video.write(frame_bgr)

    video.release()
    print(f"Created video with {len(video_frames)} frames.")
    print(f" - Support images: {len(all_support_images)} (repeated {NUM_REPEATS} times)")
    print(f" - Query image: 1 (at the end)")
    print(f"Saved to: {output_path}")
else:
    print("No frames to save.")

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from transformers import Sam3TrackerVideoModel, Sam3TrackerVideoProcessor
from transformers.video_utils import load_video
from matplotlib.patches import Patch

# --- Configuration ---
# Use bfloat16 for Ampere+ GPUs (RTX 30xx/40xx, A100) for better stability, otherwise float16
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
device = torch.device("cuda")

# 1. Initialize Model
print(f"Loading SAM3 model on {device} with {dtype}...")

# Load directly to GPU with optimized attention implementation
model = Sam3TrackerVideoModel.from_pretrained(
    "facebook/sam3",
    torch_dtype=dtype,
    attn_implementation="flash_attention_2" # Significant speedup for Transformers
).to(device)

processor = Sam3TrackerVideoProcessor.from_pretrained("facebook/sam3")

In [None]:
# 2. Load the video
video_path = "support_query_sequence.mp4"
video_frames, _ = load_video(video_path)
print(f"Loaded video with {len(video_frames)} frames.")

# 3. Init Session
# Passing dtype here ensures internal tensors are initialized in half-precision
inference_session = processor.init_video_session(
    video=video_frames,
    inference_device=device,
    processing_device=device,
    dtype=dtype, 
)

# 4. Add Support Prompts (Masks)
print("Adding support prompts (masks)...")

# Determine target size once
target_h, target_w = 1024, 1024
if hasattr(processor, "image_processor") and hasattr(processor.image_processor, "size"):
     size_conf = processor.image_processor.size
     target_h = size_conf.get("height", 1024)
     target_w = size_conf.get("width", 1024)

# Use torch.inference_mode() - lower overhead than no_grad
with torch.inference_mode():
    for idx, meta in enumerate(all_support_metadata):
        frame_idx = idx
        
        # Get ground truth mask
        ds_item = support_set[meta['index']]
        gt_mask = ds_item['refs_by_cat'][meta['cat_ind']]['masks']
        
        if isinstance(gt_mask, torch.Tensor):
            gt_mask = gt_mask.numpy()
        if gt_mask.ndim == 3:
            gt_mask = gt_mask[0]

        # --- OPTIMIZATION START: GPU-based Preprocessing ---
        # Move raw mask to GPU immediately to avoid CPU processing bottlenecks
        mask_tensor = torch.from_numpy(gt_mask).to(device)
        
        # Add batch/channel dims: (H, W) -> (1, 1, H, W)
        mask_tensor = mask_tensor.unsqueeze(0).unsqueeze(0).float()

        # Resize on GPU using PyTorch interpolate instead of OpenCV
        if mask_tensor.shape[-2] != target_h or mask_tensor.shape[-1] != target_w:
            mask_tensor = torch.nn.functional.interpolate(
                mask_tensor, 
                size=(target_h, target_w), 
                mode='nearest'
            )
        
        # Binarize and cast to correct dtype
        mask_tensor = (mask_tensor > 0).to(dtype)
        # --- OPTIMIZATION END ---

        obj_id = meta['cat_ind'] + 1
        
        processor.add_inputs_to_inference_session(
            inference_session=inference_session,
            frame_idx=frame_idx,
            obj_ids=[obj_id],
            input_masks=mask_tensor
        )
        
        # Encode prompt
        model(
            inference_session=inference_session,
            frame_idx=frame_idx,
        )
        
        if idx % 10 == 0:
            print(f"Processed support frame {idx}/{len(all_support_metadata)}")

print(f"Added prompts for {len(all_support_metadata)} support frames. Propagating...")

# 5. Propagate
video_segments = {}

# Run propagation in inference mode
with torch.inference_mode():
    for output in model.propagate_in_video_iterator(inference_session):
        video_segments[output.frame_idx] = output.pred_masks

In [None]:
# 6. Visualize Result on Query Frame
query_frame_idx = len(video_frames) - 1
print(f"--- Result for Query Frame {query_frame_idx} ---")

if query_frame_idx in video_segments:
    # Keep logits on GPU for post-processing if possible, but processor usually handles list of tensors
    masks_logits = video_segments[query_frame_idx]
    
    # Post-process (This step often requires CPU sync internally in HF processors, but we delayed it until the end)
    video_res_masks = processor.post_process_masks(
        [masks_logits], 
        original_sizes=[[inference_session.video_height, inference_session.video_width]], 
        binarize=True
    )[0]
    
    # Visualization Code (CPU side)
    
    # Check for Ground Truth
    gt_anns = query_item.get('tar_anns_by_cat', None)
    
    if gt_anns:
        fig, axes = plt.subplots(1, 2, figsize=(20, 10))
        ax_pred = axes[0]
        ax_gt = axes[1]
    else:
        fig, ax_pred = plt.subplots(1, 1, figsize=(10, 10))
        ax_gt = None

    # --- Plot Prediction ---
    ax_pred.imshow(query_img_np) 
    
    found = False
    cmap = plt.get_cmap('tab10')
    
    # Move final mask to CPU only when strictly needed for plotting
    video_res_masks_np = video_res_masks.cpu().numpy()
    
    handles = []
    labels = []

    for i in range(video_res_masks_np.shape[0]):
        m = video_res_masks_np[i, 0]
        if m.max() > 0:
            found = True
            color = np.array(cmap(i % 10))
            colored_mask = np.zeros((m.shape[0], m.shape[1], 4))
            colored_mask[m > 0] = color
            colored_mask[..., 3] = (m > 0) * 0.5 
            ax_pred.imshow(colored_mask)

            handles.append(Patch(color=color, label=f"Object {i + 1}"))
            labels.append(f"Object {i + 1}")

    if handles:
        ax_pred.legend(handles=handles, labels=labels, loc="upper right", frameon=True)

    ax_pred.set_title(f"Predicted Disease on Query Image\n(Frame {query_frame_idx})")
    ax_pred.axis('off')
    
    if not found:
        print("No disease detected on query frame.")

    # --- Plot Ground Truth ---
    if ax_gt is not None:
        ax_gt.imshow(query_img_np)
        
        gt_handles = []
        gt_labels = []
        
        # Iterate over ground truth categories available in tar_anns_by_cat
        # We try to match colors with predictions if possible.
        # Prediction index i corresponds to cat_ind = i
        
        for cat_ind, ann_data in gt_anns.items():
            # Use same color encoding as prediction: i = cat_ind
            color = np.array(cmap(cat_ind % 10))
            
            gt_tensor = ann_data['masks']
            if isinstance(gt_tensor, torch.Tensor):
                if gt_tensor.ndim == 3:
                     gt_mask = gt_tensor.sum(dim=0).cpu().numpy() > 0
                else:
                     gt_mask = gt_tensor.cpu().numpy() > 0
            else:
                gt_mask = gt_tensor # Assuming numpy if not tensor

            if gt_mask.max() > 0:
                colored_mask = np.zeros((gt_mask.shape[0], gt_mask.shape[1], 4))
                colored_mask[gt_mask > 0] = color
                colored_mask[..., 3] = (gt_mask > 0) * 0.5
                ax_gt.imshow(colored_mask)

                # Add to legend
                real_cat_id = support_set.cat_inds_to_ids[cat_ind] if hasattr(support_set, 'cat_inds_to_ids') else f"ind_{cat_ind}"
                gt_handles.append(Patch(color=color, label=f"Class {real_cat_id} ({cat_ind})"))
                gt_labels.append(f"Class {real_cat_id} ({cat_ind})")

        if gt_handles:
            ax_gt.legend(handles=gt_handles, labels=gt_labels, loc="upper right", frameon=True)

        ax_gt.set_title("Ground Truth Annotations")
        ax_gt.axis('off')

    plt.tight_layout()
    plt.show()

else:
    print("Error: No prediction propagated to the last frame.")

In [None]:
# Evaluate IoU against Ground Truth
if 'video_res_masks_np' in locals():
    print(f"\n--- Quantitative Evaluation ---")
    
    # 1. Helper function
    def calculate_iou(pred_mask, gt_mask):
        intersection = np.logical_and(pred_mask, gt_mask).sum()
        union = np.logical_or(pred_mask, gt_mask).sum()
        if union == 0:
            return 1.0 if intersection == 0 else 0.0 
        return intersection / union

    # 2. Get GT
    # query_item was loaded earlier
    gt_anns = query_item.get('tar_anns_by_cat', None)
    
    if gt_anns:
        ious = []
        
        # video_res_masks_np has shape (N_tracked, 1, H, W)
        # N_tracked should match the number of unique categories we prompted (or at least queried)
        print(f"Evaluating {video_res_masks_np.shape[0]} predictions...")
        
        for i in range(video_res_masks_np.shape[0]):
            pred_mask = video_res_masks_np[i, 0] > 0
            
            # Assuming implicit mapping: prediction i corresponds to obj_id i+1 -> category index i
            # This relies on objects being tracked with IDs 1..N and returned in order
            cat_ind = i 
            
            gt_mask = np.zeros_like(pred_mask, dtype=bool)
            
            if cat_ind in gt_anns:
                # gt_anns[cat_ind]['masks'] is typically a tensor (N_inst, H, W)
                gt_tensor = gt_anns[cat_ind]['masks']
                
                # Convert to numpy and merge instances (semantic segmentation style)
                if gt_tensor.ndim == 3:
                     gt_encoded = gt_tensor.sum(dim=0).cpu().numpy() > 0
                else:
                     gt_encoded = gt_tensor.cpu().numpy() > 0
                gt_mask = gt_encoded
                
            iou = calculate_iou(pred_mask, gt_mask)
            ious.append(iou)
            
            # Label
            cat_id = support_set.cat_inds_to_ids[cat_ind] if hasattr(support_set, 'cat_inds_to_ids') else f"ind_{cat_ind}"
            print(f"Class {cat_ind} (ID {cat_id}): IoU = {iou:.4f}")
            
        print(f"Mean IoU: {np.mean(ious):.4f}")
        
    else:
        print("Ground Truth annotations (tar_anns_by_cat) not found in 'query_item'.")
        print("Available keys:", query_item.keys())
else:
    print("No predictions found (video_res_masks_np not defined). Run previous cell first.")