In [None]:
import pickle

# Open the pickle file in binary read mode ('rb')
with open('my_dict.pkl', 'rb') as file:
    batch = pickle.load(file)

# Now 'data' contains the Python object that was saved
print(batch.keys())

In [None]:
# Open the pickle file in binary read mode ('rb')
with open('t_list.pkl', 'rb') as file:
    t_list = pickle.load(file)

# Now 'data' contains the Python object that was saved
print(len(t_list))

In [None]:
t_list[0].shape, t_list[1].shape, t_list[2].shape, 

In [4]:
# Open the pickle file in binary read mode ('rb')
with open('s_list.pkl', 'rb') as file:
    s_list = pickle.load(file)

# Now 'data' contains the Python object that was saved
print(len(s_list))

3


In [5]:
import math
import torch

In [6]:
use_topk = True          # Set True to use top-k adaptive thresholding (Option 2)
topk_ratio = 1/28          # Keep top 20% of locations per image (only used if use_topk=True)
use_dfl_objectness = True  # Set True to modulate by DFL entropy (Option 4)

# ===== FIXED HYPERPARAMETERS =====
num_classes = 80
reg_max = 16
dfl_channels = 4 * reg_max
base_conf_thresh = 0.25   # Only used if use_topk=False
eps = 1e-8
max_ent = math.log(reg_max)  # ~2.7726 for reg_max=16

In [7]:
teacher_fg_masks = []

In [8]:
for pred in t_list:
    B, C, H, W = pred.shape
    # --- 1. Classification confidence ---
    cls_pred = pred[:, dfl_channels:, :, :]  # [B, 80, H, W]  
    # we know have probability and the equivalent classes
    cls_conf, cls_name = cls_pred.sigmoid().max(dim=1)  # [B, H, W]
    if use_dfl_objectness:
        dfl_pred = pred[:, :dfl_channels, :, :]  # [B, 64, H, W]
        dfl_probs = dfl_pred.view(B, 4, reg_max, H, W).softmax(dim=2)  # [B, 4, 16, H, W]
        entropy = -(dfl_probs * torch.log(dfl_probs + eps)).sum(dim=2)
        mean_entropy = entropy.mean(dim=1)  # [B, H, W]
        norm_entropy = torch.clamp(mean_entropy / max_ent, 0.0, 1.0)
        objectness_dfl = 1.0 - norm_entropy  # [B, H, W]
        joint_conf = cls_conf * objectness_dfl
    else:
        joint_conf = cls_conf
    # we have calculated the scores of every feature points and scaled them 
    # with dfl now we will proceed with the selection of the best candidates
    if use_topk:
        # Flatten spatial dimensions: [B, H*W]
        conf_flat = joint_conf.view(B, -1)
        num_keep = max(1, int(topk_ratio * H * W))
        # Get the k-th largest value per image (threshold)
        topk_vals, _ = torch.topk(conf_flat, num_keep, dim=1, sorted=False)
        thresh = topk_vals.min(dim=1, keepdim=True)[0]  # [B, 1]
        # Broadcast threshold to [B, H, W]
        thresh = thresh.view(B, 1, 1).expand(-1, H, W)
        fg_mask = joint_conf >= thresh
    else:
        fg_mask = joint_conf > base_conf_thresh
    teacher_fg_masks.append(fg_mask)

In [9]:
teacher_fg_masks[0].shape, teacher_fg_masks[1].shape, teacher_fg_masks[2].shape,

(torch.Size([4, 80, 80]), torch.Size([4, 40, 40]), torch.Size([4, 20, 20]))

In [10]:
batch["img"].shape

torch.Size([4, 3, 640, 640])

In [11]:
cls_name.shape

torch.Size([4, 20, 20])

In [12]:
teacher_fg_masks[0][0].min()

tensor(False)

In [None]:
import torch
import numpy as np
from torchvision.utils import make_grid
from torch.nn.functional import interpolate
from PIL import Image
from IPython.display import display

