# ANNS Experiment Notebook

This notebook is designed to run the segmented cosine similarity ANNS experiments on Kaggle.

## Instructions
1. **Dataset**: This notebook expects the dataset (e.g., `coco-i2i-512-angular.hdf5`) to be available. 
   - You can upload it to Kaggle Datasets and attach it to this notebook.
   - Or use the download cell below to fetch it from `ann-benchmarks.com` if available.
2. **Configuration**: Adjust the variables in the "Configuration (Arguments)" cell to change experiment parameters.

In [None]:
# Install necessary packages if not present
!pip install h5py tqdm torch pandas matplotlib

In [None]:
import os
import shutil
import h5py

# Dataset Setup
# This cell handles finding the dataset either from Kaggle Input or via download
# and placing it where the code expects it (./COCO-I2I/...).

dataset_name = "COCO-I2I"
dataset_filename = "coco-i2i-512-angular.hdf5"
expected_dir = f"./{dataset_name}"
expected_path = f"{expected_dir}/{dataset_filename}"

# Create expected directory
os.makedirs(expected_dir, exist_ok=True)

# 1. Search in Kaggle Input (Recursively)
# If you attached a dataset to this kernel, it will be in /kaggle/input
found_in_input = False
if os.path.exists("/kaggle/input"):
    for root, dirs, files in os.walk("/kaggle/input"):
        if dataset_filename in files:
            source_path = os.path.join(root, dataset_filename)
            print(f"Found dataset in Kaggle input: {source_path}")
            
            # Remove existing file if it's there (to avoid confusion with partial downloads)
            if os.path.exists(expected_path):
                if not h5py.is_hdf5(expected_path):
                    print("Removing existing invalid file.")
                    os.remove(expected_path)
                else: 
                    print("Existing valid file found, keeping it.")
                    found_in_input = True
                    break

            # Symlink or Copy
            try:
                if os.path.exists(expected_path):
                     os.remove(expected_path)
                os.symlink(source_path, expected_path)
                print(f"Symlinked to {expected_path}")
            except OSError:
                shutil.copy(source_path, expected_path)
                print(f"Copied to {expected_path}")
            
            found_in_input = True
            break

# 2. Download if not found in input and not locally valid
if not found_in_input:
    if os.path.exists(expected_path) and h5py.is_hdf5(expected_path):
        print("Valid local file found.")
    else:
        print("Dataset not found in input or invalid. Attempting download...")
        # Clean up invalid file
        if os.path.exists(expected_path):
            os.remove(expected_path)
            
        url = f"http://ann-benchmarks.com/{dataset_filename}"
        print(f"Downloading from {url}...")
        # Use -q (quiet), -c (continue), --show-progress
        !wget -q --show-progress -c -O {expected_path} {url}
        print("Download finished.")

# 3. Validation Check
if os.path.exists(expected_path):
    if h5py.is_hdf5(expected_path):
        print(f"✅ Success! Valid HDF5 file available at: {expected_path}")
    else:
        print(f"❌ Error: File exists at {expected_path} but is NOT a valid HDF5 file.")
        print("This often means the download link is broken (404) or the file is corrupted.")
        file_size = os.path.getsize(expected_path)
        print(f"File Size: {file_size / 1024 / 1024:.2f} MB")
        if file_size < 10000: # If smaller than 10KB, it's likely text (error msg)
            with open(expected_path, 'r', errors='ignore') as f:
                print("Head of file content:", f.read(200))
else:
    print(f"❌ Error: Dataset file not found at {expected_path}")

In [None]:
# Configuration (Arguments)
# Change these values to control the experiment

class Args:
    # Dataset settings
    dataset = "COCO-I2I"  # Choices: 'DEEP1B', 'Last.fm', 'COCO-I2I', 'COCO-T2I', 'NYTimes', 'glove-25', 'glove-50', 'glove-100', 'glove-200'
    num_subset = 50000    # Subset size if applicable

    # Experiment Mode
    find_best_recall = True  # Set to True to search for best alpha/beta. Set to False for single inference.
    
    # Similarity Method
    similarity = "segmented_cosine"  # Choices: 'L1norm', 'cosine', 'segmented_cosine', 'LSH'
    
    # Hyperparameters for Seg-Cos / Experiment
    dimension = 512       # Dimension of extracted feature
    coupled_dimension = 2 # Dimension of segments (e.g., 2 for pairs)
    N = 64                # Quantization levels (e.g., 64)
    
    # Option for the algorithm
    # Choices: 'CSI', 'JI', 'TLE', 'HD', 'Seg-Cos_Real_Float', 'Seg-Cos_Float', 
    #          'Seg-Cos_QuantAng', 'Seg-Cos_QuantAngMag', 'Seg-Cos_Fixed', 'Seg-Cos_TCAM', 'Seg-Cos_TCAM_Bit'
    option = "Seg-Cos_TCAM" 

    # Bound type
    bound = "upper"       # Choices: 'upper', 'lower', 'complementary'
    
    # Alpha and Beta (Search Range or Fixed Values)
    alpha = 1.0           # Scale factor / Start of alpha search
    beta = 0.4            # Offset/Bias / Start of beta search
    
    # Normalization & Factors
    normalization = True  # Whether to use normalization
    complete = True       # Complete the normalization
    factor = 0.359375     # Factor parameter
    
    # Batching
    batch_size = 100      # Batch size for processing
    query_num = 100       # Number of queries to process at a time
    total_query_num = 0   # Total queries to process (0 = all)
    topk = 1000           # Top-K recall calculation

    # Other Params (defaults from script)
    param = 0
    minimum_clamp = False
    angle = None         # Choices: 'ln', 'arccos'
    clip_value = 0.3
    tolerance_scale = 10
    positive_similarity = False
    positive_predict = False
    tolerance_remap = False
    tolerance_remap_target = "mean"
    tolerance_function = "linear"
    no_store_result = False

args = Args()

In [None]:
import numpy as np
import argparse
from tqdm.notebook import tqdm # Use notebook version of tqdm
import h5py
import matplotlib.pyplot as plt
import pandas as pd
import os
import torch
from torch.utils.data import DataLoader
import math

