# Automatic concept annotation

In [None]:
import torch
import torchvision.transforms as T
import clip
import pandas as pd
from transformers import AutoProcessor, AutoModelForZeroShotImageClassification
from pathlib import Path
import glob
from PIL import Image
import scipy
import numpy as np
import tqdm
from matplotlib import gridspec
import matplotlib.pyplot as plt

# Utils funcs

In [None]:
class ImageDataset(torch.utils.data.Dataset):
    """
    A custom dataset class for loading images and associated metadata.

    Args:
        image_path_list (list): A list of file paths to the images.
        transform (callable): A function/transform to apply to the images.
        metadata_df (pandas.DataFrame, optional): A pandas DataFrame containing metadata for the images.

    Raises:
        AssertionError: If the length of `image_path_list` is not equal to the length of `metadata_df`.

    Returns:
        dict: A dictionary containing the image and metadata (if available) for a given index.

    """

    def __init__(self, image_path_list, transform, metadata_df=None):
        self.image_path_list = image_path_list
        self.transform = transform
        self.metadata_df = metadata_df

        if self.metadata_df is None:
            self.metadata_df = pd.Series(index=self.image_path_list)
        else:
            assert len(self.image_path_list) == len(
                self.metadata_df
            ), "image_path_list and metadata_df must have the same length"
            self.metadata_df.index = self.image_path_list

    def __getitem__(self, idx):
        image = Image.open(self.image_path_list[idx])

        ret = {"path": str(self.image_path_list[idx]), "image": self.transform(image)}

        if self.metadata_df is not None:
            ret.update({"metadata": self.metadata_df.iloc[idx]})

        return ret

    def __len__(self):
        return len(self.image_path_list)

In [None]:
def custom_collate(batch):
    """Custom collate function for the dataloader.

    Args:
        batch (list): list of dictionaries, each dictionary is a batch of data

    Returns:
        dict: dictionary of collated data
    """

    ret = {}
    for key in batch[0]:
        if isinstance(batch[0][key], pd.Series):
            try:
                ret[key] = pd.concat([d[key] for d in batch], axis=1).T
            except RuntimeError:
                raise RuntimeError(f"Error while concatenating {key}")
        else:
            # print(f"{key} at custom collate")
            try:
                ret[key] = torch.utils.data.dataloader.default_collate(
                    [d[key] for d in batch]
                )
            except RuntimeError:
                raise RuntimeError(f"Error while concatenating {key}")
    # print(ret)
    return ret


def custom_collate_per_key(batch_all):
    """Custom collate function batched outputs.

    Args:
        batch_all (dict): dictionary of lists of objects, each dictionary is a batch of data
    Returns:
        dict: dictionary of collated data
    """

    # print(batch_all.keys())
    # print(batch_all["image_features"][0].shape)
    # print(batch_all["metadata"][0])

    ret = {}
    for key in batch_all:
        if isinstance(batch_all[key][0], pd.DataFrame):
            # print(f"key = {key}, which is DataFrame")
            ret[key] = pd.concat(batch_all[key], axis=0)
        elif isinstance(batch_all[key][0], torch.Tensor):
            # print(f"key = {key}, which is Tensor")
            # print(batch_all[key][0].shape)
            ret[key] = torch.concat(batch_all[key], axis=0)
            # print(ret[key].shape)
        else:
            # print(f"Collating {key}...")
            ret[key] = torch.utils.data.dataloader.default_collate(
                [elem for batch in tqdm.tqdm(batch_all[key]) for elem in batch]
            )

    return ret


def dataloader_apply_func(
    dataloader, func, collate_fn=custom_collate_per_key, verbose=True
):
    """Apply a function to a dataloader.

    Args:
        dataloader (torch.utils.data.DataLoader): torch dataloader
        func (function): function to apply to each batch
        collate_fn (function, optional): collate function. Defaults to custom_collate_batch.

    Returns:
        dict: dictionary of outputs
    """
    func_out_dict = {}

    for batch in tqdm.tqdm(dataloader):
        for key, func_out in func(batch).items():
            func_out_dict.setdefault(key, []).append(func_out)

    return collate_fn(func_out_dict)


