In [35]:
import struct
from typing import List, Optional
from typing import List, Optional

class Partition:
    def __init__(self, partition_bitvector: int = 0):
        """
        Represents a Partition with a bitvector.

        :param partition_bitvector: Encoded partition bitvector.
        """
        
def count_trailing_zeros_64(x: int) -> int:
    x = x & 0xFFFFFFFFFFFFFFFF  # force to 64 bits
    if x == 0:
        return 64
    return (x & -x).bit_length() - 1    

class FlatCutIndex:
    def __init__(self, partition_bitvector: int = 0, dist_index: Optional[List[int]] = None, distances: Optional[List[int]] = None):
        """
        Represents a FlatCutIndex.

        :param partition_bitvector: Encoded partition bitvector.
        :param dist_index: List of distance indices.
        :param distances: List of distances.
        """
        self.partition_bitvector = partition_bitvector
        self.dist_index = dist_index if dist_index is not None else []
        self.distances = distances if distances is not None else []

    def partition(self) -> int:
        """
        Extracts the partition from the partition bitvector.

        :return: Partition value.
        """
        return self.partition_bitvector >> 6

    def cut_level(self) -> int:
        """
        Extracts the cut level from the partition bitvector.

        :return: Cut level value.
        """
        return self.partition_bitvector & 63

    def size(self) -> int:
        """
        Calculates the size of the FlatCutIndex in bytes.

        :return: Size in bytes.
        """
        return len(self.distances) * 4 + len(self.dist_index) * 2 + 8  # 8 bytes for partition_bitvector

    def label_count(self) -> int:
        """
        Returns the number of labels stored.

        :return: Label count.
        """
        return len(self.distances)

    def cut_size(self, cl: int) -> int:
        """
        Returns the number of labels at a given cut level.

        :param cl: Cut level.
        :return: Number of labels at the given cut level.
        """
        if cl == 0:
            return self.dist_index[0]
        return self.dist_index[cl] - self.dist_index[cl - 1]

    def bottom_cut_size(self) -> int:
        """
        Returns the number of labels at the lowest cut level.

        :return: Number of labels at the lowest cut level.
        """
        return self.cut_size(self.cut_level())

    def empty(self) -> bool:
        """
        Checks if the FlatCutIndex is empty.

        :return: True if empty, False otherwise.
        """
        return len(self.distances) == 0 and len(self.dist_index) == 0 and self.partition_bitvector == 0 

    def cl_begin(self, cl: int) -> List[int]:
        """
        Returns the start of distance labels for a given cut level.

        :param cl: Cut level.
        :return: List of distances starting at the given cut level.
        """
        offset = self.dist_index[cl - 1] if cl > 0 else 0
        return self.distances[offset:]

    def cl_end(self, cl: int) -> List[int]:
        """
        Returns the end of distance labels for a given cut level.

        :param cl: Cut level.
        :return: List of distances ending at the given cut level.
        """
        return self.distances[:self.dist_index[cl]]

    def unflatten(self) -> List[List[int]]:
        """
        Returns labels in a list-of-lists format.

        :return: List of lists of distances.
        """
        labels = []
        for cl in range(self.cut_level() + 1):
            labels.append(self.distances[self.dist_index[cl - 1] if cl > 0 else 0:self.dist_index[cl]])
        return labels

    def __repr__(self):
        """
        String representation of the FlatCutIndex.

        :return: String representation.
        """
        return f"FlatCutIndex(partition_bitvector={self.partition_bitvector}, dist_index={self.dist_index}, distances={self.distances})"
    
    def is_same(self, other: 'FlatCutIndex') -> bool:
        """
        Checks if two FlatCutIndex instances are the same.

        :param other: Another FlatCutIndex instance.
        :return: True if they are the same, False otherwise.
        """
        return (self.partition_bitvector == other.partition_bitvector and
                self.dist_index == other.dist_index and
                self.distances == other.distances)
    

    def get_lca(self, other: 'FlatCutIndex') -> int:
       cut_level_s = self.cut_level()
       cut_level_o = other.cut_level()
       lca_level = min(cut_level_s, cut_level_o)
       p1,p2 = self.partition(), other.partition()
       if p1!=p2:
           p3 = p1^p2
           diff_level = count_trailing_zeros_64(p3)
           if diff_level < lca_level:
               lca_level = diff_level

       return lca_level
    
