In [None]:
import os
import argparse
import wandb
import re
import clip
from sklearn.metrics import (
    adjusted_rand_score,
    jaccard_score,
    normalized_mutual_info_score,
)
import pickle

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# add the parent directory to the path
import sys
current_dir = os.getcwd()
parent_dir = os.path.abspath(os.path.join(current_dir, '..'))
sys.path.insert(0, parent_dir)
                
from utils.dcbm import *



# Load data & Apply Concept extraction

In [None]:
# ----------------- Fix -----------------
embed_path = "../data/Embeddings/"
dataset = "cub"
class_labels_path = "../data/classes/cub_classes.txt"
segment_path = "../data/Segments/"
selected_image_concepts = "../data/Embeddings/subsets"

# ----------------- Hyperparameters -----------------

model_name = "CLIP-ViT-L14"  # "CLIP-ViT-L14", "CLIP-RN50", CLIP-ViT-B16

segmentation_technique = "SAM2"  # GDINO, SAM, SAM2, DETR, MaskRCNN
concept_name = None # Define for GDINE [awa, sun, sun-lowthresh, cub...]

device = "cpu"

cluster_array = [128, 256, 512, 1024, 2048]
cluster_method = "kmeans"  # "hierarchical", "kmeans"
centroid_method = "median"  # "mean", "median"

concept_per_class = 50  # How many images for each class 5,10,20,50, None

one_hot = False
epochs = 200
lambda_1 = 1e-4
lr = 1e-4
batch_size = 32

crop = False  # True without background

use_wandb = False
project = "YOUR_PROJECT_NAME"  # Define your own project name within wandb

## Segcrop

In [None]:
cluster_array_1 = {}

for clusters in cluster_array:
    cluster_array_1[clusters] = []
    model_names = ["CLIP-ViT-B16", "CLIP-ViT-L14", "CLIP-RN50"]

    for model_name in model_names:
        cbm = CBM(embed_path, dataset, model_name, class_labels_path)
        cbm.load_concepts(
            segment_path,
            segmentation_technique,
            concept_name,
            selected_image_concepts,
            concept_per_class,
        )

        norm = np.linalg.norm(cbm.image_segments, axis=1, keepdims=True)
        image_embedding = cbm.image_segments / norm
        # image_embedding = cbm.image_segments

        if clusters > image_embedding.shape[0]:
            clusters = image_embedding.shape[0]

        print("Amount of image embeddings: ", image_embedding.shape[0])
        print("Amount of clusters: ", clusters)

        pca = PCA(n_components=100)  # You can adjust the number of components as needed
        image_embedding_pca = pca.fit_transform(image_embedding)

        cbm.cluster_technique = cluster_method
        if cluster_method == "kmeans":
            clustering_model = KMeans(n_clusters=clusters, random_state=42)
            cluster_ids = clustering_model.fit_predict(image_embedding_pca)
        elif cluster_method == "hierarchical":
            clustering_model = AgglomerativeClustering(n_clusters=clusters)
            cluster_ids = clustering_model.fit_predict(image_embedding_pca)
        elif cluster_method == "dbscan":
            clustering_model = DBSCAN()
            cluster_ids = clustering_model.fit_predict(image_embedding_pca)
        else:
            raise ValueError(f"Unsupported clustering method: {cluster_method}")

        cluster_array_1[clusters].append(cluster_ids)


# save cluster_array_1

with open(f"similarities/cub_SAM_2.pkl", "wb") as f:
    pickle.dump(cluster_array_1, f)

In [None]:
import seaborn as sns

# Initialize an empty list to store the data
data = []

for key in cluster_array_1.keys():
    local_array = cluster_array_1[key]

    nmi_scores = []
    model_names_pairs = []
    for i in range(len(local_array)):
        for j in range(i + 1, len(local_array)):
            nmi_score = normalized_mutual_info_score(local_array[i], local_array[j])
            model1, model2 = model_names[i], model_names[j]
            data.append([key, model1, model2, nmi_score])

# Create DataFrame after collecting all data
pandas_df = pd.DataFrame(data, columns=["Clusters", "Model 1", "Model 2", "NMI_score"])
# Create a pivot table for the heatmap
pivot_df = pandas_df.pivot(
    index="Clusters", columns=["Model 1", "Model 2"], values="NMI_score"
)

# Create the heatmap
plt.figure(figsize=(12, 10))
sns.heatmap(pivot_df, annot=True, cmap="YlGnBu", fmt=".3f")
plt.title("NMI Scores Heatmap")
plt.tight_layout()
plt.show()

# Display the original dataframe as well
display(pandas_df)

In [None]:
segmentation_technique = "GDINO"
concept_name = "partimagenet"

cluster_array_2 = {}

for clusters in cluster_array:
    cluster_array_2[clusters] = []
    model_names = ["CLIP-ViT-B16", "CLIP-ViT-L14", "CLIP-RN50"]

    for model_name in model_names:
        cbm = CBM(embed_path, dataset, model_name, class_labels_path)
        cbm.load_concepts(
            segment_path,
            segmentation_technique,
            concept_name,
            selected_image_concepts,
            concept_per_class,
        )

        norm = np.linalg.norm(cbm.image_segments, axis=1, keepdims=True)
        image_embedding = cbm.image_segments / norm
        # image_embedding = cbm.image_segments

        if clusters > image_embedding.shape[0]:
            clusters = image_embedding.shape[0]

        print("Amount of image embeddings: ", image_embedding.shape[0])
        print("Amount of clusters: ", clusters)

        pca = PCA(n_components=100)  # You can adjust the number of components as needed
        image_embedding_pca = pca.fit_transform(image_embedding)

        cbm.cluster_technique = cluster_method
        if cluster_method == "kmeans":
            clustering_model = KMeans(n_clusters=clusters, random_state=42)
            cluster_ids = clustering_model.fit_predict(image_embedding_pca)
        elif cluster_method == "hierarchical":
            clustering_model = AgglomerativeClustering(n_clusters=clusters)
            cluster_ids = clustering_model.fit_predict(image_embedding_pca)
        elif cluster_method == "dbscan":
            clustering_model = DBSCAN()
            cluster_ids = clustering_model.fit_predict(image_embedding_pca)
        else:
            raise ValueError(f"Unsupported clustering method: {cluster_method}")

        cluster_array_2[clusters].append(cluster_ids)