def overlay_masks_on_images(images, masks, alpha=0.5):
    """
    Overlay masks on images while preserving original colors
    """
    # Ensure images and masks are on CPU and detached
    images = images.cpu().detach()
    masks = masks.cpu().detach()
    
    # Normalize images to [0, 1] if they're not already
    if images.max() > 1:
        images = images / 255.0
    
    # Convert images to numpy and change from CHW to HWC
    images_np = images.numpy().transpose(0, 2, 3, 1)  # [4, H, W, 3]
    
    # Prepare masks (already upscaled to 640x640)
    masks_np = masks.squeeze(1).numpy()  # Remove channel dim → [4, H, W]
    
    # Create overlays
    overlays = []
    for img, mask in zip(images_np, masks_np):
        # Create RGB mask (red color) with same shape as image
        mask_rgb = np.zeros_like(img)
        mask_rgb[mask > 0] = [1.0, 0.0, 0.0]  # Red color
        
        # Blend: keep original image, only add mask color where mask exists
        # This preserves the original image colors completely
        overlay = img.copy()
        mask_area = mask > 0
        overlay[mask_area] = img[mask_area] * (1 - alpha) + mask_rgb[mask_area] * alpha
        
        overlays.append(overlay)
    
    # Convert back to tensor and CHW format
    overlays_tensor = torch.from_numpy(np.array(overlays)).permute(0, 3, 1, 2).float()
    return overlays_tensor

def show_images_with_masks(batch_images, teacher_fg_masks, alpha=0.5):
    """
    Display original images and images with masks overlaid for each mask level
    """
    batch_images = batch_images.cpu().detach()
    
    print(f"Number of mask groups: {len(teacher_fg_masks)}")
    print(f"Image batch shape: {batch_images.shape}")
    
    # First, display original images
    print("\n--- Original Images ---")
    show_images_1x4(batch_images)
    
    # Process each mask group separately
    for i, mask_group in enumerate(teacher_fg_masks):
        print(f"\n--- Mask Group {i}: {mask_group.shape} -> Overlay on Images ---")
        
        # Prepare masks
        if mask_group.ndim == 3 and mask_group.shape[0] == 4:
            masks = mask_group.float().unsqueeze(1)  # → [4, 1, H, W]
        elif mask_group.ndim == 4 and mask_group.shape[1] == 1:
            masks = mask_group.float()
        else:
            raise ValueError(f"Unexpected mask shape: {mask_group.shape}")
        
        masks = masks.detach().cpu()
        
        # Upscale masks to match image size (640x640)
        upscaled_masks = interpolate(masks, size=(640, 640), mode='nearest')
        print(f"Upscaled masks shape: {upscaled_masks.shape}")
        
        # Create overlays
        overlays = overlay_masks_on_images(batch_images, upscaled_masks, alpha=alpha)
        print(f"Overlays shape: {overlays.shape}")
        
        # Display overlayed images
        original_res = mask_group.shape[1:]  # Get H, W from original mask
        print(f"Original mask resolution: {original_res} -> Upscaled to 640x640")
        
        show_images_1x4(overlays)

def show_images_1x4(batch_tensor):
    """
    Display images in a 1×4 grid layout
    """
    batch_tensor = batch_tensor.cpu().detach()
    
    # Create a grid with 4 images per row (1 row × 4 columns)
    grid = make_grid(batch_tensor, nrow=4, padding=2, normalize=True)
    
    # Convert to numpy array and change from CHW to HWC format
    grid_np = grid.numpy().transpose((1, 2, 0))
    
    # Convert to PIL Image and display
    grid_img = Image.fromarray((grid_np * 255).astype(np.uint8))
    display(grid_img)

# Usage:
show_images_with_masks(batch["img"], teacher_fg_masks, alpha=1)

plotting ground truth labels

In [None]:
from PIL import Image, ImageDraw, ImageFont
from IPython.display import display