class ContractionLabel:
    def __init__(self, cut_index: Optional[FlatCutIndex] = None, distance_offset: int = 0, parent: int = None):
        """
        Represents a contraction label.

        :param cut_index: Instance of FlatCutIndex or equivalent data structure.
        :param distance_offset: Distance to the node owning the labels (default is 0).
        :param parent: Parent node in the tree rooted at the label-owning node (default is None).
        """
        self.cut_index = cut_index if cut_index is not None else FlatCutIndex()
        self.distance_offset = distance_offset
        self.parent = parent

    def size(self) -> int:
        """
        Calculates the size of the contraction label in bytes.

        :return: Size of the contraction label.
        """
        total_size = self.__sizeof__()
        if self.distance_offset == 0 and self.cut_index is not None:
            total_size += self.cut_index.size()
        return total_size

    def __repr__(self):
        """
        String representation of the ContractionLabel.

        :return: String representation.
        """
        return f"ContractionLabel(cut_index={self.cut_index}, distance_offset={self.distance_offset}, parent={self.parent})"

class ContractionIndex:
    def __init__(self, labels: List[ContractionLabel]):
        """
        Represents a contraction index.

        :param labels: List of ContractionLabel objects.
        """
        self.labels = labels
        self.merge_map = {}
        for i in range(1, len(labels)):
            entry = labels[i]
            if entry.parent!=None and entry.parent != 0:
                parent=entry.parent
                while entry.parent !=0 and entry.parent != None:
                    parent = entry.parent
                    entry = labels[entry.parent]
                self.merge_map[i]=parent



    def get_distance(self, v: int, w: int) -> int:
        """
        Computes the distance between two nodes using the contraction index.

        :param v: Node ID of the first node.
        :param w: Node ID of the second node.
        :return: Distance between the two nodes.
        """
        cv = self.labels[v]
        cw = self.labels[w]
        #print(cv)
        #print(cw)
        #assert not cv.cut_index.empty() and not cw.cut_index.empty()
        same_flag = False
        if cv.cut_index.empty() and cw.cut_index.empty():
                p1 = self.merge_map.get(v)
                p2 = self.merge_map.get(w)
                cvv = self.labels[p1]
                cww = self.labels[p2]
                assert not cvv.cut_index.empty() and not cww.cut_index.empty()
                if cvv.cut_index.is_same(cww.cut_index):
                    same_flag=True
        elif cv.cut_index.is_same(cw.cut_index):
            same_flag=True
        else:
            same_flag=False
        
        if same_flag:
            if v == w:
                return 0
            if cv.distance_offset == 0:
                return cw.distance_offset
            if cw.distance_offset == 0:
                return cv.distance_offset
            if cv.parent == w:
                return cv.distance_offset - cw.distance_offset
            if cw.parent == v:
                return cw.distance_offset - cv.distance_offset

             # Find the lowest common ancestor
            v_anc, w_anc = v, w
            cv_anc, cw_anc = cv, cw
            while v_anc != w_anc:
                if cv_anc.distance_offset < cw_anc.distance_offset:
                    w_anc = cw_anc.parent
                    cw_anc = self.labels[w_anc]
                elif cv_anc.distance_offset > cw_anc.distance_offset:
                    v_anc = cv_anc.parent
                    cv_anc = self.labels[v_anc]
                else:
                    v_anc = cv_anc.parent
                    w_anc = cw_anc.parent
                    cv_anc = self.labels[v_anc]
                    cw_anc = self.labels[w_anc]

            return cv.distance_offset + cw.distance_offset - 2 * cv_anc.distance_offset
        """
        if cv.cut_index.is_same(cw.cut_index):
            if cv.cut_index.empty() and cw.cut_index.empty():
                p1 = self.merge_map.get(v)
                p2 = self.merge_map.get(w)
                cvv = self.labels[p1]
                cww = self.labels[p2]
                assert not cvv.cut_index.empty() and not cww.cut_index.empty()
                if cvv.cut_index.is_same(cww.cut_index):
                    if v == w:
                        return 0
                    if cv.distance_offset == 0:
                        return cw.distance_offset
                    if cw.distance_offset == 0:
                        return cv.distance_offset
                    if cv.parent == w:
                        return cv.distance_offset - cw.distance_offset
                    if cw.parent == v:
                        return cw.distance_offset - cv.distance_offset
                    # Find the lowest common ancestor
                    v_anc, w_anc = v, w
                    cv_anc, cw_anc = cv, cw
                    while v_anc != w_anc:
                        if cv_anc.distance_offset < cw_anc.distance_offset:
                            w_anc = cw_anc.parent
                            cw_anc = self.labels[w_anc]
                        elif cv_anc.distance_offset > cw_anc.distance_offset:
                            v_anc = cv_anc.parent
                            cv_anc = self.labels[v_anc]
                        else:
                            v_anc = cv_anc.parent
                            w_anc = cw_anc.parent
                            cv_anc = self.labels[v_anc]
                            cw_anc = self.labels[w_anc]
                    return cv.distance_offset + cw.distance_offset - 2 * cv_anc.distance_offset
                    
            else:#actually the same cut index
                if v == w:
                    return 0
                if cv.distance_offset == 0:
                    return cw.distance_offset
                if cw.distance_offset == 0:
                    return cv.distance_offset
                if cv.parent == w:
                    return cv.distance_offset - cw.distance_offset
                if cw.parent == v:
                    return cw.distance_offset - cv.distance_offset

                # Find the lowest common ancestor
                v_anc, w_anc = v, w
                cv_anc, cw_anc = cv, cw
                while v_anc != w_anc:
                    if cv_anc.distance_offset < cw_anc.distance_offset:
                        w_anc = cw_anc.parent
                        cw_anc = self.labels[w_anc]
                    elif cv_anc.distance_offset > cw_anc.distance_offset:
                        v_anc = cv_anc.parent
                        cv_anc = self.labels[v_anc]
                    else:
                        v_anc = cv_anc.parent
                        w_anc = cw_anc.parent
                        cv_anc = self.labels[v_anc]
                        cw_anc = self.labels[w_anc]

                return cv.distance_offset + cw.distance_offset - 2 * cv_anc.distance_offset
        """

        # Fallback to hierarchical distance computation
        result =  cv.distance_offset + cw.distance_offset
        if cv.cut_index.empty():
            assert(cv.parent is not None and cv.parent != 0) 
            cv = self.labels[self.merge_map.get(v)]
        if cw.cut_index.empty():
            assert(cw.parent is not None and cw.parent != 0)
            cw = self.labels[self.merge_map.get(w)]
        return result+self.get_hierarchical_distance(cv.cut_index, cw.cut_index) #cv.distance_offset + cw.distance_offset + self.get_hierarchical_distance(cv.cut_index, cw.cut_index)

    def get_hierarchical_distance(self, a: FlatCutIndex, b: FlatCutIndex) -> int:
        """
        Computes the hierarchical distance between two FlatCutIndex objects.

        :param a: First FlatCutIndex.
        :param b: Second FlatCutIndex.
        :return: Hierarchical distance.
        """
        #cut_level = min(a.cut_level(), b.cut_level())#Libin: double check the cut level implementation
        cut_level = a.get_lca(b)
        a_offset = a.dist_index[cut_level - 1] if cut_level > 0 else 0
        b_offset = b.dist_index[cut_level - 1] if cut_level > 0 else 0
        #print(cut_level)
        a_end = min((a.dist_index[cut_level]-a_offset), (b.dist_index[cut_level]-b_offset))

        min_dist = float('inf')
        for i in range(0,a_end):
            dist = a.distances[a_offset+i] + b.distances[b_offset+i]
            if dist < min_dist:
                min_dist = dist
        
        """
        for i in range(a_offset, a_end):
            dist = a.distances[i] + b.distances[i]
            if dist < min_dist:
                min_dist = dist
        """

        return min_dist

