In [1]:
import jax
from jax import (
    Array,
    numpy as jnp,
    random as jrand,
    lax
)
from scipy.optimize import linear_sum_assignment as scipy_lsa
from typing import Sequence

import keras as nn
import keras_cv as ncv

2024-03-05 16:50:20.864795: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-03-05 16:50:20.864850: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-03-05 16:50:20.890411: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [2]:
# possible classes = [0, 1, 2]
class_true = jnp.array([[1, 2, 0], 
                        [2, 1, 0]], dtype=jnp.float32)
print("class_true", class_true.shape) # (B, N)
class_pred = jnp.array([[1, 0, 2], 
                        [0, 1, 2]], dtype=jnp.float32)
print("class_pred", class_pred.shape) # (B, N)
N = class_true.shape[-1]

class_prob = jnp.array([[[0.1, 0.8, 0.1], # (B, N, n_classes)
                        [0.8, 0.0, 0.2],
                        [0.1, 0.0, 0.9]],
                       [[0.9, 0.0, 0.1],
                        [0.4, 0.5, 0.1],
                        [0.3, 0.1, 0.6]]], dtype=jnp.float32)
print("class_prob", class_prob.shape)
bbox_true = jnp.array([[[0.2, 0.1, 0.6, 0.9],           # 1
                       [0.1, 0.4, 0.5, 0.6],            # 2
                       [0.0, 0.0, 0.0, 0.0]],           # 0
                       [[0.1, 0.6, 0.5, 0.2],        # 2
                        [0.1, 0.3, 0.5, 0.4],         # 1
                        [0.0, 0.0, 0.0, 0.0]]],      # 0
                            dtype=jnp.float32) # (B, N, 4)
print("bbox_true", bbox_true.shape)
bbox_pred = jnp.array([[[0.198, 0.1, 0.601, 0.91],                           # 1
                       [0.01, 0.009, 0.001, 0.0],                            # 0
                       [0.101, 0.39, 0.501, 0.601]],                         # 2
                      [[0.01, 0.009, 0.001, 0.0],                    # 0
                       [0.11, 0.298, 0.499, 0.39],                # 1
                       [0.11, 0.62, 0.501, 0.2009]]],                # 2
                            dtype=jnp.float32) # (B, N, 4)
print("bbox_pred", bbox_pred.shape)

print("\nmaximum number of onbjects that can be detected is", N)
print("unique classes", jnp.unique(class_true), end="\n\n")

print("class_true", class_true, end="\n\n", sep="\n")
print("class_pred", class_pred, end="\n\n", sep="\n")
print("class_prob", class_prob, end="\n\n", sep="\n")
print("bbox_true", bbox_true, end="\n\n", sep="\n")
print("bbox_pred", bbox_pred, end="\n\n", sep="\n")

2024-03-05 16:50:22.961253: W external/xla/xla/service/gpu/nvptx_compiler.cc:742] The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.3.107). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


class_true (2, 3)
class_pred (2, 3)
class_prob (2, 3, 3)
bbox_true (2, 3, 4)
bbox_pred (2, 3, 4)

maximum number of onbjects that can be detected is 3
unique classes [0. 1. 2.]

class_true
[[1. 2. 0.]
 [2. 1. 0.]]

class_pred
[[1. 0. 2.]
 [0. 1. 2.]]

class_prob
[[[0.1 0.8 0.1]
  [0.8 0.  0.2]
  [0.1 0.  0.9]]

 [[0.9 0.  0.1]
  [0.4 0.5 0.1]
  [0.3 0.1 0.6]]]

bbox_true
[[[0.2 0.1 0.6 0.9]
  [0.1 0.4 0.5 0.6]
  [0.  0.  0.  0. ]]

 [[0.1 0.6 0.5 0.2]
  [0.1 0.3 0.5 0.4]
  [0.  0.  0.  0. ]]]