def plot_batch_detections_pil_jupyter(batch, class_names=None):
    """
    Display images with bounding boxes in Jupyter using only PIL
    """
    if class_names is None:
        class_names = {
            22: 'class_22', 23: 'class_23', 45: 'class_45', 
            49: 'class_49', 50: 'class_50'
        }
    
    batch_size = batch['img'].shape[0]
    batch_indices_np = batch['batch_idx'].cpu().numpy()
    bboxes_np = batch['bboxes'].cpu().numpy()
    classes_np = batch['cls'].cpu().numpy().flatten()
    
    for img_idx in range(batch_size):
        img_mask = batch_indices_np == img_idx
        img_bboxes = bboxes_np[img_mask]
        img_classes = classes_np[img_mask]
        
        # Get and prepare the image
        img_tensor = batch['img'][img_idx]
        img = img_tensor.cpu().numpy().transpose(1, 2, 0)
        
        if img.max() <= 1.0:
            img = (img * 255).astype(np.uint8)
        
        # Convert to PIL Image
        pil_img = Image.fromarray(img)
        draw = ImageDraw.Draw(pil_img)
        
        # Try to use a font
        try:
            font = ImageFont.truetype("arial.ttf", 20)
        except:
            # Fallback to default font
            font = ImageFont.load_default()
        
        # Colors for different classes
        colors = ['red', 'blue', 'green', 'orange', 'purple']
        
        # Draw bounding boxes
        for bbox, cls_idx in zip(img_bboxes, img_classes):
            # Convert normalized coordinates to pixel coordinates
            x_center, y_center, width, height = bbox
            x1 = (x_center - width/2) * img.shape[1]
            y1 = (y_center - height/2) * img.shape[0]
            x2 = (x_center + width/2) * img.shape[1]
            y2 = (y_center + height/2) * img.shape[0]
            
            # Choose color
            color = colors[int(cls_idx) % len(colors)]
            
            # Draw rectangle
            draw.rectangle([x1, y1, x2, y2], outline=color, width=3)
            
            # Add class label
            class_name = class_names.get(int(cls_idx), f'class_{int(cls_idx)}')
            label = f'{class_name}'
            
            # Draw text background and label
            text_bbox = draw.textbbox((x1, y1), label, font=font)
            draw.rectangle(text_bbox, fill=color)
            draw.text((x1, y1), label, fill='white', font=font)
        
        print(f"Image {img_idx} - {len(img_bboxes)} detections")
        display(pil_img)
        print("\n" + "="*50 + "\n")

# Usage
plot_batch_detections_pil_jupyter(batch)

In [15]:
# Your three output scales
outputs = [
    torch.randn([4, 144, 80, 80]),  # Scale 1: 80x80
    torch.randn([4, 144, 40, 40]),  # Scale 2: 40x40  
    torch.randn([4, 144, 20, 20])   # Scale 3: 20x20
]

# Define strides for each scale (image_size / grid_size)
strides = [8, 16, 32]  # 640/80=8, 640/40=16, 640/20=32
image_size = 640
conf_threshold = 0.5

all_detections = [[] for _ in range(4)]  # 4 batch items

for scale_idx, output in enumerate(outputs):
    print(f"\n--- Processing scale {scale_idx} ({output.shape[2]}x{output.shape[3]}) ---")
    
    # Step 1: Separate DFL and classification
    dfl_output = output[:, :64, :, :]  # [4, 64, grid_h, grid_w]
    cls_output = output[:, 64:, :, :]  # [4, 80, grid_h, grid_w]
    
    # Step 2: Reshape DFL to [4, 4, 16, grid_h, grid_w]
    grid_h, grid_w = output.shape[2], output.shape[3]
    dfl_reshaped = dfl_output.view(4, 4, 16, grid_h, grid_w)
    
    # Step 3: Convert DFL to coordinate values
    bins = torch.arange(16, device=dfl_reshaped.device).view(1, 1, 16, 1, 1)
    dfl_probs = torch.softmax(dfl_reshaped, dim=2)
    
    coord_x = torch.sum(bins * dfl_probs[:, 0:1, :, :, :], dim=2)
    coord_y = torch.sum(bins * dfl_probs[:, 1:2, :, :, :], dim=2)
    coord_w = torch.sum(bins * dfl_probs[:, 2:3, :, :, :], dim=2)
    coord_h = torch.sum(bins * dfl_probs[:, 3:4, :, :, :], dim=2)
    
    bbox_coords = torch.cat([coord_x, coord_y, coord_w, coord_h], dim=1)
    
    # Step 4: Decode coordinates with correct stride
    grid_y, grid_x = torch.meshgrid(
        torch.arange(grid_h, device=bbox_coords.device),
        torch.arange(grid_w, device=bbox_coords.device),
        indexing='ij'
    )
    grid_x = grid_x.unsqueeze(0).unsqueeze(0)
    grid_y = grid_y.unsqueeze(0).unsqueeze(0)
    
    stride = strides[scale_idx]
    decoded_x = (bbox_coords[:, 0:1, :, :] + grid_x) / grid_w
    decoded_y = (bbox_coords[:, 1:2, :, :] + grid_y) / grid_h
    decoded_w = torch.exp(bbox_coords[:, 2:3, :, :]) * stride / image_size
    decoded_h = torch.exp(bbox_coords[:, 3:4, :, :]) * stride / image_size
    
    decoded_bboxes = torch.cat([decoded_x, decoded_y, decoded_w, decoded_h], dim=1)
    
    # Step 5: Process classification
    cls_probs = torch.sigmoid(cls_output)
    max_cls_prob, max_cls_idx = torch.max(cls_probs, dim=1)
    confidence_scores = max_cls_prob
    
    # Step 6: Reshape and filter
    bboxes_flat = decoded_bboxes.permute(0, 2, 3, 1).contiguous().view(4, grid_h*grid_w, 4)
    confidences_flat = confidence_scores.view(4, grid_h*grid_w)
    class_indices_flat = max_cls_idx.view(4, grid_h*grid_w)
    
    conf_mask = confidences_flat > conf_threshold
    
    # Step 7: Collect detections for this scale
    for batch_idx in range(4):
        grid_indices = torch.nonzero(conf_mask[batch_idx], as_tuple=True)[0]
        
        for grid_idx in grid_indices:
            detection = {
                'bbox': bboxes_flat[batch_idx, grid_idx].tolist(),
                'confidence': confidences_flat[batch_idx, grid_idx].item(),
                'class_id': class_indices_flat[batch_idx, grid_idx].item(),
                'scale': scale_idx
            }
            all_detections[batch_idx].append(detection)

