In [3]:
import os
import pickle
import random
import numpy as np
import torch
from PIL import Image
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
import pycocotools.mask as mask_utils
import matplotlib.pyplot as plt
import cv2
from transformers import Sam3Processor, Sam3Model
from collections import defaultdict
from tqdm import tqdm
import json

# Initialize model if not present (e.g. if previous cells weren't run)
if 'model' not in globals() or 'processor' not in globals():
    print("Model or processor not found. Initializing...")
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = Sam3Model.from_pretrained("facebook/sam3").to(device)
    processor = Sam3Processor.from_pretrained("facebook/sam3")
else:
    print("Using existing model and processor.")
    if 'device' not in globals():
         device = "cuda" if torch.cuda.is_available() else "cpu"

# --- Metrics Classes from no-time-to-train ---

def mask_to_rle(binary_mask):
    rle = mask_utils.encode(np.asfortranarray(binary_mask))
    rle['counts'] = rle['counts'].decode('utf-8')
    return rle

class COCOInstToSegmEvaluator:
    def __init__(self, coco_gt, results, cat_ids=None, img_ids=None):
        """Initialize the evaluator with ground truth and prediction paths"""
        self.confidence_threshold = 0.5
        self.coco_gt = coco_gt
        self.results = results
        
        if cat_ids is None:
            self.cat_ids = sorted(coco_gt.getCatIds())
        else:
            self.cat_ids = sorted(cat_ids)
            
        # Create category id to index mapping
        # Note: mapping starts at 1 to differentiate from background (0)
        self.cat_id_to_idx = {cat_id: idx + 1 for idx, cat_id in enumerate(self.cat_ids)}
        
        # Pre-compute all semantic masks
        self.pred_semantic_masks = {}
        self.gt_semantic_masks = {}
        
        # Determine images to evaluate
        if img_ids is None:
            target_img_ids = coco_gt.imgs.keys()
        else:
            target_img_ids = img_ids

        self.image_sizes = {img_id: (coco_gt.imgs[img_id]['height'], coco_gt.imgs[img_id]['width']) 
                           for img_id in target_img_ids if img_id in coco_gt.imgs}
        
        # Pre-compute all prediction masks
        print("Converting instance predictions to semantic masks...")
        # Optimize: group results by image_id first
        self.results_by_img = defaultdict(list)
        for r in self.results:
            self.results_by_img[r['image_id']].append(r)
            
        for img_id, (height, width) in tqdm(self.image_sizes.items()):
            self.pred_semantic_masks[img_id] = self._convert_pred_to_semantic(
                img_id, height, width)
            
        # Pre-compute all ground truth masks
        print("Converting ground truth instances to semantic masks...")
        for img_id, (height, width) in tqdm(self.image_sizes.items()):
            self.gt_semantic_masks[img_id] = self._convert_gt_to_semantic(
                img_id, height, width)

    def _convert_pred_to_semantic(self, img_id, height, width):
        """Helper method to convert predictions for one image"""
        semantic_mask = np.zeros((height, width), dtype=np.uint8)
        
        img_preds = [p for p in self.results_by_img.get(img_id, [])
                    if p['score'] >= self.confidence_threshold]
        
        # Sort by score descending (highest first)
        img_preds.sort(key=lambda x: x['score'], reverse=True)
        
        for pred in img_preds:
            if pred['category_id'] not in self.cat_id_to_idx:
                continue
            
            if isinstance(pred['segmentation'], dict):
                 binary_mask = mask_utils.decode(pred['segmentation'])
            else:
                 # Helper if rle is not in dict format (unlikely with pycocotools)
                 pass

            category_idx = self.cat_id_to_idx[pred['category_id']]
            semantic_mask[binary_mask > 0] = category_idx
            
        return semantic_mask

    def _convert_gt_to_semantic(self, img_id, height, width):
        """Helper method to convert ground truth for one image"""
        semantic_mask = np.zeros((height, width), dtype=np.uint8)
        
        # Load anns
        ann_ids = self.coco_gt.getAnnIds(imgIds=img_id, iscrowd=0)
        img_anns = self.coco_gt.loadAnns(ann_ids)
        
        for ann in img_anns:
            if ann['category_id'] not in self.cat_id_to_idx:
                continue
                
            if isinstance(ann['segmentation'], dict):
                binary_mask = mask_utils.decode(ann['segmentation'])
            else:
                rles = mask_utils.frPyObjects(ann['segmentation'], height, width)
                rle = mask_utils.merge(rles)
                binary_mask = mask_utils.decode(rle)
            
            category_idx = self.cat_id_to_idx[ann['category_id']]
            semantic_mask[binary_mask > 0] = category_idx
            
        return semantic_mask

    def evaluate(self):
        """Evaluate semantic segmentation results"""
        total_inter = defaultdict(int)
        total_union = defaultdict(int)
        
        print("Computing IoU metrics...")
        for img_id in self.image_sizes.keys():
            pred_mask = torch.from_numpy(self.pred_semantic_masks[img_id])
            gt_mask = torch.from_numpy(self.gt_semantic_masks[img_id])
            
            # Use computed indices (1..N)
            for cat_id, class_idx in self.cat_id_to_idx.items():
                pred_binary = (pred_mask == class_idx)
                gt_binary = (gt_mask == class_idx)
                
                intersection = (pred_binary & gt_binary).sum().item()
                union = (pred_binary | gt_binary).sum().item()
                
                total_inter[cat_id] += intersection
                total_union[cat_id] += union
        
        # Compute IoU for each class (and mean)
        ious = {}
        sum_iou = 0.0
        count = 0
        
        ids_to_names = {cat['id']: cat['name'] for cat in self.coco_gt.loadCats(self.cat_ids)}

        print("Per-class IoU:")
        for cat_id in self.cat_ids:
            if total_union[cat_id] == 0:
                continue
                
            iou = total_inter[cat_id] / total_union[cat_id]
            ious[cat_id] = iou
            sum_iou += iou
            count += 1
            print(f"  {ids_to_names.get(cat_id, cat_id)}: {iou:.4f}")
        
        miou = sum_iou / count if count > 0 else 0.0
        print(f"Mean IoU: {miou:.4f}")
        return miou, ious