bbox_pred
[[[0.198  0.1    0.601  0.91  ]
  [0.01   0.009  0.001  0.    ]
  [0.101  0.39   0.501  0.601 ]]

 [[0.01   0.009  0.001  0.    ]
  [0.11   0.298  0.499  0.39  ]
  [0.11   0.62   0.501  0.2009]]]



In [3]:
def BoxLoss(bbox_true:Array, bbox_pred:Array, lambda_giou:float=2., lambda_l1:float=5.):
    giou_loss = lambda_giou*ncv.losses.GIoULoss(bounding_box_format="xyWH")(bbox_true, bbox_pred)
    l1_loss = lambda_l1*nn.losses.MeanAbsoluteError()(bbox_true, bbox_pred)
    return giou_loss + l1_loss

def ClassLoss(class_true:Array, class_prob:Array, down_weight_no_object_class:bool=False):
    sample_weight = jnp.where(class_true==0, 0.1, 1.) if down_weight_no_object_class else None # downsample no_object class by 10% if True
    return nn.losses.SparseCategoricalCrossentropy()(class_true, class_prob, sample_weight=sample_weight) # -logprob(class_true)

########## Test ############
BoxLoss(bbox_true[0], bbox_pred[0]), ClassLoss(class_true, class_prob, False), ClassLoss(class_true, class_prob, True)

(Array(3.3262436, dtype=float32),
 Array(1.3891453, dtype=float32),
 Array(0.86316156, dtype=float32))

In [4]:
BoxLoss(bbox_true[0][0][None], bbox_pred[0][0][None]), bbox_true[0][0][None], bbox_pred[0][0][None]

(Array(0.0480862, dtype=float32),
 Array([[0.2, 0.1, 0.6, 0.9]], dtype=float32),
 Array([[0.198, 0.1  , 0.601, 0.91 ]], dtype=float32))

In [5]:
print("class_true", class_true, end="\n\n", sep="\n")
print("class_pred", class_pred, end="\n\n", sep="\n")
print("class_prob", class_prob, end="\n\n", sep="\n")
print("bbox_true", bbox_true, end="\n\n", sep="\n")
print("bbox_pred", bbox_pred, end="\n\n", sep="\n")

class_true
[[1. 2. 0.]
 [2. 1. 0.]]

class_pred
[[1. 0. 2.]
 [0. 1. 2.]]

class_prob
[[[0.1 0.8 0.1]
  [0.8 0.  0.2]
  [0.1 0.  0.9]]

 [[0.9 0.  0.1]
  [0.4 0.5 0.1]
  [0.3 0.1 0.6]]]

bbox_true
[[[0.2 0.1 0.6 0.9]
  [0.1 0.4 0.5 0.6]
  [0.  0.  0.  0. ]]

 [[0.1 0.6 0.5 0.2]
  [0.1 0.3 0.5 0.4]
  [0.  0.  0.  0. ]]]

bbox_pred
[[[0.198  0.1    0.601  0.91  ]
  [0.01   0.009  0.001  0.    ]
  [0.101  0.39   0.501  0.601 ]]

 [[0.01   0.009  0.001  0.    ]
  [0.11   0.298  0.499  0.39  ]
  [0.11   0.62   0.501  0.2009]]]



In [6]:
[[0, 2, 1],
 [2, 1, 0]]

[[0, 2, 1], [2, 1, 0]]

In [7]:
@jax.jit
def MatchLoss(class_true:Array, class_prob:Array, bbox_true:Array, bbox_pred:Array):
    """```
    Inouts:
        class_true:Array => shape(,)
        class_prob:Array => shape(,)
        bbox_true:Array => shape(4,)
        bbox_pred:Array => shape(4,)
    Outputs:
        match_loss:Array => shape(,)
    ```"""
    class_bool = (class_true!=0).astype(float) # int(not class_true==0)
    # class_true = no_object = 0 => int(not True) = 0
    # class_true != no_object != 0 => int(not False) = 1
    
    match_loss = -class_bool*class_prob + class_bool*BoxLoss(bbox_true[None], bbox_pred[None])
    return match_loss