# Print summary
print(f"\n=== FINAL DETECTIONS SUMMARY ===")
for batch_idx, dets in enumerate(all_detections):
    print(f"Batch {batch_idx}: {len(dets)} total detections across all scales")


--- Processing scale 0 (80x80) ---

--- Processing scale 1 (40x40) ---

--- Processing scale 2 (20x20) ---

=== FINAL DETECTIONS SUMMARY ===
Batch 0: 8400 total detections across all scales
Batch 1: 8400 total detections across all scales
Batch 2: 8400 total detections across all scales
Batch 3: 8400 total detections across all scales


In [30]:
import torchvision

def apply_nms_to_batch(batch_detections, iou_threshold=0.5):
    """Apply NMS to detections from one batch"""
    if not batch_detections:
        return []
    
    # Convert to tensors
    boxes = torch.tensor([det['bbox'] for det in batch_detections])  # [x_center, y_center, w, h]
    scores = torch.tensor([det['confidence'] for det in batch_detections])
    
    # Convert from center format [x_center, y_center, w, h] to corner format [x1, y1, x2, y2]
    x_center, y_center, width, height = boxes.unbind(1)
    x1 = x_center - width / 2
    y1 = y_center - height / 2
    x2 = x_center + width / 2
    y2 = y_center + height / 2
    boxes_corners = torch.stack([x1, y1, x2, y2], dim=1)
    
    # Apply NMS
    keep_indices = torchvision.ops.nms(boxes_corners, scores, iou_threshold)
    
    # Return filtered detections
    return [batch_detections[i] for i in keep_indices]

# Apply NMS to each batch
final_detections = []
for batch_idx, batch_dets in enumerate(all_detections):
    nms_detections = apply_nms_to_batch(batch_dets, iou_threshold=0.1)
    final_detections.append(nms_detections)
    
    print(f"Batch {batch_idx}: {len(batch_dets)} before NMS -> {len(nms_detections)} after NMS")

# Print final results
print(f"\n=== FINAL RESULTS AFTER NMS ===")
for batch_idx, dets in enumerate(final_detections):
    print(f"\nBatch {batch_idx}: {len(dets)} detections")
    
    # Sort by confidence (highest first)
    dets_sorted = sorted(dets, key=lambda x: x['confidence'], reverse=True)
    
    for i, det in enumerate(dets_sorted[:5]):  # Show top 5 detections
        bbox = det['bbox']
        print(f"  Detection {i+1}:")
        print(f"    Class: {det['class_id']}, Confidence: {det['confidence']:.3f}")
        print(f"    BBox: x_center={bbox[0]:.3f}, y_center={bbox[1]:.3f}, width={bbox[2]:.3f}, height={bbox[3]:.3f}")
        print(f"    Scale: {det['scale']}")

