In [1]:
import numpy as np
import random
import time
import bisect

In [2]:
class Neighbor:
    def __init__(self, id, distance, flag):
        self.id = id
        self.distance = distance
        self.flag = flag

    def __lt__(self, other: 'Neighbor'):
        return self.distance < other.distance
    
    def printNeighbor(self):
        print("id: ", self.id, "distance: ", self.distance, "flag: ", self.flag)

### NSG Index structure

In [3]:
def compare(query_data, db_vec, dimension: int):
    return np.sum((query_data - db_vec) ** 2)

def print_retset(retset):
    for i in range(len(retset)):
        retset[i].printNeighbor()
        
class IndexNSG():
    def __init__(self, dimension: int, n: int):
        self.final_graph = []
        self.dimension = dimension
        self.width = 0
        self.ep_ = 0
        self.nd_ = n
    
    def Load(self, filename: str):
        with open(filename, "rb") as f:
            self.width = int.from_bytes(f.read(4), byteorder='little', signed=False)
            self.ep_ = int.from_bytes(f.read(4), byteorder='little', signed=False)
            cc = 0
            while True:
                k_bytes = f.read(4)
                if not k_bytes:
                    break
                k = int.from_bytes(k_bytes, byteorder='little', signed=False)
                cc += k
                tmp = list(np.frombuffer(f.read(k * 4), dtype=np.uint32))
                self.final_graph.append(tmp)
            cc //= self.nd_
            # print(cc)

    def search_with_base_graph(self, query, x, K, parameters):
        L = parameters['L_search']
        data_ = x
        retset:list[Neighbor] = []
        init_ids = []
        flags = [0] * self.nd_
        node_counter = 0
        
        for tmp_l in range(min(L, len(self.final_graph[self.ep_]))):
            init_ids.append(self.final_graph[self.ep_][tmp_l])
            flags[init_ids[tmp_l]] = 1
        
        tmp_l += 1
    
        while tmp_l < L:
            id = random.randint(0, self.nd_ - 1)
            if flags[id] == 1:
                continue
            flags[id] = 1
            init_ids.append(id)
            tmp_l += 1

        for i in range(len(init_ids)):
            id = init_ids[i]
            dist = compare(data_[id], query, self.dimension)
            retset.append(Neighbor(id, dist, 1))
        
        node_counter += len(init_ids)
        
        retset.sort()

        # for item in retset[:L]:
        #     print(item.id, " ", end="")
        # print("")
        
        # cc = 0
        k = 0
        while k < L:
            # cc += 1
            nk = L
            if retset[k].flag:
                retset[k].flag = 0
                n = retset[k].id
                
                # if cc < 10:
                #     print(n, " ", end="")
        
                for m in range(len(self.final_graph[n])):
                    id = self.final_graph[n][m]
                    if flags[id] == 1:
                        continue
                    flags[id] = 1
                    dist = compare(query, data_[id], self.dimension)
                    node_counter += 1
                    if dist >= retset[L - 1].distance:
                        continue
                    nn = Neighbor(id, dist, 1)
                    bisect.insort_left(retset, nn)
                    r = retset.index(nn)
                    if r < nk:
                        nk = r
                    # nk: the index of the smallest dist
            
            # if cc > 0:
            #     print("cc:", cc, end="\t")
            #     for item in retset[:L]:
            #         print(item.id, "\t", end="")
            #     print("")
            #     print("\t", end="")
            #     for item in retset[:L]:
            #         print(item.flag, "\t", end="")
            #     print("")
            
            if nk <= k:
                k = nk
            else:  
                k += 1
                
        # print_retset(retset)
        
        indices = [0] * K
        for i in range(K):
            indices[i] = retset[i].id
        
        # print(cc)
        return (indices, node_counter)

    def search_with_base_graph_2queue(self, query, x, K, parameters):
        L = parameters['L_search']
        data_ = x
        candidate_set:list[Neighbor] = []
        top_candidates:list[Neighbor] = []
        node_counter = 0
        
        init_ids = []
        flags = [0] * self.nd_
        
        for tmp_l in range(min(L, len(self.final_graph[self.ep_]))):
            init_ids.append(self.final_graph[self.ep_][tmp_l])
            flags[init_ids[tmp_l]] = 1
    
        while tmp_l < L:
            id = random.randint(0, self.nd_ - 1)
            if flags[id] == 1:
                continue
            flags[id] = 1
            init_ids.append(id)
            tmp_l += 1

        for i in range(len(init_ids)):
            id = init_ids[i]
            dist = compare(data_[id], query, self.dimension)
            candidate_set.append(Neighbor(id, dist, 1))
            top_candidates.append(Neighbor(id, dist, 1))
        
        node_counter += len(init_ids)
        
        candidate_set.sort()
        top_candidates.sort()
        
        # cc = 0
        while len(candidate_set) > 0:
            # cc += 1
            cur_node = candidate_set.pop(0)
            
            if cur_node.distance > top_candidates[L - 1].distance:  # here we assume both queues has infinite capacity
                break
            
            cur_id = cur_node.id
            
            # if cc < 10:
            #     print(cur_id, " ", end="")

            for m in range(len(self.final_graph[cur_id])):
                id = self.final_graph[cur_id][m]
                if flags[id] == 1:
                    continue
                flags[id] = 1
                dist = compare(query, data_[id], self.dimension)
                node_counter += 1
                if dist >= top_candidates[L - 1].distance:  # here L-1 is a relaxed condition ???
                    continue
                nn = Neighbor(id, dist, 1)
                bisect.insort_left(candidate_set, nn)
                bisect.insort_left(top_candidates, nn)
        
        
        indices = [0] * K
        for i in range(K):
            indices[i] = top_candidates[i].id
        
        
        return (indices, node_counter)

