In [None]:
import cv2; print(cv2.imread("1.jpg").shape[:2][::-1])


In [None]:
import numpy as np
import os
import matplotlib.pyplot as plt
image_name = "field0600"
methods = ["EigenGradCAM", "GradCAM", "EigenCAM", "XGradCAM", "RandomCAM", "LayerCAM", "KPCA_CAM", "HiResCAM", "GradCAMPlusPlus", "GradCAMElementWise"]
layers = ["12", "15", "17", "21", "15_17", "15_17_21"]

for method in methods:
    print(f"{method}:")
    for layer in layers:
        csv_path = f'/home/axy9651/Tryp/trypanosome_parasite_detection/tryp/YOLOv8_Explainer/heatmaps/All_Layers/{method}/{layer}/{image_name}.csv'
        if os.path.exists(csv_path):
            # Load and process saliency map
            saliency_map = np.loadtxt(csv_path, delimiter=',')
            saliency_mask = (saliency_map >= 0.3).astype(int)
            
            # Create GT mask
            gt_mask = np.zeros_like(saliency_mask)
            with open(f'/home/axy9651/Tryp/trypanosome_parasite_detection/chagas_detection/DINO-atten-spa/data_cvat/val2017/labels/{image_name}.txt', 'r') as f:
                for line in f:
                    parts = line.strip().split()
                    if len(parts) >= 5:
                        center_x, center_y, width, height = [float(x) for x in parts[1:5]]
                        h, w = saliency_map.shape
                        x1, y1 = int((center_x - width/2) * w), int((center_y - height/2) * h)
                        x2, y2 = int((center_x + width/2) * w), int((center_y + height/2) * h)
                        gt_mask[max(0,y1):min(h,y2), max(0,x1):min(w,x2)] = 1
            
            # Calculate IoU
            intersection = np.sum(saliency_mask * gt_mask)
            union = np.sum((saliency_mask + gt_mask) > 0)
            iou = intersection / union if union > 0 else 0
            print(f"  Layer {layer}: {iou:.4f}")
        else:
            print(f"  Layer {layer}: Not found")
    print()

In [None]:
import numpy as np
import os

methods = ["EigenGradCAM", "GradCAM", "EigenCAM", "XGradCAM", "RandomCAM", "LayerCAM", "KPCA_CAM", "HiResCAM", "GradCAMPlusPlus", "GradCAMElementWise"]
layers = ["12", "15", "17", "21", "15_17", "15_17_21"]
labels_folder = '/home/axy9651/Tryp/trypanosome_parasite_detection/chagas_detection/DINO-atten-spa/data_cvat/val2017/labels/'
heatmaps_base = '/home/axy9651/Tryp/trypanosome_parasite_detection/tryp/YOLOv8_Explainer/heatmaps/All_Layers'

for method in methods:
    print(f"{method}:")
    for layer in layers:
        heatmap_dir = f'{heatmaps_base}/{method}/{layer}'
        
        if os.path.exists(heatmap_dir):
            layer_ious = []
            
            for csv_file in os.listdir(heatmap_dir):
                if csv_file.endswith('.csv'):
                    image_name = os.path.splitext(csv_file)[0]  # Remove .csv extension
                    csv_path = os.path.join(heatmap_dir, csv_file)
                    label_path = os.path.join(labels_folder, f"{image_name}.txt")
                    
                    if os.path.exists(label_path):
                        # Load and process saliency map
                        saliency_map = np.loadtxt(csv_path, delimiter=',')
                        saliency_mask = (saliency_map >= 0.3).astype(int)
                        
                        # Create GT mask
                        gt_mask = np.zeros_like(saliency_mask)
                        with open(label_path, 'r') as f:
                            for line in f:
                                parts = line.strip().split()
                                if len(parts) >= 5:
                                    center_x, center_y, width, height = [float(x) for x in parts[1:5]]
                                    h, w = saliency_map.shape
                                    x1, y1 = int((center_x - width/2) * w), int((center_y - height/2) * h)
                                    x2, y2 = int((center_x + width/2) * w), int((center_y + height/2) * h)
                                    gt_mask[max(0,y1):min(h,y2), max(0,x1):min(w,x2)] = 1
                        
                        # Calculate IoU
                        intersection = np.sum(saliency_mask * gt_mask)
                        union = np.sum((saliency_mask + gt_mask) > 0)
                        iou = intersection / union if union > 0 else 0
                        layer_ious.append(iou)
            
            # Calculate average IoU for this method/layer
            avg_iou = np.mean(layer_ious) if layer_ious else 0
            print(f"  Layer {layer}: {avg_iou:.4f} (images: {len(layer_ious)})")
        else:
            print(f"  Layer {layer}: Directory not found")
    print()

In [None]:
import numpy as np
import os
import matplotlib.pyplot as plt
import matplotlib.patches as patches

def energy_based_pointing_game(saliency_map, bbox):
    """
    Calculate the EBPG score for a single or multiple bounding boxes.
    Parameters:
        - saliency_map: 2D-array in range [0, 1]
        - bbox: Single bounding box (torch.Size([7])) or multiple (torch.Size([N, 7]))
    Returns: EBPG scores for the saliency map.
    """
    # Check if bbox is single or multiple
    if len(bbox.shape) == 1:  # Single bounding box (torch.Size([7]))
        bbox = bbox.reshape(1, -1)  # Convert to (1, N)

    scores = []
    for box in bbox:  # Iterate over all bounding boxes -> consider all detected objects in the image
        x_min, y_min, x_max, y_max = box[:4]
        x_min, y_min, x_max, y_max = map(lambda x: max(int(x), 0), [x_min, y_min, x_max, y_max])

        # Create bounding box mask
        mask = np.zeros_like(saliency_map)
        mask[y_min:y_max, x_min:x_max] = 1 # y=rows, x=columns

        # Normalize saliency map if needed
        if saliency_map.max() > 1.0:
            saliency_map = (saliency_map - saliency_map.min()) / (saliency_map.max() - saliency_map.min())

        # Calculate energy
        energy_bbox = np.sum(saliency_map * mask)
        energy_whole = np.sum(saliency_map)

        # Calculate EBPG score
        score = energy_bbox / energy_whole if energy_whole > 0 else 0
        scores.append(score)

    return scores if len(scores) > 1 else scores[0]

# Set your image name
image_name = "field0600"

methods = ["EigenGradCAM", "GradCAM", "EigenCAM", "XGradCAM", "RandomCAM", "LayerCAM", "KPCA_CAM", "HiResCAM", "GradCAMPlusPlus", "GradCAMElementWise"]
layers = ["12", "15", "17", "21", "15_17", "15_17_21"]
labels_folder = '/home/axy9651/Tryp/trypanosome_parasite_detection/chagas_detection/DINO-atten-spa/data_cvat/val2017/labels/'
heatmaps_base = '/home/axy9651/Tryp/trypanosome_parasite_detection/tryp/YOLOv8_Explainer/heatmaps/All_Layers'

# Load labels once for the image
label_path = os.path.join(labels_folder, f"{image_name}.txt")
if not os.path.exists(label_path):
    print(f"Label file not found for {image_name}")
    exit()

print(f"Processing image: {image_name}")
print("-" * 50)

for method in methods:
    print(f"{method}:")
    for layer in layers:
        csv_path = f'{heatmaps_base}/{method}/{layer}/{image_name}.csv'
        
        if os.path.exists(csv_path):
            # Load saliency map
            saliency_map = np.loadtxt(csv_path, delimiter=',')
            h, w = saliency_map.shape
            
            # Convert YOLO labels to bbox format [x_min, y_min, x_max, y_max]
            bboxes = []
            with open(label_path, 'r') as f:
                for line in f:
                    parts = line.strip().split()
                    if len(parts) >= 5:
                        center_x, center_y, width, height = [float(x) for x in parts[1:5]]
                        x_min = int((center_x - width/2) * w)
                        y_min = int((center_y - height/2) * h)
                        x_max = int((center_x + width/2) * w)
                        y_max = int((center_y + height/2) * h)
                        bboxes.append([x_min, y_min, x_max, y_max])
                        
            # plt.imshow(saliency_map, cmap='hot'); [plt.gca().add_patch(patches.Rectangle((bbox[0], bbox[1]), bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=2, edgecolor='cyan', facecolor='none')) for bbox in bboxes]; plt.show()
            if bboxes:
                bbox_array = np.array(bboxes)
                scores = energy_based_pointing_game(saliency_map, bbox_array)
                
                # Handle single vs multiple scores
                if isinstance(scores, list):
                    avg_score = np.mean(scores)
                    print(f"  Layer {layer}: {avg_score:.4f} (boxes: {len(scores)})")
                else:
                    print(f"  Layer {layer}: {scores:.4f} (boxes: 1)")
            else:
                print(f"  Layer {layer}: No bounding boxes found")
        else:
            print(f"  Layer {layer}: CSV not found")
    print()

