In [1]:
import sys
path = ".."
if path not in sys.path:
    sys.path.insert(0, path)

In [2]:
from sklearn.metrics.pairwise import cosine_similarity
from data_retrieval import lipade_groundtruth
import matplotlib.pyplot as plt
from PIL import Image
from tqdm import tqdm
import numpy as np
import torch
import clip

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
modelName = 'ViT-L/14'
model, preprocess = clip.load(modelName, device)

In [3]:
x,m,y = lipade_groundtruth.getDataset(mode="similar")

# Caption retrieval

In [4]:
captions_per_image = [captions.split('\n') for captions in m[1]]
images_per_captions = {}

keys = [caption.split(';')[0] for caption in captions_per_image[0]]
for key in keys:
    images_per_captions[key] = []

for i in range(len(captions_per_image)):
    for caption in captions_per_image[i]:
        c = caption.split(';')
        key = c[0]
        images_per_captions[key].append(c[1])

# Caption evaluation

## Image embeddings

In [5]:
image_embeddings = []
for i in tqdm(range(len(x))):
    input = Image.open(x[i]).convert('RGB')
    preprocessed = preprocess(input).unsqueeze(0).to(device)
    with torch.no_grad():
        output = model.encode_image(preprocessed)
    image_embeddings.append(output.cpu().numpy())

 20%|█▉        | 55/279 [00:06<00:25,  8.84it/s]


KeyboardInterrupt: 

In [None]:
image_embeddings = np.array(image_embeddings).reshape(len(x), -1)
image_embeddings.shape

## Caption embeddings

In [None]:
caption_embeddings_per_prompt = []
for key in keys:
    caption_embeddings = []
    captions = images_per_captions[key]
    for i in tqdm(range(len(x))):
        input = [captions[i]]
        preprocessed = clip.tokenize(input).to(device)
        with torch.no_grad():
            output = model.encode_text(preprocessed)
        caption_embeddings.append(output.cpu().numpy())
    caption_embeddings_per_prompt.append((key, np.array(caption_embeddings).reshape(len(x), -1)))

In [None]:
caption_embeddings_per_prompt[0][1].shape

## Similarity

In [None]:
similarity_per_prompt = []
for i in range(len(caption_embeddings_per_prompt)):
    prompt, caption_embeddings = caption_embeddings_per_prompt[i]
    similarity_per_prompt.append((prompt, cosine_similarity(image_embeddings, caption_embeddings)))

## Visualisation

In [None]:
images = []
for i in range(len(x)):
    images.append(Image.open(x[i]).convert('RGB'))

In [None]:
prompt_i = 0
prompt, M = similarity_per_prompt[prompt_i]

limit = 10

In [None]:
for i in range(len(x))[:limit]:
    row = M[i, :]
    sorted_indices = np.argsort(-row)
    sorted_scores = row[sorted_indices]
    
    if len(sorted_scores) >= 5:
        cutoff = sorted_scores[4]
    else:
        cutoff = sorted_scores[-1]
    
    mask = sorted_scores >= cutoff
    top_indices = sorted_indices[mask]
    
    plt.imshow(images[i])
    plt.show()
    for idx in top_indices:
        score = row[idx]
        if idx == i:
            prefix = "CORRECT - "
        elif y[idx] == y[i]:
            prefix = "SIMILAR - "
        else:
            prefix = "INCORRECT - "
        print(f"{prefix}{score:.4f} : {captions[idx]}")

In [None]:
for j in range(len(x))[:limit]:
    col = M[:, j]
    sorted_indices = np.argsort(-col)
    sorted_scores = col[sorted_indices]

    if len(sorted_scores) >= 5:
        cutoff = sorted_scores[4]
    else:
        cutoff = sorted_scores[-1]

    mask = sorted_scores >= cutoff
    top_indices = sorted_indices[mask]
    
    print(captions[j])
    num_images = len(top_indices)
    plt.figure(figsize=(15, 3))
    for pos, img_idx in enumerate(top_indices):
        plt.subplot(1, num_images, pos + 1)
        plt.imshow(images[img_idx])
        
        width, height = images[img_idx].size[0], images[img_idx].size[1]
        if img_idx == j:
            plt.gca().add_patch(plt.Rectangle((0, 0), width - 1, height - 1, edgecolor='green', linewidth=3, fill=False))
        elif y[img_idx] == y[j]:
            plt.gca().add_patch(plt.Rectangle((0, 0), width - 1, height - 1, edgecolor='yellow', linewidth=3, fill=False))
        else:
            plt.gca().add_patch(plt.Rectangle((0, 0), width - 1, height - 1, edgecolor='red', linewidth=3, fill=False))
        plt.title(f"Sim: {col[img_idx]:.4f}")
        plt.axis('off')
    plt.tight_layout()
    plt.show()