In [5]:
import numpy as np

def L2_square_float(pa, pb, pd):
    diff = np.array(pa) - np.array(pb)
    return np.sum(diff ** 2)

def L2_square_int(pa, pb, pd):
    diff = np.array(pa) - np.array(pb)
    return np.sum(diff ** 2)

def inner_product_float(pa, pb, pd):
    return np.dot(np.array(pa), np.array(pb))

def inner_product_int(pa, pb, pd):
    return np.dot(np.array(pa), np.array(pb))

In [3]:
import numpy as np

class MetricType:
    L2 = 100
    IP = 200

class DataType:
    INT32 = 100
    FLOAT = 200
    UINT8 = 300

class MetricSpace:
    def __init__(self, dim):
        self.dis_func_ = None
        self.esize = None

    def full_dist(self, pa, pb, pd):
        raise NotImplementedError

    def half_dist(self, pa, pb, pd):
        dim = int(pd[0]) >> 1
        return self.full_dist(pa, pb, np.array([dim]))

    def post_half(self, pa, pb, pd):
        dim = int(pd[0])
        ppa = np.array(pa)
        ppb = np.array(pb)
        return self.full_dist(ppa[(dim >> 1) * self.esize:], ppb[(dim >> 1) * self.esize:], np.array([dim - (dim >> 1)]))

class L2SpaceF(MetricSpace):
    def __init__(self, dim):
        super().__init__(dim)
        self.dis_func_ = self.L2_square_f
        self.esize = np.dtype('float32').itemsize

    def full_dist(self, pa, pb, pd):
        return self.dis_func_(pa, pb, pd)

    @staticmethod
    def L2_square_f(pa, pb, pd):
        diff = np.array(pa) - np.array(pb)
        return np.sum(diff ** 2).item()

class IPSapceF(MetricSpace):
    def __init__(self, dim):
        super().__init__(dim)
        self.dis_func_ = self.inner_product_f
        self.esize = np.dtype('float32').itemsize

    def full_dist(self, pa, pb, pd):
        return self.dis_func_(pa, pb, pd)

    @staticmethod
    def inner_product_f(pa, pb, pd):
        return np.dot(np.array(pa), np.array(pb)).item()

In [4]:
import ctypes
import math
import random
import numpy as np
import heapq
import threading
import time
import concurrent.futures

