In [1]:
import torch
from three_class_metrics import *

In [2]:
def compute_conf_mat(out, y, num_labels):
    labels = torch.arange(0, num_labels).to(out.device)
    return ((out == labels[:, None]) & (y == labels[:, None, None])).sum(
        dim=2, dtype=torch.float32)

In [3]:
# Nonemine and nonmine
i = torch.ones(1) * 0
t = torch.ones(1) * 0
compute_conf_mat(i, t, 3)

tensor([[1., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]])

In [4]:
# lc and nonmine
i = torch.ones(1) * 1
t = torch.ones(1) * 0
compute_conf_mat(i, t, 3)

tensor([[0., 1., 0.],
        [0., 0., 0.],
        [0., 0., 0.]])

In [5]:
# mine and nonmine
i = torch.ones(1) * 2
t = torch.ones(1) * 0
compute_conf_mat(i, t, 3)

tensor([[0., 0., 1.],
        [0., 0., 0.],
        [0., 0., 0.]])

In [6]:
# Nonemine and nonmine
i = torch.ones(1) * 0
t = torch.ones(1) * 0
compute_conf_mat(i, t, 3)

tensor([[1., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]])

In [7]:
# Nonemine and lc
i = torch.ones(1) * 0
t = torch.ones(1) * 1
compute_conf_mat(i, t, 3)

tensor([[0., 0., 0.],
        [1., 0., 0.],
        [0., 0., 0.]])

In [8]:
# Nonemine and lc
i = torch.ones(1) * 0
t = torch.ones(1) * 2
compute_conf_mat(i, t, 3)

tensor([[0., 0., 0.],
        [0., 0., 0.],
        [1., 0., 0.]])

In [9]:
torch.manual_seed(42)
i = torch.randint(0, 3, (100, ))
t = torch.randint(0, 3, (100, ))
conf_mat = compute_conf_mat(i, t, 3)
conf_mat

tensor([[ 5., 17.,  9.],
        [12.,  7., 14.],
        [10., 16., 10.]])

In [10]:
eps = torch.tensor(1e-6)
conf_mat = conf_mat.cpu()
conf_mat[1] = 0
conf_mat[:, 1] = 0
gt_count = conf_mat.sum(dim=1)
pred_count = conf_mat.sum(dim=0)
total = conf_mat.sum()
true_pos = torch.diag(conf_mat)
precision = true_pos / torch.max(pred_count, eps)
recall = true_pos / torch.max(gt_count, eps)
f1 = (2 * precision * recall) / torch.max(precision + recall, eps)

weights = gt_count / total
weighted_precision = (weights * precision).sum()
weighted_recall = (weights * recall).sum()
weighted_f1 = ((2 * weighted_precision * weighted_recall) / torch.max(
    weighted_precision + weighted_recall, eps))

metrics = {
    'avg_precision': weighted_precision.item(),
    'avg_recall': weighted_recall.item(),
    'avg_f1': weighted_f1.item()
}

In [11]:
conf_mat

tensor([[ 5.,  0.,  9.],
        [ 0.,  0.,  0.],
        [10.,  0., 10.]])

In [12]:
precision

tensor([0.3333, 0.0000, 0.5263])

In [13]:
recall

tensor([0.3571, 0.0000, 0.5000])