In [1]:
import torch
import json
from tqdm import tqdm
from load import *
import os

def save_to_json(similarity, json_output_path):
    # Check if the file already exists
    if os.path.exists(json_output_path):
        with open(json_output_path, 'r') as f:
            existing_data = json.load(f)
    else:
        existing_data = {}

    # Update existing data with new similarity results
    existing_data.update(similarity)

    # Write the updated data back to the JSON file
    with open(json_output_path, 'w') as f:
        json.dump(existing_data, f, indent=4)

def compute_class_description_cosine_similarity(data):
    """
    Input:
    data: dictionary containing class descriptions
    
    Compute the cosine similarity between each class and all images,
    normalize these values, and save the results in JSON format.
    """
    device = torch.device(hparams['device'])
    model, preprocess = clip.load(hparams['model_size'], device=device, jit=False)
    model.eval()

    class_list = compute_class_list(data, sort_config=True)

    class_descriptor_dict = load_json(hparams['descriptor_fname'])
    class_list = compute_class_list(class_descriptor_dict, sort_config=True)
    descriptor_list = compute_descriptor_list(class_descriptor_dict, sort_config=True)

    class_list = [c.replace('-', ' ') for c in class_list]

    seed_everything(hparams['seed'])

    # Load the model and preprocessing
    print("Loading model...")
    device = torch.device(hparams['device'])
    model, preprocess = clip.load(hparams['model_size'], device=device, jit=False)
    model.eval()
    model.requires_grad_(False)

    # Encode descriptions and labels
    print("Encoding descriptions...")
    description_encodings = F.normalize(model.encode_text(clip.tokenize(descriptor_list).to(device)))
    label_encodings = F.normalize(model.encode_text(clip.tokenize(class_list).to(device)))

    cosine_similarity = torch.mm(description_encodings, label_encodings.T)
    print(cosine_similarity.size())

    cosine_similarity_per_class = cosine_similarity.cpu().detach().numpy()
    print("cosine_similarity:", cosine_similarity_per_class)

    # Calculate average cosine similarity for each class
    average_cosine_similarity = cosine_similarity.mean(dim=0).tolist()

    similarity = {}
    similarity[hparams['dataset']] = {}
    for i, class_name in enumerate(class_list):
        similarity[hparams['dataset']][class_name] = {}
        similarity[hparams['dataset']][class_name]["cos_sim"] = cosine_similarity_per_class[i].tolist()
        if len(similarity[hparams['dataset']][class_name]["cos_sim"]) == len(dataset_classes):
            similarity[hparams['dataset']][class_name]["cos_sim_sorted"] = sorted(similarity[hparams['dataset']][class_name]["cos_sim"], reverse=True)
            similarity[hparams['dataset']][class_name]["average_cos_sim"] = average_cosine_similarity[i]

    return similarity

Creating descriptors from descriptors_cub_gpt4_8_descriptors...
Example description for class 'Black-footed Albatross': "Black-footed Albatross, which has Seabird with contrasting black and white plumage"

Creating descriptor frequencies...


In [2]:
descriptor_file_path = hparams['descriptor_fname']
print(descriptor_file_path)

class_descriptor_dict = load_json(descriptor_file_path)
analysis_dict = compute_class_description_cosine_similarity(class_descriptor_dict)

output_path_name = descriptor_file_path.split("/")[-1].split(".")[0].split("_")[1]
print(output_path_name)
json_file_path = f'class_analysis/class_analysis_{output_path_name}.json'

save_to_json(analysis_dict, json_file_path)

./descriptors/descriptors_cub_gpt4_8_descriptors
Loading model...
Encoding descriptions...
torch.Size([1561, 200])
cosine_similarity: [[0.6025 0.6504 0.541  ... 0.648  0.5273 0.578 ]
 [0.6167 0.689  0.551  ... 0.6553 0.535  0.587 ]
 [0.5977 0.6987 0.555  ... 0.699  0.6143 0.6064]
 ...
 [0.5723 0.6284 0.4668 ... 0.574  0.4639 0.5317]
 [0.59   0.6562 0.4878 ... 0.633  0.4963 0.5522]
 [0.5923 0.6777 0.4985 ... 0.6377 0.5063 0.562 ]]
cub


#### Visualise differences in all datasets from a particular dataset

In [3]:
## This file will:
# 1. Load the JSON file containing the cosine similarities between classes and descriptors
# 2. Visualize the cosine similarities using heatmaps (one each for each key in "data", and one for the difference between each)
#   - There is no limit to the number of keys in "data"
#   - Use seaborn, with class names on the x-axis and descriptor names on the y-axis
#   - The heatmaps should be colored according to the cosine similarity values
#   - The class names should be rotated 90 degrees
# 3. Save the heatmap as a PNG file

import json
import numpy as np
import plotly.express as px
import os

# Load the JSON file
def load_json(json_file_path):
    with open(json_file_path, 'r') as f:
        data = json.load(f)
    return data

# Function to dynamically create the descriptor file path
def get_descriptor_file_path(dataset_name):
    descriptor_file_path = f"./descriptors/descriptors_{dataset_name}.json"
    return descriptor_file_path

# Function to generate a single heatmap for all class-descriptor similarities using Plotly
def generate_combined_heatmap(data, dataset_name, output_dir, descriptor_dict):
    os.makedirs(output_dir, exist_ok=True)

    # Extract class names and descriptors
    class_list = compute_class_list(descriptor_dict, sort_config=True)
    descriptor_list = compute_descriptor_list(descriptor_dict, sort_config=True)

    # Extract cosine similarity values for each class
    similarities_per_class = [data[dataset_name][class_name]["cos_sim"] for class_name in class_list]

    # Combine all cosine similarities into a single matrix
    cos_sim_matrix = np.array(similarities_per_class)

    # Create the heatmap using Plotly
    fig = px.imshow(
        cos_sim_matrix,
        labels=dict(x="Descriptors", y="Classes", color="Cosine Similarity"),
        x=descriptor_list,
        y=class_list,
        color_continuous_scale='RdBu_r',
        aspect='auto'
    )

    fig.update_layout(
        title=f"Cosine Similarity Heatmap for {dataset_name}",
        xaxis_nticks=len(descriptor_list),
        yaxis_nticks=len(class_list)
    )

    # Save the heatmap as an interactive HTML file
    fig.write_html(os.path.join(output_dir, f"{dataset_name}_combined_heatmap.html"))

    # Optionally, display the figure
    fig.show()

# Main code execution
output_dir = 'combined_heatmaps'  # Output directory for heatmaps

# Load the data from the JSON file
data = load_json(json_file_path)

# Iterate through each dataset in the JSON and generate combined heatmaps
for dataset_name in data.keys():
    # Dynamically generate the descriptor file path
    descriptor_file_path = get_descriptor_file_path(dataset_name)
    
    # Load the descriptor file for the current dataset
    descriptor_dict = load_json(descriptor_file_path)
    
    # Generate and save the combined heatmap
    generate_combined_heatmap(data, dataset_name, output_dir, descriptor_dict)

KeyError: 'American-Three-toed Woodpecker'