In [None]:
import sys

sys.path.append('../..')
import json
import os
from ILLUME.imagenet_hierarchy import direct_subclasses, hierarchy_attr_to_classes, super_class_dir
from collections import defaultdict
import numpy as np


import scipy.stats

hierarchy_name = 'extended_v1'


def compute_acc_and_ent(count_dict, set_dir, q_cluster, q_cls):

    top_1_dir = count_dict['top_1']
    top_5_dir = count_dict['top_5']

    total_count = sum(list(top_1_dir.values()))

    pred_prob_dict = {img_cls: float(pred_ct)/total_count for img_cls, pred_ct in top_1_dir.items()}
    # print(pred_prob_dict)

    # all_leafs = hierarchy_attr_to_classes[hierarchy_name][q_cluster] if q_cluster in hierarchy_attr_to_classes[hierarchy_name] else [q_cluster]
    # Uniform distribution (each category has equal probability)
    # desired_probs_dict = {leaf: 1.0 / len(all_leafs) for leaf in all_leafs}


    # desired_classes_list = list( set( desired_probs_dict.keys()))
    # print(desired_classes_list)
    # print(f"top_1 dir: {top_1_dir}")

    all_cluster_classes = hierarchy_attr_to_classes[hierarchy_name][q_cluster] if q_cluster in hierarchy_attr_to_classes[hierarchy_name] else [q_cluster]
    count_correct_top_1 = sum([value for key, value in top_1_dir.items() if key in all_cluster_classes])
    acc_top_1 = float(count_correct_top_1) / float(total_count)

    count_correct_top_5 = sum([value for key, value in top_5_dir.items() if key in all_cluster_classes])
    acc_top_5 = float(count_correct_top_5) / float(sum(list(top_5_dir.values())))
    print(f"count_correct_top_1: {count_correct_top_1} ; count_correct_top_5: {count_correct_top_5}")

    # pred_prob_list_filtered = [pred_prob_dict[key] if key in pred_prob_dict else 0.0 for key in all_cluster_classes]
    # pred_prob_list = list(pred_prob_dict.values())
    
    # print(pred_prob_list)
    # ent = defaultdict(float)
    ent = {}
    for super_class in super_class_dir[hierarchy_name][q_cls]:
        if super_class == "animal.n.01":
            continue
        all_leafs = hierarchy_attr_to_classes[hierarchy_name][super_class]
        pred_prob_list_filtered = [pred_prob_dict[key] if key in pred_prob_dict else 0.0 for key in all_leafs]
        ent[super_class] = categorical_entropy(pred_prob_list_filtered)

    # print(f"Entropy is {ent}")

    # desired_prob_list = [desired_probs_dict[key] for key in desired_classes_list]

    # Compute KL divergence
    # kl_div = scipy.stats.entropy(desired_prob_list, pred_prob_list)  # KL(P || Q)

    # cat_ce = categorical_crossentropy(true_dist=desired_prob_list, pred_dist=pred_prob_list)

    # print(f"KL Divergence: {kl_div:.4f}")
    # print(f"Categorical Cross-Entropy: {kl_div:.4f}")
    return acc_top_1, acc_top_5, ent



def categorical_entropy(probs, base=np.e):
    probs = np.array(probs)
    probs = probs[probs > 0]  # Remove zero probabilities to avoid log(0)
    return -np.sum(probs * np.log(probs) / np.log(base))

set_direction_attr = "animal.n.01"

model_name = "illume"

if model_name == "Set Learner":
    json_dir = "/export/scratch/ra48gaq/set_rf_dit/outputs/eval_paper"
elif model_name == "Visual Prompting v1":
    json_dir = "/export/scratch/ra48gaq/set_rf_dit/outputs/eval_paper/vis_prompt"
elif model_name == "Visual Prompting v2":
    json_dir = "/export/scratch/ra48gaq/set_rf_dit/outputs/eval_paper/vis_prompt_v2"
elif model_name == "Set Learner Sketch qs":
    json_dir = "/home/hpc/v104dd/v104dd24/code/diffusion/outputs/ours_sketch_qs"
elif model_name == "Set Learner Sketch ctx":
    json_dir = "/home/hpc/v104dd/v104dd24/code/diffusion/outputs/ours_sketch_ctx"
elif model_name == "Set Learner Sketch both":
    json_dir = "/home/hpc/v104dd/v104dd24/code/diffusion/outputs/ours_sketch_both"
elif model_name == "Visual Prompting Sketch qs":
    json_dir = "/home/hpc/v104dd/v104dd24/code/diffusion/outputs/vis_prompt_sketch_qs"
elif model_name == "Visual Prompting Sketch ctx":
    json_dir = "/home/hpc/v104dd/v104dd24/code/diffusion/outputs/vis_prompt_sketch_ctx"
elif model_name == "Visual Prompting Sketch both":
    json_dir = "/home/hpc/v104dd/v104dd24/code/diffusion/outputs/vis_prompt_sketch_both"
