In [None]:
# Cell 1: Imports and User-Defined Paths
import sys
import os
import torch
import numpy as np
import cv2 # For video writing and image manipulation
import json # For config display
from PIL import Image # For mask saving if needed, or consistency
from tqdm import tqdm

# Add project root to sys.path if not already there
# Assuming notebook is in aot_plus/ or its parent directory
if '.' not in sys.path:
    sys.path.insert(0, '.')
project_root = os.path.abspath(os.path.join(os.getcwd(), '..')) if os.path.basename(os.getcwd()) == 'aot_plus' else os.path.abspath('.')
if project_root not in sys.path:
    sys.path.insert(0, project_root)
# Ensure 'aot_plus' is the CWD if running from parent
if os.path.basename(os.getcwd()) != 'aot_plus' and os.path.exists('aot_plus'):
    print("Changing CWD to 'aot_plus'")
    os.chdir('aot_plus')

# --- User-configurable paths ---
# !!! IMPORTANT: User needs to set this path to their fine-tuned model !!!
FINETUNED_MODEL_CKPT_PATH = "./results/finetune_extracted_notebook_R50_AOTL_Temp_pe_Slot_4/default/ckpt/save_step_10000.pth" # Example path
OUTPUT_VIDEO_DIR = "./evaluation_videos" # Directory to save output videos and masks
EVALUATION_DATA_ROOT = "./extracted_frames/" # Path to the evaluation data (e.g., extracted frames)

# --- GPU Configuration ---
GPU_ID = 0

# --- Create output directory ---
os.makedirs(OUTPUT_VIDEO_DIR, exist_ok=True)

print(f"Current working directory: {os.getcwd()}")
print(f"Finetuned model checkpoint: {FINETUNED_MODEL_CKPT_PATH}")
print(f"Output video directory: {OUTPUT_VIDEO_DIR}")
print(f"Evaluation data root: {EVALUATION_DATA_ROOT}")
if not os.path.exists(FINETUNED_MODEL_CKPT_PATH):
    print(f"WARNING: Finetuned model checkpoint not found at {FINETUNED_MODEL_CKPT_PATH}")

In [None]:
# Cell 2: Load Configuration
from tools.get_config import get_config # Re-using get_config from tools
from utils.utils import Tee, make_log_dir # For potential logging setup (optional for notebook)

# --- Basic Configuration Parameters (can be adjusted) ---
# These should ideally match the training config under which the model was fine-tuned,
# especially model-specific parts.
# For evaluation, we might load a default config and then override specific test parameters.
EXP_NAME_FOR_CONFIG = "default_eval" # Can be generic for eval
MODEL_NAME_STR_FOR_CONFIG = "r50_aotl" # Should match the fine-tuned model's architecture
STAGE_STR_FOR_CONFIG = "default" # Or the stage used for training if it affects model structure

# Load base configuration
cfg = get_config(STAGE_STR_FOR_CONFIG, EXP_NAME_FOR_CONFIG, MODEL_NAME_STR_FOR_CONFIG)

# --- Override with Evaluation-Specific Settings ---
cfg.TEST_CKPT_PATH = FINETUNED_MODEL_CKPT_PATH
cfg.TEST_GPU_ID = GPU_ID
cfg.TEST_GPU_NUM = 1 # Single GPU for notebook evaluation
cfg.DIST_ENABLE = False # Ensure non-distributed mode

# Settings for AOTInferEngine (can be adjusted if needed)
cfg.TEST_LONG_TERM_MEM_GAP = 9999 # From AOTInferEngine default in evaluator
cfg.MODEL_MAX_OBJ_NUM = cfg.MODEL_MAX_OBJ_NUM # Ensure this is consistent with training

# Ensure the dataset config for ExtractedFrames is present if we use it by name later
# This part is mostly for consistency if build_eval_dataset looks for it.
# For direct instantiation of ExtractedFramesTrain, we pass params directly.
if 'EXTRACTED_FRAMES' not in cfg.DATASET_CONFIGS:
    cfg.DATASET_CONFIGS["EXTRACTED_FRAMES"] = {
        "TYPE": "ExtractedFramesTrain", # We'll use this class
        "CONFIG": {
            "COMMON": { # Common params ExtractedFramesTrain might expect via config
                "DATA_IMG_DIR": EVALUATION_DATA_ROOT,
                "DATA_ANNO_DIR": EVALUATION_DATA_ROOT, # Not strictly needed if GT loaded by class directly
                "SEQ_LEN": 5, # This will be used by ExtractedFramesTrain for sequence loading
                "MAX_OBJ_NUM": cfg.MODEL_MAX_OBJ_NUM,
                "OUTPUT_SIZE": cfg.DATA_RANDOMCROP, # Example, might need specific eval size
            },
            "TRAIN": { "RGB": True } # Example
        }
    }
