In [1]:
import pandas as pd
from typing import Iterable, Literal, overload
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import cv2 as cv
import numpy as np
import math
import os

In [2]:
os.environ['TF_GPU_ALLOCATOR'] = 'cuda_malloc_async'

In [3]:
import tensorflow as tf
from tensorflow.keras.preprocessing.image import load_img,img_to_array

2025-11-26 15:45:28.342592: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-11-26 15:45:28.362736: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-11-26 15:45:28.368157: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-11-26 15:45:28.383933: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [4]:
def cxcywh_toxyxy_core(boxes: tf.Tensor):

    tf.debugging.assert_equal(tf.shape(boxes)[-1], 4, message="boxes last dim must be 4")
    
    cx, cy, w, h = tf.split(boxes,num_or_size_splits = 4, axis=-1)

    x_min = cx - (0.5*w)
    y_min = cy - (0.5*h)

    x_max = cx + (0.5*w)
    y_max = cy + (0.5*h)

    return tf.concat([x_min,y_min,x_max,y_max], axis = -1)

In [5]:
def area_xyxy_core(boxes: tf.Tensor):
    # Making sure the shape is correct
    tf.debugging.assert_equal(tf.shape(boxes)[-1], 4, message="boxes last dim must be 4")
    
    # Split the coordinates
    x_min, y_min, x_max, y_max = tf.split(boxes,num_or_size_splits = 4, axis=-1)

    # Calculate the area
    w = tf.maximum(x_max - x_min, 0.0)
    h = tf.maximum(y_max - y_min, 0.0)

    area = w * h

    return tf.squeeze(area, axis=-1)

@tf.function(
    input_signature = [
        tf.TensorSpec(shape=[None, None, 4], dtype = tf.float32),
    ]
)
def area_xyxy_batched(boxes: tf.Tensor):

    B = tf.shape(boxes)[0]
    N = tf.shape(boxes)[1]

    # Flattening the boxes
    flattened_boxes_xyxy = tf.reshape(boxes,[-1,4])
    flattened_area = area_xyxy_core(flattened_boxes_xyxy)

    return tf.reshape(flattened_area, [B,N])

@tf.function(
    input_signature = [
        tf.TensorSpec(shape=[None, 4], dtype = tf.float32),
        tf.TensorSpec(shape=[None, 4], dtype = tf.float32)
    ]
)
def intersection_xyxy_core(boxes_1,boxes_2):
    tf.debugging.assert_equal(tf.shape(boxes_1)[-1], 4, message="boxes 1 last dim must be 4")
    tf.debugging.assert_equal(tf.shape(boxes_2)[-1], 4, message="boxes 2 last dim must be 4")

    # Split the coordinates
    ax_min,ay_min,ax_max,ay_max = tf.split(boxes_1,num_or_size_splits = 4, axis=-1)
    bx_min,by_min,bx_max,by_max = tf.split(boxes_2,num_or_size_splits = 4, axis=-1)

    # Calculating the proper coordinates
    x_min = tf.maximum(ax_min[:,None], bx_min[None,:])
    y_min = tf.maximum(ay_min[:,None], by_min[None,:])
    x_max = tf.minimum(ax_max[:,None], bx_max[None,:])
    y_max = tf.minimum(ay_max[:,None], by_max[None,:])

    # Calculating the intersection
    w = tf.maximum(x_max - x_min, 0.0)
    h = tf.maximum(y_max - y_min, 0.0)

    intersection = w * h

    return tf.squeeze(intersection,axis=-1)