# Optional: Convert normalized coordinates to pixel coordinates
print(f"\n=== PIXEL COORDINATES (for 640x640 image) ===")
for batch_idx, dets in enumerate(final_detections):
    print(f"\nBatch {batch_idx}:")
    for i, det in enumerate(dets[:3]):  # Show first 3
        bbox = det['bbox']
        x_center_px = bbox[0] * 640
        y_center_px = bbox[1] * 640
        width_px = bbox[2] * 640
        height_px = bbox[3] * 640
        
        # Convert to corner coordinates
        x1 = x_center_px - width_px / 2
        y1 = y_center_px - height_px / 2
        x2 = x_center_px + width_px / 2
        y2 = y_center_px + height_px / 2
        
        print(f"  Detection {i+1}: Class {det['class_id']}, Conf {det['confidence']:.3f}")
        print(f"    Pixel bbox: ({x1:.1f}, {y1:.1f}, {x2:.1f}, {y2:.1f})")

Batch 0: 8400 before NMS -> 19 after NMS
Batch 1: 8400 before NMS -> 23 after NMS
Batch 2: 8400 before NMS -> 20 after NMS
Batch 3: 8400 before NMS -> 20 after NMS

=== FINAL RESULTS AFTER NMS ===

Batch 0: 19 detections
  Detection 1:
    Class: 28, Confidence: 0.991
    BBox: x_center=0.557, y_center=0.802, width=21.043, height=55.393
    Scale: 1
  Detection 2:
    Class: 40, Confidence: 0.989
    BBox: x_center=0.381, y_center=0.457, width=341.420, height=195.348
    Scale: 2
  Detection 3:
    Class: 46, Confidence: 0.989
    BBox: x_center=0.530, y_center=0.929, width=9.783, height=2.808
    Scale: 0
  Detection 4:
    Class: 74, Confidence: 0.989
    BBox: x_center=0.511, y_center=0.601, width=239.343, height=14.701
    Scale: 0
  Detection 5:
    Class: 19, Confidence: 0.985
    BBox: x_center=0.537, y_center=0.492, width=0.918, height=34.276
    Scale: 0

Batch 1: 23 detections
  Detection 1:
    Class: 67, Confidence: 0.992
    BBox: x_center=0.847, y_center=0.402, width=5.69

In [31]:
from PIL import Image, ImageDraw, ImageFont
from IPython.display import display
import numpy as np
import torch

