## **Imports**

In [34]:
!pip install torchmetrics

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [2]:
import torch
import torchmetrics
from torchmetrics import Metric, ConfusionMatrix

import numpy as np
from sklearn.metrics import multilabel_confusion_matrix, confusion_matrix, accuracy_score, precision_score, recall_score

## **Binary confusion matrix**

In [36]:
torch.manual_seed(6)
#Inicialize the binary randomic values
binary_samples = 20
binary_classes = 2
binary_output = torch.randn(binary_samples, binary_classes)
binary_pred = torch.argmax(binary_output, 1)
binary_target = torch.randint(0, high = binary_classes, size = (binary_samples,))
print(f'Predict: {binary_pred}')
print(f'Target:  {binary_target}')

Predict: tensor([1, 1, 1, 0, 1, 0, 0, 0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1])
Target:  tensor([0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1])


In [37]:
#Create the confusion matrix and each TP, TN, FP and FN of the distribution
binary_confmat = ConfusionMatrix(num_classes = 2)

binary_cm = binary_confmat(binary_pred, binary_target)
print(f'Pytorch confusion matrix: \n {binary_cm} \n')
[binary_TN, binary_FP], [binary_FN, binary_TP] =  binary_confmat(binary_pred, binary_target)
print('TP {}, TN {}, FP {}, FN {}'.format(binary_TP, binary_TN, binary_FP, binary_FN))

Pytorch confusion matrix: 
 tensor([[5, 4],
        [2, 9]]) 

TP 9, TN 5, FP 4, FN 2


In [38]:
# Calculating each metric by myself

binary_accuracy = (binary_TP + binary_TN) / (binary_TP + binary_FP + binary_TN + binary_FN)
binary_accuracy_cm = np.sum(np.diag(binary_cm)/np.sum(binary_cm.numpy()))

print(' Binary accuracy: {} \n Using confussion matrix: {}'.format(binary_accuracy, binary_accuracy_cm))

binary_precision   = binary_TP / (binary_TP + binary_FP)
binary_sensitivity = binary_TP / (binary_TP + binary_FN)
binary_specificity = binary_TN / (binary_TN + binary_FP)

print(' Precision: {} \n Sensitivity: {} \n Specificity: {}'.format(binary_precision, binary_sensitivity, binary_specificity))

 Binary accuracy: 0.699999988079071 
 Using confussion matrix: 0.7
 Precision: 0.692307710647583 
 Sensitivity: 0.8181818127632141 
 Specificity: 0.5555555820465088


In [39]:
#sk-learn metrics

b_acc = accuracy_score(binary_target, binary_pred)
b_pre = precision_score(binary_target, binary_pred)
b_rec = recall_score(binary_target, binary_pred)
print(' Accuracy: {} \n Precision: {} \n Sensitivity: {} \n'.format(b_acc, b_pre, b_rec))

 Accuracy: 0.7 
 Precision: 0.6923076923076923 
 Sensitivity: 0.8181818181818182 



## **Multi-class confusion matrix**

In [40]:
torch.manual_seed(0)
#Inicialize the multi-class randomic values
nb_samples = 20
nb_classes = 4
mc_output = torch.randn(nb_samples, nb_classes)
mc_pred = torch.argmax(mc_output, 1)
mc_target = torch.randint(0, high = nb_classes, size = (nb_samples,))
print(f'Predict: {mc_pred}')
print(f'Target:  {mc_target}')

Predict: tensor([2, 0, 2, 1, 3, 3, 0, 1, 3, 2, 1, 1, 2, 2, 3, 2, 3, 1, 2, 1])
Target:  tensor([2, 3, 0, 1, 1, 3, 1, 1, 3, 2, 3, 3, 2, 2, 3, 0, 2, 3, 1, 0])


In [41]:
# Pytorch confusion matrix
mc_conf_matrix = torch.zeros(nb_classes, nb_classes)
for t, p in zip(mc_target, mc_pred):
    mc_conf_matrix[t, p] += 1
print(f'Pytorch confusion matrix: \n {mc_conf_matrix} \n')

mc_confmat = ConfusionMatrix(num_classes = 4)
print(f'Pytorch confusion matrix: \n {mc_confmat(mc_pred, mc_target)} \n')

# Sklearn confusion matrix
mc_cm = confusion_matrix(mc_target.flatten(), mc_pred.flatten(), labels=[0,1,2,3])
print(f'Sklearn multi-label confusion matrix\n  {mc_cm} \n')

Pytorch confusion matrix: 
 tensor([[0., 1., 2., 0.],
        [1., 2., 1., 1.],
        [0., 0., 4., 1.],
        [1., 3., 0., 3.]]) 

Pytorch confusion matrix: 
 tensor([[0, 1, 2, 0],
        [1, 2, 1, 1],
        [0, 0, 4, 1],
        [1, 3, 0, 3]]) 

Sklearn multi-label confusion matrix
  [[0 1 2 0]
 [1 2 1 1]
 [0 0 4 1]
 [1 3 0 3]] 



In [42]:
# Pytorch TP - TN - FP - FN for each class
mc_TP = mc_conf_matrix.diag()
for c in range(nb_classes):
    idx = torch.ones(nb_classes).byte()
    idx[c] = 0
    # all non-class samples classified as non-class
    mc_TN = mc_conf_matrix[idx.nonzero()[:, None], idx.nonzero()].sum() 
    # all non-class samples classified as class
    mc_FP = mc_conf_matrix[idx, c].sum()
    # all class samples not classified as class
    mc_FN = mc_conf_matrix[c, idx].sum()
    
    print('Class {}\nTP {}, TN {}, FP {}, FN {}'.format(
        c, mc_TP[c], mc_TN, mc_FP, mc_FN))