@tf.function(
    input_signature=[
        tf.TensorSpec([None, None, 4], tf.float32), 
        tf.TensorSpec([None, None, 4], tf.float32),  
    ]
)
def intersection_xyxy_batched(boxes_1, boxes_2):
    
    tf.debugging.assert_equal(tf.shape(boxes_1)[-1], 4, message="boxes_1 last dim must be 4")
    tf.debugging.assert_equal(tf.shape(boxes_2)[-1], 4, message="boxes_2 last dim must be 4")

    a_xmin, a_ymin, a_xmax, a_ymax = tf.split(boxes_1, 4, axis=-1)  
    b_xmin, b_ymin, b_xmax, b_ymax = tf.split(boxes_2, 4, axis=-1) 

    x1 = tf.maximum(a_xmin[:, :, None, :], b_xmin[:, None, :, :])  
    y1 = tf.maximum(a_ymin[:, :, None, :], b_ymin[:, None, :, :])  
    x2 = tf.minimum(a_xmax[:, :, None, :], b_xmax[:, None, :, :])  
    y2 = tf.minimum(a_ymax[:, :, None, :], b_ymax[:, None, :, :])  

    w = tf.maximum(x2 - x1, 0.0)
    h = tf.maximum(y2 - y1, 0.0)
    inter = w * h                                                

    return tf.squeeze(inter, axis=-1)

@tf.function(
    input_signature=[
        tf.TensorSpec([None], tf.float32),     
        tf.TensorSpec([None], tf.float32),     
        tf.TensorSpec([None, None], tf.float32)
    ]
)
def union_from_areas_core(a_area, b_area, inter):
    # Broadcast: (N,1) + (1,M) - (N,M)
    union = tf.maximum(tf.expand_dims(a_area, 1) + tf.expand_dims(b_area, 0) - inter, 1e-7)
    return union

@tf.function(
    input_signature=[
        tf.TensorSpec([None, None], tf.float32),   
        tf.TensorSpec([None, None], tf.float32),   
        tf.TensorSpec([None, None, None], tf.float32)  
    ]
)
def union_from_areas_batched(a_area, b_area, inter):
    union = a_area[:, :, None] + b_area[:, None, :] - inter
    return tf.maximum(union, tf.constant(1e-7, tf.float32))

@tf.function(
    input_signature=[
        tf.TensorSpec([None, 4], tf.float32), 
        tf.TensorSpec([None, 4], tf.float32),  
    ]
)
def iou_matrix_core(boxes_1, boxes_2):
    tf.debugging.assert_rank(boxes_1, 2, message="boxes1 must be (M,4)")
    tf.debugging.assert_equal(tf.shape(boxes_1)[-1], 4)
    tf.debugging.assert_rank(boxes_2, 2, message="boxes2 must be (N,4)")
    tf.debugging.assert_equal(tf.shape(boxes_2)[-1], 4)
    # areas
    
    a_area = area_xyxy_core(boxes_1)
    b_area = area_xyxy_core(boxes_2)
    # intersections
    inter = intersection_xyxy_core(boxes_1, boxes_2)
    # unions
    union = union_from_areas_core(a_area, b_area, inter)
    # IoU
    return inter / union

In [6]:
gt_boxes_xyxy = tf.constant([
    [0.1, 0.1, 0.3, 0.3],   # GT0: A (class 3)
    [0.6, 0.6, 0.9, 0.9],   # GT1: B (class 2)
], tf.float32)
gt_labels = tf.constant([3, 2], tf.int32)
gt_valid_mask = tf.constant([True, True], tf.bool)
priors_cxcywh = tf.constant([
    [0.2,  0.2,  0.20, 0.20],    # P0 → IoU 1.00 with A
    [0.2,  0.2,  0.16, 0.16],    # P1 → IoU 0.64 with A
    [0.2,  0.2,  0.12, 0.12],    # P2 → IoU 0.36 with A
    [0.75, 0.75, 0.30, 0.30],    # P3 → IoU 1.00 with B
    [0.75, 0.75, 0.24, 0.24],    # P4 → IoU 0.64 with B
    [0.05, 0.90, 0.10, 0.10],    # P5 → IoU 0.00
], tf.float32)