def evaluate_coco_metrics(coco_gt, results_list, img_ids):
    # 1. Standard COCO Instance Segmentation Evaluation
    print("\n--- COCO Instance Segmentation Evaluation (mAP) ---")
    
    # Needs a hack for COCOeval to accept existing list, usually it loads from file or directly
    # But cocoeval usually takes a coco object for res.
    # We can create a COCO object for results
    coco_dt = coco_gt.loadRes(results_list)
    coco_eval = COCOeval(coco_gt, coco_dt, 'segm')
    coco_eval.params.imgIds = img_ids
    coco_eval.evaluate()
    coco_eval.accumulate()
    coco_eval.summarize()
    
    # 2. Semantic Segmentation Evaluation (mIoU)
    print("\n--- Semantic Segmentation Evaluation (mIoU) ---")
    inst_seg_evaluator = COCOInstToSegmEvaluator(coco_gt, results_list, img_ids=img_ids)
    inst_seg_evaluator.evaluate()

# --- End Metrics Classes ---

# Settings
DATA_DIR = "./data/olive_diseases"
TRAIN_IMG_DIR = os.path.join(DATA_DIR, "train2017")
VAL_IMG_DIR = os.path.join(DATA_DIR, "val2017")
VAL_ANN_FILE = os.path.join(DATA_DIR, "annotations/instances_val2017.json")
TRAIN_ANN_FILE = os.path.join(DATA_DIR, "annotations/instances_train2017.json")

SUPPORT_SETS_DIRS = {
    "v1": "support_sets_olive",
    "v2": "support_sets_olive_v2"
}
# K_SHOT = 3 # Removed single K
K_VALUES = [1, 2, 3, 5, 10]
# MAX_QUERY_IMAGES = 5 # Removed limit

# Load COCO
print("Loading COCO annotations...")
coco_train = COCO(TRAIN_ANN_FILE)
coco_val = COCO(VAL_ANN_FILE)
cats = coco_val.loadCats(coco_val.getCatIds())
cat_id_to_name = {cat['id']: cat['name'] for cat in cats}

def load_support_set(pkl_path):
    with open(pkl_path, 'rb') as f:
        return pickle.load(f)

def get_image(coco, img_dir, img_id):
    img_info = coco.loadImgs([img_id])[0]
    img_path = os.path.join(img_dir, img_info['file_name'])
    return Image.open(img_path).convert("RGB")

def get_ann_bboxes(coco, ann_ids):
    anns = coco.loadAnns(ann_ids)
    bboxes = []
    for ann in anns:
        x, y, w, h = ann['bbox']
        bboxes.append([x, y, x + w, y + h]) # Convert to xyxy
    return bboxes

def concat_images_and_boxes(support_data_list, query_image, target_h=512):
    # support_data_list: list of (image, bboxes)
    # bboxes are [x1, y1, x2, y2]
    
    resized_imgs = []
    all_shifted_bboxes = []
    
    current_x = 0
    
    # Process support images
    for img, bboxes in support_data_list:
        w, h = img.size
        scale = target_h / h
        new_w = int(w * scale)
        img_resized = img.resize((new_w, target_h), Image.Resampling.LANCZOS)
        
        # Transform bboxes
        for box in bboxes:
            bx1, by1, bx2, by2 = box
            # Scale
            bx1 *= scale
            by1 *= scale
            bx2 *= scale
            by2 *= scale
            # Shift
            bx1 += current_x
            bx2 += current_x
            # Append
            all_shifted_bboxes.append([bx1, by1, bx2, by2])
            
        resized_imgs.append(img_resized)
        current_x += new_w
        
    # Process query
    w, h = query_image.size
    scale = target_h / h
    query_new_w = int(w * scale)
    query_resized = query_image.resize((query_new_w, target_h), Image.Resampling.LANCZOS)
    
    # Create canvas
    final_w = current_x + query_new_w
    final_img = Image.new("RGB", (final_w, target_h))
    
    # Paste support
    curr_paste_x = 0
    for img in resized_imgs:
        final_img.paste(img, (curr_paste_x, 0))
        curr_paste_x += img.size[0]
        
    # Paste query
    query_x_start = curr_paste_x
    final_img.paste(query_resized, (query_x_start, 0))
    query_x_end = query_x_start + query_new_w
    
    # Return image, support bboxes in new image coord, query bbox in new image coord
    return final_img, all_shifted_bboxes, (query_x_start, 0, query_x_end, target_h)