def convert_image_to_rgb(image):
    return image.convert("RGB")


def get_transform(n_px):

    return T.Compose(
        [
            T.Resize(n_px, interpolation=T.InterpolationMode.BICUBIC),
            T.CenterCrop(n_px),
            convert_image_to_rgb,
            T.ToTensor(),
            T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        ]
    )

# Model initilaize

In [None]:
# select the GPU device to use
device = "cuda:0"
device = "cpu"

model_api = "clip"

if model_api == "clip":
    # Load model using original clip implementation
    model, preprocess = clip.load("ViT-L/14", device=device, jit=False)[
        0
    ], get_transform(n_px=224)
    model.load_state_dict(
        torch.hub.load_state_dict_from_url(
            "https://aimslab.cs.washington.edu/MONET/weight_clip.pt",
            map_location=torch.device("cpu"),
        )
    )
    model.eval()
    print("model was loaded using original clip implementation")
else:
    # Load model using huggingface clip implementation
    processor_hf = AutoProcessor.from_pretrained("chanwkim/monet")
    model_hf = AutoModelForZeroShotImageClassification.from_pretrained("chanwkim/monet")
    model_hf.to(device)
    model_hf.eval()
    print("model was loaded using huggingface clip implementation")

## Define dataset

In [None]:
import os

annotated = True

data_dir = os.path.join(os.getcwd(), "data")
image_dir = os.path.join(data_dir, "Fitzpatric_subset")
if annotated:
    data_csv = os.path.join(data_dir, "annotated_data.csv")
else:
    data_csv = os.path.join(
        data_dir, "data.csv"
    )  # Load annotated data, the score for each concept is available, the score is generated by MONET

data = pd.read_csv(data_csv)
suffix = ".jpg"
image_path_list = data["image_path"]

image_path_list = [
    os.path.join(image_dir, f"{path}{suffix}") for path in image_path_list
]
data["path"] = image_path_list

print(f"total number of images = {len(image_path_list)}")
data.head()

image_dataset = ImageDataset(
    image_path_list[
        :20
    ],  # pick 20 images for demo only, don't slice if you want to sort and plot entire dataset
    preprocess,
)

## Get image embedding

In [None]:
# Run to let MONET generate the concept scores, otherwise skip this entire section

dataloader = torch.utils.data.DataLoader(
    image_dataset,
    batch_size=16,
    num_workers=0,
    collate_fn=custom_collate,
    shuffle=False,
)


def batch_func(batch):
    with torch.no_grad():
        if model_api == "clip":
            image_features = model.encode_image(batch["image"].to(device))
        else:
            image_features = model_hf.get_image_features(batch["image"].to(device))

    # print("path:")
    # print(batch["path"])
    return {
        "image_features": image_features.detach().cpu(),
        "metadata": batch["metadata"],
    }


image_embedding = dataloader_apply_func(
    dataloader=dataloader,
    func=batch_func,
    collate_fn=custom_collate_per_key,
)

print(f"embedding shape = {image_embedding["image_features"].shape}")

## Get concept embedding