### Load .fvecs data

In [4]:
def load_data(filename):
    with open(filename, "rb") as file:
        dim_bytes = file.read(4)  # Read 4 bytes for dimension
        dim = int.from_bytes(dim_bytes, byteorder='little')  # Convert bytes to integer for dimension

        file.seek(0, 2)  # Move the file pointer to the end
        fsize = file.tell()  # Get the file size
        num = fsize // ((dim + 1) * 4)  # Calculate the number of data points

        file.seek(0)  # Move the file pointer back to the beginning
        data = np.empty((num, dim), dtype=np.float32)  # Create an empty numpy array to store data

        for i in range(num):
            file.seek(4, 1)  # Move the file pointer forward by 4 bytes to skip index
            data[i] = np.fromfile(file, dtype=np.float32, count=dim)  # Read dim number of float values

    return data, num, dim

def ivecs_read(fname):
    a = np.fromfile(fname, dtype='int32')
    d = a[0]
    # Wenqi: Format of ground truth (for 10000 query vectors):
    #   1000(topK), [1000 ids]
    #   1000(topK), [1000 ids]
    #        ...     ...
    #   1000(topK), [1000 ids]
    # 10000 rows in total, 10000 * 1001 elements, 10000 * 1001 * 4 bytes
    return a.reshape(-1, d + 1)[:, 1:].copy()

# Change to the path of your data
data_load, points_num, dim = load_data("../../nsg_eva/sift/sift_base.fvecs")
query_load, query_num, query_dim = load_data("../../nsg_eva/sift/sift_query.fvecs")
assert(dim == query_dim)

gt_load = ivecs_read("../../nsg_eva/sift/sift_groundtruth.ivecs")

print("data_load: ", np.shape(data_load))
print("query_load: ", np.shape(query_load))
print("gt_load: ", np.shape(gt_load))

data_load:  (1000000, 128)
query_load:  (10000, 128)
gt_load:  (10000, 100)


In [7]:
index = IndexNSG(dim, points_num)
index.Load("../../sift.nsg")

### Process queries

In [9]:
K = 100
L = 100
assert(L >= K)
qsize = 100
paras = {'L_search': L}
# for i in range(query_num):
# for i in range(1):
#     index.search_with_base_graph(query_load[i], data_load, K, paras)

