In [1]:
import torch
from torchmetrics import classification

In [2]:
targets = torch.tensor([
    [0, 1, 1, 0, 0],
    [0, 1, 1, 0, 0],
    [0, 0, 0, 2, 2],
    [0, 0, 0, 2, 2],
    [0, 0, 0, 0, 3]
])

preds = torch.tensor([
    [0, 1, 1, 0, 0],
    [0, 2, 1, 0, 0],
    [0, 0, 0, 1, 2],
    [0, 0, 0, 2, 2],
    [3, 0, 0, 0, 3]
])

unique_classes = torch.unique(targets)

for class_index in unique_classes:
    if class_index == 0:
        # Mask out non-background elements for background class (0)
        background_mask = targets != 0
    else:
        # Mask out background elements for other classes
        background_mask = targets == 0

    preds_fields = preds[~background_mask]
    targets_fields = targets[~background_mask]

    print(f"targets_fields: {targets_fields}")
    print(f"preds_fields: {preds_fields}")

    targets_class = (targets_fields == class_index)
    preds_class = (preds_fields == class_index)

    print(f"targets_class: {targets_class}")
    print(f"preds_class: {preds_class}")
   
    TP = (targets_class & preds_class).sum().item()
    TN = (~targets_class & ~preds_class).sum().item()
    FP = (~targets_class & preds_class).sum().item()
    FN = (targets_class & ~preds_class).sum().item()

    print(f"TP: {TP}")
    print(f"TN: {TN}")
    print(f"FP: {FP}")
    print(f"FN: {FN}")

    accuracy = (TP + TN) / (TP + TN + FP + FN) if (TP + TN + FP + FN) > 0 else 0
    print(f"Accuracy for class {class_index}: {accuracy * 100:.2f}%")

targets_fields: tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
preds_fields: tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0])
targets_class: tensor([True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True])
preds_class: tensor([ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True, False,  True,  True,  True])
TP: 15
TN: 0
FP: 0
FN: 1
Accuracy for class 0: 93.75%
targets_fields: tensor([1, 1, 1, 1, 2, 2, 2, 2, 3])
preds_fields: tensor([1, 1, 2, 1, 1, 2, 2, 2, 3])
targets_class: tensor([ True,  True,  True,  True, False, False, False, False, False])
preds_class: tensor([ True,  True, False,  True,  True, False, False, False, False])
TP: 3
TN: 4
FP: 1
FN: 1
Accuracy for class 1: 77.78%
targets_fields: tensor([1, 1, 1, 1, 2, 2, 2, 2, 3])
preds_fields: tensor([1, 1, 2, 1, 1, 2, 2, 2, 3])
targets_class: tensor([False, False, False, False,  True,  True,  True,  True, False])
preds_class: ten

In [3]:
# Initialize the tensors
targets = torch.tensor([
    [0, 1, 1, 0, 0],
    [0, 1, 1, 0, 0],
    [0, 0, 0, 2, 2],
    [0, 0, 0, 2, 2],
    [0, 0, 0, 0, 3]
])

preds = torch.tensor([
    [0, 1, 1, 0, 0],
    [0, 2, 1, 0, 0],
    [0, 0, 0, 1, 2],
    [0, 0, 0, 2, 2],
    [3, 0, 0, 0, 3]
])

# Initialize the unique classes excluding the background class 0
unique_classes = torch.unique(targets)

# Initialize BinaryAccuracy metric for each class
per_class_accuracies = {class_index.item(): classification.BinaryAccuracy() for class_index in unique_classes}

# Update and compute accuracy for each class
for class_index in unique_classes:
    print("---")
    print(f"Class index: {class_index}")

    if class_index == 0:
        # Mask out non-background elements for background class (0)
        background_mask = targets != 0
    else:
        # Mask out background elements for other classes
        background_mask = targets == 0

    preds_fields = preds[~background_mask]
    targets_fields = targets[~background_mask]

    print(f"targets_fields: {targets_fields}")
    print(f"preds_fields: {preds_fields}")

    targets_class = (targets_fields == class_index).float()
    preds_class = (preds_fields == class_index).float()

    print(f"targets_class: {targets_class}")
    print(f"preds_class: {preds_class}")
   
    # Update the BinaryAccuracy metric for the current class
    per_class_accuracies[class_index.item()].update(preds_class, targets_class)