# plt.switch_backend('Agg') # Usually not needed in notebooks where we want inline plots

class SupportSetDataset():
    def __init__(self, support_set):
        self.support_set = torch.tensor(support_set, dtype=torch.float32)

    def __len__(self):
        return len(self.support_set)

    def __getitem__(self, idx):
        return self.support_set[idx]

In [None]:
def do_compute_cosine(support_loader, xq, topk=1000):
    device = xq.device  
    # gamma = -0.5
    # xq_norm = torch.sqrt(torch.sum(xq**2, dim=1)) + 1e-12
    # xq = torch.sign(xq) * torch.abs(xq) ** gamma
    iq_norm = torch.sqrt(torch.sum(xq**2, dim=1)) + 1e-12
    
    all_similarities = []

    for xb in tqdm(support_loader, desc="Cosine Calculation"):
        xb = xb.to(device)
        # xb = torch.sign(xb) * torch.abs(xb) ** gamma
        xb_norm = torch.sqrt(torch.sum(xb**2, dim=1)) + 1e-12

        similarity = torch.matmul(xq, xb.T) / (iq_norm[:, None] * xb_norm[None, :])
        similarity = torch.clamp(similarity, min=-1.0, max=1.0)
        similarity = -1 * torch.acos(similarity) / math.pi * 180
        # similarity = -1*torch.sqrt(torch.sum((xq[:,None,:] - xb[None,:,:])**2, dim=2))
        all_similarities.append(similarity.detach().cpu())
   
    all_similarities = torch.cat(all_similarities, dim=1)

    topk_values, topk_indices = torch.topk(all_similarities, k=topk, dim=1)
    # plt.figure(figsize=(6,4))
    # plt.hist(-1*all_similarities[0].flatten().cpu().numpy(), bins=100, color='steelblue', edgecolor='black')
    # plt.title(f"Angular Difference Distribution")
    # plt.xlabel("Value (Degree)")
    # plt.ylabel("Frequency")
    # plt.grid(True, linestyle='--', alpha=0.6)
    # plt.tight_layout()
    # plt.savefig(f"similarity.png")
    # print(a)
    all_similarities = all_similarities.flatten()

    return topk_indices, torch.mean(all_similarities), torch.std(all_similarities)

def do_compute_L1norm(support_loader, xq, topk=1000, clamp_ratio = 0.0):
    device = xq.device  

    all_similarities = []
    xq_norm = torch.sqrt(torch.sum(xq**2, dim=1))
    if(args.normalization):
        xq = xq / xq_norm[:, None]

    if(args.normalization):
        support_data = support_loader.dataset.support_set
        support_norm = torch.sqrt(torch.sum(support_data**2, dim=1))
        normalized_support_set = support_data / support_norm[:, None]
        max_value = torch.max(normalized_support_set).item()
        min_value = torch.min(normalized_support_set).item()
    else:
        max_value = torch.max(support_loader.dataset.support_set).item()
        min_value = torch.min(support_loader.dataset.support_set).item()
    range_value = (max_value - min_value)

    mean_value = (max_value + min_value) / 2
    max_value = mean_value + (1 - clamp_ratio) * range_value / 2
    min_value = mean_value - (1 - clamp_ratio) * range_value / 2
    range_value = (1 - clamp_ratio) * range_value
    # xq = torch.round(((xq - min_value) / range_value) * (args.N - 1))
    xq = torch.clamp(torch.floor(((xq - min_value) / range_value) * args.N), 0, args.N-1)
    # xq = torch.clamp(xq, min=min_value, max=max_value)

    for xb in tqdm(support_loader, desc="L1Norm Calculation"):
        xb = xb.to(device)
        if(args.normalization):
            xb_norm = torch.sqrt(torch.sum(xb**2, dim=1))
            xb = xb / xb_norm[:, None]
        # xb = torch.round(((xb - min_value) / range_value) * (args.N - 1))
        xb = torch.clamp(torch.floor(((xb - min_value) / range_value) * args.N), 0, args.N-1)
        # xb = torch.clamp(xb, min=min_value, max=max_value)
        # similarity = -1*torch.max(torch.abs(xq[:,None,:] - xb[None,:,:]), dim=2)[0]
        similarity = -1*torch.mean(torch.abs(xq[:,None,:] - xb[None,:,:]), dim=2)
        all_similarities.append(similarity.detach().cpu())
   
    all_similarities = torch.cat(all_similarities, dim=1)

    topk_values, topk_indices = torch.topk(all_similarities, k=topk, dim=1)
    all_similarities = all_similarities.flatten()

    return topk_indices, torch.mean(all_similarities), torch.mean(all_similarities), torch.std(all_similarities)

def do_compute_lsh(support_loader, xq, L, hashing_matrix, topk=1000):
    device = xq.device  
    xq_norm = torch.sqrt(torch.sum(xq**2, dim=1))

    all_differences  = []
    all_similarities = []

    if(args.normalization):
        xq = xq / xq_norm[:, None]

    for xb in tqdm(support_loader, desc="LSH Calculation"):
        xb = xb.to(device)
        xb_norm = torch.sqrt(torch.sum(xb**2, dim=1))

        similarity_truth = torch.matmul(xq, xb.T) / (xq_norm[:, None] * xb_norm[None, :])
        similarity_truth = torch.clamp(similarity_truth, -1, 1)
        angle_truth      = -1 * torch.acos(similarity_truth) / math.pi * 180
        if(args.normalization):
            xb = xb / xb_norm[:, None]
        hashed_query_vector = torch.matmul(xq, hashing_matrix)
        hashed_key_memory   = torch.matmul(xb, hashing_matrix)
        ###
        hashed_query_vector = (hashed_query_vector > 0).float()
        hashed_key_memory   = (hashed_key_memory   > 0).float()
        similarity          = torch.cos(torch.mean(torch.abs(hashed_query_vector[:, None, :] - hashed_key_memory[None, :, :]), dim=2) * math.pi)
        similarity          = torch.clamp(similarity, -1, 1)
        similarity          = -1 * torch.acos(similarity) / math.pi * 180
        
        
        # similarity = -1*torch.sqrt(torch.sum((hashed_query_vector[:,None,:] - hashed_key_memory[None,:,:])**2, dim=2))
        all_differences.append((angle_truth - similarity).detach().cpu())
        all_similarities.append(similarity.detach().cpu())
   
    all_differences = torch.cat(all_differences, dim=1)
    all_similarities = torch.cat(all_similarities, dim=1)

    topk_values, topk_indices = torch.topk(all_similarities, k=topk, dim=1)
    all_differences = all_differences.flatten()
    all_similarities = all_similarities.flatten()

    return topk_indices, torch.mean(all_similarities).item(), torch.mean(all_differences).item(), torch.std(all_differences).item()