In [None]:
def get_prompt_embedding(
    concept_term_list=[],
    prompt_template_list=[
        "This is skin image of {}",
        "This is dermatology image of {}",
        "This is image of {}",
    ],
    prompt_ref_list=[
        ["This is skin image"],
        ["This is dermatology image"],
        ["This is image"],
    ],
):
    """
    Generate prompt embeddings for a concept

    Args:
        concept_term_list (list): List of concept terms that will be used to generate prompt target embeddings.
        prompt_template_list (list): List of prompt templates.
        prompt_ref_list (list): List of reference phrases.

    Returns:
        dict: A dictionary containing the normalized prompt target embeddings and prompt reference embeddings.
    """
    # target embedding
    prompt_target = [
        [prompt_template.format(term) for term in concept_term_list]
        for prompt_template in prompt_template_list
    ]  # [3, n] where n = num_concept
    # print(prompt_target)
    prompt_target_tokenized = [
        clip.tokenize(prompt_list, truncate=True) for prompt_list in prompt_target
    ]
    # print(prompt_target_tokenized)
    # print(prompt_target_tokenized[0].shape)
    with torch.no_grad():
        prompt_target_embedding = torch.stack(
            [
                model.encode_text(prompt_tokenized.to(next(model.parameters()).device))
                .detach()
                .cpu()
                # model_hf.get_text_features(prompt_tokenized.to(next(model.parameters()).device)).detach().cpu()
                for prompt_tokenized in prompt_target_tokenized
            ]
        )  # [3, 77]
    # print(prompt_target_embedding.shape)
    prompt_target_embedding_norm = (
        prompt_target_embedding / prompt_target_embedding.norm(dim=2, keepdim=True)
    )

    # reference embedding
    prompt_ref_tokenized = [
        clip.tokenize(prompt_list, truncate=True) for prompt_list in prompt_ref_list
    ]
    with torch.no_grad():
        prompt_ref_embedding = torch.stack(
            [
                model.encode_text(prompt_tokenized.to(next(model.parameters()).device))
                .detach()
                .cpu()
                # model_hf.get_text_features(prompt_tokenized.to(next(model.parameters()).device)).detach().cpu()
                for prompt_tokenized in prompt_ref_tokenized
            ]
        )
    prompt_ref_embedding_norm = prompt_ref_embedding / prompt_ref_embedding.norm(
        dim=2, keepdim=True
    )

    return {
        "prompt_target_embedding_norm": prompt_target_embedding_norm,
        "prompt_ref_embedding_norm": prompt_ref_embedding_norm,
    }

In [None]:
# For the concept "bullae", we here use the terms "bullae" and "blister" to generate the prompt embedding.
artifacts = [
    "pen",
    "hair strands",
    "nail",
    "finger",
    "ear",
    "eye",
    "nostril",
    "lip",
]

In [None]:
# Sort all images base on the scores for each artifact and display if we have all the scores already, this can take sometime because of the large number of images.