else: # If it exists, ensure DATA_IMG_DIR is set for our eval data
    cfg.DATASET_CONFIGS['EXTRACTED_FRAMES']['CONFIG']['COMMON']['DATA_IMG_DIR'] = EVALUATION_DATA_ROOT
    cfg.DATASET_CONFIGS['EXTRACTED_FRAMES']['CONFIG']['COMMON']['DATA_ANNO_DIR'] = EVALUATION_DATA_ROOT
    # Important: Ensure SEQ_LEN for ExtractedFramesTrain is what the inference loop expects
    # If engine processes frame-by-frame from a sequence, dataloader should provide that sequence.
    cfg.DATASET_CONFIGS['EXTRACTED_FRAMES']['CONFIG']['COMMON']['SEQ_LEN'] = cfg.DATA_SEQ_LEN


# Display some key config values
print(f"Using checkpoint: {cfg.TEST_CKPT_PATH}")
print(f"Evaluation on GPU: {cfg.TEST_GPU_ID}")
print(f"Model: {cfg.MODEL_NAME}")
# print(json.dumps(cfg.__dict__, indent=2, default=str)) # For full config display

In [None]:
# Cell 3: Model and Inference Engine Loading
from networks.models import build_vos_model
from utils.checkpoint import load_network
from networks.engines.aot_engine import AOTInferEngine # Ensure this is the correct inference engine