elif model_name == "Visual Prompting v3":
    json_dir = "/home/hpc/v104dd/v104dd24/code/diffusion/outputs/vis_prompt_v3"
elif model_name == "Ablation FFN":
    json_dir = "/home/hpc/v104dd/v104dd24/code/diffusion/outputs/ours_ablation_ffn"
elif model_name == "Ablation 2dirs fixes setsize":
    json_dir = "/home/hpc/v104dd/v104dd24/code/diffusion/outputs/ours_ablation_2dirs_fix_setsize"
elif model_name == "Visual Prompting 21k ctx":
    json_dir = "/anvme/workspace/v104dd11-compvis_data/kolja/set_learner/outputs/vis_prompt_21k_ctx"
elif model_name == "illume":
    json_dir = "/export/home/ra48gaq/code/ILLUME_plus/outputs"
else:
    print("Error: No valid model name provided")

all_animal_classes = list(hierarchy_attr_to_classes[hierarchy_name]['animal.n.01'])

acc_list_top_1 = []
acc_list_top_5 = []
accs_per_set_dir_top_1 = defaultdict(list)
accs_per_set_dir_top_5 = defaultdict(list)

entropy_ratios = []

for animal_class in all_animal_classes:
    count_file = os.path.join(json_dir, f"{animal_class}_predicted_cls_count_per_dir_per_q.json")
    with open(count_file, 'r') as f:
        count_dict = json.load(f)
    
    entropies = {}
    for set_direction, count_dict_dir in count_dict.items():
        # Compute entropy
        # print(count_dict_dir)
        # Compute accuracy

        super_class_list = super_class_dir[hierarchy_name][animal_class]

        # print(f"set_direction is {set_direction} ; count_file is {count_file}")

        q_sub_cls = list(set(super_class_list) & set(direct_subclasses[hierarchy_name][set_direction]))
        # print(f"q_sub_cls: {q_sub_cls}")

        q_sub_cls = q_sub_cls[0] if len(q_sub_cls) > 0 else animal_class

        acc_top_1, acc_top_5, ent = compute_acc_and_ent(count_dict_dir[animal_class], set_dir=set_direction, q_cluster=q_sub_cls, q_cls=animal_class)
        print(f"For query class {animal_class} projected onto {set_direction} direction, top-1 acc for subcluster {q_sub_cls} is {acc_top_1 * 100}% ;  entropy is {ent}")
        acc_list_top_1.append(acc_top_1)
        acc_list_top_5.append(acc_top_5)
        accs_per_set_dir_top_1[set_direction].append(acc_top_1)
        accs_per_set_dir_top_5[set_direction].append(acc_top_5)

        entropies[set_direction] = ent

    ### Calculate entropy ratios
    print(f"entropy dict for animal class {animal_class} is {entropies}")
    for set_direction in count_dict.keys():
        # Compare to next lower set direction,  if there is one
        direct_sub = list(set(direct_subclasses[hierarchy_name][set_direction]) & set(count_dict.keys()))
        if len(direct_sub) > 0:
            direct_sub = direct_sub[0]
            print(f"Comparing entropies for animal class {animal_class} projected onto {set_direction} and {direct_sub}")
            entropy_abstract = entropies[set_direction][direct_sub]
            entropy_specific = entropies[direct_sub][direct_sub]
            entropy_ratio = entropy_abstract / entropy_specific
            print(f"Entropy ratio is {entropy_ratio}")
            entropy_ratios.append(max(entropy_ratio, 0))

# Compute average top 1 acc
mean_acc_top_1 = sum(acc_list_top_1) / len(acc_list_top_1)
mean_acc_per_set_dir_top_1 = {set_dir: sum(acc_list_per_dir) / len(acc_list_per_dir) for set_dir, acc_list_per_dir in accs_per_set_dir_top_1.items()}
# Compute average top 5 acc
mean_acc_top_5 = sum(acc_list_top_5) / len(acc_list_top_5)
mean_acc_per_set_dir_top_5 = {set_dir: sum(acc_list_per_dir) / len(acc_list_per_dir) for set_dir, acc_list_per_dir in accs_per_set_dir_top_5.items()}
# Compute mean entropy ratio
mean_ent_ratio = sum(entropy_ratios) / len(entropy_ratios)

mean_acc_avg_set_top_1 = sum(mean_acc_per_set_dir_top_1.values()) / len(mean_acc_per_set_dir_top_1)
mean_acc_avg_set_top_5 = sum(mean_acc_per_set_dir_top_5.values()) / len(mean_acc_per_set_dir_top_5)
print(f" ===== RESULTS FOR MODEL {model_name} =======")
print(f" ===== MEAN ACCURACY avg over set-query combinations: TOP-1: {mean_acc_top_1 * 100}% =======")
print(f" ===== MEAN ACCURACY avg over set directions: TOP-1: {mean_acc_avg_set_top_1 * 100}% =======")
print(f" ===== MEAN ENTROPY RATIO: {mean_ent_ratio} =======")

SyntaxError: EOL while scanning string literal (2823154798.py, line 106)