In [None]:
# imports
import os
import numpy as np
import json
from matplotlib import pyplot as plt
from PIL import Image
%load_ext autoreload
%autoreload 2
from clip_model import CLIPModel
from keyword_extractor import KeyWordExtractor
from pipeline import Pipeline

In [None]:
# set up pipeline
device = 'mps'
embedding_database_path = 'tests/flickr8k/img_embeddings.h5'
clipmodel = CLIPModel(device)
keywordextractor = KeyWordExtractor('kw-extractor-tokenizer', 'kw-extactor-model')
pipeline = Pipeline(clipmodel,keywordextractor, embedding_database_path, device)

In [None]:
# get some matches
matches = pipeline.run("dogs playing in a park", True)

In [None]:
def plot_matches(selected_imgs):
    n_images = len(selected_imgs)
    cols = 4  # Adjust this for layout
    rows = (n_images + cols - 1) // cols  # Round up division

    plt.figure(figsize=(cols * 4, rows * 4))

    for i, path in enumerate(selected_imgs):
        img = Image.open(path).convert("RGB")
        plt.subplot(rows, cols, i + 1)
        plt.imshow(img)
        plt.title(path.split("/")[-1])  # Show filename as title
        plt.axis("off")

    plt.tight_layout()
    plt.show()
plot_matches(matches[:5])

In [None]:
# recall@k benchmark
def benchmark_rk(captions_file):
    
    with open(captions_file, 'r', encoding='utf-8') as file:
        captions_dict = json.load(file)

    r1, r5, r10 = 0, 0, 0    
    
    imgs = list(captions_dict.keys())
    for i, img in enumerate(imgs):
        caption = captions_dict[img][0] # get the first caption of the image

        matches = pipeline.run(caption, True)[:10]

        img = os.path.join('/Users/benjaminkasper/Documents/Uni/RCI/Module/in2106/dataset/Flicker8k_Dataset/', img)
        
        # R @ 1
        r1 += img in matches[:1]

        # R @ 5
        r5 += img in matches[:5]

        # R @ 10
        r10 += img in matches[:10]

        print(i, r1, r5, r10)

    print(r1/len(imgs), r5/len(imgs), r10/len(imgs))

In [None]:
# semantic benchmark
def benchmark_semantic(captions_file):
    
    with open(captions_file, 'r', encoding='utf-8') as file:
        captions_dict = json.load(file)

    sum_similarity = 0   
    
    imgs = list(captions_dict.keys())
    for i, img in enumerate(imgs): # iterate through all images in the dataset

        caption = captions_dict[img][0] # get the first caption of the image

        matches = pipeline.run(caption, True)[:1] # get top 10 matches of the query

        img = os.path.join('/Users/benjaminkasper/Documents/Uni/RCI/Module/in2106/dataset/Flicker8k_Dataset/', img)
        
        matched_image_emb = clipmodel.get_img_embedding(matches[0]) # get top matched images and compute embedding
        queried_image_emb = clipmodel.get_img_embedding(img) # get original image and it's embedding

        similarity = np.inner(matched_image_emb, queried_image_emb) # compute semantic similarity: retrieved image is original image
        
        sum_similarity += similarity[0,0]

        print(i, sum_similarity)

    print(sum_similarity/(imgs))

In [None]:
benchmark_semantic('tests/flickr8k/flickr8k_captions.json')