In [6]:
import os
import numpy as np
import faiss
from scipy.io import loadmat
import plotly.graph_objects as go
import csv



In [22]:
###HNSW_FLAT

In [48]:
def load_templates(data_folder):
    """
    Load enrollment templates (excluding '_3') from .mat files.
    Returns: vectors (ndarray), labels (list)
    """
    files = sorted([f for f in os.listdir(data_folder) if f.endswith('.mat')])
    vectors, labels = [], []

    for file in files:
        if '_3' in file:  # Skip probe templates
            continue

        path = os.path.join(data_folder, file)
        data = loadmat(path)

        vec = data['template'].flatten().astype('float32')
        pid, side = file.split('_')[0], file.split('_')[1]
        labels.append(f"{pid}_{side}")
        vectors.append(vec)

    return np.vstack(vectors), labels


def create_hnsw_index(vectors, M=12, ef_construction=30, ef_search=15):
    """
    Create an HNSW index with FAISS.
    """
    d = vectors.shape[1]
    index = faiss.IndexHNSWFlat(d, M)
    index.hnsw.efConstruction = ef_construction
    index.add(vectors)
    index.hnsw.efSearch = ef_search
    return index


def evaluate_accuracy(index, labels, data_folder, k):
    """
    Evaluate top-k accuracy using probe templates ('_3').
    Returns: accuracy (float), total_queries (int)
    """
    query_files = sorted([f for f in os.listdir(data_folder) if '_3' in f and f.endswith('.mat')])
    correct, total = 0, 0

    for qf in query_files:
        qpid = qf.split('_')[0]
        qpath = os.path.join(data_folder, qf)

        qvec = loadmat(qpath)['template'].flatten().astype('float32').reshape(1, -1)
        distances, indices = index.search(qvec, k)

        predicted_pids = [labels[i].split('_')[0] for i in indices[0]]
        if qpid in predicted_pids:
            correct += 1
        total += 1

    return (correct / total if total > 0 else 0)


def plot_hit_rate_plotly(data_folder, k_max=50):
    vectors, labels = load_templates(data_folder)
    index = create_hnsw_index(vectors)

    k_values = list(range(1, k_max + 1))
    hit_rates = []

    for k in k_values:
        acc = evaluate_accuracy(index, labels, data_folder, k)
        hit_rates.append(acc * 100)  # in %

    # ---- Save results to CSV ----
    with open("/home/nishkal/sg/iris_indexing/results/CASIA1_hnsw_2_results.csv", mode='w', newline='') as file:
        writer = csv.writer(file)
        writer.writerow(["k", "hit_rate_percent"])
        for k_val, hr in zip(k_values, hit_rates):
            writer.writerow([k_val, hr])
    # print(f"Results saved to {csv_path}")

    # ---- Print Y-range ----
    # print(f"Hit Rate Range: {min(hit_rates):.2f}% to {max(hit_rates):.2f}%")

    # ---- Plot interactive graph ----
    fig = go.Figure()
    fig.add_trace(go.Scatter(
        x=k_values, y=hit_rates,
        mode='lines+markers',
        name='Hit Rate (%)',
        line=dict(color='blue', width=2)
    ))
    fig.update_layout(
        title="Hit Rate vs k",
        xaxis_title="k (Top-k)",
        yaxis_title="Hit Rate (%)",
        template="plotly_white"
    )
    fig.show()
    return k_values, hit_rates

In [49]:
data_folder = "templates/CASIA1/features"
plot_hit_rate_plotly(data_folder, k_max=108)