In [8]:
C = jnp.zeros((2, N, N))
for b in range(2):
    for i in range(N): # [0, 1, 2]
        for j in range(N): # [0, 1, 2]
            C = C.at[b, i, j].set(MatchLoss(class_true[b][i], class_prob[b][j, class_true[b][i].astype(int)], bbox_true[b][i], bbox_pred[b][j]))

In [9]:
C

Array([[[-0.7519138 ,  4.8454957 ,  2.271843  ],
        [ 2.2103505 ,  4.748815  , -0.8098676 ],
        [ 0.        ,  0.        ,  0.        ]],

       [[ 5.1964498 ,  2.2173624 , -0.12120765],
        [ 4.618862  , -0.3294528 ,  2.2982922 ],
        [ 0.        ,  0.        ,  0.        ]]], dtype=float32)

In [10]:
class Matcher:
    def __init__(self, vmaped:bool=True):
        self.vmaped:bool = vmaped

    @staticmethod
    @jax.jit
    def MatchLoss(class_true:Array, class_prob:Array, bbox_true:Array, bbox_pred:Array):
        """```
        Inouts:
            class_true:Array => shape(,)
            class_prob:Array => shape(,)
            bbox_true:Array => shape(4,)
            bbox_pred:Array => shape(4,)
        Outputs:
            match_loss:Array => shape(,)
        ```"""
        class_bool = (class_true!=0).astype(float) # int(not class_true==0)
        # class_true = no_object = 0 => int(not True) = 0
        # class_true != no_object != 0 => int(not False) = 1
        
        match_loss = -class_bool*class_prob + class_bool*BoxLoss(bbox_true[None], bbox_pred[None])
        return match_loss
    
    @staticmethod
    @jax.jit
    def compute_unbatched_cost_matrix(class_true:Array, class_prob:Array, bbox_true:Array, bbox_pred:Array):
        """```
        Inputs:
            class_true:Array => shape(N,)
            class_prob:Array => shape(N, num_classes)
            bbox_true:Array => shape(N, 4)
            bbox_pred:Array => shape(N, 4)
        Outputs:
            unbatched_cost:Array => shape(N, N)
        ```"""
        N = class_true.shape[0]
        cost_i = lambda i: jax.vmap(lambda j: Matcher.MatchLoss(
            class_true[i],
            class_prob[j, class_true[i].astype(int)],
            bbox_true[i],
            bbox_pred[j]
        ), in_axes=0, out_axes=0)(jnp.arange(N))
        unbatched_cost = jax.vmap(lambda i: cost_i(i), in_axes=0, out_axes=0)(jnp.arange(N))
        return unbatched_cost # (N, N)
    
    @staticmethod
    @jax.jit
    def compute_batched_cost_matrix(class_true:Array, class_prob:Array, bbox_true:Array, bbox_pred:Array):
        """```
        Inputs:
            class_true:Array => shape(B, N)
            class_prob:Array => shape(B, N, num_classes)
            bbox_true:Array => shape(B, N, 4)
            bbox_pred:Array => shape(B, N, 4)
        Outputs:
            C:Array => shape(B, N, N)
        ```"""
        batch_size = class_true.shape[0]
        C = jax.vmap(lambda B: Matcher.compute_unbatched_cost_matrix(
            class_true[B], class_prob[B], bbox_true[B], bbox_pred[B]
        ), in_axes=0, out_axes=0)(jnp.arange(batch_size))
        return C # (B, N, N)
    
    @staticmethod
    @jax.jit
    def unvmaped_compute_batched_cost_matrix(class_true:Array, class_prob:Array, bbox_true:Array, bbox_pred:Array):
        """```
        Inputs:
            class_true:Array => shape(B, N)
            class_prob:Array => shape(B, N, num_classes)
            bbox_true:Array => shape(B, N, 4)
            bbox_pred:Array => shape(B, N, 4)
        ```"""
        batch_size = len(class_true)
        C = jnp.zeros((batch_size, N, N))
        for b in range(batch_size):
            for i in range(N):
                for j in range(N):
                    C = C.at[b, i, j].set(
                        Matcher.MatchLoss(
                            class_true[b][i],
                            class_prob[b][j, class_true[b][i].astype(int)], # prob of true class
                            bbox_true[b][i],
                            bbox_pred[b][j]
                        )
                    )
        return C

    @staticmethod
    # Cannot jit this function as linear_sum_assignment is used which is a numpy function not a jax function
    def match(class_true:Array, class_prob:Array, bbox_true:Array, bbox_pred:Array, vmaped:bool=True):
        """```
        Inputs:
            class_true:Array => shape(B, N)
            class_prob:Array => shape(B, N, num_classes)
            bbox_true:Array => shape(B, N, 4)
            bbox_pred:Array => shape(B, N, 4)
        Outputs:
            matched_class_prob:Array => shape(B, N, num_classes)
            matched_bbox_pred:Array => shape(B, N, 4)
        ```"""
        C:Array = Matcher.compute_batched_cost_matrix( # (B, N, N)
                    class_true, class_prob, bbox_true, bbox_pred
                ) if vmaped else Matcher.unvmaped_compute_batched_cost_matrix(
                    class_true, class_prob, bbox_true, bbox_pred
                )
        to_indices = jnp.stack(list(map(lambda Cb: scipy_lsa(Cb)[1], C)))[..., None] # (B, N, 1)

        matched_class_prob = jnp.take_along_axis(class_prob, to_indices, axis=1) # (B, N, num_classes)
        matched_bbox_pred = jnp.take_along_axis(bbox_pred, to_indices, axis=1)   # (B, N, 4)
        return matched_class_prob, matched_bbox_pred # (B, N, num_classes), (B, N, 4)
    
    def __call__(self, y_true:Sequence[Array], y_pred:Sequence[Array]):
        (class_true, bbox_true), (class_prob, bbox_pred) = y_true, y_pred
        (class_true, bbox_true) = lax.stop_gradient(class_true.astype(int)), lax.stop_gradient(bbox_true)
        (class_prob, bbox_pred) = lax.stop_gradient(class_prob), lax.stop_gradient(bbox_pred)

        y_matched_pred = Matcher.match(class_true, class_prob, bbox_true, bbox_pred, vmaped=self.vmaped)
        return y_matched_pred

