# Meme Retrieval Evaluation
This notebook implements IR metrics (MRR, Precision@K, Recall@K, mAP) for text-to-image retrieval, and provides analysis and visualization including Precision-Recall curves and MRR breakdowns by meme category and query length.

In [5]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict
from sklearn.metrics import precision_recall_curve, average_precision_score

# Load test data
df_test_meme = pd.read_csv('test_meme.csv', names=['query', 'label','directory'])
df_test_template = pd.read_csv('test_template.csv', names=['query', 'label','directory'])
df_test_meme = df_test_meme.dropna()
df_test_template = df_test_template.dropna()
df_test_meme.head()

Unnamed: 0,query,label,directory
0,student life memes,2.0,test_images/meme_submissions_1490490.png
1,final exam memes,2.0,test_images/sad-baby_92.png
2,data science memes,2.0,test_images/John_Daly_and_Tiger_Woods_1.png
3,machine learning memes,1.0,test_images/batman-and-superman_20.png
4,math major memes,2.0,test_images/big-book-small-book_21.png


In [6]:
# --- 1. IR Metrics ---
def mean_reciprocal_rank(rs):
    return np.mean([1/(np.where(r)[0][0]+1) if np.any(r) else 0 for r in rs])

def precision_at_k(r, k):
    r = np.asarray(r)[:k]
    return np.mean(r)

def recall_at_k(r, k, total_relevant):
    r = np.asarray(r)[:k]
    return np.sum(r) / total_relevant if total_relevant else 0

def average_precision(r):
    r = np.asarray(r)
    out = [precision_at_k(r, k+1) for k in range(len(r)) if r[k]]
    return np.mean(out) if out else 0

def mean_average_precision(rs):
    return np.mean([average_precision(r) for r in rs])

In [7]:
# --- 2. Build Relevance Judgments ---
def build_relevance_judgments(df):
    rel = defaultdict(dict)
    for _, row in df.iterrows():
        q = row['query']
        d = row['directory']
        l = row['label']
        rel[q][d] = int(l)
    return rel

rel_meme = build_relevance_judgments(df_test_meme)
rel_template = build_relevance_judgments(df_test_template)

In [10]:
# --- Connect to your model for retrieval ---
from frontend import main  # main(query) returns image path for a given query

def model_retrieve_text_to_image(query, candidates):
    # Use your model to get the top_k image paths for a text query
    # Here, we assume main(query) returns the best image path; you may want to adapt for top_k
    result = main(query)
    # If your model supports batch or top-k, replace this logic accordingly
    if result in candidates:
        ranked = [result] + [c for c in candidates if c != result]
    else:
        ranked = candidates[:]
    return ranked

# --- Use these in your evaluation functions ---
def evaluate_text_to_image_with_model(rel_judgments, k=5):
    queries = list(rel_judgments.keys())
    all_dirs = set(d for q in rel_judgments for d in rel_judgments[q])
    rs = []
    for q in queries:
        ranked = model_retrieve_text_to_image(q, list(all_dirs))
        rels = [rel_judgments[q].get(d, 0) for d in ranked]
        binary_rels = [1 if r == 2 else 0 for r in rels]
        rs.append(binary_rels)
    print("Text-to-Image (Model):")
    print("MRR:", mean_reciprocal_rank(rs))
    print("mAP:", mean_average_precision(rs))
    print("Precision@K:", np.mean([precision_at_k(r, k) for r in rs]))
    print("Recall@K:", np.mean([recall_at_k(r, k, sum(1 for v in r if v == 1)) for r in rs]))
    return rs

In [None]:
evaluate_text_to_image_with_model(rel_meme)

## Analysis & Visualization

In [None]:
# --- 5. Precision-Recall Curve ---
def plot_pr_curve_for_query(rel_judgments, query, ranked_dirs):
    y_true = [1 if rel_judgments[query].get(d, 0) == 2 else 0 for d in ranked_dirs]
    y_scores = list(reversed(range(len(ranked_dirs))))  # Simulate scores (replace with model scores)
    precision, recall, _ = precision_recall_curve(y_true, y_scores)
    ap = average_precision_score(y_true, y_scores)
    plt.plot(recall, precision, marker='.')
    plt.title(f'PR Curve for "{query}" (AP={ap:.2f})')
    plt.xlabel('Recall')
    plt.ylabel('Precision')
    plt.show()