In [1]:
import os
import numpy as np
import scipy.io
import faiss
import plotly.graph_objects as go


In [7]:
import json

# Open the JSON file
with open('iitd_features_1_with seprate_left_abd_wrute.json', 'r') as file:
    data = json.load(file)  # Load JSON content into a Python dictionary or list

# print(len(data["X"][0]))
print(data["y"])



[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 13, 10013, 10013, 13, 10013, 13, 10013, 13, 13, 10013, 14, 10014, 14, 10014, 10014, 10014, 14, 14, 14, 10014, 15, 15, 10015, 10015, 15, 15, 10015, 10015, 10015, 15, 10016, 16, 16, 16, 10016, 10016, 16, 16, 10016, 10016, 10017, 10017, 10017, 10017, 17, 17, 17, 17, 10017, 17, 18, 18, 10018, 18, 10018, 10018, 18, 10018, 18, 10018, 19, 10019, 10019, 10019, 19, 19, 19, 10019, 10019, 19, 10020, 20, 10020, 20, 10020, 20, 20, 10020, 10020, 20, 21, 21, 10021, 21, 10021, 21, 10021, 21, 10021, 10021, 10022, 22, 22, 10022, 22, 10022, 22, 22, 10022, 10022, 10023, 10023, 23, 23, 10023, 

In [4]:
import json
import numpy as np
import pandas as pd
import faiss
import plotly.graph_objects as go

# ---------------------
# Load JSON data
# ---------------------
def load_data(json_path):
    with open(json_path, 'r') as f:
        data = json.load(f)
    x = np.array(data["X"], dtype='float32')
    y = np.array(data["y"])
    return x, y

# ---------------------
# Split representatives & queries
# ---------------------
def split_representatives(x, y):
    unique_labels = np.unique(y)
    index_vectors = []
    index_labels = []
    query_vectors = []
    query_labels = []

    for label in unique_labels:
        indices = np.where(y == label)[0]
        if len(indices) == 0:
            continue
        # First sample as representative
        index_vectors.append(x[indices[0]])
        index_labels.append(label)
        # Remaining for queries
        if len(indices) > 1:
            remaining = indices[1:]
            query_vectors.extend(x[remaining])
            query_labels.extend(y[remaining])

    return (np.array(index_vectors, dtype='float32'),
            np.array(index_labels),
            np.array(query_vectors, dtype='float32'),
            np.array(query_labels))

# ---------------------
# Build HNSW Index
# ---------------------
def build_hnsw_index(vectors, dim, M=16, efConstruction=50):
    index = faiss.IndexHNSWFlat(dim, M)
    index.hnsw.efConstruction = efConstruction
    index.add(vectors)
    return index

# ---------------------
# Evaluate Hit Rate for varying efSearch and save CSV
# ---------------------
def evaluate_hit_rate(index, index_labels, query_vectors, query_labels, ef_max=100, csv_path=None):
    hit_rates = []
    csv_data = []
    total_indexed = len(index_labels)
    ef_values = range(1, ef_max + 1)

    for ef in ef_values:
        index.hnsw.efSearch = ef
        distances, indices = index.search(query_vectors, k=1)

        # Get predicted labels
        predicted_labels = index_labels[indices.flatten()]

        # Calculate hits
        hits = np.sum(predicted_labels == query_labels)
        total = len(query_labels)
        hit_rate = hits / total if total > 0 else 0

        hit_rates.append(hit_rate)
        penetration = ef / total_indexed
        csv_data.append({"efSearch": ef, "penetration": penetration, "hit_rate": hit_rate})

        print(f"efSearch={ef}, Hit rate={hit_rate:.4f}")

    # Save CSV if path provided
    if csv_path:
        df = pd.DataFrame(csv_data)
        df.to_csv(csv_path, index=False)
        print(f"CSV saved to: {csv_path}")

    return ef_values, hit_rates

# ---------------------
# Plot Hit Rate vs efSearch (using Plotly)
# ---------------------
def plot_hit_rate_plotly(ef_values, hit_rates):
    fig = go.Figure()
    fig.add_trace(go.Scatter(
        x=list(ef_values),
        y=hit_rates,
        mode='lines+markers',
        name="Hit Rate",
        line=dict(width=2),
        marker=dict(size=6)
    ))
    fig.update_layout(
        title="IITDV1_HNSW",
        xaxis_title="efSearch (with Penetration in CSV)",
        yaxis_title="Hit Rate",
        template="plotly_white",
        hovermode="x unified"
    )
    fig.show()

# ---------------------
# Main Workflow
# ---------------------
def main(json_path, ef_max=50, csv_output="hnsw_hit_rate.csv"):
    # Step 1: Load data
    x, y = load_data(json_path)

    # Step 2: Split representatives & queries
    index_x, index_y, query_x, query_y = split_representatives(x, y)

    # Step 3: Build HNSW index
    dim = index_x.shape[1]
    index = build_hnsw_index(index_x, dim)

    # Step 4: Evaluate hit rate and save CSV
    ef_values, hit_rates = evaluate_hit_rate(
        index, index_y, query_x, query_y, ef_max=ef_max, csv_path=csv_output
    )

    # Step 5: Plot graph with Plotly
    plot_hit_rate_plotly(ef_values, hit_rates)

# Example usage
if __name__ == "__main__":
    main("iitd_features_1_with seprate_left_abd_wrute.json", ef_max=100, csv_output="results/IITDV1_HNSW.csv")


efSearch=1, Hit rate=0.6964
efSearch=2, Hit rate=0.8078
efSearch=3, Hit rate=0.8609
efSearch=4, Hit rate=0.8925
efSearch=5, Hit rate=0.9030
efSearch=6, Hit rate=0.9158
efSearch=7, Hit rate=0.9224
efSearch=8, Hit rate=0.9263
efSearch=9, Hit rate=0.9285
efSearch=10, Hit rate=0.9307
efSearch=11, Hit rate=0.9313
efSearch=12, Hit rate=0.9335
efSearch=13, Hit rate=0.9352
efSearch=14, Hit rate=0.9357
efSearch=15, Hit rate=0.9374
efSearch=16, Hit rate=0.9380
efSearch=17, Hit rate=0.9396
efSearch=18, Hit rate=0.9396
efSearch=19, Hit rate=0.9396
efSearch=20, Hit rate=0.9402
efSearch=21, Hit rate=0.9402
efSearch=22, Hit rate=0.9402
efSearch=23, Hit rate=0.9407
efSearch=24, Hit rate=0.9407
efSearch=25, Hit rate=0.9418
efSearch=26, Hit rate=0.9424
efSearch=27, Hit rate=0.9424
efSearch=28, Hit rate=0.9424
efSearch=29, Hit rate=0.9424
efSearch=30, Hit rate=0.9429
efSearch=31, Hit rate=0.9429
efSearch=32, Hit rate=0.9429
efSearch=33, Hit rate=0.9429
efSearch=34, Hit rate=0.9429
efSearch=35, Hit rate=0

In [None]:
import json
import numpy as np
import faiss
import plotly.graph_objects as go

# ---------------------
# Load JSON data
# ---------------------
def load_data(json_path):
    with open(json_path, 'r') as f:
        data = json.load(f)
    x = np.array(data["X"], dtype='float32')
    y = np.array(data["y"])
    return x, y

# ---------------------
# Split representatives & queries
# ---------------------
def split_representatives(x, y):
    unique_labels = np.unique(y)
    index_vectors = []
    index_labels = []
    query_vectors = []
    query_labels = []

    for label in unique_labels:
        indices = np.where(y == label)[0]
        if len(indices) == 0:
            continue
        index_vectors.append(x[indices[0]])
        index_labels.append(label)
        if len(indices) > 1:
            remaining = indices[1:]
            query_vectors.extend(x[remaining])
            query_labels.extend(y[remaining])

    return (np.array(index_vectors, dtype='float32'),
            np.array(index_labels),
            np.array(query_vectors, dtype='float32'),
            np.array(query_labels))

# ---------------------
# Build HNSW Index
# ---------------------
def build_hnsw_index(vectors, dim, M=16, efConstruction=200):
    index = faiss.IndexHNSWFlat(dim, M)
    index.hnsw.efConstruction = efConstruction
    index.add(vectors)
    return index

# ---------------------
# Evaluate Hit Rate and Label X-Axis with Penetration
# ---------------------
def evaluate_hit_rate(index, index_labels, query_vectors, query_labels, ef_max=100):
    hit_rates = []
    x_labels = []
    total_indexed = len(index_labels)
    highlight_point = None

    for ef in range(1, ef_max + 1):
        index.hnsw.efSearch = ef
        distances, indices = index.search(query_vectors, k=1)

        predicted_labels = index_labels[indices.flatten()]
        hits = np.sum(predicted_labels == query_labels)
        total = len(query_labels)
        hit_rate = hits / total if total > 0 else 0
        hit_rates.append(hit_rate)

        penetration = ef / total_indexed
        x_labels.append(f"{ef} ({penetration:.2f})")

        if highlight_point is None and hit_rate >= 0.98:
            highlight_point = (ef, hit_rate, penetration)

        print(f"efSearch={ef}, Hit rate={hit_rate:.4f}")

    return x_labels, hit_rates, highlight_point

# ---------------------
# Plot Hit Rate with Penetration Labels
# ---------------------
def plot_hit_rate_plotly(x_labels, hit_rates, highlight_point):
    fig = go.Figure()

    # Main line
    fig.add_trace(go.Scatter(
        x=x_labels,
        y=hit_rates,
        mode='lines+markers',
        name="Hit Rate",
        line=dict(width=2),
        marker=dict(size=6)
    ))

    # Highlight first ≥99% hit rate
    if highlight_point:
        ef_highlight, hr_highlight, penetration = highlight_point
        highlight_label = f"{ef_highlight} ({penetration:.2f})"
        fig.add_trace(go.Scatter(
            x=[highlight_label],
            y=[hr_highlight],
            mode='markers+text',
            marker=dict(color='red', size=12, symbol='star'),
            text=[f"Penetration={penetration:.2f}"],
            textposition="top center",
            name="First ≥99% Hit"
        ))

    fig.update_layout(
        title="IITDV1_HNSW",
        xaxis_title="efSearch (Penetration)",
        yaxis_title="Hit Rate",
        template="plotly_white",
        hovermode="x unified"
    )
    fig.show()

# ---------------------
# Main Workflow
# ---------------------
def main(json_path, ef_max=50):
    x, y = load_data(json_path)
    index_x, index_y, query_x, query_y = split_representatives(x, y)
    dim = index_x.shape[1]
    index = build_hnsw_index(index_x, dim)

    x_labels, hit_rates, highlight_point = evaluate_hit_rate(
        index, index_y, query_x, query_y, ef_max=ef_max
    )

    plot_hit_rate_plotly(x_labels, hit_rates, highlight_point)

# Example usage
if __name__ == "__main__":
    main("iitd_features_1_with seprate_left_abd_wrute.json", ef_max=50)


efSearch=1, Hit rate=0.7302
efSearch=2, Hit rate=0.8360
efSearch=3, Hit rate=0.8787
efSearch=4, Hit rate=0.9047
efSearch=5, Hit rate=0.9136
efSearch=6, Hit rate=0.9224
efSearch=7, Hit rate=0.9269
efSearch=8, Hit rate=0.9291
efSearch=9, Hit rate=0.9319
efSearch=10, Hit rate=0.9341
efSearch=11, Hit rate=0.9363
efSearch=12, Hit rate=0.9380
efSearch=13, Hit rate=0.9396
efSearch=14, Hit rate=0.9396
efSearch=15, Hit rate=0.9396
efSearch=16, Hit rate=0.9396
efSearch=17, Hit rate=0.9396
efSearch=18, Hit rate=0.9407
efSearch=19, Hit rate=0.9413
efSearch=20, Hit rate=0.9413
efSearch=21, Hit rate=0.9413
efSearch=22, Hit rate=0.9413
efSearch=23, Hit rate=0.9413
efSearch=24, Hit rate=0.9413
efSearch=25, Hit rate=0.9424
efSearch=26, Hit rate=0.9429
efSearch=27, Hit rate=0.9429
efSearch=28, Hit rate=0.9429
efSearch=29, Hit rate=0.9429
efSearch=30, Hit rate=0.9429
efSearch=31, Hit rate=0.9429
efSearch=32, Hit rate=0.9429
efSearch=33, Hit rate=0.9429
efSearch=34, Hit rate=0.9429
efSearch=35, Hit rate=0

In [16]:
def get_incorrect_indices(index, index_labels, query_vectors, query_labels, ef_search=50):
    """
    Returns the indices of queries where predicted label != actual label
    for a given ef_search value.
    """
    # Set efSearch
    index.hnsw.efSearch = ef_search
    
    # Perform top-1 search
    distances, indices = index.search(query_vectors, k=1)
    
    predicted_labels = index_labels[indices.flatten()]
    incorrect = np.where(predicted_labels != query_labels)[0]
    
    return incorrect, predicted_labels[incorrect], query_labels[incorrect]


In [19]:
def main_check_incorrect(json_path, ef_search=50):
    # Load and prepare data
    x, y = load_data(json_path)
    index_x, index_y, query_x, query_y = split_representatives(x, y)
    dim = index_x.shape[1]
    index = build_hnsw_index(index_x, dim)

    # Get incorrect matches
    incorrect_indices, predicted, actual = get_incorrect_indices(
        index, index_y, query_x, query_y, ef_search=ef_search
    )

    print(len(index_x))
    print(f"Total queries: {len(query_y)}")
    print(f"Incorrect matches: {len(incorrect_indices)}")
    print("Indices of incorrect queries:", incorrect_indices)
    print("Predicted labels:", predicted)
    print("Actual labels:", actual)

# Example usage
if __name__ == "__main__":
    main_check_incorrect("iitd_features_1_with seprate_left_abd_wrute.json", ef_search=50)


435
Total queries: 1805
Incorrect matches: 103
Indices of incorrect queries: [   2   13   16   27   33  101  103  105  111  137  144  161  165  166
  167  168  169  170  173  233  290  340  341  342  343  392  394  412
  413  414  415  428  429  430  431  444  493  497  544  556  614  836
  837  838  839  869  870  871  967 1004 1005 1007 1014 1035 1036 1056
 1059 1060 1061 1082 1113 1116 1156 1179 1189 1190 1194 1282 1318 1319
 1321 1322 1323 1326 1365 1371 1372 1373 1374 1375 1406 1409 1431 1449
 1451 1452 1485 1486 1487 1560 1590 1592 1613 1614 1615 1616 1630 1632
 1690 1693 1695 1696 1745]
Predicted labels: [  198   189 10054    69   197 10192 10192 10194    77 10018 10019 10081
   105   105   170   105 10054 10211   223 10189 10211    36    36    36
    36    70 10081    77    98   120   220   159 10090 10046   212    86
   187 10072 10064   157 10019    39   190   153   153   205   205   205
 10100    23 10091 10032   159   120 10098 10083    37    48 10221    70
    50 10217    

In [None]:
### FLAT

In [2]:
import json
import numpy as np
import pandas as pd
import faiss
import plotly.graph_objects as go

# ---------------------
# Load JSON data
# ---------------------
def load_data(json_path):
    with open(json_path, 'r') as f:
        data = json.load(f)
    x = np.array(data["X"], dtype='float32')
    y = np.array(data["y"])
    return x, y

# ---------------------
# Split representatives & queries
# ---------------------
def split_representatives(x, y):
    unique_labels = np.unique(y)
    index_vectors = []
    index_labels = []
    query_vectors = []
    query_labels = []

    for label in unique_labels:
        indices = np.where(y == label)[0]
        if len(indices) == 0:
            continue
        index_vectors.append(x[indices[0]])
        index_labels.append(label)
        if len(indices) > 1:
            remaining = indices[1:]
            query_vectors.extend(x[remaining])
            query_labels.extend(y[remaining])

    return (np.array(index_vectors, dtype='float32'),
            np.array(index_labels),
            np.array(query_vectors, dtype='float32'),
            np.array(query_labels))

# ---------------------
# Build FLAT Index (Euclidean)
# ---------------------
def build_flat_index(vectors, dim):
    index = faiss.IndexFlatL2(dim)  # Euclidean distance
    index.add(vectors)
    return index

# ---------------------
# Evaluate Hit Rate vs top-k and Save to CSV
# ---------------------
def evaluate_hit_rate_topk(index, index_labels, query_vectors, query_labels, k_max=50, csv_path=None):
    hit_rates = []
    x_labels = []
    csv_data = []
    total_indexed = len(index_labels)
    highlight_point = None

    for k in range(1, k_max + 1):
        distances, indices = index.search(query_vectors, k=k)

        predicted_labels_k = index_labels[indices]
        hits = sum(query_labels[i] in predicted_labels_k[i] for i in range(len(query_labels)))
        total = len(query_labels)
        hit_rate = hits / total if total > 0 else 0
        hit_rates.append(hit_rate)

        penetration = k / total_indexed
        x_labels.append(f"{k} ({penetration:.2f})")
        csv_data.append({"k": k, "penetration": penetration, "hit_rate": hit_rate})

        if highlight_point is None and hit_rate >= 0.99:
            highlight_point = (k, hit_rate, penetration)

        print(f"top-k={k}, Hit rate={hit_rate:.4f}")

    # Save CSV if path provided
    if csv_path:
        df = pd.DataFrame(csv_data)
        df.to_csv(csv_path, index=False)
        print(f"CSV saved to: {csv_path}")

    return x_labels, hit_rates, highlight_point

# ---------------------
# Plot Hit Rate vs top-k
# ---------------------
def plot_hit_rate_plotly(x_labels, hit_rates, highlight_point):
    fig = go.Figure()

    fig.add_trace(go.Scatter(
        x=x_labels,
        y=hit_rates,
        mode='lines+markers',
        name="Hit Rate",
        line=dict(width=2),
        marker=dict(size=6)
    ))

    if highlight_point:
        k_highlight, hr_highlight, penetration = highlight_point
        highlight_label = f"{k_highlight} ({penetration:.2f})"
        fig.add_trace(go.Scatter(
            x=[highlight_label],
            y=[hr_highlight],
            mode='markers+text',
            marker=dict(color='red', size=12, symbol='star'),
            text=[f"Penetration={penetration:.2f}"],
            textposition="top center",
            name="First ≥99% Hit"
        ))

    fig.update_layout(
        title="IITDV1_FLAT (Top-k Hit Rate)",
        xaxis_title="Top-k (Penetration)",
        yaxis_title="Hit Rate",
        template="plotly_white",
        hovermode="x unified"
    )
    fig.show()

# ---------------------
# Main Workflow
# ---------------------
def main(json_path, k_max=50, csv_output="hit_rate_topk.csv"):
    x, y = load_data(json_path)
    index_x, index_y, query_x, query_y = split_representatives(x, y)
    dim = index_x.shape[1]
    index = build_flat_index(index_x, dim)

    x_labels, hit_rates, highlight_point = evaluate_hit_rate_topk(
        index, index_y, query_x, query_y, k_max=k_max, csv_path=csv_output
    )

    plot_hit_rate_plotly(x_labels, hit_rates, highlight_point)

# Example usage
if __name__ == "__main__":
    main("iitd_features_1_with seprate_left_abd_wrute.json", k_max=100, csv_output="results/IITDV1_FLAT.csv")



top-k=1, Hit rate=0.9429
top-k=2, Hit rate=0.9529
top-k=3, Hit rate=0.9584
top-k=4, Hit rate=0.9629
top-k=5, Hit rate=0.9645
top-k=6, Hit rate=0.9662
top-k=7, Hit rate=0.9668
top-k=8, Hit rate=0.9684
top-k=9, Hit rate=0.9684
top-k=10, Hit rate=0.9684
top-k=11, Hit rate=0.9695
top-k=12, Hit rate=0.9701
top-k=13, Hit rate=0.9706
top-k=14, Hit rate=0.9706
top-k=15, Hit rate=0.9712
top-k=16, Hit rate=0.9712
top-k=17, Hit rate=0.9712
top-k=18, Hit rate=0.9717
top-k=19, Hit rate=0.9723
top-k=20, Hit rate=0.9734
top-k=21, Hit rate=0.9740
top-k=22, Hit rate=0.9751
top-k=23, Hit rate=0.9751
top-k=24, Hit rate=0.9751
top-k=25, Hit rate=0.9751
top-k=26, Hit rate=0.9756
top-k=27, Hit rate=0.9756
top-k=28, Hit rate=0.9756
top-k=29, Hit rate=0.9767
top-k=30, Hit rate=0.9767
top-k=31, Hit rate=0.9767
top-k=32, Hit rate=0.9767
top-k=33, Hit rate=0.9778
top-k=34, Hit rate=0.9784
top-k=35, Hit rate=0.9789
top-k=36, Hit rate=0.9789
top-k=37, Hit rate=0.9789
top-k=38, Hit rate=0.9801
top-k=39, Hit rate=0.

In [None]:
### IVF

In [14]:
import json
import numpy as np
import pandas as pd
import faiss
import plotly.graph_objects as go

# ---------------------
# Load JSON data
# ---------------------
def load_data(json_path):
    with open(json_path, 'r') as f:
        data = json.load(f)
    x = np.array(data["X"], dtype='float32')
    y = np.array(data["y"])
    return x, y

# ---------------------
# Split representatives & queries
# ---------------------
def split_representatives(x, y):
    unique_labels = np.unique(y)
    index_vectors = []
    index_labels = []
    query_vectors = []
    query_labels = []

    for label in unique_labels:
        indices = np.where(y == label)[0]
        if len(indices) == 0:
            continue
        index_vectors.append(x[indices[0]])
        index_labels.append(label)
        if len(indices) > 1:
            remaining = indices[1:]
            query_vectors.extend(x[remaining])
            query_labels.extend(y[remaining])

    return (np.array(index_vectors, dtype='float32'),
            np.array(index_labels),
            np.array(query_vectors, dtype='float32'),
            np.array(query_labels))

# ---------------------
# Build IVF Index (Euclidean)
# ---------------------
def build_ivf_index(vectors, dim, nlist):
    quantizer = faiss.IndexFlatL2(dim)
    index = faiss.IndexIVFFlat(quantizer, dim, nlist, faiss.METRIC_L2)
    index.train(vectors)
    index.add(vectors)
    return index

# ---------------------
# Evaluate Hit Rate (k=1) vs (nprobe, nlist)
# ---------------------
def evaluate_hit_rate_clusters(index_vectors, index_labels, query_vectors, query_labels,
                               nlist_values, csv_path=None):
    results = []
    k = 1  # Fixed top-1

    for nlist in nlist_values:
        index = build_ivf_index(index_vectors, index_vectors.shape[1], nlist=nlist)

        for nprobe in range(1, nlist + 1):
            index.nprobe = nprobe
            distances, indices = index.search(query_vectors, k=k)
            predicted_labels = index_labels[indices]
            hits = sum(query_labels[i] in predicted_labels[i] for i in range(len(query_labels)))
            total = len(query_labels)
            hit_rate = hits / total if total > 0 else 0
            penetration = nprobe / nlist

            results.append({
                "nlist": nlist,
                "nprobe": nprobe,
                "penetration": penetration,
                "hit_rate": hit_rate
            })

            # print(f"nlist={nlist}, nprobe={nprobe}, penetration={penetration:.2f}, hit_rate={hit_rate:.4f}")

    df = pd.DataFrame(results)
    if csv_path:
        df.to_csv(csv_path, index=False)
        print(f"CSV saved to: {csv_path}")
    return df

# ---------------------
# Plot Heatmap (nprobe vs nlist) with Penetration
# ---------------------
def plot_hit_rate_heatmap(df):
    # Create pivot tables
    pivot_hit_rate = df.pivot(index="nlist", columns="nprobe", values="hit_rate")
    pivot_penetration = df.pivot(index="nlist", columns="nprobe", values="penetration")

    # Align both pivots to same shape & sort
    pivot_hit_rate = pivot_hit_rate.sort_index(axis=0).sort_index(axis=1)
    pivot_penetration = pivot_penetration.reindex_like(pivot_hit_rate)

    # Replace NaN with 0 (or any default)
    pivot_hit_rate = pivot_hit_rate.fillna(0)
    pivot_penetration = pivot_penetration.fillna(0)

    fig = go.Figure(data=go.Heatmap(
        z=pivot_hit_rate.values,
        x=pivot_hit_rate.columns,
        y=pivot_hit_rate.index,
        colorscale="Viridis",
        colorbar=dict(title="Hit Rate"),
        customdata=pivot_penetration.values.astype(float),  # Ensure float
        hovertemplate=(
            "Total Clusters (nlist): %{y}<br>"
            "Clusters Searched (nprobe): %{x}<br>"
            "Hit Rate: %{z:.4f}<br>"
            "Penetration (nprobe/nlist): %{customdata}<extra></extra>"
        )
    ))

    fig.update_layout(
        title="Top-1 Hit Rate: Clusters Searched vs Total Clusters",
        xaxis_title="Clusters Searched (nprobe)",
        yaxis_title="Total Clusters (nlist)",
        template="plotly_white"
    )
    fig.show()



# ---------------------
# Main Workflow
# ---------------------
def main(json_path, csv_output="hit_rate_top1_clusters.csv"):
    nlist_values = list(range(1, 51))  # nlist = 1 to 50
    x, y = load_data(json_path)
    index_x, index_y, query_x, query_y = split_representatives(x, y)

    df = evaluate_hit_rate_clusters(
        index_x, index_y, query_x, query_y,
        nlist_values=nlist_values, csv_path=csv_output
    )

    plot_hit_rate_heatmap(df)

# Example usage
if __name__ == "__main__":
    main("iitd_features_1_with seprate_left_abd_wrute.json",
         csv_output="results/IITDV1_IVF_top1_clusters.csv")




CSV saved to: results/IITDV1_IVF_top1_clusters.csv
