In [1]:
import warnings
warnings.filterwarnings("ignore")

In [2]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix
from numba import jit

In [3]:
seed=1234
np.random.seed(seed)

%matplotlib inline
%config InlineBackend.figure_format="svg"

In [4]:
# Function to get the confusion matrix
def get_confusion_matrix(y_true, y_pred):
    """
    Creates a confusion matrix.
    
    Args:
        y_true: Ground truth with n elements
        y_pred: predictions with n elements
    
    Returns:
        conf_mtx: Confusion matrix of size nXn
        
    """
    assert len(y_true)==len(y_pred), """Size of groundtruth vector and 
                                    prediction vector didn't match"""
    
    if not isinstance(y_true, np.ndarray):
        y_true = np.array(y_true, np.int64)
        
    if not isinstance(y_pred, np.ndarray):
        y_pred = np.array(y_pred, np.int64)
        
    classes = np.unique(np.vstack((y_true, y_pred))).shape[0]
    conf_mtx = np.zeros((classes, classes), dtype=np.int64)
    
    for i, j in zip(y_true, y_pred):
        conf_mtx[i][j] +=1
    
    return conf_mtx

In [5]:
# create two random arrays, binary as well as multi-class

# binary classification
binary_y_true = np.random.randint(0,2, size=(20))
binary_y_pred = np.random.randint(0,2, size=(20))

# multi-class classification
multi_y_true = np.random.randint(0,5, size=(30))
multi_y_pred = np.random.randint(0,5, size=(30))

In [6]:
## 1. Check for binary case

# get the confusion matrix from custom implementation
conf_mtx_custom = get_confusion_matrix(binary_y_true, binary_y_pred)

# get the confsuion matrix from sklearn implementation
conf_mtx = confusion_matrix(binary_y_true, binary_y_pred)

print("Custom implementation: ")
print(conf_mtx_custom)
print("\nscikit-learn implementation: ")
print(conf_mtx)

np.all(conf_mtx == conf_mtx_custom)

Custom implementation: 
[[6 5]
 [7 2]]

scikit-learn implementation: 
[[6 5]
 [7 2]]


True

In [7]:
## 2. Check for mutli-class

conf_mtx_custom = get_confusion_matrix(multi_y_true, multi_y_pred)
conf_mtx = confusion_matrix(multi_y_true, multi_y_pred)

print("Custom implementation: ")
print(conf_mtx_custom)
print("\nscikit-learn implementation: ")
print(conf_mtx)

np.all(conf_mtx == conf_mtx_custom)

Custom implementation: 
[[2 1 1 0 0]
 [2 1 3 1 0]
 [0 0 1 0 2]
 [0 1 3 2 3]
 [0 4 1 0 2]]

scikit-learn implementation: 
[[2 1 1 0 0]
 [2 1 3 1 0]
 [0 0 1 0 2]
 [0 1 3 2 3]
 [0 4 1 0 2]]


True

## Count TP FP TN and FN

In [8]:
def get_count(conf_mtx, kind='binary'):
    """
    Counts TP, FP, TN and FN in a confusion matrix
    
    Args:
        conf_mtx: confusion matrix of size nXn
    Returns:
        tn, fp, fn, tp
        
        If the confusion matrix contains two classes, then
        these four quantities will be scalar values. If the 
        confusion matrix contains more than two classes, then 
        each of these four quantities will be a numpy array
        corresponding the type of quantity for each class
        
    """
    assert kind in('binary', 'multi_class'), "Only binary and multi-class are supported"
    
    if kind=='binary':
        tn, fp, fn, tp = conf_mtx.ravel()
    else:
        # Count TP
        tp = np.diag(conf_mtx)
        # count FP
        fp = conf_mtx.sum(axis=0) - tp
        # count FN
        fn = conf_mtx.sum(axis=1) - tp
        # count TN
        tn = conf_mtx.sum() - (tp + fp + fn)
        
    return tn, fp, fn, tp

In [9]:
# Check for binary
conf_mtx = get_confusion_matrix(binary_y_true, binary_y_pred)
tn, fp, fn, tp = get_count(conf_mtx, kind='binary')

print("Confusion matrix: ")
print(conf_mtx)
print("\nTP: ", tp)
print("FP: ", fp)
print("FN: ", fn)
print("TN: ", tn)

Confusion matrix: 
[[6 5]
 [7 2]]

TP:  2
FP:  5
FN:  7
TN:  6


In [10]:
conf_mtx = get_confusion_matrix(multi_y_true, multi_y_pred)
tn, fp, fn, tp = get_count(conf_mtx, kind='multi_class')

print("Confusion matrix: ")
print(conf_mtx,"\n")

for i in range(conf_mtx.shape[0]):
    print(f"Class: {i} \n TN: {tn[i]} FP: {fp[i]} \n FN: {fn[i]} TP: {tp[i]}\n")

Confusion matrix: 
[[2 1 1 0 0]
 [2 1 3 1 0]
 [0 0 1 0 2]
 [0 1 3 2 3]
 [0 4 1 0 2]] 

Class: 0 
 TN: 24 FP: 2 
 FN: 2 TP: 2

Class: 1 
 TN: 17 FP: 6 
 FN: 6 TP: 1

Class: 2 
 TN: 19 FP: 8 
 FN: 2 TP: 1

Class: 3 
 TN: 20 FP: 1 
 FN: 7 TP: 2

Class: 4 
 TN: 18 FP: 5 
 FN: 5 TP: 2