Class 0
TP 0.0, TN 15.0, FP 2.0, FN 3.0
Class 1
TP 2.0, TN 11.0, FP 4.0, FN 3.0
Class 2
TP 4.0, TN 12.0, FP 3.0, FN 1.0
Class 3
TP 3.0, TN 11.0, FP 2.0, FN 4.0


  if __name__ == '__main__':
  # This is added back by InteractiveShellApp.init_path()


In [43]:
#Sklearn  TP - TN - FP - FN for each class
cm_ml = multilabel_confusion_matrix(mc_target, mc_pred)
print('Sklearn multi-label confusion matrix\n', cm_ml)

Sklearn multi-label confusion matrix
 [[[15  2]
  [ 3  0]]

 [[11  4]
  [ 3  2]]

 [[12  3]
  [ 1  4]]

 [[11  2]
  [ 4  3]]]


TorchMetrics is an open-source PyTorch native collection of functional and module-wise metrics for simple performance evaluations. **You can use out-of-the-box implementations** for common metrics such as Accuracy, Recall, Precision, AUROC, RMSE, R² etc. or create your own metric.

## Functional metrics

They are simple python functions that, as input, take torch.tensors and return the corresponding metric as a torch.tensor. Nevertheless there is an actual problem when evaluating the metrics in for binary problems. 

In [48]:
mc_acc  = torchmetrics.functional.accuracy(mc_pred, mc_target, num_classes = 4, average='macro')
bin_acc = torchmetrics.functional.accuracy(binary_pred, binary_target, num_classes = 2, average= None).mean()
print(' Multi-class accuracy: {} \n Binary accuracy: {} \n'.format(mc_acc, bin_acc))

mc_prec  = torchmetrics.functional.precision(mc_pred, mc_target, num_classes = 4, average='macro')
bin_prec = torchmetrics.functional.precision(binary_pred, binary_target, num_classes = 2, average= None)[-1]
print(' Multi-class precision: {} \n Binary precision: {} \n'.format(mc_prec, bin_prec))

mc_sens  = torchmetrics.functional.recall(mc_pred, mc_target, num_classes = 4, average='macro')
bin_sens = torchmetrics.functional.recall(binary_pred, binary_target, num_classes = 2, average= None)[-1]
print(' Multi-class sensitivity: {} \n Binary sensitivity: {} \n'.format(mc_sens, bin_sens))

mc_speci  = torchmetrics.functional.specificity(mc_pred, mc_target, num_classes = 4, average='macro')
bin_speci = torchmetrics.functional.specificity(binary_pred, binary_target, num_classes = 2, average= None)[-1]
print(' Multi-class specificity: {} \n Binary specificity: {} \n'.format(mc_speci, bin_speci))

 Multi-class accuracy: 0.40714287757873535 
 Binary accuracy: 0.6868686676025391 

 Multi-class precision: 0.3761904835700989 
 Binary precision: 0.692307710647583 

 Multi-class sensitivity: 0.40714287757873535 
 Binary sensitivity: 0.8181818127632141 

 Multi-class specificity: 0.8154600858688354 
 Binary specificity: 0.5555555820465088 



## **Module metrics**
The class-based metrics are characterized by having one or more internal metrics states (similar to the parameters of the PyTorch module) that allow them to offer additional functionalities:

- Accumulation of multiple batches
- Automatic synchronization between multiple devices
- Metric arithmetic

In [47]:
torch.manual_seed(42)

# initialize metric
metric = torchmetrics.Accuracy()

n_batches = 10
for i in range(n_batches):
    # simulate a classification problem
    preds = torch.randn(10, 5).softmax(dim=-1)
    target = torch.randint(5, (10,))
    # metric on current batch
    acc = metric(preds, target)
    print(f"Accuracy on batch {i}: {acc}")

# metric on all batches using custom accumulation
acc = metric.compute()
print(f"Accuracy on all data: {acc}")

# Reseting internal state such that metric ready for new data
metric.reset()

Accuracy on batch 0: 0.20000000298023224
Accuracy on batch 1: 0.20000000298023224
Accuracy on batch 2: 0.4000000059604645
Accuracy on batch 3: 0.30000001192092896
Accuracy on batch 4: 0.10000000149011612
Accuracy on batch 5: 0.10000000149011612
Accuracy on batch 6: 0.10000000149011612
Accuracy on batch 7: 0.20000000298023224
Accuracy on batch 8: 0.20000000298023224
Accuracy on batch 9: 0.10000000149011612
Accuracy on all data: 0.1899999976158142


## Implementing a metric

In [31]:
from torchmetrics import Metric

class MyAccuracy(Metric):
    def __init__(self):
        super().__init__()
        self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum")
        self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")

    def update(self, preds: torch.Tensor, target: torch.Tensor):
        print("Calling update")
        assert preds.shape == target.shape

        self.correct += torch.sum(preds == target)
        self.total += target.numel()

    def compute(self):
        print("Calling compute")
        return self.correct.float() / self.total

In [33]:
torch.manual_seed(100)
#Inicialize the multi-class randomic values
met_output = torch.randn(20, 2)
met_pred = torch.argmax(met_output, 1)
met_target = torch.randint(0, high = 2, size = (20,))

custom_acc = MyAccuracy()
custom_acc(met_pred, met_target)

Calling update
Calling update
Calling compute


tensor(0.6000)

## Data of this class was find on:

- https://torchmetrics.readthedocs.io/en/stable/classification/accuracy.html
- https://torchmetrics.readthedocs.io/en/stable/classification/confusion_matrix.html
- https://github.com/Lightning-AI/metrics/issues/1113
- https://scikit-learn.org/stable/modules/model_evaluation.html