In [None]:
import sys
path = ".."
if path not in sys.path:
    sys.path.insert(0, path)

In [None]:
from sklearn.metrics.pairwise import cosine_similarity
from data_retrieval import lipade_groundtruth
from absolute_path import absolutePath
import matplotlib.pyplot as plt
from PIL import Image
from tqdm import tqdm
import numpy as np
import torch
import clip

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
modelName = 'ViT-L/14'
model, preprocess = clip.load(modelName, device)

corpus = "lipade_groundtruth"
path = absolutePath + 'data_generation/generated/' + corpus + '.csv'

In [None]:
x,m,_ = lipade_groundtruth.getDataset(mode="unique")
x = np.array(x)[m[2]]
captions = m[1]

# Caption retrieval

In [None]:
images_per_captions = {}
images = []

for file,prompt in captions.keys():
    images_per_captions[prompt] = []

for i in range(len(x)):
    f = '/'.join(x[i].split('/')[-2:])
    if (f, prompt) not in captions.keys():
        images.append(x[i])

x = x.tolist()
for im in images:
    x.remove(im)

for prompt in images_per_captions.keys():
    for i in range(len(x)):
        f = '/'.join(x[i].split('/')[-2:])
        images_per_captions[prompt].append(captions[(f, prompt)])

# Caption evaluation

## Image embeddings

In [None]:
image_embeddings = []
for i in tqdm(range(len(x))):
    input = Image.open(x[i]).convert('RGB')
    preprocessed = preprocess(input).unsqueeze(0).to(device)
    with torch.no_grad():
        output = model.encode_image(preprocessed)
    image_embeddings.append(output.cpu().numpy())

In [None]:
image_embeddings = np.array(image_embeddings).reshape(len(x), -1)
image_embeddings.shape

## Caption embeddings

In [None]:
caption_embeddings_per_prompt = []
for key in images_per_captions.keys():
    caption_embeddings = []
    captions = images_per_captions[key]
    for i in tqdm(range(len(x))):
        input = [captions[i]]
        preprocessed = clip.tokenize(input).to(device)
        with torch.no_grad():
            output = model.encode_text(preprocessed)
        caption_embeddings.append(output.cpu().numpy())
    caption_embeddings_per_prompt.append((key, np.array(caption_embeddings).reshape(len(x), -1)))

In [None]:
caption_embeddings_per_prompt[0][1].shape

## Similarity

In [None]:
similarity_per_prompt = []
for i in range(len(caption_embeddings_per_prompt)):
    prompt, caption_embeddings = caption_embeddings_per_prompt[i]
    similarity_per_prompt.append((prompt, cosine_similarity(image_embeddings, caption_embeddings)))

## Visualisation

In [None]:
prompt_i = 0
prompt, M = similarity_per_prompt[prompt_i]

limit = 50

In [None]:
for i in range(len(x))[:limit]:
    row = M[i, :]
    sorted_indices = np.argsort(-row)
    sorted_scores = row[sorted_indices]
    
    if len(sorted_scores) >= 5:
        cutoff = sorted_scores[4]
    else:
        cutoff = sorted_scores[-1]
    
    mask = sorted_scores >= cutoff
    top_indices = sorted_indices[mask]
    
    plt.imshow(Image.open(x[i]).convert('RGB'))
    plt.show()
    for idx in top_indices:
        score = row[idx]
        if idx == i:
            prefix = "CORRECT - "
        else:
            prefix = ""
        print(f"{prefix}{score:.4f} : {captions[idx]}")

In [None]:
for j in range(len(x))[:limit]:
    col = M[:, j]
    sorted_indices = np.argsort(-col)
    sorted_scores = col[sorted_indices]

    if len(sorted_scores) >= 5:
        cutoff = sorted_scores[4]
    else:
        cutoff = sorted_scores[-1]

    mask = sorted_scores >= cutoff
    top_indices = sorted_indices[mask]
    
    print(captions[j])
    num_images = len(top_indices)
    plt.figure(figsize=(15, 3))
    for pos, img_idx in enumerate(top_indices):
        plt.subplot(1, num_images, pos + 1)
        img = Image.open(x[img_idx]).convert('RGB')
        plt.imshow(img)
        
        width, height = img.size[0], img.size[1]
        if img_idx == j:
            plt.gca().add_patch(plt.Rectangle((0, 0), width - 1, height - 1, edgecolor='green', linewidth=3, fill=False))
        else:
            plt.gca().add_patch(plt.Rectangle((0, 0), width - 1, height - 1, edgecolor='red', linewidth=3, fill=False))
        plt.title(f"Sim: {col[img_idx]:.4f}")
        plt.axis('off')
    plt.tight_layout()
    plt.show()

# Evaluation
Position moyenne relative

In [None]:
for prompt, M in similarity_per_prompt:
    positions = []

    for i in range(len(x)):
        row = M[i, :]
        sorted_indices = np.argsort(-row)
        sorted_scores = row[sorted_indices]
        positions.append([np.argwhere(sorted_indices == i)[0][0]])

    for j in range(len(x)):
        col = M[:, j]
        sorted_indices = np.argsort(-col)
        sorted_scores = col[sorted_indices]
        positions[j].append(np.argwhere(sorted_indices == i)[0][0])
    
    positions = np.array(positions)

    print(prompt, ':', 1 - positions.mean(axis=0) / len(M))


Score strict k : "% d'images d'origines classes dans les k premiers %"

In [None]:
for k in [1, 5, 10, 25, 50]:
    print(k)
    for prompt, M in similarity_per_prompt:
        N = len(M)

        scores = []

        for i in range(len(x)):
            row = M[i, :]
            sorted_indices = np.argsort(-row)
            sorted_scores = row[sorted_indices]
            if np.argwhere(sorted_indices == i)[0][0] < k * N // 100:
                scores.append([1])
            else:
                scores.append([0])

        for j in range(len(x)):
            col = M[:, j]
            sorted_indices = np.argsort(-col)
            sorted_scores = col[sorted_indices]
            if np.argwhere(sorted_indices == i)[0][0] < k * N // 100:
                scores[j].append(1)
            else:
                scores[j].append(0)
        
        scores = np.array(scores)

        print(prompt, ':', scores.mean(axis=0))