class Vamana:
    def __init__(self, mt, dt, r, l, alp, dim):
        self.R_ = r
        self.L_ = l
        self.alpha_ = alp
        self.dim_ = dim
        self.link_list_locks_ = [0] * 1000000

        if mt == "L2":
            # todo: init by DataType, specified float temporarily
            self.ms_ = L2SpaceF(dim)
        elif mt == "IP":
            self.ms_ = IPSapceF(dim)

        # todo: default deal with float data
        self.data_size_ = dim * 4  # assuming 4 bytes per float
        self.link_size_ = self.R_ * 4 + 4  # assuming 4 bytes per idx_t
        self.node_size_ = self.link_size_ + self.data_size_
        self.index_built_ = False
        self.sp_ = 0
        self.graph_ = None
        self.ntotal_ = 0

    def __del__(self):
        self.drop_index()

    def drop_index(self):
        if self.graph_:
            del self.graph_
        self.graph_ = None
        self.index_built_ = False

    def create_index(self, pdata, n):
        if self.index_built_:
            print("FBI Warning: index already built, if want re-build, call DropIndex first")
            return

        self.ntotal_ = n
        self.graph_ = ctypes.create_string_buffer(self.node_size_ * n)
        ctypes.memset(self.graph_, 0, self.node_size_ * self.ntotal_)

        if self.R_ < math.ceil(math.log2(self.ntotal_)):
            print("FBI Warning: the parameter is less than log2(n), maybe result in low recall")

        self.link_list_locks_ = [threading.Lock() for _ in range(self.ntotal_)]
        self.add_points(pdata, self.ntotal_)
        self.random_init()
        self.healthy_check()
        self.build_index(pdata)
        self.index_built_ = True

    # Search from different starting points
    def search(self, pquery, topk):
        ret = [(0, 0) for _ in range(topk)]
        ans = self.search(pquery, self.sp_, topk)
        sz = len(ans)
        while ans:
            sz -= 1
            ret[sz] = ans[0]  # Assuming ans is a list of tuples
            heapq.heappop(ans)
        return ret

    def search_multiple_queries(self, pqueries, topk):
        ret = []
        # Search parallel for each query
        for i in range(len(pqueries)):
            ret.append(self.search_parallel(pqueries[i], topk))
        return ret

    def search_parallel(self, pqueries, topk):
        with Pool() as pool:
            results = pool.starmap(self.search, [(query, topk) for query in pqueries])
        return results
    
    def healthy_check(self):
        degree_hist = self.scan_graph()
        print("show degree histogram of graph:")
        for i, cnt in enumerate(degree_hist):
            print(f"degree = {i}: cnt = {cnt}")


    def get_link_by_id(self, idx):
        return ctypes.cast(ctypes.addressof(self.graph_) + self.node_size_ * idx, ctypes.POINTER(ctypes.c_int)).contents


    def get_data_by_id(self, idx):
        return ctypes.cast(ctypes.addressof(self.graph_) + self.node_size_ * idx + self.link_size_, ctypes.POINTER(ctypes.c_char)).contents


    def add_points(self, pdata, n):
        pd = ctypes.cast(pdata, ctypes.POINTER(ctypes.c_char))
        for i in range(n):
            ctypes.memmove(ctypes.addressof(self.graph_) + i * self.node_size_ + self.link_size_, pd + i * self.data_size_, self.data_size_)


    def random_init(self):
        random.seed()  # Seed the random number generator
        num_threads = threading.active_count()  # Number of threads

        # Using a lock for thread safety
        lock = threading.Lock()

        def worker(start, end):
            for i in range(start, end):
                if i % num_threads == threading.current_thread().ident % num_threads:
                    random_neighbors = set()
                    while len(random_neighbors) < self.R_:
                        random_neighbors.add(random.randint(0, self.ntotal_ - 1))

                    with lock:
                        p_link = self.get_link_by_id(i)
                        assert len(random_neighbors) <= self.R_
                        for chosen in random_neighbors:
                            p_link.contents.value += 1
                            p_link.contents[p_link.contents.value] = chosen

        # Split the work among threads
        threads = []
        chunk_size = self.ntotal_ // num_threads
        for i in range(num_threads):
            start = i * chunk_size
            end = (i + 1) * chunk_size if i < num_threads - 1 else self.ntotal_
            thread = threading.Thread(target=worker, args=(start, end))
            thread.start()
            threads.append(thread)

        # Wait for all threads to finish
        for thread in threads:
            thread.join()

    
    def build_index(self, pdata):
        assert self.ntotal_ > 0
        pd = np.frombuffer(pdata, dtype=np.float32)
        center = np.zeros(self.dim_, dtype=np.float32)

        # Step 1: Calculate start point, i.e., navigate point in NSG
        for i in range(self.ntotal_ * self.dim_):
            center[i % self.dim_] += pd[i]

        center /= self.ntotal_

        tstart = time.time()
        tpL = self.search(center, np.random.randint(0, self.ntotal_), self.L_)
        tend = time.time()
        print(f"first search 4 sp_ finished in {tend - tstart:.3f} seconds.")

        while tpL:
            self.sp_ = tpL[0][1]
            tpL.pop()

        print(f"init sp_ = {self.sp_}")

        # Step 2: Do the first iteration with alpha = 1
        tstart = time.time()
        with concurrent.futures.ThreadPoolExecutor() as executor:
            futures = []
            for i in range(self.ntotal_):
                future = executor.submit(self.search, pd[i * self.dim_: (i + 1) * self.dim_], self.L_)
                futures.append((i, future))

            for i, future in futures:
                candidates = future.result()
                self.robust_prune(i, candidates, 1.0, 1)
                self.make_edge(i, 1.0)

        tend = time.time()
        print(f"the first round iteration finished in {tend - tstart:.3f} seconds.")

        # Todo: Need to update sp_?
        tpL = self.search(center, self.sp_, self.L_)
        while tpL:
            self.sp_ = tpL[0][1]
            tpL.pop()

        print(f"updated sp_ after 1st iteration: {self.sp_}")

        print("HealthyCheck after the 1st round iteration:")
        self.healthy_check()

        # Step 3: Do the second iteration with alpha = alpha_
        tstart = time.time()
        with concurrent.futures.ThreadPoolExecutor() as executor:
            futures = []
            for i in range(self.ntotal_ - 1, -1, -1):
                future = executor.submit(self.search, pd[i * self.dim_: (i + 1) * self.dim_], self.L_)
                futures.append((i, future))

            for i, future in futures:
                candidates = future.result()
                self.robust_prune(i, candidates, self.alpha_, 2)
                self.make_edge(i, self.alpha_)

        tend = time.time()
        print(f"the second round iteration finished in {tend - tstart:.3f} seconds.")

        # Step 4: Update sp_
        tpL = self.search(center, self.sp_, self.L_)
        while tpL:
            self.sp_ = tpL[0][1]
            tpL.pop()

        print(f"updated sp_ after 2nd iteration: {self.sp_}")

        print("HealthyCheck after the 2nd round iteration:")
        self.healthy_check()

    def search(self, qp, neighbor_candi):
        while neighbor_candi:
            heapq.heappop(neighbor_candi)
            print("neighbor_candi not empty")

        vis = [False] * self.ntotal_
        resultset = []
        expandset = []
        heapq.heappush(expandset, (-self.ms_.full_dist(self.get_data_by_id(self.sp_), qp, self.dim_), self.sp_))
        vis[self.sp_] = True
        lower_bound = -expandset[0][0]
        heapq.heappush(neighbor_candi, (expandset[0][0], expandset[0][1]))

        while expandset:
            cur = heapq.heappop(expandset)
            assert cur[1] < self.ntotal_

            if -cur[0] > lower_bound:
                break

            link = self.get_link_by_id(cur[1])
            linksz = link.contents.value

            if linksz > self.R_:
                print(f"search: linksz = {linksz} which is > R_ = {self.R_}")

            assert linksz <= self.R_

            with threading.Lock():
                for i in range(1, linksz + 1):
                    candi_id = link[i]

                    if vis[candi_id]:
                        continue

                    candi_data = self.get_data_by_id(candi_id)
                    dist = self.ms_.full_dist(qp, candi_data, self.dim_)

                    if len(resultset) < self.L_ or dist < lower_bound:
                        heapq.heappush(expandset, (-dist, candi_id))
                        vis[candi_id] = True
                        heapq.heappush(neighbor_candi, (dist, candi_id))
                        heapq.heappush(resultset, (dist, candi_id))

                        if len(resultset) > self.L_:
                            heapq.heappop(resultset)

                        if resultset:
                            lower_bound = resultset[0][0]

        return resultset
    
    
    def search(self, qp, sp, topk):
        ub = max(self.L_, topk)
        vis = [False] * self.ntotal_
        resultset = []
        expandset = []
        heapq.heappush(expandset, (-self.ms_.full_dist(self.get_data_by_id(sp), qp, self.dim_), sp))
        vis[sp] = True
        lower_bound = -expandset[0][0]

        while expandset:
            cur = heapq.heappop(expandset)
            assert cur[1] < self.ntotal_

            if -cur[0] > lower_bound:
                break

            link = self.get_link_by_id(cur[1])
            linksz = link.contents.value

            if linksz > self.R_:
                print(f"search_st: linksz = {linksz} which is > R_ = {self.R_}")

            assert linksz <= self.R_

            for i in range(1, linksz + 1):
                candi_id = link[i]

                if vis[candi_id]:
                    continue

                candi_data = self.get_data_by_id(candi_id)
                dist = self.ms_.full_dist(qp, candi_data, self.dim_)

                if len(resultset) < ub or dist < lower_bound:
                    heapq.heappush(expandset, (-dist, candi_id))
                    vis[candi_id] = True
                    heapq.heappush(resultset, (dist, candi_id))

                    if len(resultset) > ub:
                        heapq.heappop(resultset)

                    if resultset:
                        lower_bound = resultset[0][0]

        return resultset
    

    def robust_prune(self, p, cand_set, alpha, flag):
        with threading.Lock():
            link = self.get_link_by_id(p)

            if len(cand_set) <= self.R_:
                link.contents.value = len(cand_set)
                for i in range(1, link.contents.value + 1):
                    link[i] = cand_set[0][1]
                    cand_set.pop()

                return

            link.contents.value = 0

            while cand_set:
                if link.contents.value >= self.R_:
                    break

                cur = cand_set[0]
                cand_set.pop()

                good = True

                for j in range(1, link.contents.value + 1):
                    dist = self.ms_.full_dist(self.get_data_by_id(cur[1]), self.get_data_by_id(link[j]), self.dim_)
                    if dist * alpha < -cur[0]:
                        good = False
                        break

                if good:
                    link.contents.value += 1
                    link[link.contents.value] = cur[1]


    def is_duplicate(self, p, link):
        assert link.contents.value <= self.R_
        for i in range(1, link.contents.value + 1):
            if p == link[i]:
                return True
        return False


    def make_edge(self, p, alpha):
        link = self.get_link_by_id(p)
        for i in range(1, link.contents.value + 1):
            neighbor_link = self.get_link_by_id(link[i])

            with threading.Lock():
                if not self.is_duplicate(p, neighbor_link):
                    if neighbor_link.contents.value < self.R_:
                        neighbor_link.contents.value += 1
                        neighbor_link[neighbor_link.contents.value] = p
                    else:
                        prune_candi = []
                        dist = self.ms_.full_dist(self.get_data_by_id(p), self.get_data_by_id(link[i]), self.dim_)
                        heapq.heappush(prune_candi, (-dist, p))
                        for j in range(1, neighbor_link.contents.value + 1):
                            heapq.heappush(prune_candi, (-self.ms_.full_dist(self.get_data_by_id(link[i]), self.get_data_by_id(neighbor_link[j]), self.dim_), neighbor_link[j]))

                        self.robust_prune(link[i], prune_candi, alpha, 3)
    
    
    def scan_graph(self, degree_histogram):
        degree_histogram = [0] * (self.R_ + 1)
        total = 0

        def process_node(i):
            link = self.get_link_by_id(i)

            with threading.Lock():
                if link.contents.value > self.R_:
                    print(f"scan_graph: *(link) = {link.contents.value} which is > R_ = {self.R_}")

                assert link.contents.value <= self.R_
                degree_histogram[link.contents.value] += 1

            ns = set()

            for j in range(1, link.contents.value + 1):
                assert link[j] < self.ntotal_
                ns.add(link[j])

            nonlocal total
            total += abs(link.contents.value - len(ns))

        with threading.Lock():
            for i in range(self.ntotal_):
                process_node(i)

        print(f"scan_graph done, duplicate total = {total}")