In [None]:
import os
import sys
import numpy as np
import pickle
import pandas as pd
import json
from sklearn.metrics import confusion_matrix
from tqdm import tqdm

root_dir = "../"
sys.path.append(root_dir)
import configs
import datasets
from ibydmt.utils.config import get_config
from ibydmt.utils.data import get_dataset
from ibydmt.classifiers import ZeroShotClassifier
from ibydmt.tester import sweep

config = get_config("cub")
dataset = get_dataset(config, train=False)
label = [target for _, target in dataset]
backbone_configs = sweep(config, sweep_keys=["data.backbone"])

In [None]:
attribute_path = os.path.join(root_dir, "data", "CUB", "attributes", "attributes.txt")
with open(attribute_path, "r") as f:
    attributes = f.readlines()

auc = []
for backbone_config in backbone_configs:
    backbone_cav_state_path = os.path.join(
        root_dir, "weights", "cub", f"{backbone_config.backbone_name()}_cav.pkl"
    )
    with open(backbone_cav_state_path, "rb") as f:
        attribute_idx, _, backbone_auc = pickle.load(f)
        auc.append(backbone_auc)

auc = np.array(auc)
mu_auc = auc.mean(axis=0)
std_auc = auc.std(axis=0)

(good_cav_idx,) = np.where(mu_auc > 0.85)
good_attribute_idx = attribute_idx[good_cav_idx].tolist()

# good_attribute_path = os.path.join(root_dir, "concepts", "cub", "good_attributes.txt")
# with open(good_attribute_path, "w") as f:
#     for _idx in good_attribute_idx:
#         _text = attributes[_idx - 1].strip().split()[1]
#         f.write(f"{_idx} {_text}\n")

In [None]:
test_df = pd.read_parquet(os.path.join(root_dir, "data", "CUB", "test_cub_cav.parquet"))

good_attribute_label = []
for image_idx in tqdm(dataset.image_idx):
    good_labels = [-1] * len(good_attribute_idx)
    image_df = test_df[test_df["image_idx"] == image_idx]
    for i, row in image_df.iterrows():
        attribute_idx = row["attribute_idx"]
        attribute_label = row["label"]
        if attribute_idx in good_attribute_idx:
            good_labels[good_attribute_idx.index(attribute_idx)] = attribute_label

    good_attribute_label.append(good_labels)

image_good_attribute_label_path = os.path.join(
    root_dir, "concepts", "cub", "image_good_attribute_labels.parquet"
)
good_attribute_df = pd.DataFrame(
    {"image_idx": dataset.image_idx, "label": good_attribute_label}
)
good_attribute_df.to_parquet(image_good_attribute_label_path)

In [None]:
accuracy = []
for i, backbone_config in tqdm(enumerate(backbone_configs)):
    prediction_df = ZeroShotClassifier.get_predictions(backbone_config)
    output = prediction_df.values[:, 1:]
    prediction = np.argmax(output, axis=-1)
    backbone_confusion_matrix = confusion_matrix(label, prediction)
    backbone_accuracy = np.diag(backbone_confusion_matrix) / np.sum(
        backbone_confusion_matrix, axis=1
    )
    accuracy.append(backbone_accuracy)

accuracy = np.array(accuracy)
mu_accuracy = np.mean(accuracy, axis=0)
std_accuracy = np.std(accuracy, axis=0)

k = 10
sorted_idx = np.argsort(mu_accuracy)[::-1][:k]
sorted_mu_accuracy = mu_accuracy[sorted_idx]
sorted_std_accuracy = std_accuracy[sorted_idx]
sorted_classes = np.array(dataset.classes)[sorted_idx]
print(list(zip(sorted_classes, sorted_mu_accuracy, sorted_std_accuracy)))

with open(os.path.join(root_dir, "concepts", "cub", "good_classes.txt"), "w") as f:
    for class_idx, class_name in zip(sorted_idx, sorted_classes):
        f.write(f"{class_idx} {class_name}\n")

In [None]:
good_class_idx = sorted_idx
good_classes = sorted_classes

good_image_idx = {class_name: [] for class_name in good_classes}
for image_idx, image_label in zip(dataset.image_idx, label):
    if image_label in good_class_idx:
        image_class_name = dataset.classes[image_label]

        image_good_attribute_label = good_attribute_df[
            good_attribute_df["image_idx"] == image_idx
        ]["label"].values[0]
        if -1 not in image_good_attribute_label:
            good_image_idx[image_class_name].append(dataset.image_idx.index(image_idx))

local_test_idx = {
    str(class_idx): np.random.choice(v, size=2, replace=False).squeeze().tolist()
    for class_idx, v in good_image_idx.items()
}

local_test_idx_path = os.path.join(
    root_dir, "results", "cub", "local_cond", "local_test_idx.json"
)
with open(local_test_idx_path, "w") as f:
    json.dump(local_test_idx, f)