class HCL:
    def __init__(self, filename):
        self.filename = filename
        self.data = {}
        self.parse_file()

    def parse_file(self):
        with open(self.filename, 'r') as file:
            lines = file.readlines()

        for line in lines:
            print(line)

In [36]:
# Example usage

import json
import random
import time
import torch
from torch.utils.data import Dataset, DataLoader
def experiment(hcl:ContractionIndex, filename:str):
    total_queries=0
    total_time=0
    with open(filename, 'r') as f:
        count = 0
        for line in f:
            if count==0:
                total_queries = int(line.strip())
                count+=1
                continue
            u,v= line.strip().split()
            u = int(u)
            v = int(v)
            start_time = time.perf_counter()
            hcl.get_distance(u,v)
            end_time = time.perf_counter()
            total_time += end_time - start_time

    print(f"Total queries: {total_queries}")
    print(f"Total time: {total_time:.4f} seconds")
    print(f"Average time per query: {total_time / total_queries:.6f} seconds")

def parse_hcl_file(filename: str) -> List[ContractionLabel]:
    label_list = []
    label_list.append(ContractionLabel(cut_index=FlatCutIndex(partition_bitvector=0, dist_index=[0], distances=[]), distance_offset=0, parent=None))
    with open(filename, 'r') as f:
        count = 0
        for line in f:
            if line.startswith('{'):
                count += 1
            elif line.startswith('}'):
                count += 1
            else:
                line = line.strip().rstrip(',')
                if not line:
                    continue
                key, value = line.split(':', 1)
                key = int(key)
                assert(key==count)
                # Try to parse the value as JSON
                import json
                try:
                    parsed_value = json.loads(value)
                    p = parsed_value.get('p')
                    d = parsed_value.get('d')
                    #print(f"Key: {key}, p: {p}, d: {d}")
                    label_list.append(ContractionLabel(parent=p, distance_offset=d))

                except json.JSONDecodeError:
                    # If not valid JSON, just keep as string
                    parsed_value = value
                    #print(f"Key: {key}, Value: {parsed_value}")
                    int_part, list_part = value.split(',', 1)
                    partition_vector = int(int_part.strip())
                    list_of_lists = json.loads(list_part)
                    dist_index = []
                    distances = []
                    for list in list_of_lists:
                        for entry in list:
                            distances.append(int(entry))
                        dist_index.append(len(distances))
                    #print(list_of_lists)
                    cut_index = FlatCutIndex(partition_bitvector=partition_vector, dist_index=dist_index, distances=distances)
                    label_list.append(ContractionLabel(cut_index=cut_index, distance_offset=0, parent=None))
                count += 1
                
    return label_list
       