In [13]:
class_prob, bbox_pred, class_pred, class_true

(Array([[[0.1, 0.8, 0.1],
         [0.8, 0. , 0.2],
         [0.1, 0. , 0.9]],
 
        [[0.9, 0. , 0.1],
         [0.4, 0.5, 0.1],
         [0.3, 0.1, 0.6]]], dtype=float32),
 Array([[[0.198 , 0.1   , 0.601 , 0.91  ],
         [0.01  , 0.009 , 0.001 , 0.    ],
         [0.101 , 0.39  , 0.501 , 0.601 ]],
 
        [[0.01  , 0.009 , 0.001 , 0.    ],
         [0.11  , 0.298 , 0.499 , 0.39  ],
         [0.11  , 0.62  , 0.501 , 0.2009]]], dtype=float32),
 Array([[1., 0., 2.],
        [0., 1., 2.]], dtype=float32),
 Array([[1., 2., 0.],
        [2., 1., 0.]], dtype=float32))

In [12]:
Matcher()(y_true=(class_true, bbox_true), y_pred=(class_prob, bbox_pred))

(Array([[[0.1, 0.8, 0.1],
         [0.1, 0. , 0.9],
         [0.8, 0. , 0.2]],
 
        [[0.3, 0.1, 0.6],
         [0.4, 0.5, 0.1],
         [0.9, 0. , 0.1]]], dtype=float32),
 Array([[[0.198 , 0.1   , 0.601 , 0.91  ],
         [0.101 , 0.39  , 0.501 , 0.601 ],
         [0.01  , 0.009 , 0.001 , 0.    ]],
 
        [[0.11  , 0.62  , 0.501 , 0.2009],
         [0.11  , 0.298 , 0.499 , 0.39  ],
         [0.01  , 0.009 , 0.001 , 0.    ]]], dtype=float32))