I0000 00:00:1764189931.120819   25720 cuda_executor.cc:1001] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node
Your kernel may have been built without NUMA support.
I0000 00:00:1764189931.227756   25720 cuda_executor.cc:1001] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node
Your kernel may have been built without NUMA support.
I0000 00:00:1764189931.227824   25720 cuda_executor.cc:1001] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node
Your kernel may have been built without NUMA support.
I0000 00:00:1764189931.229234   25720 cuda_executor.cc:1001] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node
Your kernel may have been built without NUMA support.
I0000 00:00:1764189931.229297   25720 cuda_executor.cc:1001] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node
Your kernel may have been built without NUMA support.
I0000 00:0

In [7]:
def _check_for_center_alignment(priors_cxcywh: tf.Tensor, gt_boxes_xyxy: tf.Tensor):
    
    cx,cy,_,_ = tf.split(priors_cxcywh,num_or_size_splits = 4, axis=-1)
    x_min, y_min, x_max, y_max = tf.split(gt_boxes_xyxy,num_or_size_splits = 4, axis=-1)

    cx = tf.transpose(cx, perm=[1, 0]) 
    cy = tf.transpose(cy, perm=[1, 0])
    
    # Check if (x_min <= cx <= x_max) and (y_min <= cy <= y_max)
    center_alignment = (x_min <= cx) & (cx<= x_max) & (y_min <= cy) & (cy <= y_max)

    return center_alignment

In [8]:
def _calculate_matches(iou_matrix: tf.Tensor,gt_boxes: tf.Tensor,positive_iou_thresh: float,negative_iou_thresh: float,enforce_bipartite: bool = True):
    # Apply rules
    M = tf.shape(gt_boxes)[0]
    N = tf.shape(iou_matrix)[1]
    
    max_iou_per_anchor = tf.reduce_max(iou_matrix,axis=0)
    assigned_gt_box_index = tf.argmax(iou_matrix,axis = 0,output_type = tf.int32)

    positive_mask = max_iou_per_anchor >= positive_iou_thresh
    negative_mask = max_iou_per_anchor < negative_iou_thresh
    ignore_mask = tf.logical_not(tf.logical_or(positive_mask,negative_mask))

    if enforce_bipartite:
        # For each GT box get the prior with the most value
        best_prior_per_gt = tf.argmax(iou_matrix, axis=1, output_type = tf.int32)

        # Get the values of the IoUs for the best matches
        best_iou_per_gt = tf.reduce_max(iou_matrix, axis=1)
        valid_gt = best_iou_per_gt > tf.constant(-0.5, iou_matrix.dtype)

        best_indices_all = tf.stack([tf.range(M, dtype=tf.int32), best_prior_per_gt], axis=1)  # (M,2)
        best_indices = tf.boolean_mask(best_indices_all, valid_gt)
        best_values  = tf.boolean_mask(tf.gather_nd(iou_matrix, best_indices_all), valid_gt)
        
        # Create a sparse Matrix to resolve any potential conflicts
        sparse_iou = tf.scatter_nd(best_indices,best_values,shape=tf.stack([M,N]))

        # Find which of the columns are forced
        forced_cols = tf.reduce_any(sparse_iou > tf.constant(-0.5, iou_matrix.dtype),axis=0)

        # Now calculate the best gt box per anchor to remove the conflicts by having the best one pick the prior
        best_gt_per_anchor = tf.argmax(sparse_iou, axis=0, output_type = tf.int32)

        # Now override the values where there was a force that was occuring
        resolved_gt_assignment = tf.where(forced_cols, best_gt_per_anchor, assigned_gt_box_index)

        # Now override the assignment for the  previous indices with the new resolved one
        assigned_gt_box_index = tf.where(forced_cols, resolved_gt_assignment, assigned_gt_box_index)

        # Stack the assigned gt box and each prior
        assigned_gt_box_per_prior = tf.stack([assigned_gt_box_index, tf.range(N,dtype=tf.int32)], axis=1)

        # Get the last resolved IoU matrix using it
        resolved_iou = tf.gather_nd(iou_matrix,assigned_gt_box_per_prior)
        max_iou_per_anchor = tf.where(forced_cols, resolved_iou, max_iou_per_anchor)

        # Now Update the masks with the new forced picks
        positive_mask = tf.logical_or(positive_mask,forced_cols)
        negative_mask = tf.where(forced_cols, tf.zeros_like(negative_mask),negative_mask)
        ignore_mask = tf.where(forced_cols, tf.zeros_like(ignore_mask),ignore_mask)
        
    
    # Calculate the number of positives
    number_of_positive_priors = tf.reduce_sum(tf.cast(positive_mask,tf.int32))
    # Calculate where the labels need to be ignored
    assigned_gt_box_index = tf.where(positive_mask, assigned_gt_box_index, -tf.ones_like(assigned_gt_box_index))

    # return assigned_gt_box_index, max_iou_per_anchor, positive_mask, negative_mask, ignore_mask, number_of_positive_priors
    return {
        "assigned_gt_box_index": assigned_gt_box_index,
        "max_iou_per_prior": max_iou_per_anchor,
        "pos_mask": positive_mask,
        "neg_mask": negative_mask,
        "ignore_mask": ignore_mask,
        "num_pos": number_of_positive_priors,
    }
    

