In [None]:
import sys
import numpy as np
from sklearn import metrics

root_dir = "../"
sys.path.append(root_dir)
import configs
import datasets
from ibydmt.utils.pcbm import PCBM
from ibydmt.utils.config import get_config
from ibydmt.utils.concept_data import get_dataset_with_concepts
from ibydmt.utils.data import get_dataset
from ibydmt.utils.result import TestingResults
from ibydmt.classifiers import ZeroShotClassifier
from ibydmt.tester import get_test_classes

config_name = "awa2"
results_kw = {"ckde.scale": 2000, "testing.kernel_scale": 0.9, "testing.tau_max": 800}

config = get_config(config_name)
backbone_configs = config.sweep(["data.backbone"])
dataset = get_dataset(config)
classes = dataset.classes

test_classes = get_test_classes(config)
test_classes_idx = [classes.index(c) for c in test_classes]

In [None]:
for i, backbone_config in enumerate(backbone_configs):
    results = TestingResults.load(
        backbone_config, "global_cond", "class", results_kw=results_kw
    )

    zero_shot_output = ZeroShotClassifier.get_predictions(backbone_config).values[:, 1:]
    zero_shot_prediction = np.argmax(zero_shot_output, axis=1)

    zero_shot_accuracy = np.zeros((len(test_classes)))
    pcbm_accuracy = np.zeros((len(test_classes)))

    cskit_performance = np.zeros((len(test_classes), 3))
    pcbm_performance = np.zeros((len(test_classes), 3))
    for j, class_name in enumerate(test_classes):
        class_idx = classes.index(class_name)

        test_concept_dataset = get_dataset_with_concepts(
            backbone_config, train=False, concept_class_name=class_name
        )
        test_semantics = test_concept_dataset.semantics
        test_label = test_concept_dataset.label
        concepts = test_concept_dataset.concepts
        gt_importance = [1] * (len(concepts) // 2) + [0] * (len(concepts) // 2)

        zero_shot_confusion_matrix = metrics.confusion_matrix(
            test_label, zero_shot_prediction
        )
        zero_shot_accuracy[j] = (
            zero_shot_confusion_matrix.diagonal()
            / zero_shot_confusion_matrix.sum(axis=1)
        )[class_idx]

        pcbm = PCBM.load_or_train(backbone_config, concept_class_name=class_name)
        pcbm_class_accuracy = np.zeros((len(pcbm.classifier_hist)))
        for k, classifier in enumerate(pcbm.classifier_hist):
            pcbm_prediction = classifier.predict(test_semantics)
            pcbm_confusion_matrix = metrics.confusion_matrix(
                test_label, pcbm_prediction
            )
            pcbm_class_accuracy[k] = (
                pcbm_confusion_matrix.diagonal() / pcbm_confusion_matrix.sum(axis=1)
            )[class_idx]

        pcbm_accuracy[j] = np.mean(pcbm_class_accuracy)

        weights = pcbm.weights()
        class_weights = weights[class_idx]
        class_weights = np.abs(class_weights)
        class_weights /= class_weights.sum()

        sorted_idx = np.argsort(class_weights)[::-1]
        sorted_concepts = [concepts[i] for i in sorted_idx]
        importance = [1] * (len(concepts) // 2) + [0] * (len(concepts) // 2)

        precision, recall, f1_score, _ = metrics.precision_recall_fscore_support(
            [gt_importance[concepts.index(c)] for c in sorted_concepts],
            importance,
            average="binary",
        )
        pcbm_performance[j] = [precision, recall, f1_score]

        _, sorted_concepts, _, sorted_tau = results.sort(class_name, fdr_control=True)
        importance = [1] * (len(concepts) // 2) + [0] * (len(concepts) // 2)
        precision, recall, f1_score, _ = metrics.precision_recall_fscore_support(
            [gt_importance[concepts.index(c)] for c in sorted_concepts],
            importance,
            average="binary",
        )
        cskit_performance[j] = [precision, recall, f1_score]

    print(f"Backbone {backbone_config.data.backbone}:")
    print(
        f"\tPCBM results: {pcbm_accuracy.mean():.2%}, f1:"
        f" {pcbm_performance[:, 2].mean():.2f}"
    )
    print(
        f"\tc-SKIT accuracy: {zero_shot_accuracy.mean():.2%}, f1:"
        f" {cskit_performance[:, 2].mean():.2f}"
    )