def plot_yolov11_predictions(batch, predictions, class_names=None, conf_threshold=0.25):
    """
    Display images with YOLOv11 predicted bounding boxes in Jupyter
    
    Args:
        batch: Your input batch with images
        predictions: List of detections from our YOLOv11 processing
        class_names: Dictionary mapping class_ids to class names
        conf_threshold: Confidence threshold for display
    """
    if class_names is None:
        # Default COCO class names - adjust based on your dataset
        class_names = {
            0: 'person', 1: 'bicycle', 2: 'car', 3: 'motorcycle', 4: 'airplane',
            5: 'bus', 6: 'train', 7: 'truck', 8: 'boat', 9: 'traffic light',
            # Add more classes as needed for your dataset
            22: 'class_22', 23: 'class_23', 45: 'class_45', 
            49: 'class_49', 50: 'class_50'
        }
    
    batch_size = batch['img'].shape[0]
    
    for img_idx in range(batch_size):
        # Get the image from batch
        img_tensor = batch['img'][img_idx]
        img = img_tensor.cpu().numpy().transpose(1, 2, 0)
        
        # Normalize image if needed
        if img.max() <= 1.0:
            img = (img * 255).astype(np.uint8)
        else:
            img = img.astype(np.uint8)
        
        # Convert to PIL Image
        pil_img = Image.fromarray(img)
        draw = ImageDraw.Draw(pil_img)
        
        # Try to use a font
        try:
            font = ImageFont.truetype("arial.ttf", 20)
        except:
            try:
                font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 20)
            except:
                # Fallback to default font
                font = ImageFont.load_default()
        
        # Colors for different classes
        colors = ['red', 'blue', 'green', 'orange', 'purple', 'yellow', 'cyan', 'magenta']
        
        # Get predictions for this image
        img_predictions = predictions[img_idx]
        
        # Filter by confidence threshold
        filtered_predictions = [det for det in img_predictions if det['confidence'] >= conf_threshold]
        
        # Draw bounding boxes for predictions
        for detection in filtered_predictions:
            bbox = detection['bbox']  # [x_center, y_center, width, height] normalized
            confidence = detection['confidence']
            class_id = detection['class_id']
            
            # Convert normalized coordinates to pixel coordinates
            x_center, y_center, width, height = bbox
            x1 = (x_center - width/2) * img.shape[1]
            y1 = (y_center - height/2) * img.shape[0]
            x2 = (x_center + width/2) * img.shape[1]
            y2 = (y_center + height/2) * img.shape[0]
            
            # Choose color based on class
            color = colors[class_id % len(colors)]
            
            # Draw rectangle
            draw.rectangle([x1, y1, x2, y2], outline=color, width=3)
            
            # Add class label with confidence
            class_name = class_names.get(int(class_id), f'class_{int(class_id)}')
            label = f'{class_name} {confidence:.2f}'
            
            # Draw text background and label
            text_bbox = draw.textbbox((x1, y1), label, font=font)
            # Expand text background slightly
            text_bbox = (text_bbox[0]-2, text_bbox[1]-2, text_bbox[2]+2, text_bbox[3]+2)
            draw.rectangle(text_bbox, fill=color)
            draw.text((x1, y1), label, fill='white', font=font)
        
        print(f"Image {img_idx} - {len(filtered_predictions)} predictions (confidence ≥ {conf_threshold})")
        display(pil_img)
        print("\n" + "="*50 + "\n")

# Alternative version that also shows ground truth for comparison
def plot_yolov11_vs_ground_truth(batch, predictions, class_names=None, conf_threshold=0.25):
    """
    Display images with both YOLOv11 predictions and ground truth bounding boxes
    """
    if class_names is None:
        class_names = {
            0: 'person', 1: 'bicycle', 2: 'car', 3: 'motorcycle', 4: 'airplane',
            22: 'class_22', 23: 'class_23', 45: 'class_45', 
            49: 'class_49', 50: 'class_50'
        }
    
    batch_size = batch['img'].shape[0]
    batch_indices_np = batch['batch_idx'].cpu().numpy()
    bboxes_np = batch['bboxes'].cpu().numpy()
    classes_np = batch['cls'].cpu().numpy().flatten()
    
    for img_idx in range(batch_size):
        # Get ground truth for this image
        img_mask = batch_indices_np == img_idx
        gt_bboxes = bboxes_np[img_mask]
        gt_classes = classes_np[img_mask]
        
        # Get the image
        img_tensor = batch['img'][img_idx]
        img = img_tensor.cpu().numpy().transpose(1, 2, 0)
        
        if img.max() <= 1.0:
            img = (img * 255).astype(np.uint8)
        
        pil_img = Image.fromarray(img)
        draw = ImageDraw.Draw(pil_img)
        
        try:
            font = ImageFont.truetype("arial.ttf", 16)
        except:
            font = ImageFont.load_default()
        
        # Get predictions for this image
        img_predictions = predictions[img_idx]
        filtered_predictions = [det for det in img_predictions if det['confidence'] >= conf_threshold]
        
        # Draw ground truth boxes (green)
        for bbox, cls_idx in zip(gt_bboxes, gt_classes):
            x_center, y_center, width, height = bbox
            x1 = (x_center - width/2) * img.shape[1]
            y1 = (y_center - height/2) * img.shape[0]
            x2 = (x_center + width/2) * img.shape[1]
            y2 = (y_center + height/2) * img.shape[0]
            
            # Draw ground truth in green
            draw.rectangle([x1, y1, x2, y2], outline='green', width=2)
            
            class_name = class_names.get(int(cls_idx), f'class_{int(cls_idx)}')
            label = f'GT: {class_name}'
            text_bbox = draw.textbbox((x1, y1), label, font=font)
            draw.rectangle(text_bbox, fill='green')
            draw.text((x1, y1), label, fill='white', font=font)
        
        # Draw prediction boxes (red)
        for detection in filtered_predictions:
            bbox = detection['bbox']
            confidence = detection['confidence']
            class_id = detection['class_id']
            
            x_center, y_center, width, height = bbox
            x1 = (x_center - width/2) * img.shape[1]
            y1 = (y_center - height/2) * img.shape[0]
            x2 = (x_center + width/2) * img.shape[1]
            y2 = (y_center + height/2) * img.shape[0]
            
            # Draw predictions in red
            draw.rectangle([x1, y1, x2, y2], outline='red', width=3)
            
            class_name = class_names.get(int(class_id), f'class_{int(class_id)}')
            label = f'Pred: {class_name} {confidence:.2f}'
            text_bbox = draw.textbbox((x1, y1-20), label, font=font)  # Offset to avoid overlap
            draw.rectangle(text_bbox, fill='red')
            draw.text((x1, y1-20), label, fill='white', font=font)
        
        print(f"Image {img_idx} - GT: {len(gt_bboxes)}, Pred: {len(filtered_predictions)} (confidence ≥ {conf_threshold})")
        print("Green: Ground Truth, Red: Predictions")
        display(pil_img)
        print("\n" + "="*50 + "\n")