In [24]:
def match_priors(priors_cxcywh: tf.Tensor, gt_boxes_xyxy: tf.Tensor, gt_labels: tf.Tensor, gt_valid_mask: tf.Tensor | None, positive_iou_thresh: float, negative_iou_thresh: float, max_pos_per_gt: list[int] | None, allow_low_qual_matches: bool = True, center_in_gt: bool = True , return_iou: bool = False):

    priors_cxcywh = tf.reshape(priors_cxcywh, [-1, 4])
    gt_boxes_xyxy = tf.reshape(gt_boxes_xyxy, [-1, 4])

    N = tf.shape(priors_cxcywh)[0]
    M = tf.shape(gt_boxes_xyxy)[0]

    # Fringe case, if gt_valid_mask is None then treat all boxes as valid
    if gt_valid_mask is None:
        gt_valid_mask = tf.ones_like(gt_labels,tf.bool)

    validity_check = tf.cast(gt_valid_mask,dtype=tf.int32)

    if tf.equal(tf.size(gt_boxes_xyxy),0) or tf.equal(tf.reduce_sum(validity_check),tf.constant(0,dtype=tf.int32)):
        return {
        "matched_gt_xyxy": tf.zeros([N, 4], tf.float32),
        "matched_labels":  tf.zeros([N], tf.int32),
        "pos_mask":        tf.zeros([N], tf.bool),
        "neg_mask":        tf.ones([N],  tf.bool),
        "ignore_mask":     tf.zeros([N], tf.bool),
        "matched_gt_idx":  -tf.ones([N], tf.int32),
        "matched_iou":     tf.zeros([N], tf.float32),
        "num_pos":         tf.zeros([], tf.int32),
        }

    # Need to compute which of the gt boxes are valid
    valid_indices = tf.where(gt_valid_mask)

    valid_gt_boxes = tf.gather_nd(gt_boxes_xyxy, valid_indices)
    valid_labels = tf.gather_nd(gt_labels, valid_indices)

    # Compute the IoU Matrix
    priors_xyxy = cxcywh_toxyxy_core(priors_cxcywh)
    iou_matrix = iou_matrix_core(valid_gt_boxes,priors_xyxy)

    if center_in_gt:
        center_aligned = _check_for_center_alignment(priors_cxcywh,valid_gt_boxes)
        # Filter out the Non centered priors
        iou_matrix = tf.where(center_aligned, iou_matrix, tf.zeros_like(iou_matrix))

    # Calculate matching mask
    match_dict = _calculate_matches(iou_matrix,valid_gt_boxes,positive_iou_thresh,negative_iou_thresh,enforce_bipartite = allow_low_qual_matches)

    labels_g = tf.gather(valid_labels,match_dict["assigned_gt_box_index"])
    boxes_g  = tf.gather(valid_gt_boxes, match_dict["assigned_gt_box_index"])

    zeros_labels = tf.zeros_like(match_dict["assigned_gt_box_index"], dtype=tf.int32)
    zeros_boxes  = tf.zeros([N, 4], dtype=gt_boxes_xyxy.dtype)
    
    # Calculate the matching labels
    matched_labels  = tf.where(match_dict["pos_mask"], labels_g, zeros_labels)

    # Calculate the matching ground truth boxes
    matched_gt_xyxy = tf.where(tf.expand_dims(match_dict["pos_mask"], 1), boxes_g, zeros_boxes)

    matched_gt_idx = tf.where(match_dict["pos_mask"], match_dict["assigned_gt_box_index"], -tf.ones_like(match_dict["assigned_gt_box_index"]))

    return_dict = {
        "matched_gt_xyxy" : matched_gt_xyxy,
        "matched_gt_labels": matched_labels,
        "pos_mask": match_dict['pos_mask'],
        "neg_mask": match_dict['neg_mask'],
        "ignore_mask": match_dict['ignore_mask'],
        "matched_gt_idx": matched_gt_idx,
        "num_pos": match_dict['num_pos']
    }

    if return_iou:
        # Calculate IoU for the images
        max_iou = match_dict["max_iou_per_prior"] 
        matched_iou = tf.where(match_dict["pos_mask"], max_iou, tf.zeros_like(max_iou))
        return_dict['matched_iou'] = matched_iou

    return return_dict