# Set device
device = torch.device(f"cuda:{GPU_ID}" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
    torch.cuda.set_device(device)
print(f"Using device: {device}")

# Build model
model = build_vos_model(cfg.MODEL_VOS, cfg)

# Load fine-tuned weights
if cfg.TEST_CKPT_PATH and os.path.exists(cfg.TEST_CKPT_PATH):
    model, removed_dict = load_network(model, cfg.TEST_CKPT_PATH, device) # Pass device object
    if len(removed_dict) > 0:
        print(f"Removed {len(removed_dict)} keys from checkpoint: {removed_dict}")
    print(f"Successfully loaded model weights from: {cfg.TEST_CKPT_PATH}")
else:
    print(f"ERROR: Checkpoint not found at {cfg.TEST_CKPT_PATH}. Please set FINETUNED_MODEL_CKPT_PATH correctly.")
    # Potentially raise an error or stop execution
    assert False, "Checkpoint not found."

model.eval() # Set to evaluation mode
model = model.to(device) # Ensure model is on the correct device

# Build inference engine
# AOTInferEngine(aot_model, gpu_id, long_term_mem_gap, short_term_mem_skip, max_aot_obj_num)
# These engine params might need to come from cfg or have sensible defaults.
# Based on Evaluator, long_term_mem_gap is cfg.TEST_LONG_TERM_MEM_GAP
# short_term_mem_skip is often 1. max_aot_obj_num defaults to model.max_obj_num.
engine = AOTInferEngine(
    aot_model=model,
    gpu_id=GPU_ID, # gpu_id for the engine
    long_term_mem_gap=cfg.TEST_LONG_TERM_MEM_GAP if hasattr(cfg, 'TEST_LONG_TERM_MEM_GAP') else 9999,
    # short_term_mem_skip=1 # Default in AOTInferEngine if not passed or part of its internal logic
)
engine.eval() # Engine also has an eval mode
# No explicit .cuda() on engine, as AOTInferEngine's constructor takes gpu_id and handles device for its components.
# Its internal AOTEngine instances will use the gpu_id.

print("Model and Inference Engine loaded successfully.")
print(f"Model is on device: {next(model.parameters()).device}")
# For AOTInferEngine, its internal engines are AOTEngine instances which store gpu_id.
# We can check one of its AOT model's device if needed after a reference frame is added,
# but the main model passed to it is already on the correct device.

In [None]:
# Cell 4: Dataset and DataLoader for Evaluation
from dataloaders.train_datasets import ExtractedFramesTrain # Using the refactored version
from dataloaders.video_transforms import ToTensor as VideoToTensor # The custom ToTensor from video_transforms
from torchvision import transforms as tv_transforms # Standard torchvision transforms

# Define a simple transform for evaluation: just ToTensor and normalization
# Note: The custom ToTensor in video_transforms.py also does normalization.
# If we need just scaling to [0,1] and then standard Normalize, we might need to adjust.
# For now, using the project's custom ToTensor.
eval_transforms = tv_transforms.Compose([
    VideoToTensor() # This should handle the sample dict and normalize
])

# Instantiate the dataset
# We'll use ExtractedFramesTrain. It will load sequences of cfg.DATA_SEQ_LEN.
# The inference engine will then process these frames one by one.
eval_dataset = ExtractedFramesTrain(
    image_root=EVALUATION_DATA_ROOT, # From Cell 1
    transform=eval_transforms,
    rgb=cfg.DATASET_CONFIGS['EXTRACTED_FRAMES']['CONFIG']['TRAIN'].get('RGB', True), # from cfg
    seq_len=cfg.DATA_SEQ_LEN, # from cfg, e.g., 5
    # max_obj_n, repeat_time, ignore_thresh use defaults or values from cfg if needed by __init__
    max_obj_n=cfg.MODEL_MAX_OBJ_NUM,
    repeat_time=1, # For evaluation, typically process each sequence once
    ignore_thresh=cfg.DATASET_CONFIGS['EXTRACTED_FRAMES']['CONFIG']['TRAIN'].get('IGNORE_THRESH', 0.0)
)

if len(eval_dataset) == 0:
    print(f"WARNING: Evaluation dataset is empty. Check EVALUATION_DATA_ROOT ({EVALUATION_DATA_ROOT}) and seq_len ({cfg.DATA_SEQ_LEN}).")
    # assert False, "Evaluation dataset is empty." # Optionally stop execution

eval_dataloader = torch.utils.data.DataLoader(
    eval_dataset,
    batch_size=1, # Process one sequence at a time
    shuffle=False,
    num_workers=0, # For simplicity in notebook, can be > 0 if data loading is slow
    pin_memory=True
)

print(f"Evaluation dataset loaded: {len(eval_dataset)} sequences (or samples).")
# Each sample from this loader will be a sequence of frames

In [None]:
# Cell 5: Inference Loop and Mask Generation
# This cell will iterate through sequences, perform inference, and store results.

# Lists to store results for later metric calculation and video generation
all_gt_masks_for_metrics = [] # List of [ (num_frames-1, H, W), ... ] per sequence for each object
all_pred_masks_for_metrics = [] # List of [ (num_frames-1, H, W), ... ] per sequence for each object
# For video generation, we might want to store full original images and overlaid masks per sequence
# For simplicity now, let's focus on masks for metrics.
# We can re-load images for video generation or store paths.

# Store paths to original images and predicted masks for video generation for a few examples
example_sequence_data = [] # list of dicts: {'seq_name': name, 'frames': [paths], 'gt_masks': [masks], 'pred_masks': [masks]}
MAX_EXAMPLE_SEQUENCES_FOR_VIDEO = 3 # Generate video for up to this many sequences

print("Starting inference...")
for batch_idx, sample_sequence_dict in enumerate(tqdm(eval_dataloader)):
    
    # Extract data for the current sequence (batch_size is 1)
    # Squeeze batch dimension for ref/prev, but keep list for curr
    # .to(device) is crucial here if not already done by DataLoader's collate_fn (unlikely for custom dict)
    
    ref_img_tensor = sample_sequence_dict['ref_img'].to(device)
    ref_label_tensor = sample_sequence_dict['ref_label'].to(device).float() # Engine might expect float mask
    
    prev_img_tensor = sample_sequence_dict['prev_img'].to(device) if 'prev_img' in sample_sequence_dict else None
    prev_label_tensor = sample_sequence_dict['prev_label'].to(device).float() if 'prev_label' in sample_sequence_dict else None
        
    curr_img_tensors = [img.to(device) for img in sample_sequence_dict['curr_img']]
    curr_label_tensors = [label.to(device).float() for label in sample_sequence_dict['curr_label']]

    seq_meta = sample_sequence_dict['meta']
    seq_name = seq_meta['seq_name'][0] if isinstance(seq_meta['seq_name'], list) else seq_meta['seq_name']
    
    # For AOTInferEngine, obj_nums is a list of integers
    obj_nums_list = [seq_meta['obj_num'].item()] if torch.is_tensor(seq_meta['obj_num']) else [seq_meta['obj_num']]


    # --- Prepare for AOTInferEngine ---
    # The engine processes one frame at a time after an initial reference frame.
    # The loaded sample has: ref_img, prev_img, curr_img (list)
    # Total frames in this sample = 1 (ref) + 1 (prev) + len(curr_img) = cfg.DATA_SEQ_LEN
    
    frames_for_engine = [ref_img_tensor, prev_img_tensor] + curr_img_tensors
    gt_labels_for_engine = [ref_label_tensor, prev_label_tensor] + curr_label_tensors
    
    # (Optional) Get original image dimensions if needed for upscaling predictions
    # This assumes ToTensor in eval_transforms doesn't alter aspect ratio fundamentally before Resize
    # For now, we'll use mask dimensions as they come.
    # ori_height = seq_meta.get('original_height', ref_img_tensor.shape[-2]) 
    # ori_width = seq_meta.get('original_width', ref_img_tensor.shape[-1])

    engine.restart_engine()
    
    # --- Add Reference Frame ---
    # Ensure ref_label_tensor is suitable for AOTInferEngine.add_reference_frame
    # It expects an integer mask, (B, H, W) or (B, 1, H, W)
    # Our labels are (B, 1, H, W), float. Convert to int.
    engine.add_reference_frame(
        frames_for_engine[0], # ref_img_tensor
        gt_labels_for_engine[0].int(),  # ref_label_tensor, converted to int
        obj_nums=obj_nums_list 
    )
    
    # Store predicted masks for this sequence (excluding reference frame, as per common eval)
    current_seq_pred_masks_np = []
    current_seq_gt_masks_np = [] # For metrics

    # Store data for video if this is an example sequence
    save_for_video = (len(example_sequence_data) < MAX_EXAMPLE_SEQUENCES_FOR_VIDEO)
    if save_for_video:
        video_data_item = {'seq_name': seq_name, 'frames': [], 'pred_masks': [], 'gt_masks': []}
        # Store ref frame/mask for video
        video_data_item['frames'].append(ref_img_tensor.cpu().squeeze(0).numpy()) # Store as HWC for cv2
        video_data_item['pred_masks'].append(ref_label_tensor.cpu().squeeze(0).numpy()) # Use GT as "pred" for first frame
        video_data_item['gt_masks'].append(ref_label_tensor.cpu().squeeze(0).numpy())


    # --- Propagate through subsequent frames ---
    # (prev_img is frames_for_engine[1], then curr_imgs start from frames_for_engine[2])
    for frame_idx_in_seq in range(1, len(frames_for_engine)):
        current_frame_tensor = frames_for_engine[frame_idx_in_seq]
        current_gt_label_tensor = gt_labels_for_engine[frame_idx_in_seq] # For metrics

        # Predict mask
        pred_logits = engine.match_propogate_one_frame(current_frame_tensor) # Output size not specified, uses input size
        pred_mask_tensor = engine.predict_current_mask(output_size=current_frame_tensor.shape[-2:]) # Resize to input frame size
        
        # For multi-object, pred_mask_tensor is (1, H, W). Squeeze batch dim.
        pred_mask_np = pred_mask_tensor.squeeze(0).cpu().numpy().astype(np.uint8)
        current_seq_pred_masks_np.append(pred_mask_np)
        
        # Store GT mask for metrics (squeezed, CPU, uint8)
        gt_mask_np = current_gt_label_tensor.squeeze(0).cpu().numpy().astype(np.uint8)
        current_seq_gt_masks_np.append(gt_mask_np)

        # Update engine memory with predicted mask
        engine.update_memory(pred_mask_tensor.float()) # Engine might expect float mask

        if save_for_video:
            video_data_item['frames'].append(current_frame_tensor.cpu().squeeze(0).numpy())
            video_data_item['pred_masks'].append(pred_mask_np)
            video_data_item['gt_masks'].append(gt_mask_np)
            
    if current_seq_pred_masks_np: # If there were non-reference frames
        # For metrics, typically need (num_objects, num_frames, H, W)
        # Assuming single object evaluation for now for simplicity, or main object.
        # If multi-object, need to expand/select objects from masks.
        # For now, stack them: (num_frames-1, H, W)
        # The metric functions (db_eval_iou) will need to handle this.
        # Often, evaluation is per object, so we might need to iterate through object IDs.
        # Let's assume for now we evaluate the combined mask (object ID 1 if single, or all IDs)
        
        # Placeholder: for now, we store the raw sequence of masks.
        # Metric calculation in next cell will need to process this.
        # For simplicity, let's assume we are interested in the primary object or combined mask.
        # db_eval_iou expects (num_objects, num_frames, H, W)
        # Our masks are (H,W) and we have (num_frames-1) of them.
        # We'd need to identify objects and stack them.
        # Let's simplify for now: assume metrics will be calculated on combined masks.
        all_pred_masks_for_metrics.append(np.stack(current_seq_pred_masks_np, axis=0))
        all_gt_masks_for_metrics.append(np.stack(current_seq_gt_masks_np, axis=0))


    if save_for_video:
        example_sequence_data.append(video_data_item)
        
    # Optional: Clear CUDA cache if memory is an issue between sequences
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

print("Inference completed.")
if not all_pred_masks_for_metrics:
    print("WARNING: No predictions were generated. Check dataset and inference loop.")