# Usage examples:

# 1. Plot only predictions
# plot_yolov11_predictions(batch, final_detections, conf_threshold=0.25)

# 2. Plot predictions vs ground truth
# plot_yolov11_vs_ground_truth(batch, final_detections, conf_threshold=0.25)

# If you want to use the same format as your original function:
def plot_batch_detections_pil_jupyter(batch, predictions=None, class_names=None):
    """
    Modified version of your original function that can handle both ground truth and predictions
    """
    if predictions is not None:
        # Use the new prediction plotting function
        plot_yolov11_predictions(batch, predictions, class_names)
    else:
        # Fall back to original ground truth plotting
        if class_names is None:
            class_names = {
                22: 'class_22', 23: 'class_23', 45: 'class_45', 
                49: 'class_49', 50: 'class_50'
            }
        
        batch_size = batch['img'].shape[0]
        batch_indices_np = batch['batch_idx'].cpu().numpy()
        bboxes_np = batch['bboxes'].cpu().numpy()
        classes_np = batch['cls'].cpu().numpy().flatten()
        
        for img_idx in range(batch_size):
            img_mask = batch_indices_np == img_idx
            img_bboxes = bboxes_np[img_mask]
            img_classes = classes_np[img_mask]
            
            img_tensor = batch['img'][img_idx]
            img = img_tensor.cpu().numpy().transpose(1, 2, 0)
            
            if img.max() <= 1.0:
                img = (img * 255).astype(np.uint8)
            
            pil_img = Image.fromarray(img)
            draw = ImageDraw.Draw(pil_img)
            
            try:
                font = ImageFont.truetype("arial.ttf", 20)
            except:
                font = ImageFont.load_default()
            
            colors = ['red', 'blue', 'green', 'orange', 'purple']
            
            for bbox, cls_idx in zip(img_bboxes, img_classes):
                x_center, y_center, width, height = bbox
                x1 = (x_center - width/2) * img.shape[1]
                y1 = (y_center - height/2) * img.shape[0]
                x2 = (x_center + width/2) * img.shape[1]
                y2 = (y_center + height/2) * img.shape[0]
                
                color = colors[int(cls_idx) % len(colors)]
                draw.rectangle([x1, y1, x2, y2], outline=color, width=3)
                
                class_name = class_names.get(int(cls_idx), f'class_{int(cls_idx)}')
                label = f'{class_name}'
                
                text_bbox = draw.textbbox((x1, y1), label, font=font)
                draw.rectangle(text_bbox, fill=color)
                draw.text((x1, y1), label, fill='white', font=font)
            
            print(f"Image {img_idx} - {len(img_bboxes)} detections")
            display(pil_img)
            print("\n" + "="*50 + "\n")

# Usage with your predictions:
# plot_yolov11_predictions(batch, final_detections, conf_threshold=0.25)

In [None]:
# After running your YOLOv11 processing to get final_detections
plot_yolov11_predictions(batch, final_detections, conf_threshold=.25)