In [26]:
import torch
import numpy as np
import pickle
from torch.nn.functional import cosine_similarity
import json
import csv

def load_pkl(file_path):
    with open(file_path, 'rb') as f:
        data = pickle.load(f)
    return data

def load_json(json_file):
    with open(json_file, 'r') as f:
        label_names = json.load(f)
    return label_names

def find_closest_and_furthest_labels(target_label, label_embeddings, label_names, top_ns, num_closest=20):
    results = {}

    for top_n in top_ns:
        target_embedding = label_embeddings[target_label][f'top{top_n}']
        similarities = []

        for label, embeddings in label_embeddings.items():
            if label != target_label:
                other_embedding = embeddings[f'top{top_n}']
                similarity = cosine_similarity(target_embedding.unsqueeze(0), other_embedding.unsqueeze(0)).item()
                similarities.append((label, similarity))
        
        similarities = sorted(similarities, key=lambda x: x[1], reverse=True)
        closest_labels = [(label, label_names[label], similarity) for label, similarity in similarities[:num_closest]]
        furthest_labels = [(label, label_names[label], similarity) for label, similarity in similarities[-num_closest:]]
        average_similarity = np.mean([s[1] for s in similarities])

        results[f'top{top_n}'] = {
            'closest': closest_labels,
            'furthest': furthest_labels,
            'average_similarity': average_similarity
        }
    
    return results

def save_results_as_csv(results, target_label, target_label_name, top_ns, file_path_base):
    file_path = f"{file_path_base}label_{target_label}_{target_label_name}.csv"
    
    with open(file_path, 'w', newline='') as csvfile:
        csvwriter = csv.writer(csvfile)
        csvwriter.writerow([f"Results for target label: {target_label} ({target_label_name})"])

        for top_n in top_ns:
            csvwriter.writerow([f"Top {top_n}"])
            csvwriter.writerow(["", "Closest labels"])
            csvwriter.writerow(["Label ID", "Label Name", "Similarity"])
            for label, label_name, similarity in results[f'top{top_n}']['closest']:
                csvwriter.writerow([label, label_name, f"{similarity:.4f}"])

            csvwriter.writerow(["", "Furthest labels"])
            csvwriter.writerow(["Label ID", "Label Name", "Similarity"])
            for label, label_name, similarity in results[f'top{top_n}']['furthest']:
                csvwriter.writerow([label, label_name, f"{similarity:.4f}"])

            csvwriter.writerow([f"Average similarity: {results[f'top{top_n}']['average_similarity']:.4f}"])

if __name__ == "__main__":
    json_file = "/data2/ty45972_data2/taming-transformers/datasets/imagenet/imagenet-labels.json"
    
    label_names = load_json(json_file)
    file_path_base = '/data2/ty45972_data2/taming-transformers/codebook_explanation_classification/results/Explanation/generated_data/similarity_results/'
    label_embeddings = load_pkl(file_path_base + "label_embeddings.pkl")

    import random

    # Generate a list of 100 random integers from the range 0 to 499
    random_labels = random.sample(range(500), 100)



    top_ns = [1, 5, 10, 20]
    for target_label in random_labels:
        results = find_closest_and_furthest_labels(target_label, label_embeddings, label_names, top_ns)

        target_label_name = label_names[target_label]
        save_results_as_csv(results, target_label, target_label_name, top_ns, file_path_base)