with open(f"similarities/cub_partimegenet.pkl", "wb") as f:
    pickle.dump(cluster_array_2, f)

In [None]:
segmentation_technique = "MASKRCNN"
concept_name = None

cluster_array_3 = {}

for clusters in cluster_array:
    cluster_array_3[clusters] = []
    model_names = ["CLIP-ViT-B16", "CLIP-ViT-L14", "CLIP-RN50"]

    for model_name in model_names:
        cbm = CBM(embed_path, dataset, model_name, class_labels_path)
        cbm.load_concepts(
            segment_path,
            segmentation_technique,
            concept_name,
            selected_image_concepts,
            concept_per_class,
        )

        norm = np.linalg.norm(cbm.image_segments, axis=1, keepdims=True)
        image_embedding = cbm.image_segments / norm
        # image_embedding = cbm.image_segments

        if clusters > image_embedding.shape[0]:
            clusters = image_embedding.shape[0]

        print("Amount of image embeddings: ", image_embedding.shape[0])
        print("Amount of clusters: ", clusters)

        pca = PCA(n_components=100)  # You can adjust the number of components as needed
        image_embedding_pca = pca.fit_transform(image_embedding)

        cbm.cluster_technique = cluster_method
        if cluster_method == "kmeans":
            clustering_model = KMeans(n_clusters=clusters, random_state=42)
            cluster_ids = clustering_model.fit_predict(image_embedding_pca)
        elif cluster_method == "hierarchical":
            clustering_model = AgglomerativeClustering(n_clusters=clusters)
            cluster_ids = clustering_model.fit_predict(image_embedding_pca)
        elif cluster_method == "dbscan":
            clustering_model = DBSCAN()
            cluster_ids = clustering_model.fit_predict(image_embedding_pca)
        else:
            raise ValueError(f"Unsupported clustering method: {cluster_method}")

        cluster_array_3[clusters].append(cluster_ids)


with open(f"similarities/cub_maskrcnn.pkl", "wb") as f:
    pickle.dump(cluster_array_3, f)

In [None]:
import seaborn as sns

# Initialize an empty list to store the data
data = []
nmi_dict = {}

for key in cluster_array_1.keys():
    local_array = cluster_array_1[key]

    nmi_scores = []
    model_names_pairs = []
    for i in range(len(local_array)):
        for j in range(i + 1, len(local_array)):
            nmi_score = normalized_mutual_info_score(local_array[i], local_array[j])
            model1, model2 = model_names[i], model_names[j]
            data.append([key, model1, model2, nmi_score])

nmi_dict["SAM2"] = data
data = []


for key in cluster_array_2.keys():
    local_array = cluster_array_2[key]

    nmi_scores = []
    model_names_pairs = []
    for i in range(len(local_array)):
        for j in range(i + 1, len(local_array)):
            nmi_score = normalized_mutual_info_score(local_array[i], local_array[j])
            model1, model2 = model_names[i], model_names[j]
            data.append([key, model1, model2, nmi_score])

nmi_dict["GDINO"] = data
data = []


for key in cluster_array_3.keys():
    local_array = cluster_array_3[key]

    nmi_scores = []
    model_names_pairs = []
    for i in range(len(local_array)):
        for j in range(i + 1, len(local_array)):
            nmi_score = normalized_mutual_info_score(local_array[i], local_array[j])
            model1, model2 = model_names[i], model_names[j]
            data.append([key, model1, model2, nmi_score])

nmi_dict["MASKRCNN"] = data

# Create and plot heatmaps for all segmentation techniques in nmi_dict
fig, axes = plt.subplots(1, 3, figsize=(36, 10))
# fig.suptitle("NMI Scores Heatmaps for Different Segmentation Techniques", fontsize=30)

for idx, (seg_technique, data) in enumerate(nmi_dict.items()):
    # Create DataFrame for the current segmentation technique
    pandas_df = pd.DataFrame(
        data, columns=["Clusters", "Model 1", "Model 2", "NMI_score"]
    )

    # Create a pivot table for the heatmap
    pivot_df = pandas_df.pivot(
        index="Clusters", columns=["Model 1", "Model 2"], values="NMI_score"
    )

    # Create the heatmap with annotations
    heatmap = sns.heatmap(
        pivot_df, annot=False, cmap="YlGnBu", fmt=".3f", ax=axes[idx], cbar=True
    )
    axes[idx].set_title(f"{seg_technique} Segmentation", fontsize=50)
    axes[idx].set_ylabel("Clusters", fontsize=35)
    axes[idx].set_xlabel("Model Pairs", fontsize=35)

    # Rotate x-axis labels for better readability
    axes[idx].set_xticklabels(
        axes[idx].get_xticklabels(), rotation=45, ha="right", fontsize=30
    )
    axes[idx].tick_params(axis="y", labelsize=30)

    # Increase text size of color bar and annotations
    cbar = heatmap.collections[0].colorbar
    cbar.ax.tick_params(labelsize=25)
    for t in heatmap.texts:
        t.set_fontsize(25)

    # Adjust figure size to accommodate all labels and annotations
    plt.gcf().set_size_inches(50, 20)

plt.tight_layout()
plt.show()