In [14]:
# possible classes = [0, 1, 2]
class_true = jnp.array([[1, 2, 0], 
                        [2, 1, 0]], dtype=jnp.float32)
print("class_true", class_true.shape) # (B, N)
class_pred = jnp.array([[1, 0, 2], 
                        [0, 1, 2]], dtype=jnp.float32)
print("class_pred", class_pred.shape) # (B, N)
N = class_true.shape[-1]

class_prob = jnp.array([[[0.1, 0.8, 0.1], # (B, N, n_classes)
                        [0.8, 0.0, 0.2],
                        [0.1, 0.0, 0.9]],
                       [[0.9, 0.0, 0.1],
                        [0.4, 0.5, 0.1],
                        [0.3, 0.1, 0.6]]], dtype=jnp.float32)
print("class_prob", class_prob.shape)
bbox_true = jnp.array([[[0.2, 0.1, 0.6, 0.9],           # 1
                       [0.1, 0.4, 0.5, 0.6],            # 2
                       [0.0, 0.0, 0.0, 0.0]],           # 0
                       [[0.1, 0.6, 0.5, 0.2],        # 2
                        [0.1, 0.3, 0.5, 0.4],         # 1
                        [0.0, 0.0, 0.0, 0.0]]],      # 0
                            dtype=jnp.float32) # (B, N, 4)
print("bbox_true", bbox_true.shape)
bbox_pred = jnp.array([[[0.198, 0.1, 0.601, 0.91],                           # 1
                       [0.01, 0.009, 0.001, 0.0],                            # 0
                       [0.101, 0.39, 0.501, 0.601]],                         # 2
                      [[0.01, 0.009, 0.001, 0.0],                    # 0
                       [0.11, 0.298, 0.499, 0.39],                # 1
                       [0.11, 0.62, 0.501, 0.2009]]],                # 2
                            dtype=jnp.float32) # (B, N, 4)
print("bbox_pred", bbox_pred.shape)

print("\nmaximum number of onbjects that can be detected is", N)
print("unique classes", jnp.unique(class_true), end="\n\n")

print("class_true", class_true, end="\n\n", sep="\n")
print("class_pred", class_pred, end="\n\n", sep="\n")
print("class_prob", class_prob, end="\n\n", sep="\n")
print("bbox_true", bbox_true, end="\n\n", sep="\n")
print("bbox_pred", bbox_pred, end="\n\n", sep="\n")

class_true (2, 3)
class_pred (2, 3)
class_prob (2, 3, 3)
bbox_true (2, 3, 4)
bbox_pred (2, 3, 4)

maximum number of onbjects that can be detected is 3
unique classes [0. 1. 2.]

class_true
[[1. 2. 0.]
 [2. 1. 0.]]

class_pred
[[1. 0. 2.]
 [0. 1. 2.]]

class_prob
[[[0.1 0.8 0.1]
  [0.8 0.  0.2]
  [0.1 0.  0.9]]

 [[0.9 0.  0.1]
  [0.4 0.5 0.1]
  [0.3 0.1 0.6]]]

bbox_true
[[[0.2 0.1 0.6 0.9]
  [0.1 0.4 0.5 0.6]
  [0.  0.  0.  0. ]]

 [[0.1 0.6 0.5 0.2]
  [0.1 0.3 0.5 0.4]
  [0.  0.  0.  0. ]]]

bbox_pred
[[[0.198  0.1    0.601  0.91  ]
  [0.01   0.009  0.001  0.    ]
  [0.101  0.39   0.501  0.601 ]]

 [[0.01   0.009  0.001  0.    ]
  [0.11   0.298  0.499  0.39  ]
  [0.11   0.62   0.501  0.2009]]]



In [None]:
class_true = jrand.uniform(jrand.PRNGKey(42), shape=())