In [None]:
import sys
import numpy as np
from sklearn import metrics

root_dir = "../"
sys.path.append(root_dir)
import configs
import datasets
from ibydmt.utils.pcbm import PCBM
from ibydmt.utils.config import get_config
from ibydmt.utils.concept_data import get_dataset_with_concepts
from ibydmt.utils.data import get_dataset
from ibydmt.utils.result import TestingResults
from ibydmt.classifiers import ZeroShotClassifier
from ibydmt.tester import get_test_classes

config_name, concept_type = "awa2", "class"
skit_kw = {"testing.kernel_scale": 0.9, "testing.tau_max": 400}
cskit_kw = {"ckde.scale": 2000, "testing.kernel_scale": 0.9, "testing.tau_max": 800}

config = get_config(config_name)
backbone_configs = config.sweep(["data.backbone"])
dataset = get_dataset(config)
classes = dataset.classes

test_classes = get_test_classes(config)
test_classes_idx = [classes.index(c) for c in test_classes]

In [None]:
def importance_f1(concepts, sorted_concepts):
    gt_importance = [1] * (len(concepts) // 2) + [0] * (len(concepts) // 2)
    importance = [1] * (len(concepts) // 2) + [0] * (len(concepts) // 2)
    return metrics.f1_score(
        gt_importance, [importance[concepts.index(c)] for c in sorted_concepts]
    )


zeroshot_accuracy = np.zeros((len(backbone_configs), len(test_classes)))
pcbm_accuracy = np.zeros((len(backbone_configs), len(test_classes)))

skit_f1 = np.zeros((len(backbone_configs), len(test_classes)))
cskit_f1 = np.zeros((len(backbone_configs), len(test_classes)))
pcbm_f1 = np.zeros((len(backbone_configs), len(test_classes)))
for i, backbone_config in enumerate(backbone_configs):
    skit_results = TestingResults.load(
        backbone_config, "global", concept_type, results_kw=skit_kw
    )
    cskit_results = TestingResults.load(
        backbone_config, "global_cond", concept_type, results_kw=cskit_kw
    )

    zeroshot_output = ZeroShotClassifier.get_predictions(backbone_config).values[:, 1:]
    zeroshot_prediction = np.argmax(zeroshot_output, axis=1)

    for j, class_name in enumerate(test_classes):
        class_idx = classes.index(class_name)

        test_concept_dataset = get_dataset_with_concepts(
            backbone_config, train=False, concept_class_name=class_name
        )
        test_semantics = test_concept_dataset.semantics
        test_label = test_concept_dataset.label
        concepts = test_concept_dataset.concepts

        zeroshot_confusion_matrix = metrics.confusion_matrix(
            test_label, zeroshot_prediction
        )
        zeroshot_accuracy[i, j] = (
            zeroshot_confusion_matrix.diagonal() / zeroshot_confusion_matrix.sum(axis=1)
        )[class_idx]

        pcbm = PCBM.load_or_train(backbone_config, concept_class_name=class_name)
        pcbm_class_accuracy = np.zeros((len(pcbm.classifier_hist)))
        for k, classifier in enumerate(pcbm.classifier_hist):
            pcbm_prediction = classifier.predict(test_semantics)
            pcbm_confusion_matrix = metrics.confusion_matrix(
                test_label, pcbm_prediction
            )
            pcbm_class_accuracy[k] = (
                pcbm_confusion_matrix.diagonal() / pcbm_confusion_matrix.sum(axis=1)
            )[class_idx]

        pcbm_accuracy[i, j] = np.mean(pcbm_class_accuracy)

        _, skit_sorted_concepts, _, _ = skit_results.sort(class_name, fdr_control=True)
        skit_f1[i, j] = importance_f1(concepts, skit_sorted_concepts)

        _, cskit_sorted_concepts, _, _ = cskit_results.sort(
            class_name, fdr_control=True
        )
        cskit_f1[i, j] = importance_f1(concepts, cskit_sorted_concepts)

        weights = pcbm.weights()
        class_weights = weights[class_idx]
        class_weights = np.abs(class_weights)

        pcbm_sorted_idx = np.argsort(class_weights)[::-1]
        pcbm_sorted_concepts = [concepts[i] for i in pcbm_sorted_idx]
        pcbm_f1[i, j] = importance_f1(concepts, pcbm_sorted_concepts)

    print(f"Backbone {backbone_config.data.backbone}:")
    print(
        f"\tSKIT results: {zeroshot_accuracy[i].mean():.2%}, f1:"
        f" {skit_f1[i].mean():.2f}"
    )
    print(
        f"\tc-SKIT results: {zeroshot_accuracy[i].mean():.2%}, f1:"
        f" {cskit_f1[i].mean():.2f}"
    )
    print(f"\tPCBM results: {pcbm_accuracy[i].mean():.2%}, f1: {pcbm_f1[i].mean():.2f}")

In [None]:
print("Average:")
print(
    f"\tSKIT results: {zeroshot_accuracy.mean():.2%} ({zeroshot_accuracy.std():.2%}),"
    f" f1: {skit_f1.mean():.2f} ({skit_f1[:, 2].std():.2f})"
)
print(
    "\tc-SKIT results:"
    f" {zeroshot_accuracy.mean():.2%} ({zeroshot_accuracy.std():.2%}), f1:"
    f" {cskit_f1.mean():.2f} ({cskit_f1.std():.2f})"
)
print(
    f"\tPCBM results: {pcbm_accuracy.mean():.2%} ({pcbm_accuracy.std():.2%}), f1:"
    f" {pcbm_f1[:, 2].mean():.2f} ({pcbm_f1[:, 2].std():.2f})"
)