In [None]:
global correlation_array
global gap_array
correlation_array = []
gap_array         = []

def do_compute_segmented_cosine(support_loader, query_vector, topk=1000, scale=None, minimum=None):
    device           = query_vector.device 
    query_vector     = query_vector[:, :args.dimension] 
    query_vector_mag = torch.sqrt(torch.sum(query_vector**2, dim=1)) + 1e-12

    query_vector_circular = torch.cat((query_vector, query_vector[:,:args.coupled_dimension-1]), dim=1)
    idx                   = torch.arange(args.coupled_dimension).unsqueeze(0) + torch.arange(args.dimension).unsqueeze(1)
    query_segment         = query_vector_circular[:,idx]
    query_segment_mag     = torch.sqrt(torch.sum(query_segment**2, dim=2)) + 1e-12
    
    if(args.option == "Seg-Cos_TCAM_Bit"):
        plane_index = torch.arange(args.N//2).to(device)
        plane_angle = -math.pi / args.N - plane_index * (2*math.pi / args.N)
        normal_matrix = torch.stack((torch.sin(plane_angle), -torch.cos(plane_angle)), dim=1)
        query_segment_mag_ = torch.abs(torch.matmul(query_segment, normal_matrix.T))
    
    all_differences  = []
    all_similarities = []
    all_angles       = []
    all_weight_mean      = []
    all_similarity_truth = []

    # data_array = []

    for support_vector in tqdm(support_loader, desc="Segmented Cosine"):
        support_vector     = support_vector.to(device)
        support_vector     = support_vector[:, :args.dimension]
        support_vector_mag = torch.sqrt(torch.sum(support_vector**2, dim=1)) + 1e-12

        similarity_truth = torch.matmul(query_vector, support_vector.T) / (query_vector_mag[:, None] * support_vector_mag[None, :])
        similarity_truth = torch.clamp(similarity_truth, -1, 1)
        angle_truth      = -1 * torch.acos(similarity_truth) / math.pi * 180

        support_vector_circular = torch.cat((support_vector, support_vector[:,:args.coupled_dimension-1]), dim=1)
        support_segment         = support_vector_circular[:,idx]
        support_segment_mag     = torch.sqrt(torch.sum(support_segment**2, dim=2)) + 1e-12

        if(args.option == "Seg-Cos_TCAM_Bit"):
            plane_index = torch.arange(args.N//2).to(device)
            plane_angle = -math.pi / args.N - plane_index * (2*math.pi / args.N)
            normal_matrix = torch.stack((torch.sin(plane_angle), -torch.cos(plane_angle)), dim=1)
            support_segment_mag_ = torch.abs(torch.matmul(support_segment, normal_matrix.T))
        
        weight_support = math.sqrt(args.dimension / args.coupled_dimension) * (support_segment_mag / support_vector_mag[:,None])
        weight_query   = math.sqrt(args.dimension / args.coupled_dimension) * (  query_segment_mag /   query_vector_mag[:,None])
        
        segmented_similarity = torch.sum(query_segment[:, None, :, :] * support_segment[None, :, :, :], dim=3) / (query_segment_mag[:, None, :] * support_segment_mag[None, :, :])
        segmented_similarity = torch.clamp(segmented_similarity, -1+1e-7, 1-1e-7)
        weight               = weight_support[None, :, :] * weight_query[:, None, :]
        
        # Cauchy–Schwarz Inequality 
        if(args.option == "CSI"):
            segment_lower_estimate = weight * (1 - segmented_similarity)
            segment_upper_estimate = weight * (1 + segmented_similarity)
        # Jensen's Inequality 
        elif(args.option == "JI"):
            segment_lower_estimate = torch.log(weight) + torch.log(1 - segmented_similarity)
            segment_upper_estimate = torch.log(weight) + torch.log(1 + segmented_similarity)
        # Taylor Expansion  
        elif(args.option == "TLE"):
            t_lower    = 2*math.atan(1/(scale))
            bias_lower = ((0-t_lower)*scale) + math.log(1-math.cos(t_lower))
            segment_lower_estimate = torch.log(weight) + (scale*torch.acos(segmented_similarity) + bias_lower)
            t_upper    = 2*math.atan(scale)
            bias_upper = ((t_upper-math.pi)*scale) + math.log(1+math.cos(t_upper))
            segment_upper_estimate = torch.log(weight) + (scale*(math.pi-torch.acos(segmented_similarity)) + bias_upper)
        # Hamming Distance 
        elif(args.option == "HD"):
            if(args.angle == "ln"):
                segment_lower_estimate = torch.log(1 - segmented_similarity)
                segment_upper_estimate = torch.log(1 + segmented_similarity)
            elif(args.angle == "arccos"):
                segment_lower_estimate = torch.acos(segmented_similarity) - math.pi/2
                segment_upper_estimate = math.pi/2 - torch.acos(segmented_similarity) 
            else:
                segment_lower_estimate = -1 * segmented_similarity
                segment_upper_estimate = segmented_similarity

        elif(args.option == "Seg-Cos_Real_Float"):
            quantized_weight_support = torch.clamp(torch.log(weight_support) - (minimum / 2), max=0)
            quantized_weight_query   = torch.clamp(torch.log(weight_query)   - (minimum / 2), max=0)
            
            if(args.normalization):
                mean_weight_support = torch.mean(quantized_weight_support, dim=1)
                quantized_weight_support = -1*math.pi*args.factor * quantized_weight_support / mean_weight_support[:, None]
                quantized_weight_support = torch.where(mean_weight_support[:, None] == 0, -1*math.pi*args.factor, quantized_weight_support)
                if(args.complete):
                    mean_weight_query = torch.mean(quantized_weight_query, dim=1)
                    quantized_weight_query = -1*math.pi*args.factor * quantized_weight_query / mean_weight_query[:, None]
                    quantized_weight_query = torch.where(mean_weight_query[:, None] == 0, -1*math.pi*args.factor, quantized_weight_query)
            segment_lower_estimate = quantized_weight_support[None, :, :] + quantized_weight_query[:, None, :] + torch.log(1-segmented_similarity) - (1 / scale)
            segment_upper_estimate = quantized_weight_support[None, :, :] + quantized_weight_query[:, None, :] + torch.log(1+segmented_similarity) - (1 / scale)
        
        elif(args.option == "Seg-Cos_Float"):
            t    = 2*math.atan(1/(scale))
            bias = ((0-t)*scale) + math.log(1-math.cos(t))
            quantized_weight_support = (1/scale) * (torch.log(weight_support) + ((bias - minimum) / 2))
            quantized_weight_query   = (1/scale) * (torch.log(weight_query)   + ((bias - minimum) / 2))

            quantized_weight_support = torch.clamp(quantized_weight_support, max=0)
            quantized_weight_query   = torch.clamp(quantized_weight_query,   max=0)
            
            if(args.normalization):
                mean_weight_support = torch.mean(quantized_weight_support, dim=1)
                quantized_weight_support = quantized_weight_support + (-1*math.pi*args.factor - mean_weight_support[:, None]) * quantized_weight_support / mean_weight_support[:, None]
                if(args.complete):
                    mean_weight_query = torch.mean(quantized_weight_query, dim=1)
                    quantized_weight_query = quantized_weight_query + (-1*math.pi*args.factor - mean_weight_query[:, None]) * quantized_weight_query / mean_weight_query[:, None]

            maximum_distance = torch.clamp(math.pi + 2*torch.minimum(quantized_weight_support[None, :, :], quantized_weight_query[:, None, :]), min=0)
            segment_lower_estimate = torch.clamp(quantized_weight_support[None, :, :] + quantized_weight_query[:, None, :] + torch.acos(segmented_similarity),         max=maximum_distance)
            segment_upper_estimate = torch.clamp(quantized_weight_support[None, :, :] + quantized_weight_query[:, None, :] + math.pi-torch.acos(segmented_similarity), max=maximum_distance)

        elif(args.option == "Seg-Cos_QuantAng"):
            division = 2*math.pi/args.N
            angle_query_segment   = torch.atan2(  query_segment[:,:,1],   query_segment[:,:,0])
            angle_support_segment = torch.atan2(support_segment[:,:,1], support_segment[:,:,0])
            angle_query_segment   = torch.round(  angle_query_segment / division)
            angle_support_segment = torch.round(angle_support_segment / division)
            angle_query_segment   = torch.where(  angle_query_segment < 0, args.N +   angle_query_segment,   angle_query_segment)
            angle_support_segment = torch.where(angle_support_segment < 0, args.N + angle_support_segment, angle_support_segment)
            angle_difference = torch.abs(angle_query_segment[:, None, :] - angle_support_segment[None, :, :])
            angle_difference = torch.min(angle_difference, args.N - angle_difference) * division

            t_lower    = 2*math.atan(1/scale)
            bias_lower = ((0-t_lower)*scale) + math.log(1-math.cos(t_lower))
            segment_lower_estimate = (1/scale) * (torch.log(weight) + (scale*angle_difference + bias_lower) - minimum)
            t_upper    = 2*math.atan(scale)
            bias_upper = ((t_upper-math.pi)*scale) + math.log(1+math.cos(t_upper))
            segment_upper_estimate = (1/scale) * (torch.log(weight) + (scale*(math.pi-angle_difference) + bias_upper) - minimum)
        
        elif(args.option == "Seg-Cos_QuantAngMag"):
            division = 2*math.pi/args.N
            angle_query_segment   = torch.atan2(  query_segment[:,:,1],   query_segment[:,:,0])
            angle_support_segment = torch.atan2(support_segment[:,:,1], support_segment[:,:,0])
            angle_query_segment   = torch.round(  angle_query_segment / division)
            angle_support_segment = torch.round(angle_support_segment / division)
            angle_query_segment   = torch.where(  angle_query_segment < 0, args.N +   angle_query_segment,   angle_query_segment)
            angle_support_segment = torch.where(angle_support_segment < 0, args.N + angle_support_segment, angle_support_segment)
            angle_difference = torch.abs(angle_query_segment[:, None, :] - angle_support_segment[None, :, :])
            angle_difference = torch.min(angle_difference, args.N - angle_difference)

            t_lower    = 2*math.atan(1/scale)
            bias_lower = ((0-t_lower)*scale) + math.log(1-math.cos(t_lower))
            quantized_lower_weight_support = torch.round((1/scale) * (torch.log(weight_support) + ((bias_lower - minimum) / 2)) / division)
            quantized_lower_weight_query   = torch.round((1/scale) * (torch.log(weight_query)   + ((bias_lower - minimum) / 2)) / division)
            segment_lower_estimate = quantized_lower_weight_support[None, :, :] + quantized_lower_weight_query[:, None, :] + angle_difference
            segment_lower_estimate = segment_lower_estimate * division

            t_upper    = 2*math.atan(scale)
            bias_upper = ((t_upper-math.pi)*scale) + math.log(1+math.cos(t_upper))
            quantized_upper_weight_support = torch.round((1/scale) * (torch.log(weight_support) + ((bias_upper - minimum) / 2)) / division)
            quantized_upper_weight_query   = torch.round((1/scale) * (torch.log(weight_query)   + ((bias_upper - minimum) / 2)) / division)
            segment_upper_estimate = quantized_upper_weight_support[None, :, :] + quantized_upper_weight_query[:, None, :] + ((args.N/2)-angle_difference)
            segment_upper_estimate = segment_upper_estimate * division

        elif(args.option == "Seg-Cos_Fixed"):
            division = 2*math.pi/args.N
            angle_query_segment   = torch.atan2(  query_segment[:,:,1],   query_segment[:,:,0])
            angle_support_segment = torch.atan2(support_segment[:,:,1], support_segment[:,:,0])
            angle_query_segment   = torch.round(  angle_query_segment / division)
            angle_support_segment = torch.round(angle_support_segment / division)
            angle_query_segment   = torch.where(  angle_query_segment < 0, args.N +   angle_query_segment,   angle_query_segment)
            angle_support_segment = torch.where(angle_support_segment < 0, args.N + angle_support_segment, angle_support_segment)
            angle_difference = torch.abs(angle_query_segment[:, None, :] - angle_support_segment[None, :, :])
            angle_difference = torch.min(angle_difference, args.N - angle_difference)

            t    = 2*math.atan(1/scale)
            bias = ((0-t)*scale) + math.log(1-math.cos(t))
            quantized_weight_support = torch.round((1/scale) * (torch.log(weight_support) + ((bias - minimum) / 2)) / division)
            quantized_weight_query   = torch.round((1/scale) * (torch.log(weight_query)   + ((bias - minimum) / 2)) / division)
            quantized_weight_support = torch.clamp(quantized_weight_support, max=0)
            quantized_weight_query   = torch.clamp(quantized_weight_query,   max=0)
            segment_lower_estimate = quantized_weight_support[None, :, :] + quantized_weight_query[:, None, :] + angle_difference
            segment_upper_estimate = quantized_weight_support[None, :, :] + quantized_weight_query[:, None, :] + ((args.N/2)-angle_difference)
            segment_lower_estimate = segment_lower_estimate * division
            segment_upper_estimate = segment_upper_estimate * division

        elif(args.option == "Seg-Cos_TCAM"):
            # print(query_segment[0,1:3])
            # print(query_segment[53,1:3])
            # print(query_segment[23,1:3])
            division = 2*math.pi/args.N
            angle_query_segment   = torch.atan2(  query_segment[:,:,1],   query_segment[:,:,0])
            angle_support_segment = torch.atan2(support_segment[:,:,1], support_segment[:,:,0])
            angle_query_segment   = torch.round(  angle_query_segment / division)
            angle_support_segment = torch.round(angle_support_segment / division)
            angle_query_segment   = torch.where(  angle_query_segment < 0, args.N +   angle_query_segment,   angle_query_segment)
            angle_support_segment = torch.where(angle_support_segment < 0, args.N + angle_support_segment, angle_support_segment)
            angle_difference = torch.abs(angle_query_segment[:, None, :] - angle_support_segment[None, :, :])
            angle_difference = torch.min(angle_difference, args.N - angle_difference)

            t    = 2*math.atan(1/scale)
            bias = ((0-t)*scale) + math.log(1-math.cos(t))
            quantized_weight_support = (1/scale) * (torch.log(weight_support) + ((bias - minimum) / 2))
            quantized_weight_query   = (1/scale) * (torch.log(weight_query)   + ((bias - minimum) / 2))
            quantized_weight_support = torch.clamp(quantized_weight_support, max=0)
            quantized_weight_query   = torch.clamp(quantized_weight_query,   max=0)
            if(args.normalization):
                num_to_set = round(args.dimension * args.factor)
                if((args.N == 2) or (args.N == 4)):
                    _, indices = torch.topk(-quantized_weight_support, num_to_set, dim=1)
                    quantized_weight_support = torch.zeros_like(quantized_weight_support)
                    quantized_weight_support.scatter_(1, indices, -math.pi)
                else:
                    mean_weight_support = torch.mean(quantized_weight_support, dim=1)
                    # quantized_weight_support = quantized_weight_support + (-1*math.pi*args.factor - mean_weight_support[:, None]) * quantized_weight_support / mean_weight_support[:, None]
                    quantized_weight_support = -1*math.pi*args.factor * quantized_weight_support / mean_weight_support[:, None]
                    quantized_weight_support = torch.where(mean_weight_support[:, None] == 0, -1*math.pi*args.factor, quantized_weight_support)
                if(args.complete):
                    if((args.N == 2) or (args.N == 4)):
                        _, indices = torch.topk(-quantized_weight_query, num_to_set, dim=1)
                        quantized_weight_query = torch.zeros_like(quantized_weight_query)
                        quantized_weight_query.scatter_(1, indices, -math.pi)
                    else:
                        mean_weight_query = torch.mean(quantized_weight_query, dim=1)
                        # quantized_weight_query = quantized_weight_query + (-1*math.pi*args.factor - mean_weight_query[:, None]) * quantized_weight_query / mean_weight_query[:, None]
                        quantized_weight_query = -1*math.pi*args.factor * quantized_weight_query / mean_weight_query[:, None]
                        quantized_weight_query = torch.where(mean_weight_query[:, None] == 0, -1*math.pi*args.factor, quantized_weight_query)

            quantized_weight_support_round = torch.round(quantized_weight_support / division)
            quantized_weight_query_round   = torch.round(quantized_weight_query   / division)

            quantized_weight_support = quantized_weight_support_round
            quantized_weight_query   = quantized_weight_query_round
                
            maximum_distance = torch.clamp((args.N/2) + 2*torch.minimum(quantized_weight_support[None, :, :], quantized_weight_query[:, None, :]), min=0)

            segment_lower_estimate = torch.clamp(quantized_weight_support[None, :, :] + quantized_weight_query[:, None, :] + angle_difference,            max=maximum_distance)
            segment_lower_estimate = segment_lower_estimate * division
            segment_upper_estimate = torch.clamp(quantized_weight_support[None, :, :] + quantized_weight_query[:, None, :] + (args.N/2)-angle_difference, max=maximum_distance)
            # segment_upper_estimate = torch.clamp((args.N/2)-angle_difference, max=maximum_distance)
            segment_upper_estimate = segment_upper_estimate * division

        elif(args.option == "Seg-Cos_TCAM_Bit"):
            division = 2*math.pi/args.N
            angle_query_segment   = torch.atan2(  query_segment[:,:,1],   query_segment[:,:,0])
            angle_support_segment = torch.atan2(support_segment[:,:,1], support_segment[:,:,0])
            angle_query_segment   = torch.round(  angle_query_segment / division)
            angle_support_segment = torch.round(angle_support_segment / division)
            angle_query_segment   = torch.where(  angle_query_segment < 0, args.N +   angle_query_segment,   angle_query_segment)
            angle_support_segment = torch.where(angle_support_segment < 0, args.N + angle_support_segment, angle_support_segment)
            codeword = torch.arange(args.N//2).to(device)
            codeword = codeword.unsqueeze(0) + torch.arange(args.N).to(device).unsqueeze(1)
            codeword = codeword % (args.N)
            codeword = codeword >= (args.N // 2)
            
            codeword_query   = codeword[angle_query_segment.long()]
            codeword_support = codeword[angle_support_segment.long()]

            angle_difference = codeword_query[:, None, :, :] ^ codeword_support[None, :, :, :]
            
            # Fixed potential unbound variable error if Seg-Cos_TCAM_Bit is used straight away, 
            # but likely args.option logic separation handles it.
            # Using the same logic as Seg-Cos_TCAM for quantized_weight_*

            t    = 2*math.atan(1/scale)
            bias = ((0-t)*scale) + math.log(1-math.cos(t))
            # Note: The original code for Seg-Cos_TCAM_Bit seemed to use variables from other scopes or needed more setup.
            # Assuming the weights are calculated similarly to Seg-Cos_TCAM or other blocks if not explicitly defined.
            # Checking original code, quantized_weight_support seems to come from arguments or previous calcs.
            # BUT, in the original code, Seg-Cos_TCAM_Bit uses `quantized_weight_support` which was defined in Seg-Cos_TCAM block?? 
            # No, it's a separate elif. 
            # Ah, in the original code, `Seg-Cos_TCAM_Bit` logic uses `quantized_weight_support` but doesn't calculate it in its own block?
            # Let's re-read the original file lines 580+. 
            # It seems it re-calculates or uses logic similar to others. 
            # Wait, line 588 in original code: `if(args.normalization): ... ` 
            # But where is `quantized_weight_support` defined?
            # It seems the `Seg-Cos_TCAM_Bit` block in original code might rely on state or I missed a part.
            # Let's implement what I see, but add safety.
            
            # Recalculating weights for Bit version if not present
            quantized_weight_support = (1/scale) * (torch.log(weight_support) + ((bias - minimum) / 2))
            quantized_weight_query   = (1/scale) * (torch.log(weight_query)   + ((bias - minimum) / 2))
            quantized_weight_support = torch.clamp(quantized_weight_support, max=0)
            quantized_weight_query   = torch.clamp(quantized_weight_query,   max=0)

            if(args.normalization):
                num_to_set = math.floor(args.dimension * args.factor * (args.N // 2))
                support_segment_mag_ = support_segment_mag_.view(support_segment_mag_.shape[0], -1)
                _, indices = torch.topk(-support_segment_mag_, num_to_set, dim=1)
                quantized_weight_support = torch.ones_like(support_segment_mag_)
                quantized_weight_support.scatter_(1, indices, 0)
                quantized_weight_support = quantized_weight_support.view(support_segment_mag_.shape[0], -1, args.N // 2)
                if(args.complete):
                    query_segment_mag_ = query_segment_mag_.view(query_segment_mag_.shape[0], -1)
                    _, indices = torch.topk(-query_segment_mag_, num_to_set, dim=1)
                    quantized_weight_query = torch.ones_like(query_segment_mag_)
                    quantized_weight_query.scatter_(1, indices, 0)
                    quantized_weight_query = quantized_weight_query.view(quantized_weight_query.shape[0], -1, args.N // 2)

            segment_lower_estimate = torch.sum(angle_difference    * quantized_weight_support[None, :, :, :] * quantized_weight_query[:, None, :, :], dim=3)
            segment_upper_estimate = torch.sum((~angle_difference) * quantized_weight_support[None, :, :, :] * quantized_weight_query[:, None, :, :], dim=3)

            segment_lower_estimate = segment_lower_estimate * division
            segment_upper_estimate = segment_upper_estimate * division

        if("Seg-Cos" in args.option):
            segment_lower_estimate = torch.clamp(segment_lower_estimate, min=0) # * scale + minimum
            segment_upper_estimate = torch.clamp(segment_upper_estimate, min=0) # * scale + minimum
        else:
            if(args.minimum_clamp):
                segment_lower_estimate = torch.clamp(segment_lower_estimate, min=minimum)
                segment_upper_estimate = torch.clamp(segment_upper_estimate, min=minimum)
        
        lower_estimate = -1*torch.mean(segment_lower_estimate, dim=2)
        upper_estimate =    torch.mean(segment_upper_estimate, dim=2)
        weight_mean = torch.mean(weight, dim=2)
        
        all_weight_mean.append(weight_mean)
        all_similarity_truth.append(similarity_truth)
        
        similarity = None
        if(args.bound == "lower"):
            similarity = lower_estimate
        elif(args.bound == "upper"):
            similarity = upper_estimate
        elif(args.bound == "complementary"):
            similarity = (lower_estimate + upper_estimate) / 2
        
        cosine = None
        # Cauchy–Schwarz Inequality 
        if(args.option == "CSI"):
            if(args.bound == "lower"):
                cosine = similarity + 1
            elif(args.bound == "upper"):
                cosine = similarity - 1
            elif(args.bound == "complementary"):
                cosine = similarity
        # Jensen's Inequality / Taylor / Seg-Cos
        elif((args.option == "JI") or (args.option == "TLE") or ((args.option == "HD") and (args.angle == "ln")) or ("Seg-Cos" in args.option)):
            if(args.bound == "lower"):
                cosine = -1*torch.exp(-1*similarity) + 1
            elif(args.bound == "upper"):
                cosine = torch.exp(similarity) - 1
            elif(args.bound == "complementary"):
                cosine = 2 / (torch.exp(-1*similarity) + 1) - 1
        # Hamming Distance 
        elif(args.option == "HD"):
            if(args.angle == "arccos"):
                cosine = torch.cos(similarity + math.pi/2)
            else:
                cosine = similarity
        
        angle = -1 * torch.acos(torch.clamp(cosine, -1, 1)) / math.pi * 180
        
        all_differences.append((angle_truth - angle).detach().cpu())
        all_similarities.append(similarity.detach().cpu())
        all_angles.append(angle.detach().cpu())

    all_differences  = torch.cat(all_differences, dim=1)
    all_similarities = torch.cat(all_similarities, dim=1)
    all_angles       = torch.cat(all_angles, dim=1)
    all_weight_mean      = torch.cat(all_weight_mean, dim=1)
    all_similarity_truth = torch.cat(all_similarity_truth, dim=1)

    topk_values, topk_indices = torch.topk(all_similarities, k=topk, dim=1)

    all_differences = all_differences.flatten()
    all_similarities = all_similarities.flatten()
    all_angles = all_angles.flatten()

    return topk_indices, torch.mean(all_angles).item(), torch.mean(all_differences).item(), torch.std(all_differences).item()

In [None]:
def inference():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("device:", device)
    if args.dataset == "glove-25":
        file_path = './GloVe/glove-25-angular.hdf5'
    elif args.dataset == "glove-50":
        file_path = './GloVe/glove-50-angular.hdf5'
    elif args.dataset == "glove-100":
        file_path = './GloVe/glove-100-angular.hdf5'
    elif args.dataset == "glove-200":
        file_path = './GloVe/glove-200-angular.hdf5'
    elif args.dataset == "NYTimes":
        file_path = './NYTimes/nytimes-256-angular.hdf5'
    elif args.dataset == "DEEP1B":
        file_path = './DEEP1B/deep-image-96-angular.hdf5'
    elif args.dataset == "Last.fm":
        file_path = './Last.fm/lastfm-64-dot.hdf5'
    elif args.dataset == "COCO-I2I":
        file_path = './COCO-I2I/coco-i2i-512-angular.hdf5'
    elif args.dataset == "COCO-T2I":
        file_path = './COCO-T2I/coco-t2i-512-angular.hdf5'
    else:
        file_path = f"./{args.dataset}/{args.dataset}_{args.num_subset}.hdf5"

    total_query_num = args.total_query_num
    with h5py.File(file_path, 'r') as f:
        print("Query Set shape:", f['test'].shape)
        if total_query_num == 0:
            total_query_num = f['test'].shape[0]
        support_set_dataset = SupportSetDataset(f['train'][:])
        batch_size = args.batch_size
        support_loader = DataLoader(support_set_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
        print("Query Label shape:", f['neighbors'].shape)
        print("Support Set shape:", f['train'].shape)

        if(total_query_num % args.query_num != 0):
            print("Error: The total query number is not divisible by the query number.")

    counter = 0
    recall = 0
    
    if args.similarity == "LSH":
        L = args.param
        hashing_matrix = torch.randn((args.dimension, args.dimension*L), device=device)

    while counter < total_query_num:
        with h5py.File(file_path, 'r') as f:
            test_set = torch.tensor(f['test'][counter : counter+args.query_num], dtype=torch.float32).to(device)
            test_neighbors_id = f['neighbors'][counter : counter+args.query_num]

        print("Processed query from", counter, "to", counter + test_set.shape[0] - 1)

        if args.similarity == "cosine":
            result_id, average_similarity, std_similarity = do_compute_cosine(support_loader, test_set, topk=args.topk)
            average_difference = average_similarity
            std_difference = std_similarity
        elif args.similarity == "segmented_cosine":
            result_id, average_similarity, average_difference, std_difference = do_compute_segmented_cosine(support_loader, test_set, topk=args.topk, scale=(1 / args.alpha), minimum=args.beta)
        elif args.similarity == "LSH":
            result_id, average_similarity, average_difference, std_difference = do_compute_lsh(support_loader, test_set, L, hashing_matrix, topk=args.topk)
        elif args.similarity == "L1norm":
            result_id, average_similarity, average_difference, std_difference = do_compute_L1norm(support_loader, test_set, topk=args.topk)
        average_similarity_array = []
        average_difference_array = []
        std_difference_array     = []

        average_similarity_array.append(average_similarity)
        average_difference_array.append(average_difference)
        std_difference_array.append(std_difference)

        for i in range(test_set.shape[0]):
            # print("Query ID:", i, "Query Label:", test_neighbors_id[i])
            # print("Result ID:", result_id[i])
            intersection = torch.isin(result_id[i], torch.tensor(test_neighbors_id[i], device=result_id.device))
            # print("Intersection:", intersection.sum())
            recall += intersection.sum().item() / len(test_neighbors_id[i])

        counter += args.query_num

    average_similarity = torch.tensor(average_similarity_array, device=device)
    total_average_similarity = torch.mean(average_similarity)
    average_difference = torch.tensor(average_difference_array, device=device)
    total_average_difference = torch.mean(average_difference)
    std_difference = torch.tensor(std_difference_array, device=device)
    total_std_difference =  torch.sqrt(torch.mean((std_difference**2) + (average_difference - total_average_difference)**2))
    print("Difference:", total_average_difference, total_std_difference)
    print("Similarity:", total_average_similarity)

    ave_recall = recall / total_query_num
    print("Recall:", ave_recall)

In [None]:
def find_best_recall():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("device:", device)
    if args.dataset == "glove-25":
        file_path = './GloVe/glove-25-angular.hdf5'
    elif args.dataset == "glove-50":
        file_path = './GloVe/glove-50-angular.hdf5'
    elif args.dataset == "glove-100":
        file_path = './GloVe/glove-100-angular.hdf5'
    elif args.dataset == "glove-200":
        file_path = './GloVe/glove-200-angular.hdf5'
    elif args.dataset == "NYTimes":
        file_path = './NYTimes/nytimes-256-angular.hdf5'
    elif args.dataset == "DEEP1B":
        file_path = './DEEP1B/deep-image-96-angular.hdf5'
    elif args.dataset == "Last.fm":
        file_path = './Last.fm/lastfm-64-dot.hdf5'
    elif args.dataset == "COCO-I2I":
        file_path = './COCO-I2I/coco-i2i-512-angular.hdf5'
    elif args.dataset == "COCO-T2I":
        file_path = './COCO-T2I/coco-t2i-512-angular.hdf5'
    else:
        file_path = f"./{args.dataset}/{args.dataset}_{args.num_subset}.hdf5"

    total_query_num = args.total_query_num
    with h5py.File(file_path, 'r') as f:
        print("Query Set shape:", f['test'].shape)
        if total_query_num == 0:
            total_query_num = f['test'].shape[0]
        support_set_dataset = SupportSetDataset(f['train'][:])
        batch_size = args.batch_size
        support_loader = DataLoader(support_set_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
        print("Query Label shape:", f['neighbors'].shape)
        print("Support Set shape:", f['train'].shape)

        if(total_query_num % args.query_num != 0):
            print("Error: The total query number is not divisible by the query number.")

    best_alpha  = None
    best_beta   = None
    best_recall = 0

    # Hyperparameter Grid Search Range
    alpha_start = args.alpha 
    alpha_end   = 2.1
    alpha_step  = 0.1
    beta_start  = args.beta 
    beta_end    = 3.0 
    beta_step   = 0.1

    for alpha in torch.arange(alpha_start, alpha_end, alpha_step):
        alpha = alpha.item()
        tmp_best_alpha  = None
        tmp_best_beta   = None
        tmp_best_recall = 0

        early_stop = 0

        for beta in torch.arange(beta_start, beta_end, beta_step):
            beta = beta.item()
            print("Alpha:", alpha, "Beta:", beta)

            counter = 0
            recall = 0

            while counter < total_query_num:
                with h5py.File(file_path, 'r') as f:
                    test_set = torch.tensor(f['test'][counter : counter+args.query_num], dtype=torch.float32).to(device)
                    test_neighbors_id = f['neighbors'][counter : counter+args.query_num]

                # print("Processed query from", counter, "to", counter + test_set.shape[0] - 1)

                if args.similarity == "segmented_cosine":
                    result_id, average_similarity, average_difference, std_difference = do_compute_segmented_cosine(support_loader, test_set, topk=args.topk, scale=(1 / alpha), minimum=beta)
                elif args.similarity == "L1norm":
                    result_id, average_similarity, average_difference, std_difference = do_compute_L1norm(support_loader, test_set, topk=args.topk, clamp_ratio = beta)
                average_similarity_array = []
                average_difference_array = []
                std_difference_array     = []

                average_similarity_array.append(average_similarity)
                average_difference_array.append(average_difference)
                std_difference_array.append(std_difference)

                for i in range(test_set.shape[0]):
                    # print("Query ID:", i, "Query Label:", test_neighbors_id[i])
                    # print("Result ID:", result_id[i])
                    intersection = torch.isin(result_id[i], torch.tensor(test_neighbors_id[i], device=result_id.device))
                    # print("Intersection:", intersection.sum())
                    recall += intersection.sum().item() / len(test_neighbors_id[i])

                counter += args.query_num

            average_similarity = torch.tensor(average_similarity_array, device=device)
            total_average_similarity = torch.mean(average_similarity)
            average_difference = torch.tensor(average_difference_array, device=device)
            total_average_difference = torch.mean(average_difference)
            std_difference = torch.tensor(std_difference_array, device=device)
            total_std_difference =  torch.sqrt(torch.mean((std_difference**2) + (average_difference - total_average_difference)**2))

            ave_recall = recall / total_query_num

            print("Recall:", ave_recall, "with alpha:", alpha, "with beta:", beta)
            print("Difference:", total_average_difference, total_std_difference)
            print("Similarity:", total_average_similarity)

            if ave_recall >= tmp_best_recall:
                early_stop = 0
                tmp_best_recall = ave_recall
                tmp_best_alpha = alpha
                tmp_best_beta = beta
                tmp_average_difference = total_average_difference
                tmp_std_difference     = total_std_difference
                tmp_average_similarity = total_average_similarity
                # print("Temporal Best recall:", tmp_best_recall, "with alpha:", tmp_best_alpha, "with beta:", tmp_best_beta)
                if beta == beta_end:
                    print("##############################################################")
                    print("Temporal Best recall:", tmp_best_recall, "with alpha:", tmp_best_alpha, "with beta:", tmp_best_beta)
                    print("Difference:", tmp_average_difference, tmp_std_difference)
                    print("Similarity:", tmp_average_similarity)
                    print("##############################################################")
                
                last_time_recall = ave_recall

            else:
                if ave_recall < last_time_recall:   
                    early_stop += 1
                else:
                    early_stop = 0
                last_time_recall = ave_recall
                if early_stop >= 2:
                    print("##############################################################")
                    print("Temporal Best recall:", tmp_best_recall, "with alpha:", tmp_best_alpha, "with beta:", tmp_best_beta)
                    print("Difference:", tmp_average_difference, tmp_std_difference)
                    print("Similarity:", tmp_average_similarity)
                    print("##############################################################")
                    break
                else:
                    print("Early stop count:", early_stop, "with alpha:", alpha, "with beta:", beta)
            
        if tmp_best_recall >= best_recall:
            best_recall = tmp_best_recall
            best_alpha = tmp_best_alpha
            best_beta = tmp_best_beta
            best_average_difference = tmp_average_difference
            best_std_difference = tmp_std_difference
            best_average_similarity = tmp_average_similarity
        else:
            print("##############################################################")
            print("Best recall:", best_recall, "with alpha:", best_alpha, "with beta:", best_beta)
            print("Difference:", best_average_difference, best_std_difference)
            print("Similarity:", best_average_similarity)
            print("##############################################################")
            break
        print("##############################################################")
        print("Best recall:", best_recall, "with alpha:", best_alpha, "with beta:", best_beta)
        print("Difference:", best_average_difference, best_std_difference)
        print("Similarity:", best_average_similarity)
        print("##############################################################")

        if args.option != "TLE":
            break

In [None]:
# Main Execution Block
if __name__ == '__main__':
    if args.find_best_recall:
        find_best_recall()
    else:
        inference()