def evaluate_support_set_version(version_name, support_dir, k_shot):
    print(f"\n--- Evaluating Support Set {version_name} ({k_shot}-shot) ---")
    pkl_path = os.path.join(support_dir, f"olive_diseases_{k_shot}shot.pkl")
    if not os.path.exists(pkl_path):
        print(f"File not found: {pkl_path}")
        return
    
    support_data = load_support_set(pkl_path)
    
    all_results = []
    evaluated_img_ids = set()
    
    for cat_id, cat_name in cat_id_to_name.items():
        if cat_id not in support_data: continue
        
        # Prepare Support Data
        supports = support_data[cat_id] # List of dicts
        support_items = []
        for item in supports:
            s_img = get_image(coco_train, TRAIN_IMG_DIR, item['img_id'])
            # Get bboxes for the few-shot examples
            s_bboxes = get_ann_bboxes(coco_train, item['ann_ids'])
            support_items.append((s_img, s_bboxes))
            
        # Get query images (ALL images now)
        query_img_ids = coco_val.getImgIds(catIds=[cat_id])
        print(f"  {cat_name}: {len(query_img_ids)} query images")
        
        for q_img_id in tqdm(query_img_ids, desc=f"Evaluating {cat_name}"):
            evaluated_img_ids.add(q_img_id)
            q_img = get_image(coco_val, VAL_IMG_DIR, q_img_id)
            
            # Concatenate
            concat_img, input_bboxes, query_bbox = concat_images_and_boxes(support_items, q_img)
            
            # Correct unpacking: query_bbox is (q_x1, q_y1, q_x2, q_y2)
            q_x1, q_y1, q_x2, q_y2 = query_bbox
            
            if not input_bboxes:
                continue

            # Run SAM3 with BBox prompts
            input_boxes = [input_bboxes] 
            input_boxes_labels = [[1] * len(input_bboxes)] 
            
            inputs = processor(
                images=concat_img, 
                input_boxes=input_boxes,
                input_boxes_labels=input_boxes_labels,
                return_tensors="pt"
            ).to(device)
            
            with torch.no_grad():
                outputs = model(**inputs)
            
            results = processor.post_process_instance_segmentation(
                outputs, threshold=0.4, mask_threshold=0.5,
                target_sizes=inputs.get("original_sizes").tolist()
            )[0]
            
            # Post-processing: Map masks back to original query image
            masks = results['masks'].cpu().numpy()
            scores = results['scores'].cpu().numpy()
            
            # Original image size
            orig_w, orig_h = q_img.size
            # The query part in concat image has size (q_x2-q_x1, q_y2-q_y1)
            # which corresponds to scaled version of original image.
            
            for mask, score in zip(masks, scores):
                # Crop the mask to the query image area
                mask_crop = mask[q_y1:q_y2, q_x1:q_x2]
                
                # If mask is empty in query area, skip
                if mask_crop.sum() == 0:
                    continue
                    
                # Resize back to original dimensions
                # mask_crop is (target_h, new_w) - we want (orig_h, orig_w)
                # cv2.resize takes (width, height)
                mask_orig = cv2.resize(mask_crop.astype(np.uint8), (orig_w, orig_h), interpolation=cv2.INTER_NEAREST)
                
                # Convert to RLE
                rle = mask_to_rle(mask_orig)
                
                all_results.append({
                    "image_id": q_img_id,
                    "category_id": cat_id,
                    "segmentation": rle,
                    "score": float(score)
                })

    # Run Standard Evaluation
    if all_results:
        print(f"Running metrics for {k_shot}-shot on {len(evaluated_img_ids)} images...")
        evaluate_coco_metrics(coco_val, all_results, sorted(list(evaluated_img_ids)))
    else:
        print("No predictions generated.")

# Run Evaluations
for k in K_VALUES:
    for v_name, v_dir in SUPPORT_SETS_DIRS.items():
        evaluate_support_set_version(v_name, v_dir, k)


Using existing model and processor.
Loading COCO annotations...
loading annotations into memory...
Done (t=0.09s)
creating index...
index created!
loading annotations into memory...
Done (t=0.01s)
creating index...
index created!

--- Evaluating Support Set v1 (1-shot) ---
  Anthracnose: 38 query images


Evaluating Anthracnose:   0%|          | 0/38 [00:02<?, ?it/s]


KeyboardInterrupt: 