([1,
  2,
  3,
  4,
  5,
  6,
  7,
  8,
  9,
  10,
  11,
  12,
  13,
  14,
  15,
  16,
  17,
  18,
  19,
  20,
  21,
  22,
  23,
  24,
  25,
  26,
  27,
  28,
  29,
  30,
  31,
  32,
  33,
  34,
  35,
  36,
  37,
  38,
  39,
  40,
  41,
  42,
  43,
  44,
  45,
  46,
  47,
  48,
  49,
  50,
  51,
  52,
  53,
  54,
  55,
  56,
  57,
  58,
  59,
  60,
  61,
  62,
  63,
  64,
  65,
  66,
  67,
  68,
  69,
  70,
  71,
  72,
  73,
  74,
  75,
  76,
  77,
  78,
  79,
  80,
  81,
  82,
  83,
  84,
  85,
  86,
  87,
  88,
  89,
  90,
  91,
  92,
  93,
  94,
  95,
  96,
  97,
  98,
  99,
  100,
  101,
  102,
  103,
  104,
  105,
  106,
  107,
  108],
 [90.74074074074075,
  92.5925925925926,
  92.5925925925926,
  93.51851851851852,
  93.51851851851852,
  94.44444444444444,
  94.44444444444444,
  94.44444444444444,
  94.44444444444444,
  94.44444444444444,
  95.37037037037037,
  95.37037037037037,
  95.37037037037037,
  96.29629629629629,
  96.29629629629629,
  96.29629629629629,
  96.296296296296

In [15]:
def get_non_matching_queries(data_folder, k_max=50):
    """
    Returns a list of person IDs whose queries never match
    in any top-k (from 1 to k_max).
    """
    # Load enrollment templates
    vectors, labels = load_templates(data_folder)
    index = create_hnsw_index(vectors)

    # Get all probe files ('_3')
    query_files = sorted([f for f in os.listdir(data_folder) if '_3' in f and f.endswith('.mat')])
    never_matched_ids = []

    for qf in query_files:
        qpid = qf.split('_')[0]  # person ID
        qpath = os.path.join(data_folder, qf)
        qvec = loadmat(qpath)['template'].flatten().astype('float32').reshape(1, -1)

        matched = False
        for k in range(1, k_max + 1):
            distances, indices = index.search(qvec, k)
            predicted_pids = [labels[i].split('_')[0] for i in indices[0]]
            if qpid in predicted_pids:
                matched = True
                break  # stop checking higher k values

        if not matched:
            never_matched_ids.append(qpid)

    return never_matched_ids


In [20]:
data_folder = "templates/CASIA1/features"
k_max = 10

non_matching_ids = get_non_matching_queries(data_folder, k_max=k_max)
print(f"Number of queries that never matched: {len(non_matching_ids)}")
print("Person IDs:", non_matching_ids)

Number of queries that never matched: 4
Person IDs: ['028', '038', '042', '063']


In [None]:
### IVF

In [33]:
def load_templates(data_folder):
    """
    Load enrollment templates (excluding '_3') from .mat files.
    Returns: vectors (ndarray), labels (list)
    """
    files = sorted([f for f in os.listdir(data_folder) if f.endswith('.mat')])
    vectors, labels = [], []

    for file in files:
        if '_3' in file:  # Skip probe templates
            continue

        path = os.path.join(data_folder, file)
        data = loadmat(path)

        vec = data['template'].flatten().astype('float32')
        pid, side = file.split('_')[0], file.split('_')[1]
        labels.append(f"{pid}_{side}")
        vectors.append(vec)

    return np.vstack(vectors), labels


def create_ivf_index(vectors, nlist=5):
    """
    Create a standard IVF index with L2 metric.
    nlist: number of clusters (set low for small datasets, e.g., 10)
    """
    d = vectors.shape[1]
    quantizer = faiss.IndexFlatL2(d)  # quantizer for clustering
    index = faiss.IndexIVFFlat(quantizer, d, nlist, faiss.METRIC_L2)
    index.train(vectors)  # IVF requires training
    index.add(vectors)
    return index


def evaluate_accuracy(index, labels, data_folder, k):
    """
    Evaluate top-k accuracy using probe templates ('_3').
    Returns: accuracy (float), total_queries (int)
    """
    query_files = sorted([f for f in os.listdir(data_folder) if '_3' in f and f.endswith('.mat')])
    correct, total = 0, 0

    for qf in query_files:
        qpid = qf.split('_')[0]
        qpath = os.path.join(data_folder, qf)

        qvec = loadmat(qpath)['template'].flatten().astype('float32').reshape(1, -1)
        distances, indices = index.search(qvec, k)

        predicted_pids = [labels[i].split('_')[0] for i in indices[0]]
        if qpid in predicted_pids:
            correct += 1
        total += 1

    return (correct / total if total > 0 else 0)


def plot_hit_rate_plotly(data_folder, k_max=50):
    vectors, labels = load_templates(data_folder)
    index = create_ivf_index(vectors)

    k_values = list(range(1, k_max + 1))
    hit_rates = []

    for k in k_values:
        acc = evaluate_accuracy(index, labels, data_folder, k)
        hit_rates.append(acc * 100)  # in %

    # ---- Save results to CSV ----
    with open("/home/nishkal/sg/iris_indexing/results/CASIA1_IVF_results.csv", mode='w', newline='') as file:
        writer = csv.writer(file)
        writer.writerow(["k", "hit_rate_percent"])
        for k_val, hr in zip(k_values, hit_rates):
            writer.writerow([k_val, hr])
    # print(f"Results saved to {csv_path}")

    # ---- Print Y-range ----
    # print(f"Hit Rate Range: {min(hit_rates):.2f}% to {max(hit_rates):.2f}%")

    # ---- Plot interactive graph ----
    fig = go.Figure()
    fig.add_trace(go.Scatter(
        x=k_values, y=hit_rates,
        mode='lines+markers',
        name='Hit Rate (%)',
        line=dict(color='blue', width=2)
    ))
    fig.update_layout(
        title="Hit Rate vs k",
        xaxis_title="k (Top-k)",
        yaxis_title="Hit Rate (%)",
        template="plotly_white"
    )
    fig.show()
    return k_values, hit_rates

In [38]:
plot_hit_rate_plotly(data_folder, k_max=108)

([1,
  2,
  3,
  4,
  5,
  6,
  7,
  8,
  9,
  10,
  11,
  12,
  13,
  14,
  15,
  16,
  17,
  18,
  19,
  20,
  21,
  22,
  23,
  24,
  25,
  26,
  27,
  28,
  29,
  30,
  31,
  32,
  33,
  34,
  35,
  36,
  37,
  38,
  39,
  40,
  41,
  42,
  43,
  44,
  45,
  46,
  47,
  48,
  49,
  50,
  51,
  52,
  53,
  54,
  55,
  56,
  57,
  58,
  59,
  60,
  61,
  62,
  63,
  64,
  65,
  66,
  67,
  68,
  69,
  70,
  71,
  72,
  73,
  74,
  75,
  76,
  77,
  78,
  79,
  80,
  81,
  82,
  83,
  84,
  85,
  86,
  87,
  88,
  89,
  90,
  91,
  92,
  93,
  94,
  95,
  96,
  97,
  98,
  99,
  100,
  101,
  102,
  103,
  104,
  105,
  106,
  107,
  108],
 [25.0,
  35.18518518518518,
  41.66666666666667,
  48.148148148148145,
  48.148148148148145,
  52.77777777777778,
  55.55555555555556,
  61.111111111111114,
  62.96296296296296,
  62.96296296296296,
  62.96296296296296,
  64.81481481481481,
  66.66666666666666,
  69.44444444444444,
  70.37037037037037,
  71.29629629629629,
  71.29629629629629,
  72

In [None]:
### LSH

In [39]:


def load_templates(data_folder):
    """
    Load enrollment templates (excluding '_3') from .mat files.
    Returns: vectors (ndarray), labels (list)
    """
    files = sorted([f for f in os.listdir(data_folder) if f.endswith('.mat')])
    vectors, labels = [], []

    for file in files:
        if '_3' in file:  # Skip probe templates
            continue

        path = os.path.join(data_folder, file)
        data = loadmat(path)

        vec = data['template'].flatten().astype('float32')
        pid, side = file.split('_')[0], file.split('_')[1]
        labels.append(f"{pid}_{side}")
        vectors.append(vec)

    return np.vstack(vectors), labels


def create_lsh_index(vectors, nbits=128):
    """
    Create a standard LSH index (binary hashing).
    nbits: number of bits for hashing.
    """
    d = vectors.shape[1]
    index = faiss.IndexLSH(d, nbits)
    index.add(vectors)
    return index


def evaluate_accuracy(index, labels, data_folder, k):
    """
    Evaluate top-k accuracy using probe templates ('_3').
    Returns: accuracy (float), total_queries (int)
    """
    query_files = sorted([f for f in os.listdir(data_folder) if '_3' in f and f.endswith('.mat')])
    correct, total = 0, 0

    for qf in query_files:
        qpid = qf.split('_')[0]
        qpath = os.path.join(data_folder, qf)

        qvec = loadmat(qpath)['template'].flatten().astype('float32').reshape(1, -1)
        distances, indices = index.search(qvec, k)

        predicted_pids = [labels[i].split('_')[0] for i in indices[0]]
        if qpid in predicted_pids:
            correct += 1
        total += 1

    return (correct / total if total > 0 else 0)


def plot_hit_rate_plotly(data_folder, k_max=50):
    vectors, labels = load_templates(data_folder)
    index = create_lsh_index(vectors)

    k_values = list(range(1, k_max + 1))
    hit_rates = []

    for k in k_values:
        acc = evaluate_accuracy(index, labels, data_folder, k)
        hit_rates.append(acc * 100)  # in %

    # ---- Save results to CSV ----
    with open("/home/nishkal/sg/iris_indexing/results/CASIA1_LSH_results.csv", mode='w', newline='') as file:
        writer = csv.writer(file)
        writer.writerow(["k", "hit_rate_percent"])
        for k_val, hr in zip(k_values, hit_rates):
            writer.writerow([k_val, hr])
    # print(f"Results saved to {csv_path}")

    

    # ---- Plot interactive graph ----
    fig = go.Figure()
    fig.add_trace(go.Scatter(
        x=k_values, y=hit_rates,
        mode='lines+markers',
        name='Hit Rate (%)',
        line=dict(color='blue', width=2)
    ))
    fig.update_layout(
        title="Hit Rate vs k",
        xaxis_title="k (Top-k)",
        yaxis_title="Hit Rate (%)",
        template="plotly_white"
    )
    fig.show()
    return k_values, hit_rates

In [40]:
plot_hit_rate_plotly(data_folder, k_max=108)

([1,
  2,
  3,
  4,
  5,
  6,
  7,
  8,
  9,
  10,
  11,
  12,
  13,
  14,
  15,
  16,
  17,
  18,
  19,
  20,
  21,
  22,
  23,
  24,
  25,
  26,
  27,
  28,
  29,
  30,
  31,
  32,
  33,
  34,
  35,
  36,
  37,
  38,
  39,
  40,
  41,
  42,
  43,
  44,
  45,
  46,
  47,
  48,
  49,
  50,
  51,
  52,
  53,
  54,
  55,
  56,
  57,
  58,
  59,
  60,
  61,
  62,
  63,
  64,
  65,
  66,
  67,
  68,
  69,
  70,
  71,
  72,
  73,
  74,
  75,
  76,
  77,
  78,
  79,
  80,
  81,
  82,
  83,
  84,
  85,
  86,
  87,
  88,
  89,
  90,
  91,
  92,
  93,
  94,
  95,
  96,
  97,
  98,
  99,
  100,
  101,
  102,
  103,
  104,
  105,
  106,
  107,
  108],
 [25.0,
  35.18518518518518,
  41.66666666666667,
  48.148148148148145,
  48.148148148148145,
  52.77777777777778,
  55.55555555555556,
  61.111111111111114,
  62.96296296296296,
  62.96296296296296,
  62.96296296296296,
  64.81481481481481,
  66.66666666666666,
  69.44444444444444,
  70.37037037037037,
  71.29629629629629,
  71.29629629629629,
  72

In [50]:


def load_templates(data_folder):
    """
    Load enrollment templates (excluding '_3') from .mat files.
    Returns: vectors (ndarray), labels (list)
    """
    files = sorted([f for f in os.listdir(data_folder) if f.endswith('.mat')])
    vectors, labels = [], []

    for file in files:
        if '_3' in file:  # Skip probe templates
            continue

        path = os.path.join(data_folder, file)
        data = loadmat(path)

        vec = data['template'].flatten().astype('float32')
        pid, side = file.split('_')[0], file.split('_')[1]
        labels.append(f"{pid}_{side}")
        vectors.append(vec)

    return np.vstack(vectors), labels


def create_flat_index(vectors, use_cosine=False):
    """
    Create a standard Flat index (exact search).
    """
    d = vectors.shape[1]

    if use_cosine:
        faiss.normalize_L2(vectors)  # normalize for cosine similarity
        metric = faiss.METRIC_INNER_PRODUCT
    else:
        metric = faiss.METRIC_L2

    index = faiss.IndexFlat(d, metric)
    index.add(vectors)
    return index



def evaluate_accuracy(index, labels, data_folder, k):
    """
    Evaluate top-k accuracy using probe templates ('_3').
    Returns: accuracy (float), total_queries (int)
    """
    query_files = sorted([f for f in os.listdir(data_folder) if '_3' in f and f.endswith('.mat')])
    correct, total = 0, 0

    for qf in query_files:
        qpid = qf.split('_')[0]
        qpath = os.path.join(data_folder, qf)

        qvec = loadmat(qpath)['template'].flatten().astype('float32').reshape(1, -1)
        distances, indices = index.search(qvec, k)

        predicted_pids = [labels[i].split('_')[0] for i in indices[0]]
        if qpid in predicted_pids:
            correct += 1
        total += 1

    return (correct / total if total > 0 else 0)


def plot_hit_rate_plotly(data_folder, k_max=50):
    vectors, labels = load_templates(data_folder)
    index = create_flat_index(vectors)

    k_values = list(range(1, k_max + 1))
    hit_rates = []

    for k in k_values:
        acc = evaluate_accuracy(index, labels, data_folder, k)
        hit_rates.append(acc * 100)  # in %

    # ---- Save results to CSV ----
    with open("/home/nishkal/sg/iris_indexing/results/CASIA1_FLAT_results.csv", mode='w', newline='') as file:
        writer = csv.writer(file)
        writer.writerow(["k", "hit_rate_percent"])
        for k_val, hr in zip(k_values, hit_rates):
            writer.writerow([k_val, hr])
    # print(f"Results saved to {csv_path}")

    

    # ---- Plot interactive graph ----
    fig = go.Figure()
    fig.add_trace(go.Scatter(
        x=k_values, y=hit_rates,
        mode='lines+markers',
        name='Hit Rate (%)',
        line=dict(color='blue', width=2)
    ))
    fig.update_layout(
        title="Hit Rate vs k",
        xaxis_title="k (Top-k)",
        yaxis_title="Hit Rate (%)",
        template="plotly_white"
    )
    fig.show()
    return k_values, hit_rates

In [51]:
plot_hit_rate_plotly(data_folder, k_max=108)

([1,
  2,
  3,
  4,
  5,
  6,
  7,
  8,
  9,
  10,
  11,
  12,
  13,
  14,
  15,
  16,
  17,
  18,
  19,
  20,
  21,
  22,
  23,
  24,
  25,
  26,
  27,
  28,
  29,
  30,
  31,
  32,
  33,
  34,
  35,
  36,
  37,
  38,
  39,
  40,
  41,
  42,
  43,
  44,
  45,
  46,
  47,
  48,
  49,
  50,
  51,
  52,
  53,
  54,
  55,
  56,
  57,
  58,
  59,
  60,
  61,
  62,
  63,
  64,
  65,
  66,
  67,
  68,
  69,
  70,
  71,
  72,
  73,
  74,
  75,
  76,
  77,
  78,
  79,
  80,
  81,
  82,
  83,
  84,
  85,
  86,
  87,
  88,
  89,
  90,
  91,
  92,
  93,
  94,
  95,
  96,
  97,
  98,
  99,
  100,
  101,
  102,
  103,
  104,
  105,
  106,
  107,
  108],
 [91.66666666666666,
  93.51851851851852,
  94.44444444444444,
  95.37037037037037,
  95.37037037037037,
  95.37037037037037,
  96.29629629629629,
  96.29629629629629,
  96.29629629629629,
  96.29629629629629,
  96.29629629629629,
  96.29629629629629,
  96.29629629629629,
  97.22222222222221,
  97.22222222222221,
  97.22222222222221,
  97.2222222222