In [None]:
import numpy as np
import os

def energy_based_pointing_game(saliency_map, bbox):
    """
    Calculate the EBPG score for a single or multiple bounding boxes.
    Parameters:
        - saliency_map: 2D-array in range [0, 1]
        - bbox: Single bounding box (torch.Size([7])) or multiple (torch.Size([N, 7]))
    Returns: EBPG scores for the saliency map.
    """
    # Check if bbox is single or multiple
    if len(bbox.shape) == 1:  # Single bounding box (torch.Size([7]))
        bbox = bbox.reshape(1, -1)  # Convert to (1, N)

    scores = []
    for box in bbox:  # Iterate over all bounding boxes -> consider all detected objects in the image
        x_min, y_min, x_max, y_max = box[:4]
        x_min, y_min, x_max, y_max = map(lambda x: max(int(x), 0), [x_min, y_min, x_max, y_max])

        # Create bounding box mask
        mask = np.zeros_like(saliency_map)
        mask[y_min:y_max, x_min:x_max] = 1 # y=rows, x=columns

        # Normalize saliency map if needed
        if saliency_map.max() > 1.0:
            saliency_map = (saliency_map - saliency_map.min()) / (saliency_map.max() - saliency_map.min())

        # Calculate energy
        energy_bbox = np.sum(saliency_map * mask)
        energy_whole = np.sum(saliency_map)

        # Calculate EBPG score
        score = energy_bbox / energy_whole if energy_whole > 0 else 0
        scores.append(score)

    return scores if len(scores) > 1 else scores[0]

methods = ["EigenGradCAM", "GradCAM", "EigenCAM", "XGradCAM", "RandomCAM", "LayerCAM", "KPCA_CAM", "HiResCAM", "GradCAMPlusPlus", "GradCAMElementWise"]
layers = ["12", "15", "17", "21", "15_17", "15_17_21"]
labels_folder = '/home/axy9651/Tryp/trypanosome_parasite_detection/chagas_detection/DINO-atten-spa/data_cvat/val2017/labels/'
heatmaps_base = '/home/axy9651/Tryp/trypanosome_parasite_detection/tryp/YOLOv8_Explainer/heatmaps/All_Layers'

for method in methods:
    print(f"{method}:")
    for layer in layers:
        heatmap_dir = f'{heatmaps_base}/{method}/{layer}'
        
        if os.path.exists(heatmap_dir):
            per_image_scores = []  # Store average score per image
            
            for csv_file in os.listdir(heatmap_dir):
                if csv_file.endswith('.csv'):
                    image_name = os.path.splitext(csv_file)[0]
                    csv_path = os.path.join(heatmap_dir, csv_file)
                    label_path = os.path.join(labels_folder, f"{image_name}.txt")
                    
                    if os.path.exists(label_path):
                        # Load saliency map
                        saliency_map = np.loadtxt(csv_path, delimiter=',')
                        h, w = saliency_map.shape
                        
                        # Convert YOLO labels to bbox format [x_min, y_min, x_max, y_max]
                        bboxes = []
                        with open(label_path, 'r') as f:
                            for line in f:
                                parts = line.strip().split()
                                if len(parts) >= 5:
                                    center_x, center_y, width, height = [float(x) for x in parts[1:5]]
                                    x_min = int((center_x - width/2) * w)
                                    y_min = int((center_y - height/2) * h)
                                    x_max = int((center_x + width/2) * w)
                                    y_max = int((center_y + height/2) * h)
                                    bboxes.append([x_min, y_min, x_max, y_max])
                        
                        if bboxes:
                            bbox_array = np.array(bboxes)
                            scores = energy_based_pointing_game(saliency_map, bbox_array)
                            
                            # Calculate average score for THIS IMAGE
                            if isinstance(scores, list):
                                image_avg_score = np.sum(scores)
                            else:
                                image_avg_score = scores
                            
                            per_image_scores.append(image_avg_score)
            
            # Calculate average across all images
            overall_avg = np.mean(per_image_scores) if per_image_scores else 0
            print(f"  Layer {layer}: {overall_avg:.4f} (images: {len(per_image_scores)})")
        else:
            print(f"  Layer {layer}: Directory not found")
    print()

In [None]:
import numpy as np
import torch
import torchvision
import math
import cv2
from YOLOX.yolox.utils import postprocess
import YOLOX.yolox.data.data_augment as data_augment
from scipy import spatial
from tqdm import tqdm

transform = data_augment.ValTransform(legacy=False)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def auc(arr):
    """Returns normalized Area Under Curve of the array."""
    return (arr.sum() - arr[0] / 2 - arr[-1] / 2) / (arr.shape[0] - 1)  # auc formula


def del_ins(model, img, bbox, saliency_map, mode='del', step=2000, kernel_width=0.25):
    del_ins = np.zeros(1)
    count = np.zeros(1)
    HW = saliency_map.shape[1] * saliency_map.shape[2]
    n_steps = (HW + step - 1) // step
    for idx in tqdm(range(saliency_map.shape[0]), desc="DEL/INS", leave=False):
        target_cls = bbox[idx][-1]
        if mode == 'del':
            start = img.copy()
            finish = np.zeros_like(start)
        elif mode == 'ins':
            start = cv2.GaussianBlur(img, (51, 51), 0)
            finish = img.copy()
        salient_order = np.flip(np.argsort(saliency_map[idx].reshape(HW, -1), axis=0), axis=0)
        y = salient_order // img.shape[1]
        x = salient_order - y * img.shape[1]
        scores = np.zeros(n_steps + 1)
        with torch.no_grad():
            for i in range(n_steps + 1):
                temp_ious = []
                temp_score = []
                torch_start = torch.from_numpy(start.transpose(2, 0, 1)).unsqueeze(0).float()
                out = model(torch_start.to(device))
                p_box, _ = postprocess(out, num_classes=1, conf_thre=0.25, nms_thre=0.45, class_agnostic=True)
                p_box = p_box[0]
                if p_box is None:
                    scores[i] = 0
                else:
                    for b in p_box:
                        sample_cls = b[-1]
                        sample_box = b[:4]
                        sample_score = b[5:-1]
                        iou = torchvision.ops.box_iou(sample_box[:4].unsqueeze(0),
                                                    bbox[idx][:4].unsqueeze(0)).cpu().item()
                        distances = spatial.distance.cosine(sample_score.cpu(), bbox[idx][5:-1].cpu())
                        weights = math.sqrt(math.exp(-(distances ** 2) / kernel_width ** 2))
                        if target_cls != sample_cls:
                            iou = 0
                            sample_score = torch.tensor(0.)
                        temp_ious.append(iou)
                        s_score = iou * weights
                        temp_score.append(s_score)
                    max_score = temp_score[np.argmax(temp_ious)]
                    scores[i] = max_score
                x_coords = x[step * i:step * (i + 1), :]
                y_coords = y[step * i:step * (i + 1), :]
                start[y_coords, x_coords, :] = finish[y_coords, x_coords, :]
        del_ins[int(target_cls)] += auc(scores)
        count[int(target_cls)] += 1
    return del_ins, count
    

In [None]:
import numpy as np
import cv2
import torch
from ultralytics import YOLO
from torchvision import transforms

def auc(arr):
    return (arr.sum() - arr[0] / 2 - arr[-1] / 2) / (arr.shape[0] - 1)

def process_yolo_output(results, img_shape, conf_threshold=0.5):
    """
    Process YOLOv8 output to create binary prediction mask
    
    Args:
        results: YOLOv8 results object
        img_shape: (height, width) of the image
        conf_threshold: confidence threshold for detections
    
    Returns:
        pred_mask: binary mask of predictions
    """
    pred_mask = np.zeros(img_shape[:2], dtype=bool)
    
    if len(results) > 0 and results[0].boxes is not None:
        boxes = results[0].boxes
        confidences = boxes.conf.cpu().numpy()
        xyxy = boxes.xyxy.cpu().numpy()
        
        # Filter by confidence
        valid_detections = confidences >= conf_threshold
        
        if np.any(valid_detections):
            valid_boxes = xyxy[valid_detections]
            
            for box in valid_boxes:
                x1, y1, x2, y2 = box.astype(int)
                # Ensure coordinates are within image bounds
                x1 = max(0, min(x1, img_shape[1]))
                y1 = max(0, min(y1, img_shape[0]))
                x2 = max(0, min(x2, img_shape[1]))
                y2 = max(0, min(y2, img_shape[0]))
                
                pred_mask[y1:y2, x1:x2] = True
    
    return pred_mask

