In [1]:
import os
import numpy as np
import faiss
from scipy.io import loadmat
import plotly.graph_objects as go



In [4]:
def load_templates(data_folder):
    """
    Load enrollment templates (excluding '_3') from .mat files.
    Returns: vectors (ndarray), labels (list)
    """
    files = sorted([f for f in os.listdir(data_folder) if f.endswith('.mat')])
    vectors, labels = [], []

    for file in files:
        if '_3' in file:  # Skip probe templates
            continue

        path = os.path.join(data_folder, file)
        data = loadmat(path)

        vec = data['template'].flatten().astype('float32')
        pid, side = file.split('_')[0], file.split('_')[1]
        labels.append(f"{pid}_{side}")
        vectors.append(vec)

    return np.vstack(vectors), labels


def create_hnsw_index(vectors, M=32, ef_construction=100, ef_search=50):
    """
    Create an HNSW index with FAISS.
    """
    d = vectors.shape[1]
    index = faiss.IndexHNSWFlat(d, M)
    index.hnsw.efConstruction = ef_construction
    index.add(vectors)
    index.hnsw.efSearch = ef_search
    return index


def evaluate_accuracy(index, labels, data_folder, k):
    """
    Evaluate top-k accuracy using probe templates ('_3').
    Returns: accuracy (float), total_queries (int)
    """
    query_files = sorted([f for f in os.listdir(data_folder) if '_3' in f and f.endswith('.mat')])
    correct, total = 0, 0

    for qf in query_files:
        qpid = qf.split('_')[0]
        qpath = os.path.join(data_folder, qf)

        qvec = loadmat(qpath)['template'].flatten().astype('float32').reshape(1, -1)
        distances, indices = index.search(qvec, k)

        predicted_pids = [labels[i].split('_')[0] for i in indices[0]]
        if qpid in predicted_pids:
            correct += 1
        total += 1

    return (correct / total if total > 0 else 0), total


def plot_hit_rate_plotly(data_folder, k_max=50):
    """
    Plot interactive Hit Rate vs k using Plotly.
    """
    vectors, labels = load_templates(data_folder)
    index = create_hnsw_index(vectors)

    k_values = list(range(1, k_max + 1))
    hit_rates = []
    penetration_rates = []

    total_queries = None
    for k in k_values:
        acc, total = evaluate_accuracy(index, labels, data_folder, k)
        hit_rates.append(acc * 100)  # in %
        if total_queries is None:
            total_queries = total
        penetration_rates.append((k / total_queries) * 100)  # in %

    
    with open(csv_path, mode='w', newline='') as file:
        writer = csv.writer(file)
        writer.writerow(["k", "hit_rate_percent", "penetration_rate_percent"])
        for k_val, hr, pr in zip(k_values, hit_rates, penetration_rates):
            writer.writerow([k_val, hr, pr])
    print(f"Results saved to {csv_path}")


    # Plot with Plotly
    fig = go.Figure()

    fig.add_trace(go.Scatter(
        x=k_values, y=hit_rates,
        mode='lines+markers',
        name='Hit Rate (%)',
        line=dict(color='blue', width=2)
    ))

    fig.add_trace(go.Scatter(
        x=k_values, y=penetration_rates,
        mode='lines+markers',
        name='Penetration Rate (%)',
        line=dict(color='red', dash='dash')
    ))

    fig.update_layout(
        title="Hit Rate and Penetration Rate vs k",
        xaxis_title="k (Top-k)",
        yaxis_title="Percentage (%)",
        legend=dict(x=0.7, y=0.1),
        template="plotly_white"
    )

    fig.show()
    return k_values, hit_rates, penetration_rates

In [5]:
data_folder = "templates/CASIA1/features"
plot_hit_rate_plotly(data_folder, k_max=54)


NameError: name 'csv_path' is not defined