# Compute and print the accuracy for each class
for class_index, class_accuracy in per_class_accuracies.items():
    class_acc_value = class_accuracy.compute().item()
    print(f"Accuracy for class {class_index}: {class_acc_value * 100:.2f}%")


---
Class index: 0
targets_fields: tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
preds_fields: tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0])
targets_class: tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])
preds_class: tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 1., 1., 1.])
---
Class index: 1
targets_fields: tensor([1, 1, 1, 1, 2, 2, 2, 2, 3])
preds_fields: tensor([1, 1, 2, 1, 1, 2, 2, 2, 3])
targets_class: tensor([1., 1., 1., 1., 0., 0., 0., 0., 0.])
preds_class: tensor([1., 1., 0., 1., 1., 0., 0., 0., 0.])
---
Class index: 2
targets_fields: tensor([1, 1, 1, 1, 2, 2, 2, 2, 3])
preds_fields: tensor([1, 1, 2, 1, 1, 2, 2, 2, 3])
targets_class: tensor([0., 0., 0., 0., 1., 1., 1., 1., 0.])
preds_class: tensor([0., 0., 1., 0., 0., 1., 1., 1., 0.])
---
Class index: 3
targets_fields: tensor([1, 1, 1, 1, 2, 2, 2, 2, 3])
preds_fields: tensor([1, 1, 2, 1, 1, 2, 2, 2, 3])
targets_class: tensor([0., 0., 0., 0., 0., 0., 0., 0., 1.])
pred

# simulate training cycle

In [7]:
per_class_accuracies = {class_index.item(): classification.BinaryAccuracy() for class_index in unique_classes}

targets = torch.tensor([
    [0, 1, 1, 0, 0],
    [0, 1, 1, 0, 0],
    [0, 0, 0, 2, 2],
    [0, 0, 0, 2, 2],
    [0, 0, 0, 0, 3]
])

preds = torch.tensor([
    [0, 1, 1, 0, 0],
    [0, 2, 1, 0, 0],
    [0, 0, 0, 1, 2],
    [0, 0, 0, 2, 2],
    [3, 0, 0, 0, 3]
])

def __update_per_class_accuracy(preds, targets, per_class_accuracies):
    for class_index, class_accuracy in per_class_accuracies.items():
        if class_index == 0:
            # Mask out non-background elements for background class (0)
            background_mask = targets != 0
        else:
            # Mask out background elements for other classes
            background_mask = targets == 0

        preds_fields = preds[~background_mask]
        targets_fields = targets[~background_mask]

        # Prepare for binary classification (needs to be float)
        preds_class = (preds_fields == class_index).float()
        targets_class = (targets_fields == class_index).float()

        print(f"Class index: {class_index}, Background mask: {background_mask}")
        print(f"Preds class: {preds_class}, Targets class: {targets_class}")

        if targets_class.any():
            class_accuracy.update(preds_class, targets_class)
            print(f"Per-class accuracy for class {class_index} updated. Update count: {per_class_accuracies[class_index]._update_count}")

for _ in range(2):
    __update_per_class_accuracy(preds, targets, per_class_accuracies) 

for class_index, class_accuracy in per_class_accuracies.items():
    class_acc_value = class_accuracy.compute().item()
    print(f"Accuracy for class {class_index}: {class_acc_value * 100:.2f}%")

Class index: 0, Background mask: tensor([[False,  True,  True, False, False],
        [False,  True,  True, False, False],
        [False, False, False,  True,  True],
        [False, False, False,  True,  True],
        [False, False, False, False,  True]])
Preds class: tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 1., 1., 1.]), Targets class: tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])
Per-class accuracy for class 0 updated. Update count: 1
Class index: 1, Background mask: tensor([[ True, False, False,  True,  True],
        [ True, False, False,  True,  True],
        [ True,  True,  True, False, False],
        [ True,  True,  True, False, False],
        [ True,  True,  True,  True, False]])
Preds class: tensor([1., 1., 0., 1., 1., 0., 0., 0., 0.]), Targets class: tensor([1., 1., 1., 1., 0., 0., 0., 0., 0.])
Per-class accuracy for class 1 updated. Update count: 1
Class index: 2, Background mask: tensor([[ True, False, False,  True,  True]