# Retrieve Text

In [1]:
import numpy as np
import faiss
import pandas as pd

data_path = "/home/horton/local-repo/demo-run-clip/data-windows/stage03_image_caption/clip-roberta"
# data_path = "/home/horton/local-repo/demo-run-clip/data-windows/stage03_image_caption/sclip-finetuned"

# Load data
df = pd.read_parquet(data_path, engine="pyarrow")[:100]

In [2]:
df.columns

Index(['document_id', 'caption', 'image_path', 'image_type', 'first_level_dir',
       'second_level_dir', 'fit_context', 'image_file_exist', 'original_index',
       'encoded_caption', 'encoded_image', 'load_status'],
      dtype='object')

In [3]:
# Create Indexes
# Build classifier
vectors = np.asarray(list(df["encoded_caption"]))
vector_dimension = 512
index = faiss.IndexFlatL2(vector_dimension)
faiss.normalize_L2(vectors)
index.add(vectors)

def retrieve_text(encoded_image, k=4):
    search_key = np.asarray([encoded_image])
    faiss.normalize_L2(search_key)
    # Search
    _, ann = index.search(search_key, k=k)
    return set(ann[0])

from sklearn.metrics import accuracy_score

def evalSearch(df, k, samples):
    sample_set = df[:samples].copy()
    y_true = list(sample_set["original_index"])

    sample_set["top_k_set"] = sample_set.apply(lambda row: retrieve_text(row["encoded_image"], k), axis=1)
    y_pred = sample_set.apply(lambda row: row["original_index"] if row["original_index"] in row["top_k_set"] else list(row["top_k_set"])[0], axis=1)

    print("{:12d}| {:12.4f}".format(k, accuracy_score(y_true, y_pred)))
    # return sample_set

for i in range(4, 100, 2):
    evalSearch(df, i, 100)

           4|       0.0200
           6|       0.0300
           8|       0.0500
          10|       0.0700
          12|       0.0900
          14|       0.1000
          16|       0.1200
          18|       0.1300
          20|       0.1600
          22|       0.1600
          24|       0.1700
          26|       0.1700
          28|       0.1800
          30|       0.2200
          32|       0.2500
          34|       0.2600
          36|       0.2700
          38|       0.3200
          40|       0.3500
          42|       0.3700
          44|       0.4100
          46|       0.4600
          48|       0.4600
          50|       0.5000
          52|       0.5200
          54|       0.5300
          56|       0.5800
          58|       0.5900
          60|       0.6100
          62|       0.6500
          64|       0.6800
          66|       0.7200
          68|       0.7300
          70|       0.7400
          72|       0.7600
          74|       0.7900
          76|       0.8000
 