<a href="https://colab.research.google.com/github/2grep/ScienceNotes/blob/main/generalized_metrics.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np

In [None]:
def generalized_accuracy(conf_matrix, weights):
    '''
    Generalized accuracy takes in a confusion matrix and returns a measure of the overall predictive value of a given classifier.
    '''
    score = 0
    num_elements = sum([sum(r) for r in conf_matrix])
    if num_elements == 0:
      raise("No elements")
      
    for i, row in enumerate(conf_matrix):
        for j, x in enumerate(row):
            distance = np.abs(i - j)
            score += x * weights[distance]

    return score / num_elements

def generalized_recall(conf_matrix, weights, label_num):
    '''
    Generalized recall is computed for a specific label. For a given label, it computes a measure of all the label's actual instances which were corretly labelled.
    '''
    col = [row[label_num] for row in conf_matrix]

    if sum(col) == 0:
        raise("No elements with that ground truth label")

    score = 0
    show_work = "("
    for i, x in enumerate(col):
        distance = np.abs(i - label_num)
        score += x * weights[distance]

        # to print out the work
        show_work += ("{} * {:.2f}").format(x, weights[distance])
        if i != len(col) - 1:
            show_work += (" + ")
        else:
            show_work += (") / ({}) = ").format(sum(col))
        # to print out the work

    gen_recall = score / sum(col)
    print(show_work, gen_recall)
    return gen_recall

def generalized_precision(conf_matrix, weights, label_num):
    '''
    Generalized precision is computed for a specific label. For a given label, it computes a measure of all the label's predicted instances which were correctly labelled.
    '''
    row = conf_matrix[label_num]

    if sum(row) == 0:
        raise("No elements with that predicted label")

    score = 0
    show_work = "("
    for i, x in enumerate(row):
        distance = np.abs(i - label_num)
        score += x * weights[distance]

        # to print out the work
        show_work += ("{} * {:.2f}").format(x, weights[distance])
        if i != len(row) - 1:
            show_work += (" + ")
        else:
            show_work += (") / ({}) = ").format(sum(row))
        # to print out the work

    gen_prec = score / sum(row)
    print(show_work, gen_prec)
    return gen_prec

In [None]:
num_classes = 4
weights = np.linspace(1, .3, ) 
# weights must be monotonically decreasing, but otherwise can be however you want. This can actually be a num_classes by num_classes matrix.

conf_matrix =  [[150, 30, 10, 7], 
                [17, 200, 20, 8], 
                [6, 35, 160, 21], 
                [4, 17, 32, 214]]
# Confusion matrix follows the convention of https://i.stack.imgur.com/a3hnS.png. The element in the ith row and jth column is the the number of elements with ground truth j which were classified as i by the model.
# Confusion matrix must have num_classes rows and columns. All elements should be non-negative integers.

for i in range(num_classes):
  print("Gen. Prec for label", i)
  generalized_precision(conf_matrix, weights, i)
  print("Gen. Recall for label", i)
  generalized_recall(conf_matrix, weights, i)
print("Overall Gen. Accuracy:", generalized_accuracy(conf_matrix, weights))

Gen. Prec for label 0
(150 * 1.00 + 30 * 0.67 + 10 * 0.33 + 7 * 0.00) / (197) =  0.8798646362098139
Gen. Recall for label 0
(150 * 1.00 + 17 * 0.67 + 6 * 0.33 + 4 * 0.00) / (177) =  0.9227871939736347
Gen. Prec for label 1
(17 * 0.67 + 200 * 1.00 + 20 * 0.67 + 8 * 0.33) / (245) =  0.927891156462585
Gen. Recall for label 1
(30 * 0.67 + 200 * 1.00 + 35 * 0.67 + 17 * 0.33) / (282) =  0.8829787234042553
Gen. Prec for label 2
(6 * 0.33 + 35 * 0.67 + 160 * 1.00 + 21 * 0.67) / (222) =  0.897897897897898
Gen. Recall for label 2
(10 * 0.33 + 20 * 0.67 + 160 * 1.00 + 32 * 0.67) / (222) =  0.891891891891892
Gen. Prec for label 3
(4 * 0.00 + 17 * 0.33 + 32 * 0.67 + 214 * 1.00) / (267) =  0.9026217228464419
Gen. Recall for label 3
(7 * 0.00 + 8 * 0.33 + 21 * 0.67 + 214 * 1.00) / (250) =  0.9226666666666666
Overall Gen. Accuracy: 0.9033297529538131