desc = True
for artifact in artifacts[6:]:
    scores = data[artifact].values
    example_per_concept = len(scores)

    fig = plt.figure(
        figsize=(100 * 2, 1.3 * (example_per_concept // 100 + 1) * 2), facecolor="white"
    )

    # Main GridSpec (num_artifacts row, 1 column)
    main_gs = gridspec.GridSpec(1, 1, figure=fig)

    # Nested GridSpec within the first subplot of main GridSpec
    # Adjust rows based on examples per concept
    nested_gs = gridspec.GridSpecFromSubplotSpec(
        example_per_concept // 100
        + (1 if example_per_concept % 100 > 0 else 0),  # rows
        100,
        subplot_spec=main_gs[0],
        wspace=0,
        hspace=0.1,
        width_ratios=[1] * 100,  # left box narrower
    )

    # Dictionary to store axes for later use
    axd = {}
    for rank_num in range(example_per_concept):
        ax = plt.Subplot(fig, nested_gs[rank_num])
        fig.add_subplot(ax)

        if desc:
            path = image_path_list[np.argsort(scores)[::-1][rank_num]]
        else:
            path = image_path_list[np.argsort(scores)[rank_num]]

        # Generate a simple pattern for demonstration
        image = Image.open(path)
        # Display the image
        ax.imshow(preprocess.transforms[1](preprocess.transforms[0](image)))
        ax.axis("off")  # Remove axes for a cleaner look
        ax.set_facecolor("white")

        # Example key, replace with your actual key
        plot_key = rank_num
        axd[plot_key] = ax

        axd[plot_key].set_title(
            f"Rank: {rank_num}\nScore: {scores[np.argsort(scores)[::-1][rank_num]]:.2f}",
            fontsize=8,
        )

    plt.tight_layout()
    plt.suptitle(
        f"{artifact} Scores {"Desc" if desc else "Asc"}",
        fontsize=16,
        fontweight="bold",
        y=1.005,
    )
    plt.show()

In [None]:
concept_embedding_dict = {}

for artifact in artifacts:
    concept_embedding_dict[artifact] = get_prompt_embedding(
        concept_term_list=[artifact]
    )

## Calculate concept presence score

In [None]:
def calculate_concept_presence_score(
    image_features_norm,
    prompt_target_embedding_norm,
    prompt_ref_embedding_norm,
    temp=1 / np.exp(4.5944),
):
    """
    Calculates the concept presence score based on the given image features and concept embeddings.

    Args:
        image_features_norm (numpy.Tensor): Normalized image features.
        prompt_target_embedding_norm (torch.Tensor): Normalized concept target embedding.
        prompt_ref_embedding_norm (torch.Tensor): Normalized concept reference embedding.
        temp (float, optional): Temperature parameter for softmax. Defaults to 1 / np.exp(4.5944).

    Returns:
        np.array: Concept presence score.
    """

    target_similarity = (
        prompt_target_embedding_norm.float() @ image_features_norm.T.float()
    )
    ref_similarity = prompt_ref_embedding_norm.float() @ image_features_norm.T.float()

    target_similarity_mean = target_similarity.mean(dim=[1])
    ref_similarity_mean = ref_similarity.mean(axis=1)

    concept_presence_score = scipy.special.softmax(
        [target_similarity_mean.numpy() / temp, ref_similarity_mean.numpy() / temp],
        axis=0,
    )[0, :].mean(axis=0)

    return concept_presence_score


image_features_norm = image_embedding["image_features"] / image_embedding[
    "image_features"
].norm(dim=1, keepdim=True)

concept_presence_score_dict = {}
for artifact, concept_embedding in concept_embedding_dict.items():
    concept_presence_score = calculate_concept_presence_score(
        image_features_norm=image_features_norm,
        prompt_target_embedding_norm=concept_embedding["prompt_target_embedding_norm"],
        prompt_ref_embedding_norm=concept_embedding["prompt_ref_embedding_norm"],
    )
    concept_presence_score_dict[artifact] = concept_presence_score

# data.to_csv("annotated_data.csv", index=False)

## Plot top 10 images

In [None]:
example_per_concept = 10

num_artifacts = len(artifacts)

# Create a figure
fig = plt.figure(
    figsize=(10 * 2, num_artifacts / 2 * (example_per_concept // 10 + 1) * 2)
)

# Main GridSpec (num_artifacts row, 1 column)
main_gs = gridspec.GridSpec(num_artifacts, 1, figure=fig)


for idx, (artifact, concept_presence_score) in enumerate(
    concept_presence_score_dict.items()
):
    # Nested GridSpec within the first subplot of main GridSpec
    # Adjust rows based on examples per concept
    nested_gs = gridspec.GridSpecFromSubplotSpec(
        example_per_concept // 10 + (1 if example_per_concept % 10 > 0 else 0),  # rows
        10 + 1,
        subplot_spec=main_gs[idx],
        wspace=0,
        hspace=0.1,
        width_ratios=[1.5] + [1] * 10,  # left box narrower
    )

    label_ax = plt.Subplot(fig, nested_gs[0])
    fig.add_subplot(label_ax)
    label_ax.axis("off")
    label_ax.text(0.5, 0.5, artifact, fontsize=12, ha="center", va="center")

    # Dictionary to store axes for later use
    axd = {}
    for rank_num in range(example_per_concept + 1)[1:]:
        ax = plt.Subplot(fig, nested_gs[rank_num])
        fig.add_subplot(ax)

        path = image_path_list[np.argsort(concept_presence_score)[::-1][rank_num]]
        # Generate a simple pattern for demonstration
        image = Image.open(path)
        # Display the image
        ax.imshow(preprocess.transforms[1](preprocess.transforms[0](image)))
        ax.axis("off")  # Remove axes for a cleaner look

        # Example key, replace with your actual key
        plot_key = rank_num
        axd[plot_key] = ax

        axd[plot_key].set_title(
            f"Rank: {rank_num}\nScore: {concept_presence_score[np.argsort(concept_presence_score)[::-1][rank_num]]:.2f}",
            fontsize=8,
        )


plt.tight_layout()
plt.show()