In [None]:
import os
import argparse
import wandb
import re
import clip
import cv2

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image

# add the parent directory to the path
import sys
current_dir = os.getcwd()
parent_dir = os.path.abspath(os.path.join(current_dir, '..'))
sys.path.insert(0, parent_dir)
                
from utils.dcbm import *



In [None]:
# ----------------- Fix -----------------
embed_path = "../data/Embeddings/"
dataset = "cifar100"
class_labels_path = "../data/classes/cifar100_classes.txt"
segment_path = "../data/Segments/"
selected_image_concepts = "../data/Embeddings/subsets"
raw_path_dataset = # TODO
raw_path = #TODO: Define the path to the raw images
# ----------------- Hyperparameters -----------------

model_name = "CLIP-ViT-L14"  # "CLIP-ViT-L14", "CLIP-RN50", CLIP-ViT-B16

segmentation_technique = "SAM2"  # GDINO, SAM, SAM2, DETR, MaskRCNN
concept_name = None # Define for GDINE [awa, sun, sun-lowthresh, cub...]

device = "cpu"

clusters = 2048
cluster_method = "kmeans"  # "hierarchical", "kmeans"
centroid_method = "median"  # "mean", "median"

concept_per_class = 50  # How many images for each class 5,10,20,50, None

one_hot = False
epochs = 200
lambda_1 = 1e-4
lr = 1e-4
batch_size = 32

crop = False  # True without background

use_wandb = False
project = "YOUR_PROJECT_NAME"  # Define your own project name within wandb

In [None]:
cbm = CBM(
    embed_path, dataset, model_name, class_labels_path, device=device
)  # Initialize CBM

cbm.load_concepts(
    segment_path,
    segmentation_technique,
    concept_name,
    selected_image_concepts,
    concept_per_class,
    crop=crop,
)  # Load concepts with predefined segmentation technique and hyperparameters

if clusters is not None:  # if clustering is needed
    cbm.cluster_image_concepts(cluster_method, clusters, pca = True)
else:
    cbm.clustered_concepts = cbm.image_segments

cbm.centroid_concepts(
    centroid_method
)  # Calculate centroids of the concepts with given method

cbm.preprocess_data(
    type_="standard", label_type=one_hot
)  # preprocess data for training
cbm.train(  # train the model
    num_epochs=epochs,
    lambda_1=lambda_1,
    lr=lr,
    device=device,
    batch_size=batch_size,
    project=project,
    to_print=False,
    early_stopping_patience=None,
    one_hot=one_hot,
    use_wandb=use_wandb,
)

In [None]:
cbm_model = cbm_model
model = CBM_Model(
    cbm_model, cbm.clustered_concepts, cbm.preprocess_module, cbm.scaler, device=device
)
print("Predictions: ")
print(model.predict_processed(cbm.X_test[:200]))
print("True Classes: ")
print(np.argmax(cbm.y_test[:200], axis=1))

In [None]:
id_image = 2

concept_ids, concept_weights = cbm.plot_instance_feature_importance(id_image)

In [None]:
model, preprocess = clip.load("ViT-L/14", device=device, jit=False)

with open("..data/classes/20k.txt", "r") as f:
    g20k = f.readlines()
names = [i.strip() for i in g20k]
tokenized_text = clip.tokenize(names).to(device)

with torch.no_grad():
    text_features = model.encode_text(tokenized_text)
        

In [None]:
def display_images_with_main(image_paths, main_image_path, concept_names, num_secondary=5):
    """
    Display a main image prominently and a series of smaller images in a row layout.

    Parameters:
    - image_paths (list): List of file paths for the secondary images.
    - main_image_path (str): File path of the main image to be displayed larger.
    - num_secondary (int): Number of secondary images to display. Default is 5.
    """
    # Remove duplicates from image paths
    image_paths = image_paths[:num_secondary]
    print(main_image_path)
    # print(image_paths)
    # Display the main image
    main_img = Image.open(main_image_path)
    plt.imshow(main_img)
    plt.axis("off")

    # Set up the grid: main image in a larger size, followed by smaller images
    fig, axes = plt.subplots(1, num_secondary, figsize=(9, 4))

    # Display each of the secondary images
    for i, image_path in enumerate(image_paths):
        print("hier", image_path)
        img = cv2.imread(image_path)

        # Image data of dtype object cannot be converted to float
        try:
            axes[i].imshow(img)
        except:
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            axes[i].imshow(img)
        axes[i].axis("off")
        axes[i].set_title(concept_names[i])

    plt.tight_layout()
    plt.show()


def display_concepts(id_image, cbm, raw_path, raw_path_dataset):

    def extract_until_number(strings):
        return [process_string(s) for s in strings]

    def process_string(s):
        match = re.match(r"^([^0-9]+)", s)
        return match.group(1).strip("_") if match else s

    concept_ids, concept_weights = cbm.plot_instance_feature_importance(id_image)

    # Get the main image path
    print("id_image", id_image)
    folder_names = os.listdir(raw_path_dataset)
    image_name = list(cbm.data_test_raw.keys())[id_image]
    folder_name = image_name.split("_Places365")[0]
    image_name = image_name.split(folder_name + "_")[1]

    image_path_org = os.path.join(raw_path_dataset, folder_name, image_name)
    # Find paths to median images based on feature importance (as in your example)
    image_median_paths = []
    concept_names = []
    for idx in concept_ids[:5]:
        idx = int(idx)
        clustered_concepts = cbm.clustered_concepts_all
        data = clustered_concepts[idx]
        median_values = np.median(data, axis=0)
        distances = np.sum(np.abs(data - median_values), axis=1)
        median_index = np.argmin(distances)
        clustered_images = cbm.image_segments_names
        median_index = np.argsort(distances)
        # median_index is numeric which retrieves the embedding and path

        # Select up to n images closest to median
        n = 20
        median_entries = [clustered_images[idx][i] for i in median_index[:n]]
        
        concept_emb = data[median_index[0]]
        concept_emb = torch.tensor(concept_emb).to(device)
        sim = torch.nn.functional.cosine_similarity(concept_emb, text_features).cpu().float()

        top_index = sim.argsort(descending=True)
        concept_names.append(names[top_index[0]])

        # Process and add paths for the secondary images
        def construct_path(strings):
            correct_strings = []
            for substring in strings:
                folder_name = process_string(substring)
                for i in range(len(folder_names)):
                    if folder_name.lower() in folder_names[i].lower():
                        folder_name = folder_names[i]
                        break
                substring = os.path.join(raw_path, folder_name, substring)
                correct_strings.append(substring)
            return correct_strings

        image_median_paths.append(construct_path(median_entries)[0])

    # Call the function to display images
    display_images_with_main(image_median_paths, image_path_org, concept_names)


id_image = random.randint(0, len(cbm.data_test_raw))

display_concepts(id_image, cbm, raw_path, raw_path_dataset)