def del_ins_multiple_boxes(model_path, img_path, csv_path, label_path, 
                          mode='del', step=200, conf_threshold=0.5, device='cuda'):
    """
    Compute deletion or insertion metrics for YOLOv8 model with multiple bounding boxes
    
    Args:
        model_path: path to YOLOv8 model (.pt file)
        img_path: path to input image
        csv_path: path to saliency map CSV
        label_path: path to ground truth labels (YOLO format)
        mode: 'del' for deletion, 'ins' for insertion
        step: number of pixels to modify at each step
        conf_threshold: confidence threshold for model predictions
        device: 'cuda' or 'cpu'
    
    Returns:
        auc_score: Area under curve for the metric
        scores: IoU scores at each step
    """
    
    # Load model
    model = YOLO(model_path)
    model.to(device)
    
    # Load image
    img = cv2.imread(img_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    
    # Load saliency map
    saliency_map = np.loadtxt(csv_path, delimiter=',')
    
    # Resize saliency map to match image if needed
    if saliency_map.shape != img.shape[:2]:
        saliency_map = cv2.resize(saliency_map, (img.shape[1], img.shape[0]))
    
    # Create GT mask from YOLO labels
    gt_mask = np.zeros(img.shape[:2], dtype=bool)
    with open(label_path, 'r') as f:
        for line in f:
            parts = line.strip().split()
            if len(parts) >= 5:
                # YOLO format: class_id center_x center_y width height (normalized)
                center_x, center_y, width, height = [float(x) for x in parts[1:5]]
                h, w = img.shape[:2]
                
                # Convert to pixel coordinates
                x1 = int((center_x - width/2) * w)
                y1 = int((center_y - height/2) * h)
                x2 = int((center_x + width/2) * w)
                y2 = int((center_y + height/2) * h)
                
                # Ensure coordinates are within bounds
                x1 = max(0, min(x1, w))
                y1 = max(0, min(y1, h))
                x2 = max(0, min(x2, w))
                y2 = max(0, min(y2, h))
                
                gt_mask[y1:y2, x1:x2] = True
    
    # Setup for deletion/insertion
    HW = saliency_map.size
    n_steps = min((HW + step - 1) // step, HW // step + 1)
    
    if mode == 'del':
        start = img.copy()
        # For deletion, replace with zeros or mean color
        baseline_value = 0  # or np.mean(img, axis=(0,1))
    else:  # insertion
        # For insertion, start with blurred image
        start = cv2.GaussianBlur(img, (51, 51), 0)
        finish = img.copy()
    
    # Get pixel order by saliency (most important first)
    salient_order = np.flip(np.argsort(saliency_map.reshape(-1)))
    y_coords = salient_order // img.shape[1]
    x_coords = salient_order % img.shape[1]
    
    scores = []
    
    # Get baseline score (before any modifications)
    results = model(start, verbose=False)
    pred_mask = process_yolo_output(results, img.shape, conf_threshold)
    
    # Calculate IoU
    intersection = np.sum(pred_mask & gt_mask)
    union = np.sum(pred_mask | gt_mask)
    baseline_score = intersection / union if union > 0 else 0
    scores.append(baseline_score)
    
    print(f"Baseline IoU: {baseline_score:.4f}")
    print(f"Processing {n_steps} steps...")
    
    # Progressive modification
    current_img = start.copy()
    
    for i in range(n_steps):
        # Determine which pixels to modify in this step
        start_idx = step * i
        end_idx = min(step * (i + 1), len(salient_order))
        
        if start_idx >= len(salient_order):
            break
            
        # Get coordinates for this batch of pixels
        batch_y = y_coords[start_idx:end_idx]
        batch_x = x_coords[start_idx:end_idx]
        
        # Modify pixels
        if mode == 'del':
            current_img[batch_y, batch_x, :] = baseline_value
        else:  # insertion
            current_img[batch_y, batch_x, :] = finish[batch_y, batch_x, :]
        
        # Run model on modified image
        results = model(current_img, verbose=False)
        pred_mask = process_yolo_output(results, img.shape, conf_threshold)
        
        # Calculate IoU
        intersection = np.sum(pred_mask & gt_mask)
        union = np.sum(pred_mask | gt_mask)
        iou_score = intersection / union if union > 0 else 0
        scores.append(iou_score)
        
        if (i + 1) % 10 == 0:
            print(f"Step {i+1}/{n_steps}, IoU: {iou_score:.4f}")
    
    scores = np.array(scores)
    auc_score = auc(scores)
    
    return auc_score, scores

# Example usage
if __name__ == "__main__":
    model_path = 'best_chagas.pt'
    img_path = '/home/axy9651/Tryp/trypanosome_parasite_detection/chagas_detection/DINO-atten-spa/data_cvat/val2017/images/field0600.jpg'
    csv_path = '/home/axy9651/Tryp/trypanosome_parasite_detection/tryp/YOLOv8_Explainer/heatmaps/All_Layers/EigenGradCAM/12/field0600.csv'
    label_path = '/home/axy9651/Tryp/trypanosome_parasite_detection/chagas_detection/DINO-atten-spa/data_cvat/val2017/labels/field0600.txt'
    
    Compute deletion metric
    del_auc, del_scores = del_ins_multiple_boxes(
        model_path, img_path, csv_path, label_path, 
        mode='del', step=1000, conf_threshold=0.5
    )
    
    print(f"\nDeletion AUC: {del_auc:.4f}")
    
    # Compute insertion metric  
    ins_auc, ins_scores = del_ins_multiple_boxes(
        model_path, img_path, csv_path, label_path, 
        mode='ins', step=1000, conf_threshold=0.5
    )
    
    print(f"Insertion AUC: {ins_auc:.4f}")

In [None]:
import numpy as np
import cv2
import torch
from ultralytics import YOLO
from torchvision import transforms
import os
import glob

def auc(arr):
    return (arr.sum() - arr[0] / 2 - arr[-1] / 2) / (arr.shape[0] - 1)

def process_yolo_output(results, img_shape, conf_threshold=0.5):
    """
    Process YOLOv8 output to create binary prediction mask
    
    Args:
        results: YOLOv8 results object
        img_shape: (height, width) of the image
        conf_threshold: confidence threshold for detections
    
    Returns:
        pred_mask: binary mask of predictions
    """
    pred_mask = np.zeros(img_shape[:2], dtype=bool)
    
    if len(results) > 0 and results[0].boxes is not None:
        boxes = results[0].boxes
        confidences = boxes.conf.cpu().numpy()
        xyxy = boxes.xyxy.cpu().numpy()
        
        # Filter by confidence
        valid_detections = confidences >= conf_threshold
        
        if np.any(valid_detections):
            valid_boxes = xyxy[valid_detections]
            
            for box in valid_boxes:
                x1, y1, x2, y2 = box.astype(int)
                # Ensure coordinates are within image bounds
                x1 = max(0, min(x1, img_shape[1]))
                y1 = max(0, min(y1, img_shape[0]))
                x2 = max(0, min(x2, img_shape[1]))
                y2 = max(0, min(y2, img_shape[0]))
                
                pred_mask[y1:y2, x1:x2] = True
    
    return pred_mask

def del_ins_multiple_boxes(model_path, img_path, csv_path, label_path, 
                          mode='del', step=200, conf_threshold=0.5, device='cuda'):
    """
    Compute deletion or insertion metrics for YOLOv8 model with multiple bounding boxes
    
    Args:
        model_path: path to YOLOv8 model (.pt file)
        img_path: path to input image
        csv_path: path to saliency map CSV
        label_path: path to ground truth labels (YOLO format)
        mode: 'del' for deletion, 'ins' for insertion
        step: number of pixels to modify at each step
        conf_threshold: confidence threshold for model predictions
        device: 'cuda' or 'cpu'
    
    Returns:
        auc_score: Area under curve for the metric
        scores: IoU scores at each step
    """
    
    # Load model
    model = YOLO(model_path)
    model.to(device)
    
    # Load image
    img = cv2.imread(img_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    
    # Load saliency map
    saliency_map = np.loadtxt(csv_path, delimiter=',')
    
    # Resize saliency map to match image if needed
    if saliency_map.shape != img.shape[:2]:
        saliency_map = cv2.resize(saliency_map, (img.shape[1], img.shape[0]))
    
    # Create GT mask from YOLO labels
    gt_mask = np.zeros(img.shape[:2], dtype=bool)
    with open(label_path, 'r') as f:
        for line in f:
            parts = line.strip().split()
            if len(parts) >= 5:
                # YOLO format: class_id center_x center_y width height (normalized)
                center_x, center_y, width, height = [float(x) for x in parts[1:5]]
                h, w = img.shape[:2]
                
                # Convert to pixel coordinates
                x1 = int((center_x - width/2) * w)
                y1 = int((center_y - height/2) * h)
                x2 = int((center_x + width/2) * w)
                y2 = int((center_y + height/2) * h)
                
                # Ensure coordinates are within bounds
                x1 = max(0, min(x1, w))
                y1 = max(0, min(y1, h))
                x2 = max(0, min(x2, w))
                y2 = max(0, min(y2, h))
                
                gt_mask[y1:y2, x1:x2] = True
    
    # Setup for deletion/insertion
    HW = saliency_map.size
    n_steps = min((HW + step - 1) // step, HW // step + 1)
    
    if mode == 'del':
        start = img.copy()
        # For deletion, replace with zeros or mean color
        baseline_value = 0  # or np.mean(img, axis=(0,1))
    else:  # insertion
        # For insertion, start with blurred image
        start = cv2.GaussianBlur(img, (51, 51), 0)
        finish = img.copy()
    
    # Get pixel order by saliency (most important first)
    salient_order = np.flip(np.argsort(saliency_map.reshape(-1)))
    y_coords = salient_order // img.shape[1]
    x_coords = salient_order % img.shape[1]
    
    scores = []
    
    # Get baseline score (before any modifications)
    results = model(start, verbose=False)
    pred_mask = process_yolo_output(results, img.shape, conf_threshold)
    
    # Calculate IoU
    intersection = np.sum(pred_mask & gt_mask)
    union = np.sum(pred_mask | gt_mask)
    baseline_score = intersection / union if union > 0 else 0
    scores.append(baseline_score)
    
    print(f"Baseline IoU: {baseline_score:.4f}")
    print(f"Processing {n_steps} steps...")
    
    # Progressive modification
    current_img = start.copy()
    
    for i in range(n_steps):
        # Determine which pixels to modify in this step
        start_idx = step * i
        end_idx = min(step * (i + 1), len(salient_order))
        
        if start_idx >= len(salient_order):
            break
            
        # Get coordinates for this batch of pixels
        batch_y = y_coords[start_idx:end_idx]
        batch_x = x_coords[start_idx:end_idx]
        
        # Modify pixels
        if mode == 'del':
            current_img[batch_y, batch_x, :] = baseline_value
        else:  # insertion
            current_img[batch_y, batch_x, :] = finish[batch_y, batch_x, :]
        
        # Run model on modified image
        results = model(current_img, verbose=False)
        pred_mask = process_yolo_output(results, img.shape, conf_threshold)
        
        # Calculate IoU
        intersection = np.sum(pred_mask & gt_mask)
        union = np.sum(pred_mask | gt_mask)
        iou_score = intersection / union if union > 0 else 0
        scores.append(iou_score)
        
        if (i + 1) % 10 == 0:
            print(f"Step {i+1}/{n_steps}, IoU: {iou_score:.4f}")
    
    scores = np.array(scores)
    auc_score = auc(scores)
    
    return auc_score, scores

# Example usage
if __name__ == "__main__":
    model_path = 'best_chagas.pt'
    images_dir = '/home/axy9651/Tryp/trypanosome_parasite_detection/chagas_detection/DINO-atten-spa/data_cvat/val2017/images'
    labels_dir = '/home/axy9651/Tryp/trypanosome_parasite_detection/chagas_detection/DINO-atten-spa/data_cvat/val2017/labels'
    csv_dir = '/home/axy9651/Tryp/trypanosome_parasite_detection/tryp/YOLOv8_Explainer/heatmaps/All_Layers/EigenGradCAM/12'
    
    # Get all image files
    image_files = glob.glob(os.path.join(images_dir, '*.jpg'))
    
    print(f"Found {len(image_files)} images to process\n")
    
    del_aucs = []
    ins_aucs = []
    
    for img_path in image_files:
        img_name = os.path.splitext(os.path.basename(img_path))[0]
        csv_path = os.path.join(csv_dir, f"{img_name}.csv")
        label_path = os.path.join(labels_dir, f"{img_name}.txt")
        
        if not os.path.exists(csv_path) or not os.path.exists(label_path):
            print(f"Skipping {img_name} - missing files")
            continue
            
        print(f"Processing {img_name}...")
        
        # Compute deletion metric
        del_auc, del_scores = del_ins_multiple_boxes(
            model_path, img_path, csv_path, label_path, 
            mode='del', step=1000, conf_threshold=0.5
        )
        
        # Compute insertion metric  
        ins_auc, ins_scores = del_ins_multiple_boxes(
            model_path, img_path, csv_path, label_path, 
            mode='ins', step=1000, conf_threshold=0.5
        )
        
        del_aucs.append(del_auc)
        ins_aucs.append(ins_auc)
        
        print(f"Deletion AUC: {del_auc:.4f}")
        print(f"Insertion AUC: {ins_auc:.4f}\n")
    
    print(f"Mean Deletion AUC: {np.mean(del_aucs):.4f}")
    print(f"Mean Insertion AUC: {np.mean(ins_aucs):.4f}")

In [None]:
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
import glob

# Paths
saliency_dir = '/home/axy9651/Tryp/trypanosome_parasite_detection/chagas_detection/DINO-atten-spa/saliency_maps'
images_dir = '/home/axy9651/Tryp/trypanosome_parasite_detection/chagas_detection/DINO-atten-spa/data_cvat/val2017/images/'
labels_dir = '/home/axy9651/Tryp/trypanosome_parasite_detection/chagas_detection/DINO-atten-spa/data_cvat/val2017/labels/'

# Get all saliency files
saliency_files = glob.glob(os.path.join(saliency_dir, "*.csv"))

for saliency_file in saliency_files:
    # Get image name
    print(saliency_file)
    image_name_base = os.path.basename(saliency_file).replace("", "").replace(".csv", "")
    
    # Load image
    image = cv2.imread(f"{images_dir}/{image_name_base}.jpg")
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    img_height, img_width = image.shape[:2]
    
    # Load saliency map
    saliency_map = np.loadtxt(saliency_file, delimiter=',')
    # print(saliency_map)
    if saliency_map.shape != (img_height, img_width):
        saliency_map = cv2.resize(saliency_map, (img_width, img_height))
    
    # Load YOLO boxes
    label_file = f"{labels_dir}/{image_name_base}.txt"
    boxes = []
    if os.path.exists(label_file):
        with open(label_file, 'r') as f:
            for line in f.readlines():
                parts = line.strip().split()
                if len(parts) >= 5:
                    x_center = float(parts[1]) * img_width
                    y_center = float(parts[2]) * img_height
                    width = float(parts[3]) * img_width
                    height = float(parts[4]) * img_height
                    x1 = int(x_center - width/2)
                    y1 = int(y_center - height/2)
                    x2 = int(x_center + width/2)
                    y2 = int(y_center + height/2)
                    boxes.append((x1, y1, x2, y2))
    
    # Visualize
    plt.figure(figsize=(10, 8))
    plt.imshow(image)
    plt.imshow(saliency_map, cmap='jet', alpha=0.5)
    
    # Draw boxes
    for (x1, y1, x2, y2) in boxes:
        rect = plt.Rectangle((x1, y1), x2-x1, y2-y1, fill=False, edgecolor='lime', linewidth=2)
        plt.gca().add_patch(rect)
    
    plt.title(f"{image_name_base}")
    plt.axis('off')
    plt.show()

In [None]:
import numpy as np
import os
import matplotlib.pyplot as plt


labels_folder = '/home/axy9651/Tryp/trypanosome_parasite_detection/chagas_detection/DINO-atten-spa/data_cvat/val2017/labels/'
heatmap_dir = '/home/axy9651/Tryp/trypanosome_parasite_detection/chagas_detection/DINO-atten-spa/saliency_maps'

        
if os.path.exists(heatmap_dir):
    layer_ious = []
    
    for csv_file in os.listdir(heatmap_dir):
        if csv_file.endswith('.csv'):
            image_name = os.path.splitext(csv_file)[0]  # Remove .csv extension
            csv_path = os.path.join(heatmap_dir, csv_file)
            label_path = os.path.join(labels_folder, f"{image_name}.txt")
         
            if os.path.exists(label_path):
                # Load and process saliency map
                saliency_map = np.loadtxt(csv_path, delimiter=',')
                saliency_mask = (saliency_map >= 0.3).astype(int)
                
                # Create GT mask
                gt_mask = np.zeros_like(saliency_mask)
                with open(label_path, 'r') as f:
                    for line in f:
                        parts = line.strip().split()
                        if len(parts) >= 5:
                            center_x, center_y, width, height = [float(x) for x in parts[1:5]]
                            h, w = saliency_map.shape
                            x1, y1 = int((center_x - width/2) * w), int((center_y - height/2) * h)
                            x2, y2 = int((center_x + width/2) * w), int((center_y + height/2) * h)
                            gt_mask[max(0,y1):min(h,y2), max(0,x1):min(w,x2)] = 1
                
                # Calculate IoU
                intersection = np.sum(saliency_mask * gt_mask)
                union = np.sum((saliency_mask + gt_mask) > 0)
                iou = intersection / union if union > 0 else 0
                layer_ious.append(iou)
                # plt.figure(figsize=(15, 5))
                # plt.subplot(1,3,1); plt.imshow(saliency_map, cmap='jet'); plt.title('Saliency Map'); plt.axis('off')
                # plt.subplot(1,3,2); plt.imshow(saliency_mask, cmap='gray'); plt.title('Saliency Mask (≥0.3)'); plt.axis('off')
                # plt.subplot(1,3,3); plt.imshow(gt_mask, cmap='gray'); plt.title('GT Mask'); plt.axis('off'); plt.show()
    # Calculate average IoU for this method/layer
    avg_iou = np.mean(layer_ious) if layer_ious else 0
    print(f"  avg iou : {avg_iou:.4f} (images: {len(layer_ious)})")
else:
    print(f"  avg iou: Directory not found")
    
print()

In [None]:
import numpy as np
import os

def energy_based_pointing_game(saliency_map, bbox):
    """
    Calculate the EBPG score for a single or multiple bounding boxes.
    Parameters:
        - saliency_map: 2D-array in range [0, 1]
        - bbox: Single bounding box (torch.Size([7])) or multiple (torch.Size([N, 7]))
    Returns: EBPG scores for the saliency map.
    """
    # Check if bbox is single or multiple
    if len(bbox.shape) == 1:  # Single bounding box (torch.Size([7]))
        bbox = bbox.reshape(1, -1)  # Convert to (1, N)

    scores = []
    for box in bbox:  # Iterate over all bounding boxes -> consider all detected objects in the image
        x_min, y_min, x_max, y_max = box[:4]
        x_min, y_min, x_max, y_max = map(lambda x: max(int(x), 0), [x_min, y_min, x_max, y_max])

        # Create bounding box mask
        mask = np.zeros_like(saliency_map)
        mask[y_min:y_max, x_min:x_max] = 1 # y=rows, x=columns

        # Normalize saliency map if needed
        if saliency_map.max() > 1.0:
            saliency_map = (saliency_map - saliency_map.min()) / (saliency_map.max() - saliency_map.min())

        # Calculate energy
        energy_bbox = np.sum(saliency_map * mask)
        energy_whole = np.sum(saliency_map)

        # Calculate EBPG score
        score = energy_bbox / energy_whole if energy_whole > 0 else 0
        scores.append(score)

    return scores if len(scores) > 1 else scores[0]


labels_folder = '/home/axy9651/Tryp/trypanosome_parasite_detection/chagas_detection/DINO-atten-spa/data_cvat/val2017/labels/'
heatmap_dir = '/home/axy9651/Tryp/trypanosome_parasite_detection/chagas_detection/DINO-atten-spa/saliency_maps'

 
if os.path.exists(heatmap_dir):
    per_image_scores = []  # Store average score per image
    
    for csv_file in os.listdir(heatmap_dir):
        if csv_file.endswith('.csv'):
            image_name = os.path.splitext(csv_file)[0]
            csv_path = os.path.join(heatmap_dir, csv_file)
            label_path = os.path.join(labels_folder, f"{image_name}.txt")
            
            if os.path.exists(label_path):
                # Load saliency map
                saliency_map = np.loadtxt(csv_path, delimiter=',')
                h, w = saliency_map.shape
                
                # Convert YOLO labels to bbox format [x_min, y_min, x_max, y_max]
                bboxes = []
                with open(label_path, 'r') as f:
                    for line in f:
                        parts = line.strip().split()
                        if len(parts) >= 5:
                            center_x, center_y, width, height = [float(x) for x in parts[1:5]]
                            x_min = int((center_x - width/2) * w)
                            y_min = int((center_y - height/2) * h)
                            x_max = int((center_x + width/2) * w)
                            y_max = int((center_y + height/2) * h)
                            bboxes.append([x_min, y_min, x_max, y_max])
                
                if bboxes:
                    bbox_array = np.array(bboxes)
                    scores = energy_based_pointing_game(saliency_map, bbox_array)
                    
                    # Calculate average score for THIS IMAGE
                    if isinstance(scores, list):
                        image_avg_score = np.sum(scores)
                    else:
                        image_avg_score = scores
                    
                    per_image_scores.append(image_avg_score)
    
    # Calculate average across all images
    overall_avg = np.mean(per_image_scores) if per_image_scores else 0
    print(f"  avg ebpg : {overall_avg:.4f} (images: {len(per_image_scores)})")
else:
    print(f"  avg ebpg : Directory not found")
print()

In [None]:
from ultralytics import RTDETR
# After training, load best model
model = RTDETR('/home/axy9651/Tryp/trypanosome_parasite_detection/tryp/ultralytics/runs/detect/parasite_medical_rtdetr/weights/best.pt')

# Validate
metrics = model.val()
print(f"mAP50: {metrics.box.map50}")
print(f"Precision: {metrics.box.mp}")
print(f"Recall: {metrics.box.mr}")

# Test inference on a single image
results = model('/home/axy9651/Tryp/trypanosome_parasite_detection/tryp/ultralytics/data_RT/images/val/field0600.jpg')
results[0].show()  # display result


In [None]:
import matplotlib.pyplot as plt
from ultralytics import RTDETR
import cv2
from pathlib import Path

# Load your trained model
model = RTDETR('/home/axy9651/Tryp/trypanosome_parasite_detection/tryp/ultralytics/runs/train/parasite_medical_rtdetr6/weights/best.pt')

# Path to validation images
val_images_path = '/home/axy9651/Tryp/trypanosome_parasite_detection/tryp/ultralytics/data_RT/images/val'

# Run predictions
results = model.predict(
    source=val_images_path,
    conf=0.25,
    save=False,
    show=False
)

print(f"Found {len(results)} images to process\n")

# Display each image with predictions
for i, result in enumerate(results):
    # Get the annotated image
    annotated_img = result.plot()
    annotated_img_rgb = cv2.cvtColor(annotated_img, cv2.COLOR_BGR2RGB)
    
    # Display in Jupyter
    plt.figure(figsize=(12, 8))
    plt.imshow(annotated_img_rgb)
    
    # Title with detection info
    img_name = Path(result.path).name
    detection_count = len(result.boxes) if result.boxes else 0
    plt.title(f'Image: {img_name} | Parasites Detected: {detection_count}', fontsize=14, fontweight='bold')
    plt.axis('off')
    plt.tight_layout()
    plt.show()
    
    # Print detection details
    if result.boxes is not None and len(result.boxes) > 0:
        print(f"Image {i+1}/{len(results)}: {img_name}")
        confidences = result.boxes.conf.cpu().numpy()
        for j, conf in enumerate(confidences):
            print(f"  Parasite {j+1}: Confidence = {conf:.3f}")
        print("-" * 50)
    else:
        print(f"Image {i+1}/{len(results)}: {img_name} - No parasites detected")
        print("-" * 50)

In [None]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import cv2
import numpy as np
from ultralytics import RTDETR
from pathlib import Path

class RTDETRAttentionExtractor:
    def __init__(self, model_path):
        self.model = RTDETR(model_path)
        self.pytorch_model = self.model.model
        self.attention_weights = {}
        self.spatial_shapes = {}
        self.hooks = []
        self._register_attention_hooks()
    
    def _register_attention_hooks(self):
        """Register hooks to capture cross-attention weights"""
        
        def attention_hook(name):
            def hook(module, input, output):
                # RT-DETR uses MultiScaleDeformableAttention
                if len(input) >= 3:
                    query = input[0]  # [B, N_queries, C]
                    reference_points = input[1]  # [B, N_queries, N_levels, 2]
                    input_flatten = input[2]  # [B, sum(H_i*W_i), C]
                    
                    # Store these for manual attention computation
                    self.attention_weights[f"{name}_query"] = query.detach().cpu()
                    self.attention_weights[f"{name}_reference_points"] = reference_points.detach().cpu()
                    self.attention_weights[f"{name}_input_flatten"] = input_flatten.detach().cpu()
                    
                    if len(input) >= 4:
                        spatial_shapes = input[3]  # This might be a list or tensor
                        # Handle both list and tensor cases
                        if isinstance(spatial_shapes, torch.Tensor):
                            self.spatial_shapes[name] = spatial_shapes.detach().cpu()
                        elif isinstance(spatial_shapes, list):
                            # Convert list to tensor if possible
                            try:
                                self.spatial_shapes[name] = torch.tensor(spatial_shapes)
                            except:
                                self.spatial_shapes[name] = spatial_shapes
                        else:
                            self.spatial_shapes[name] = spatial_shapes
                
                # Also capture the output
                if isinstance(output, torch.Tensor):
                    self.attention_weights[f"{name}_output"] = output.detach().cpu()
            
            return hook
        
        # Register hooks on cross-attention layers
        for name, module in self.pytorch_model.named_modules():
            if 'decoder' in name and 'cross_attn' in name and name.endswith('cross_attn'):
                handle = module.register_forward_hook(attention_hook(name))
                self.hooks.append(handle)
                print(f"🎯 Registered attention hook: {name}")
    
    def remove_hooks(self):
        """Remove all hooks"""
        for hook in self.hooks:
            hook.remove()
        self.hooks = []
    
    def extract_attention_for_explainability(self, image_path, conf_threshold=0.25):
        """Extract attention weights for explainability"""
        self.attention_weights = {}
        self.spatial_shapes = {}
        
        # Run inference to capture attention weights
        results = self.model.predict(image_path, conf=conf_threshold, save=False, verbose=False)
        
        return results, self.attention_weights, self.spatial_shapes

def compute_query_spatial_attention(query, reference_points, input_flatten, spatial_shapes):
    """
    Manually compute spatial attention weights for each query
    """
    B, N_q, C = query.shape
    B, N_spatial, C = input_flatten.shape
    
    # Compute simple attention: query similarity with spatial features
    query_norm = F.normalize(query, dim=-1)  # [B, N_q, C]
    spatial_norm = F.normalize(input_flatten, dim=-1)  # [B, N_spatial, C]
    
    # Compute attention scores: query @ spatial_features.T
    attention_scores = torch.bmm(query_norm, spatial_norm.transpose(1, 2))  # [B, N_q, N_spatial]
    
    # Apply softmax to get attention weights
    attention_weights = F.softmax(attention_scores, dim=-1)  # [B, N_q, N_spatial]
    
    return attention_weights

def find_best_spatial_dimensions(N_spatial):
    """Find the best spatial dimensions for reshaping"""
    
    # Try perfect squares first
    sqrt_val = int(np.sqrt(N_spatial))
    if sqrt_val * sqrt_val == N_spatial:
        return sqrt_val, sqrt_val
    
    # Try to find factors that are close to square
    factors = []
    for i in range(1, int(np.sqrt(N_spatial)) + 1):
        if N_spatial % i == 0:
            factors.append((i, N_spatial // i))
    
    if factors:
        # Choose the most square-like factors
        factors.sort(key=lambda x: abs(x[0] - x[1]))
        return factors[0]
    
    # If no exact factors, use closest square and truncate
    return sqrt_val, sqrt_val

def create_explainability_heatmaps(image_path, model_path):
    """Create explainability heatmaps showing where each query focuses"""
    
    print("🔍 Extracting Cross-Attention Weights for Explainability...")
    
    # Initialize extractor
    extractor = RTDETRAttentionExtractor(model_path)
    
    # Extract attention weights
    results, attention_data, spatial_shapes = extractor.extract_attention_for_explainability(image_path)
    
    if not attention_data:
        print("❌ No attention weights captured")
        extractor.remove_hooks()
        return
    
    print(f"✅ Captured attention data: {list(attention_data.keys())}")
    
    # Load image
    img = cv2.imread(image_path)
    img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img_h, img_w = img_rgb.shape[:2]
    
    # Check for detections
    if results[0].boxes is None or len(results[0].boxes) == 0:
        print("❌ No parasites detected for explainability")
        extractor.remove_hooks()
        return
    
    num_detections = len(results[0].boxes)
    boxes = results[0].boxes.xyxy.cpu().numpy()
    confidences = results[0].boxes.conf.cpu().numpy()
    
    print(f"🦠 Found {num_detections} parasites to explain")
    
    # Get the last decoder layer's attention (most refined)
    layer_names = set()
    for key in attention_data.keys():
        if 'query' in key:
            layer_name = '_'.join(key.split('_')[:-1])  # Remove '_query' suffix
            layer_names.add(layer_name)
    
    if not layer_names:
        print("❌ Could not find suitable attention data")
        extractor.remove_hooks()
        return
    
    # Use the last layer (highest numbered)
    layer_name = sorted(layer_names)[-1]
    print(f"📊 Using layer: {layer_name}")
    
    query_key = f"{layer_name}_query"
    input_flatten_key = f"{layer_name}_input_flatten"
    ref_points_key = f"{layer_name}_reference_points"
    
    if query_key not in attention_data or input_flatten_key not in attention_data:
        print(f"❌ Missing required attention components")
        print(f"Available keys: {list(attention_data.keys())}")
        extractor.remove_hooks()
        return
    
    query = attention_data[query_key]  # [B, N_queries, C]
    input_flatten = attention_data[input_flatten_key]  # [B, N_spatial, C]
    ref_points = attention_data.get(ref_points_key, None)
    
    print(f"📊 Query shape: {query.shape}")
    print(f"📊 Spatial features shape: {input_flatten.shape}")
    
    # Compute attention weights
    attention_weights = compute_query_spatial_attention(query, ref_points, input_flatten, spatial_shapes)
    print(f"📊 Attention weights shape: {attention_weights.shape}")
    
    # Convert attention weights to spatial heatmaps
    B, N_q, N_spatial = attention_weights.shape
    
    # Find the best spatial dimensions
    spatial_h, spatial_w = find_best_spatial_dimensions(N_spatial)
    
    print(f"📏 Spatial dimensions: {spatial_h} x {spatial_w} = {spatial_h * spatial_w}")
    print(f"📏 Available spatial features: {N_spatial}")
    
    # Handle case where dimensions don't match exactly
    if spatial_h * spatial_w > N_spatial:
        # Pad with zeros
        pad_size = spatial_h * spatial_w - N_spatial
        attention_weights_padded = F.pad(attention_weights, (0, pad_size))
        spatial_attention = attention_weights_padded[0].view(N_q, spatial_h, spatial_w)
        print(f"🔧 Padded {pad_size} features to match spatial dimensions")
    elif spatial_h * spatial_w < N_spatial:
        # Truncate
        attention_weights_truncated = attention_weights[:, :, :spatial_h * spatial_w]
        spatial_attention = attention_weights_truncated[0].view(N_q, spatial_h, spatial_w)
        print(f"🔧 Truncated {N_spatial - spatial_h * spatial_w} features to match spatial dimensions")
    else:
        # Perfect match
        spatial_attention = attention_weights[0].view(N_q, spatial_h, spatial_w)
    
    print(f"📊 Final spatial attention shape: {spatial_attention.shape}")
    
    # Find queries that likely correspond to detections
    query_energies = torch.sum(spatial_attention.view(N_q, -1), dim=1)
    top_query_indices = torch.topk(query_energies, k=min(num_detections, 6)).indices
    
    print(f"🎯 Top query indices: {top_query_indices.tolist()}")
    
    # Create explainability visualization
    fig, axes = plt.subplots(2, min(num_detections + 1, 4), figsize=(16, 10))
    if num_detections == 0:
        axes = axes.reshape(2, 1)
    
    # Original image with detections
    axes[0, 0].imshow(img_rgb)
    axes[0, 0].set_title('🦠 Parasite Detections', fontsize=14, fontweight='bold')
    axes[0, 0].axis('off')
    
    # Draw bounding boxes
    for i, (box, conf) in enumerate(zip(boxes, confidences)):
        x1, y1, x2, y2 = box
        rect = plt.Rectangle((x1, y1), x2-x1, y2-y1, 
                           fill=False, color='red', linewidth=2)
        axes[0, 0].add_patch(rect)
        axes[0, 0].text(x1, y1-5, f'P{i+1}: {conf:.2f}', 
                       color='red', fontweight='bold',
                       bbox=dict(boxstyle='round,pad=0.3', facecolor='yellow', alpha=0.9))
    
    # Generate heatmaps for each detected parasite
    for i in range(min(num_detections, 3)):
        if i + 1 < axes.shape[1]:
            query_idx = top_query_indices[i].item()
            
            # Get attention heatmap for this query
            query_attention = spatial_attention[query_idx].numpy()  # [H, W]
            
            # Resize to image dimensions
            heatmap = cv2.resize(query_attention, (img_w, img_h))
            
            # Normalize heatmap
            if heatmap.max() > 0:
                heatmap = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min())
            
            # Show individual detection
            axes[0, i + 1].imshow(img_rgb)
            if i < len(boxes):
                box = boxes[i]
                x1, y1, x2, y2 = box
                rect = plt.Rectangle((x1, y1), x2-x1, y2-y1, 
                                   fill=False, color='red', linewidth=3)
                axes[0, i + 1].add_patch(rect)
            
            axes[0, i + 1].set_title(f'Parasite {i+1}', fontsize=12, fontweight='bold')
            axes[0, i + 1].axis('off')
            
            # Show attention heatmap
            axes[1, i + 1].imshow(img_rgb, alpha=0.5)
            im = axes[1, i + 1].imshow(heatmap, cmap='jet', alpha=0.8, vmin=0, vmax=1)
            axes[1, i + 1].set_title(f'🎯 Query {query_idx} Attention\n(Explainability Heatmap)', 
                                    fontsize=11, fontweight='bold')
            axes[1, i + 1].axis('off')
            
            # Add colorbar
            plt.colorbar(im, ax=axes[1, i + 1], fraction=0.046, pad=0.04)
    
    # Summary statistics
    axes[1, 0].axis('off')
    summary_text = "🧠 EXPLAINABILITY ANALYSIS\n"
    summary_text += "="*35 + "\n\n"
    summary_text += f"📁 Image: {Path(image_path).name}\n"
    summary_text += f"🦠 Parasites: {num_detections}\n"
    summary_text += f"🔍 Total Queries: {N_q}\n"
    summary_text += f"📏 Spatial Features: {N_spatial}\n"
    summary_text += f"🗺️ Feature Map: {spatial_h}×{spatial_w}\n\n"
    
    summary_text += "🎯 QUERY ATTENTION ENERGY:\n"
    for i, idx in enumerate(top_query_indices[:3]):
        energy = query_energies[idx].item()
        summary_text += f"   Query {idx}: {energy:.2f}\n"
    
    summary_text += f"\n💡 INTERPRETATION:\n"
    summary_text += f"Heatmaps show WHERE each\n"
    summary_text += f"query focuses to detect\n"
    summary_text += f"parasites. Red = high\n"
    summary_text += f"attention, Blue = low\n"
    summary_text += f"attention."
    
    axes[1, 0].text(0.05, 0.95, summary_text, transform=axes[1, 0].transAxes, 
                   fontsize=10, verticalalignment='top', fontfamily='monospace',
                   bbox=dict(boxstyle='round,pad=0.5', facecolor='lightcyan', alpha=0.9))
    
    # Hide unused subplots
    for i in range(num_detections + 1, axes.shape[1]):
        if i < axes.shape[1]:
            axes[0, i].axis('off')
            axes[1, i].axis('off')
    
    plt.tight_layout()
    plt.savefig(f'explainability_attention_{Path(image_path).stem}.png', dpi=300, bbox_inches='tight')
    plt.show()
    
    print("✅ Explainability heatmaps generated successfully!")
    
    # Clean up
    extractor.remove_hooks()

# Test the explainability visualization
model_path = '/home/axy9651/Tryp/trypanosome_parasite_detection/tryp/ultralytics/runs/train/parasite_medical_rtdetr6/weights/best.pt'
image_path = '/home/axy9651/Tryp/trypanosome_parasite_detection/tryp/ultralytics/data_RT/images/val/field0600.jpg'

create_explainability_heatmaps(image_path, model_path)





In [None]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import cv2
import numpy as np
from ultralytics import RTDETR
from pathlib import Path

class RTDETRSharpAttentionExtractor:
    def __init__(self, model_path):
        self.model = RTDETR(model_path)
        self.pytorch_model = self.model.model
        self.attention_weights = {}
        self.spatial_info = {}
        self.hooks = []
        self._register_hooks()
    
    def _register_hooks(self):
        """Register hooks to capture attention information"""
        
        def attention_hook(name):
            def hook(module, input, output):
                # Capture inputs to deformable attention
                if len(input) >= 3:
                    query = input[0]  # [B, N_queries, C]
                    reference_points = input[1]  # [B, N_queries, n_levels, 2]
                    input_flatten = input[2]  # [B, sum(H_i*W_i), C]
                    
                    # Store components
                    self.attention_weights[f"{name}_query"] = query.detach().cpu()
                    self.attention_weights[f"{name}_input_flatten"] = input_flatten.detach().cpu()
                    self.attention_weights[f"{name}_reference_points"] = reference_points.detach().cpu()
                    
                    if len(input) >= 4:
                        spatial_shapes = input[3]
                        if isinstance(spatial_shapes, torch.Tensor):
                            self.spatial_info[f"{name}_shapes"] = spatial_shapes.detach().cpu()
                        else:
                            self.spatial_info[f"{name}_shapes"] = spatial_shapes
                
                # Store output
                if isinstance(output, torch.Tensor):
                    self.attention_weights[f"{name}_output"] = output.detach().cpu()
            
            return hook
        
        # Register hooks on cross-attention layers
        for name, module in self.pytorch_model.named_modules():
            if 'decoder' in name and 'cross_attn' in name and name.endswith('cross_attn'):
                handle = module.register_forward_hook(attention_hook(name))
                self.hooks.append(handle)
                print(f"🎯 Registered hook: {name}")
    
    def remove_hooks(self):
        for hook in self.hooks:
            hook.remove()
        self.hooks = []
    
    def extract_attention(self, image_path, conf_threshold=0.25):
        """Extract attention components"""
        self.attention_weights = {}
        self.spatial_info = {}
        
        results = self.model.predict(image_path, conf=conf_threshold, save=False, verbose=False)
        
        return results, self.attention_weights, self.spatial_info

def compute_sharp_attention_maps(query, input_flatten, reference_points, spatial_shapes):
    """Compute sharp attention maps using reference points"""
    
    B, N_q, C = query.shape
    B, N_spatial, C_spatial = input_flatten.shape
    
    print(f"🔧 Computing attention: N_q={N_q}, N_spatial={N_spatial}")
    
    # Normalize features
    query_norm = F.normalize(query, dim=-1)
    input_norm = F.normalize(input_flatten, dim=-1)
    
    # Compute basic attention
    attention_scores = torch.bmm(query_norm, input_norm.transpose(1, 2))  # [B, N_q, N_spatial]
    
    # Use reference points to create spatial weighting
    if reference_points.dim() >= 3:
        # reference_points: [B, N_q, n_levels, 2] or [B, N_q, 2]
        ref_points = reference_points.view(B, N_q, -1)  # Flatten levels
        
        # Estimate spatial grid size from N_spatial
        spatial_size = int(np.sqrt(N_spatial))
        if spatial_size * spatial_size != N_spatial:
            # Find best factors
            factors = []
            for i in range(1, int(np.sqrt(N_spatial)) + 1):
                if N_spatial % i == 0:
                    factors.append((i, N_spatial // i))
            if factors:
                # Choose most square-like factors
                factors.sort(key=lambda x: abs(x[0] - x[1]))
                h, w = factors[0]
            else:
                h = w = spatial_size
        else:
            h = w = spatial_size
        
        print(f"🗺️ Spatial grid: {h}x{w} = {h*w}, available: {N_spatial}")
        
        # Create coordinate grid that matches exactly N_spatial
        if h * w == N_spatial:
            # Perfect match
            y_coords = torch.linspace(0, 1, h).unsqueeze(1).repeat(1, w).flatten()
            x_coords = torch.linspace(0, 1, w).unsqueeze(0).repeat(h, 1).flatten()
            spatial_coords = torch.stack([x_coords, y_coords], dim=1)  # [N_spatial, 2]
        else:
            # Create coordinates for available spatial features
            total_coords_needed = N_spatial
            coords_per_level = total_coords_needed // max(1, ref_points.shape[-1] // 2)
            
            # Generate coordinates for the actual number of spatial features
            x_coords = torch.linspace(0, 1, total_coords_needed)
            y_coords = torch.linspace(0, 1, total_coords_needed)
            spatial_coords = torch.stack([x_coords, y_coords], dim=1)  # [N_spatial, 2]
        
        print(f"📍 Created spatial coordinates: {spatial_coords.shape}")
        
        # Apply reference point weighting to each query
        for q_idx in range(N_q):
            if ref_points.shape[-1] >= 2:
                # Get reference point for this query (use first 2 coordinates)
                ref_point = ref_points[0, q_idx, :2]  # [2]
                
                # Compute distances from reference point to all spatial locations
                distances = torch.norm(spatial_coords - ref_point.unsqueeze(0), dim=1)  # [N_spatial]
                
                # Create Gaussian-like weighting (sharper attention)
                spatial_weights = torch.exp(-distances * 8)  # Moderate sharpening
                
                # Ensure spatial_weights matches attention_scores dimension
                if spatial_weights.shape[0] == attention_scores.shape[2]:
                    # Apply spatial weighting to attention scores
                    attention_scores[0, q_idx] = attention_scores[0, q_idx] * spatial_weights
                else:
                    print(f"⚠️ Size mismatch: spatial_weights {spatial_weights.shape[0]} vs attention {attention_scores.shape[2]}")
    
    # Apply softmax and sharpen
    attention_weights = F.softmax(attention_scores * 1.5, dim=-1)  # Temperature scaling for sharpness
    
    return attention_weights

def create_dino_style_heatmaps(image_path, model_path):
    """Create DINO-DETR style sharp attention heatmaps"""
    
    print("🔍 Creating DINO-DETR Style Sharp Attention Maps...")
    
    # Initialize extractor
    extractor = RTDETRSharpAttentionExtractor(model_path)
    
    # Extract attention
    results, attention_data, spatial_info = extractor.extract_attention(image_path)
    
    if not attention_data:
        print("❌ No attention data captured")
        extractor.remove_hooks()
        return
    
    print(f"✅ Captured data: {list(attention_data.keys())}")
    
    # Load image
    img = cv2.imread(image_path)
    img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img_h, img_w = img_rgb.shape[:2]
    
    # Check detections
    if results[0].boxes is None or len(results[0].boxes) == 0:
        print("❌ No parasites detected")
        extractor.remove_hooks()
        return
    
    num_detections = len(results[0].boxes)
    boxes = results[0].boxes.xyxy.cpu().numpy()
    confidences = results[0].boxes.conf.cpu().numpy()
    
    print(f"🦠 Found {num_detections} parasites")
    
    # Get last layer data
    layer_names = set()
    for key in attention_data.keys():
        if 'query' in key:
            layer_name = '_'.join(key.split('_')[:-1])
            layer_names.add(layer_name)
    
    if not layer_names:
        print("❌ No suitable attention data found")
        extractor.remove_hooks()
        return
    
    layer_name = sorted(layer_names)[-1]
    
    query = attention_data[f"{layer_name}_query"]
    input_flatten = attention_data[f"{layer_name}_input_flatten"] 
    reference_points = attention_data[f"{layer_name}_reference_points"]
    spatial_shapes = spatial_info.get(f"{layer_name}_shapes", None)
    
    print(f"📊 Query: {query.shape}")
    print(f"📊 Input: {input_flatten.shape}")
    print(f"📊 Ref points: {reference_points.shape}")
    
    # Compute sharp attention maps
    attention_weights = compute_sharp_attention_maps(query, input_flatten, reference_points, spatial_shapes)
    
    B, N_q, N_spatial = attention_weights.shape
    
    # Find spatial dimensions for reshaping
    spatial_size = int(np.sqrt(N_spatial))
    if spatial_size * spatial_size > N_spatial:
        spatial_size = int(np.sqrt(N_spatial))
    
    print(f"📏 Using spatial size: {spatial_size}x{spatial_size}")
    
    # Find top queries
    query_energies = torch.sum(attention_weights[0], dim=1)
    top_query_indices = torch.topk(query_energies, k=min(num_detections, 6)).indices
    
    print(f"🎯 Top queries: {top_query_indices.tolist()}")
    
    # Create DINO-style visualization
    fig = plt.figure(figsize=(16, 10))
    fig.patch.set_facecolor('midnightblue')  # Dark blue like your example
    
    # Create grid layout
    n_cols = min(num_detections + 1, 4)
    n_rows = 2
    
    # Original image
    ax_orig = plt.subplot(n_rows, n_cols, 1)
    ax_orig.imshow(img_rgb)
    ax_orig.set_title('🦠 Detected Parasites', color='white', fontsize=12, fontweight='bold')
    ax_orig.axis('off')
    
    # Draw detections
    for i, (box, conf) in enumerate(zip(boxes, confidences)):
        x1, y1, x2, y2 = box
        rect = plt.Rectangle((x1, y1), x2-x1, y2-y1, 
                           fill=False, color='cyan', linewidth=2)
        ax_orig.add_patch(rect)
        ax_orig.text(x1, y1-5, f'P{i+1}', color='cyan', fontweight='bold', fontsize=10)
    
    # Create attention heatmaps
    for i in range(min(num_detections, n_cols-1)):
        if i < len(top_query_indices):
            query_idx = top_query_indices[i].item()
            
            # Get attention for this query
            query_attention = attention_weights[0, query_idx].numpy()
            
            # Reshape to spatial grid (handle size mismatch)
            features_to_use = min(spatial_size * spatial_size, len(query_attention))
            attention_subset = query_attention[:features_to_use]
            
            # Pad if necessary
            if len(attention_subset) < spatial_size * spatial_size:
                padding = spatial_size * spatial_size - len(attention_subset)
                attention_subset = np.concatenate([attention_subset, np.zeros(padding)])
            
            attention_spatial = attention_subset.reshape(spatial_size, spatial_size)
            
            # Apply strong sharpening like DINO-DETR
            if attention_spatial.max() > 0:
                # Normalize
                attention_spatial = attention_spatial / attention_spatial.max()
                # Apply strong power for sharpening
                attention_spatial = np.power(attention_spatial, 6)  
                # Threshold to create sharp spots
                threshold = 0.3
                attention_spatial[attention_spatial < threshold] = 0
            
            # Resize to image size
            attention_resized = cv2.resize(attention_spatial, (img_w, img_h), interpolation=cv2.INTER_CUBIC)
            
            # Top row: Individual detection
            ax1 = plt.subplot(n_rows, n_cols, i + 2)
            ax1.imshow(img_rgb)
            if i < len(boxes):
                x1, y1, x2, y2 = boxes[i]
                rect = plt.Rectangle((x1, y1), x2-x1, y2-y1, 
                                   fill=False, color='cyan', linewidth=2)
                ax1.add_patch(rect)
            ax1.set_title(f'Parasite {i+1}', color='white', fontsize=11)
            ax1.axis('off')
            
            # Bottom row: Sharp attention like DINO
            ax2 = plt.subplot(n_rows, n_cols, n_cols + i + 2)
            
            # Dark background like DINO
            dark_img = img_rgb * 0.2
            ax2.imshow(dark_img.astype(np.uint8))
            
            # Show only strong attention areas
            if attention_resized.max() > 0:
                # Create bright spots for high attention
                masked_attention = np.ma.masked_where(attention_resized < 0.05, attention_resized)
                im = ax2.imshow(masked_attention, cmap='hot', alpha=0.9, vmin=0, vmax=1)
                
                # Add bright star points for highest attention
                peak_attention = attention_resized > 0.7
                if peak_attention.any():
                    y_coords, x_coords = np.where(peak_attention)
                    ax2.scatter(x_coords, y_coords, c='cyan', s=30, alpha=1.0, marker='*')
            
            ax2.set_title(f'Query {query_idx} Focus', color='white', fontsize=10)
            ax2.axis('off')
    
    # Hide unused subplots
    for i in range(num_detections + 1, n_cols * n_rows):
        ax = plt.subplot(n_rows, n_cols, i + 1)
        ax.set_facecolor('midnightblue')
        ax.axis('off')
    
    plt.tight_layout()
    plt.savefig(f'dino_style_attention_{Path(image_path).stem}.png', 
                dpi=300, bbox_inches='tight', facecolor='midnightblue')
    plt.show()
    
    print("✅ DINO-style sharp attention maps created!")
    
    extractor.remove_hooks()

# Test the DINO-style visualization
model_path = '/home/axy9651/Tryp/trypanosome_parasite_detection/tryp/ultralytics/runs/train/parasite_medical_rtdetr6/weights/best.pt'
image_path = '/home/axy9651/Tryp/trypanosome_parasite_detection/tryp/ultralytics/data_RT/images/val/field0600.jpg'

create_dino_style_heatmaps(image_path, model_path)