#label_list = parse_hcl_file("/Users/libinzhou/Documents/HCL-python/USA-road-d.NY.gr-label.hl")
#for label in label_list:
#    print(label)     
#hci = ContractionIndex(label_list)
#print("index loaded")

""""
for i in range(10000):
    u = random.randint(0, len(label_list) - 1)
    v = random.randint(0, len(label_list) - 1)
    if i % 1000 == 0:
        print(f"Computing distance for pair ({u}, {v})")
        print(hci.get_distance(u, v))  # Example distance computation
    else:
        hci.get_distance(u, v)
"""
#experiment(hci,"/Users/libinzhou/Documents/HCL-python/NY_queries.txt")

# Define a custom dataset class
class CustomDataset(Dataset):
    def __init__(self,hci:ContractionIndex, num_samples=1000):
        vertex_count = len(hci.labels)-1  # the first label is a dummy label, vid； 1 to vertex_count
        self.source_nodes = torch.randint(1, vertex_count, (num_samples,), dtype=torch.int32)  # Random source nodes
        self.target_nodes = torch.randint(1, vertex_count, (num_samples,), dtype=torch.int32)  # Random target nodes
    

    def __len__(self):
        return len(self.source_nodes)

    def __getitem__(self, idx):
        return self.source_nodes[idx], self.target_nodes[idx]
    
def experimental_evaluation_torch(num_queries: int = 1000, label_path: str = "/Users/libinzhou/Documents/HCL-python/USA-road-d.NY.gr-label.hl"):
    # Instantiate the contraction index from a file
    label_list = parse_hcl_file("/Users/libinzhou/Documents/HCL-python/USA-road-d.NY.gr-label.hl")  
    hci = ContractionIndex(label_list)
    print("index loaded")


    # Instantiate the random dataset
    dataset = CustomDataset(hci, num_queries)
    print(f"Dataset length: {len(dataset)}")


    ########
    # Experimental Evaluation
    ########
    total_time=0
    total_queries = len(dataset)
    print("Iterating through the dataset one sample at a time:")
    for i in range(len(dataset)):
        source_node, target_node= dataset[i]
        start_time = time.perf_counter()
        distance = hci.get_distance(source_node.item(), target_node.item())
        end_time = time.perf_counter()
        total_time += end_time - start_time
        if(i%100==0):
            print(f"Query {i+1}/{total_queries}: Distance between {source_node.item()} and {target_node.item()} is {distance}")
    print(f"Average time per query: {total_time / total_queries:.6f} seconds")
    print(f"Total time for {total_queries} queries: {total_time:.4f} seconds")

experimental_evaluation_torch(10000)

index loaded
Dataset length: 10000
Iterating through the dataset one sample at a time:
Query 1/10000: Distance between 128915 and 179686 is 663004
Query 101/10000: Distance between 49933 and 172900 is 674397
Query 201/10000: Distance between 149308 and 10050 is 538770
Query 301/10000: Distance between 237201 and 241183 is 281204
Query 401/10000: Distance between 102719 and 78341 is 677829
Query 501/10000: Distance between 76587 and 124973 is 290085
Query 601/10000: Distance between 36459 and 264295 is 652582
Query 701/10000: Distance between 188869 and 134312 is 117903
Query 801/10000: Distance between 171288 and 248377 is 508775
Query 901/10000: Distance between 232904 and 27300 is 317526
Query 1001/10000: Distance between 93250 and 27948 is 341582
Query 1101/10000: Distance between 13212 and 53843 is 254657
Query 1201/10000: Distance between 131819 and 33722 is 333062
Query 1301/10000: Distance between 10405 and 106006 is 323606
Query 1401/10000: Distance between 47673 and 134189 is 