In [25]:
match_priors(priors_cxcywh,gt_boxes_xyxy,gt_labels,gt_valid_mask,0.5,0.3,None,center_in_gt = True,allow_low_qual_matches = True,return_iou = True)

{'matched_gt_xyxy': <tf.Tensor: shape=(6, 4), dtype=float32, numpy=
 array([[0.1, 0.1, 0.3, 0.3],
        [0.1, 0.1, 0.3, 0.3],
        [0.1, 0.1, 0.3, 0.3],
        [0.6, 0.6, 0.9, 0.9],
        [0.1, 0.1, 0.3, 0.3],
        [0.1, 0.1, 0.3, 0.3]], dtype=float32)>,
 'matched_gt_labels': <tf.Tensor: shape=(6,), dtype=int32, numpy=array([3, 3, 3, 2, 3, 3], dtype=int32)>,
 'pos_mask': <tf.Tensor: shape=(6,), dtype=bool, numpy=array([ True,  True,  True,  True,  True,  True])>,
 'neg_mask': <tf.Tensor: shape=(6,), dtype=bool, numpy=array([False, False, False, False, False, False])>,
 'ignore_mask': <tf.Tensor: shape=(6,), dtype=bool, numpy=array([False, False, False, False, False, False])>,
 'matched_gt_idx': <tf.Tensor: shape=(6,), dtype=int32, numpy=array([0, 0, 0, 1, 0, 0], dtype=int32)>,
 'num_pos': <tf.Tensor: shape=(), dtype=int32, numpy=6>,
 'matched_iou': <tf.Tensor: shape=(6,), dtype=float32, numpy=
 array([1.        , 0.63999987, 0.3599999 , 1.        , 0.        ,
        0.    

In [26]:
gt_boxes_xyxy = tf.constant([
    [0.0, 0.0, 0.0, 0.0],   # padding
    [0.0, 0.0, 0.0, 0.0],   # padding
    [0.0, 0.0, 0.0, 0.0],   # padding
    [0.0, 0.0, 0.0, 0.0],   # padding
], tf.float32)

gt_labels = tf.constant([0, 0, 0, 0], tf.int32)

gt_valid_mask = tf.constant([False, False, False, False], tf.bool)

priors_cxcywh = tf.constant([
    [0.2,  0.2,  0.20, 0.20],    # P0 → IoU 1.00 with A
    [0.2,  0.2,  0.16, 0.16],    # P1 → IoU 0.64 with A
    [0.2,  0.2,  0.12, 0.12],    # P2 → IoU 0.36 with A
    [0.75, 0.75, 0.30, 0.30],    # P3 → IoU 1.00 with B
    [0.75, 0.75, 0.24, 0.24],    # P4 → IoU 0.64 with B
    [0.05, 0.90, 0.10, 0.10],    # P5 → IoU 0.00
], tf.float32)

In [27]:
match_priors(priors_cxcywh,gt_boxes_xyxy,gt_labels,gt_valid_mask,0.5,0.3,None,center_in_gt = True,allow_low_qual_matches = True,return_iou = True)

{'matched_gt_xyxy': <tf.Tensor: shape=(6, 4), dtype=float32, numpy=
 array([[0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]], dtype=float32)>,
 'matched_labels': <tf.Tensor: shape=(6,), dtype=int32, numpy=array([0, 0, 0, 0, 0, 0], dtype=int32)>,
 'pos_mask': <tf.Tensor: shape=(6,), dtype=bool, numpy=array([False, False, False, False, False, False])>,
 'neg_mask': <tf.Tensor: shape=(6,), dtype=bool, numpy=array([ True,  True,  True,  True,  True,  True])>,
 'ignore_mask': <tf.Tensor: shape=(6,), dtype=bool, numpy=array([False, False, False, False, False, False])>,
 'matched_gt_idx': <tf.Tensor: shape=(6,), dtype=int32, numpy=array([-1, -1, -1, -1, -1, -1], dtype=int32)>,
 'matched_iou': <tf.Tensor: shape=(6,), dtype=float32, numpy=array([0., 0., 0., 0., 0., 0.], dtype=float32)>,
 'num_pos': <tf.Tensor: shape=(), dtype=int32, numpy=0>}

In [28]:
gt_boxes_xyxy = tf.constant([
    [0.1, 0.1, 0.3, 0.3],   # GT0: A (class 3)
    [0.6, 0.6, 0.9, 0.9],   # GT1: B (class 2)
    [0.0, 0.0, 0.0, 0.0],   # padding
    [0.0, 0.0, 0.0, 0.0],   # padding
], tf.float32)


gt_labels = tf.constant(
    [3, 2, 0, 0],  # padded rows → label ignored (valid_mask decides!)
    tf.int32
)

gt_valid_mask = tf.constant(
    [True, True, False, False],  # only 0 and 1 are real GTs
    tf.bool
)

priors_cxcywh = tf.constant([
    [0.2,  0.2,  0.20, 0.20],    # P0 → IoU 1.00 with A
    [0.2,  0.2,  0.16, 0.16],    # P1 → IoU 0.64 with A
    [0.2,  0.2,  0.12, 0.12],    # P2 → IoU 0.36 with A (ignore band)
    [0.75, 0.75, 0.30, 0.30],    # P3 → IoU 1.00 with B
    [0.75, 0.75, 0.24, 0.24],    # P4 → IoU 0.64 with B
    [0.05, 0.90, 0.10, 0.10],    # P5 → IoU 0.00
], tf.float32)

In [32]:
match_priors(priors_cxcywh,gt_boxes_xyxy,gt_labels,gt_valid_mask,0.5,0.3,None,center_in_gt = False,allow_low_qual_matches = True,return_iou = True)

{'matched_gt_xyxy': <tf.Tensor: shape=(6, 4), dtype=float32, numpy=
 array([[0.1, 0.1, 0.3, 0.3],
        [0.1, 0.1, 0.3, 0.3],
        [0.1, 0.1, 0.3, 0.3],
        [0.6, 0.6, 0.9, 0.9],
        [0.1, 0.1, 0.3, 0.3],
        [0.1, 0.1, 0.3, 0.3]], dtype=float32)>,
 'matched_gt_labels': <tf.Tensor: shape=(6,), dtype=int32, numpy=array([3, 3, 3, 2, 3, 3], dtype=int32)>,
 'pos_mask': <tf.Tensor: shape=(6,), dtype=bool, numpy=array([ True,  True,  True,  True,  True,  True])>,
 'neg_mask': <tf.Tensor: shape=(6,), dtype=bool, numpy=array([False, False, False, False, False, False])>,
 'ignore_mask': <tf.Tensor: shape=(6,), dtype=bool, numpy=array([False, False, False, False, False, False])>,
 'matched_gt_idx': <tf.Tensor: shape=(6,), dtype=int32, numpy=array([0, 0, 0, 1, 0, 0], dtype=int32)>,
 'num_pos': <tf.Tensor: shape=(), dtype=int32, numpy=6>,
 'matched_iou': <tf.Tensor: shape=(6,), dtype=float32, numpy=
 array([1.        , 0.63999987, 0.3599999 , 1.        , 0.        ,
        0.    

In [15]:
pos_mask = tf.constant([False,  True, False, False,  True, False, False, False])
neg_mask = tf.constant([ True, False,  True,  True, False,  True,  True,  True])
ignore_mask = tf.constant([False, False, False, False, False, False, False, False])
conf_loss = tf.constant([0.05,  1.20,  0.80,  0.02,  0.40,  1.50,  0.30,  0.90])

In [16]:
def hard_negative_mining(conf_loss: tf.Tensor, pos_mask: tf.Tensor, neg_mask:tf.Tensor, neg_ratio: float, min_neg: int| None, max_neg: int| None):
    
    num_positive = tf.reduce_sum(tf.cast(pos_mask, tf.int32))

    K = tf.math.floor(tf.cast(neg_ratio,dtype=tf.float32) * tf.cast(num_positive,dtype=tf.float32))
    K = tf.cast(K,tf.int32)

    if max_neg is not None:
        K = tf.minimum(K,tf.cast(max_neg,tf.int32))

    if min_neg is not None:
        K = tf.maximum(K,tf.cast(min_neg,tf.int32))

    K = tf.cast(tf.where(num_positive > 0, K, tf.zeros_like(K)),dtype=tf.int32)

    # Getting the indices for the negative boxes
    negative_indices = tf.where(neg_mask)[:,0]

    negative_losses = tf.gather(conf_loss,negative_indices)

    # Filtering the losses to not include NaN or inf
    valid_mask = tf.logical_not(tf.math.is_nan(negative_losses))
    valid_negative_indices = tf.boolean_mask(negative_indices,valid_mask)
    valid_negative_losses = tf.boolean_mask(negative_losses,valid_mask)

    # Checking if there are no valid losses
    num_valid_losses = tf.shape(valid_negative_losses)[0]
    k = tf.minimum(K,num_valid_losses)

    top_k_losses, top_k_indices = tf.math.top_k(valid_negative_losses,k=k,sorted=True)
    
    hard_negative_indices = tf.cast(tf.gather(valid_negative_indices,top_k_indices),tf.int32)

    hard_negative_indices = tf.expand_dims(hard_negative_indices,axis=1)

    selected_negative_mask = tf.scatter_nd(indices = hard_negative_indices, updates = tf.ones(k,dtype=tf.bool),shape=[tf.shape(conf_loss)[0]])

    selected_negative_indices = tf.reduce_sum(tf.cast(selected_negative_mask,tf.int32))

    return selected_negative_mask, selected_negative_indices, tf.boolean_mask(conf_loss,selected_negative_mask)

In [17]:
hard_negative_mining(conf_loss,pos_mask,neg_mask,neg_ratio = 1.0,min_neg = None, max_neg = None)

(<tf.Tensor: shape=(8,), dtype=bool, numpy=array([False, False, False, False, False,  True, False,  True])>,
 <tf.Tensor: shape=(), dtype=int32, numpy=2>,
 <tf.Tensor: shape=(2,), dtype=float32, numpy=array([1.5, 0.9], dtype=float32)>)

In [18]:
pos_mask

<tf.Tensor: shape=(8,), dtype=bool, numpy=array([False,  True, False, False,  True, False, False, False])>

In [19]:
neg_mask

<tf.Tensor: shape=(8,), dtype=bool, numpy=array([ True, False,  True,  True, False,  True,  True,  True])>

In [20]:
ignore_mask

<tf.Tensor: shape=(8,), dtype=bool, numpy=array([False, False, False, False, False, False, False, False])>