total = 0
correct = 0

total_counter = 0

for i in range(qsize):
    indices, node_counter = index.search_with_base_graph(query_load[i], data_load, K, paras)
    total_counter += node_counter
    gt = gt_load[i]
    g = set(gt)
    total += len(gt)
    
    for item in indices:
        if item in g:
            correct += 1

acc = 1.0 * correct / total
avg_counter = 1.0 * total_counter / qsize
print("acc: ", acc, "avg_counter: ", avg_counter)

acc:  0.9721 avg_counter:  2448.02


search:
123065  154617  123741  154491  620808  
cc: 1	704709 	946422 	957180 	188228 	75672 	
	1 	1 	1 	1 	1 	
cc: 2	935185 	232764 	285450 	809320 	173180 	
	1 	1 	1 	1 	1 	
cc: 3	746931 	538785 	753423 	393275 	935185 	
	1 	1 	1 	1 	0 	
cc: 4	934876 	600499 	746931 	886630 	394507 	
	1 	1 	0 	1 	1 	
cc: 5	932085 	934876 	695756 	600499 	746931 	
	1 	0 	1 	1 	0 	
cc: 6	932085 	934876 	695756 	600499 	746931 	
	0 	0 	1 	1 	0 	
cc: 7	932085 	934876 	695756 	600499 	746931 	
	0 	0 	1 	1 	0 	
cc: 8	932085 	934876 	695756 	600499 	746931 	
	0 	0 	0 	1 	0 	
cc: 9	932085 	934876 	695756 	600499 	562167 	
	0 	0 	0 	0 	1 	
cc: 10	932085 	934876 	695756 	562594 	600499 	
	0 	0 	0 	1 	0 	
cc: 11	932085 	934876 	695756 	562594 	600499 	
	0 	0 	0 	0 	0 	
cc: 12	932085 	934876 	695756 	562594 	600499 	
	0 	0 	0 	0 	0 

In [8]:
def compare():
    # TODO
    return 0.0



# define but not used
def search_with_opt_graph(query, K, parameters, indices):
    L = parameters['L_search']
    
    retset:list[Neighbor] = []
    init_ids = []
    flags = [0] * nd_
    
    neighbors = opt_graph[node_size * ep_ + data_len:]
    MaxM_ep = neighbors[0]
    neighbors = neighbors[1:]
    
    for tmp_l in range(min(L, MaxM_ep)):
        init_ids.append(neighbors[tmp_l])
        flags[init_ids[tmp_l]] = 1
    
    while tmp_l < L:
        id = random.randint(0, nd_ - 1)
        if flags[id] == 1:
            continue
        flags[id] = 1
        init_ids.append(id)
        tmp_l += 1
    
    L = 0
    for i in range(len(init_ids)):
        id = init_ids[i]
        if id >= nd_:
            continue
        x = opt_graph[node_size * id:]
        norm_x = x[0]
        x = x[1:]
        dist = compare(x, query, norm_x, dimension_)
        retset.append(Neighbor(id, dist, 1))
        flags[id] = 1
        L += 1
    
    retset.sort()
    k = 0
    while k < L:
        nk = L
        
        if retset[k].flag:
            retset[k].flag = 0
            n = retset[k].id
            neighbors = opt_graph[node_size * n + data_len:]
            MaxM = neighbors[0]
            neighbors = neighbors[1:]
            for m in range(MaxM):
                id = neighbors[m]
                if flags[id] == 1:
                    continue
                data = opt_graph[node_size * id:]
                norm = data[0]
                data = data[1:]
                dist = compare(query, data, norm, dimension_)
                if dist >= retset[L - 1].distance:
                    continue
                nn = Neighbor(id, dist, 1)
                r = InsertIntoPool(retset.data(), L, nn)
                if r < nk:
                    nk = r
        
        if nk <= k:
            k = nk
        else:  
            k += 1
    
    for i in range(K):
        indices[i] = retset[i].id
    
