In [262]:
from typing import List, Tuple

import numpy as np
from numpy import ndarray


def get_mAP(
    preds: ndarray,
    targets: ndarray
) -> Tuple[float, ndarray]:

    # compute average precision for each class
    APs = np.zeros(preds.shape[1])
    for k in range(preds.shape[1]):
        # print(preds[:, k], targets[:, k])
        APs[k] = _average_precision_2(preds[:, k], targets[:, k])

    return APs.mean(), APs


def _average_precision(output: ndarray, target: ndarray) -> float:
    # print(output, target)
    epsilon = 1e-8

    # sort examples
    indices = output.argsort()[::-1]
    # Computes prec@i
    total_count_ = np.cumsum(np.ones((len(output), 1)))

    target_ = target[indices]

    ind = target_ == 1

    pos_count_ = np.cumsum(ind)

    total = pos_count_[-1]
    pos_count_[np.logical_not(ind)] = 0

    pp = pos_count_ / total_count_
    precision_at_i_ = np.sum(pp)
    # print(pp, precision_at_i_, pos_count_, total_count_)
    precision_at_i = precision_at_i_ / (total + epsilon)

    return precision_at_i


def _average_precision_2(output: ndarray, target: ndarray) -> float:
    # print(output, target)
    epsilon = 1e-8

    total_pred_pos = np.count_nonzero(output)

    TP = np.count_nonzero(np.logical_and(output, target))
    FP = total_pred_pos - TP

    AP = TP / (TP + FP + epsilon)

    return AP

In [263]:
y_pred = np.array([[0, 0, 0, 0, 0, 1, 0, 0, 1], [0, 0, 0, 0, 0, 1, 0, 0, 1], [0, 0, 0, 0, 0, 0, 0, 0, 1]])
y_true = np.array([[1, 1, 1, 1, 1, 1, 1, 1, 0], [1, 1, 1, 1, 1, 1, 1, 1, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1]])

In [264]:
get_mAP(
    y_pred,
    y_true
)

(0.14814814746913582,
 array([0.        , 0.        , 0.        , 0.        , 0.        ,
        1.        , 0.        , 0.        , 0.33333333]))

In [265]:
from sklearn import metrics

# y_pred = np.array([[0, 0, 0, 0, 0, 1, 0, 0, 1], [0, 0, 0, 0, 0, 0, 0, 0, 1]])
# y_true = np.array([[1, 1, 1, 1, 1, 1, 1, 1, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1]])

ap = np.array(metrics.precision_score(y_true, y_pred, average=None))
print(ap)

[0.         0.         0.         0.         0.         1.
 0.         0.         0.33333333]


  _warn_prf(average, modifier, msg_start, len(result))


In [266]:
print(ap.mean())

0.14814814814814814
