In [1]:
import torch

from metric_utils.metrics import CategoricalAccuracy as CA1
from mlu.metrics import CategoricalAccuracy as CA2

In [2]:
dim = 1
ca1 = CA1()
ca2 = CA2(vector_input=False, vector_target=False, dim=1)

In [3]:
pred = torch.rand(16, 10)
target = torch.zeros(16, 10)
for i in range(len(target)):
    target[i][i % target.shape[1]] = 1.0

pred = pred.argmax(dim=dim)
target = target.argmax(dim=dim)

In [4]:
print(pred)
print(target)

tensor([4, 9, 9, 1, 7, 3, 4, 1, 6, 1, 8, 1, 0, 4, 8, 2])
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5])


In [5]:
r1 = ca1(pred, target)
r2 = ca2(pred, target)
print(r1)
print(r2)
print(r1 == r2)
assert r1 == r2, "Not eq"

tensor(0.0625)
tensor(0.0625)
tensor(True)


In [6]:
from metric_utils.metrics import FScore
from mlu.metrics import FScore as FScoreMLU

In [7]:
seed = 1234
torch.manual_seed(seed)

f1 = FScore(dim=1)
f2 = FScoreMLU(dim=1)

pred = torch.rand(4, 10).ge(0.5).float()
target = torch.zeros(4, 10).ge(0.5).float()

def assert_eq(a, b):
    assert a.eq(b).all(), f"Not eq: {a} != {b}"

s1 = f1(pred, target)
s2 = f2(pred, target)

r1 = f1.recall_func(pred, target)
r2 = f2.recall(pred, target)

p1 = f1.precision_func(pred, target)
p2 = f2.precision(pred, target)

assert_eq(r1, r2)
assert_eq(p1, p2)
assert_eq(s1, s2)
assert_eq(torch.mean(s1), torch.mean(s2))

In [8]:
print("Done")

Done
