In [None]:
"""
RMI (Recursive Model Index) Implementation in Python
Rust implementation at: https://github.com/learnedsystems/RMI
"""

import numpy as np
from typing import List, Tuple, Optional, Union, Iterator, Callable
from abc import ABC, abstractmethod
from dataclasses import dataclass
import time
import math
from enum import Enum
import bisect

# =============================================================================
# From: rmi_lib/src/models/mod.rs
# Model base classes and type definitions
# =============================================================================

class ModelDataType(Enum):
    INT = "uint64_t"
    INT128 = "uint128_t"
    FLOAT = "double"

class ModelInput:
    """From: rmi_lib/src/models/mod.rs"""
    def __init__(self, value: Union[int, float]):
        if isinstance(value, (int, np.integer)):
            self.value = int(value)
            self.is_float = False
        else:
            self.value = float(value)
            self.is_float = True
    
    def as_float(self) -> float:
        return float(self.value)
    
    def as_int(self) -> int:
        return int(self.value)

class RMITrainingData:
    """
    From: rmi_lib/src/models/mod.rs
    Training data container for RMI
    """
    def __init__(self, data: List[Tuple[Union[int, float], int]]):
        self.data = sorted(data, key=lambda x: x[0])
        self.scale = 1.0
    
    def __len__(self) -> int:
        return len(self.data)
    
    def set_scale(self, scale: float):
        self.scale = scale
    
    def get(self, idx: int) -> Tuple[Union[int, float], int]:
        key, offset = self.data[idx]
        if abs(self.scale - 1.0) > 1e-10:
            return (key, int(offset * self.scale))
        return (key, offset)
    
    def get_key(self, idx: int) -> Union[int, float]:
        return self.data[idx][0]
    
    def iter(self) -> Iterator[Tuple[Union[int, float], int]]:
        """Iterator with scale applied"""
        for key, offset in self.data:
            if abs(self.scale - 1.0) > 1e-10:
                yield (key, int(offset * self.scale))
            else:
                yield (key, offset)
    
    def iter_model_input(self) -> Iterator[Tuple[ModelInput, int]]:
        """Iterator yielding ModelInput objects"""
        for key, offset in self.iter():
            yield (ModelInput(key), offset)
    
    def iter_unique(self) -> Iterator[Tuple[Union[int, float], int]]:
        """Iterator that removes duplicate keys"""
        if len(self.data) == 0:
            return
        last_key = None
        for key, offset in self.iter():
            if last_key is None or key != last_key:
                yield (key, offset)
                last_key = key
    
    def lower_bound_by(self, cmp_func: Callable) -> int:
        """From: rmi_lib/src/models/mod.rs - binary search for lower bound"""
        size = len(self)
        if size == 0:
            return 0
        
        base = 0
        while size > 1:
            half = size // 2
            mid = base + half
            cmp_result = cmp_func(self.get(mid))
            if cmp_result < 0:  # Less
                base = mid
            size -= half
        
        cmp_result = cmp_func(self.get(base))
        base += (1 if cmp_result < 0 else 0)
        return base
    
    def soft_copy(self):
        """Create a shallow copy with same data"""
        new_data = RMITrainingData(self.data[:])
        new_data.scale = self.scale
        return new_data

class Model(ABC):
    """
    From: rmi_lib/src/models/mod.rs
    Base class for all RMI models
    """
    @abstractmethod
    def predict_to_float(self, inp: ModelInput) -> float:
        pass
    
    def predict_to_int(self, inp: ModelInput) -> int:
        return max(0, int(math.floor(self.predict_to_float(inp))))
    
    @abstractmethod
    def input_type(self) -> ModelDataType:
        pass
    
    @abstractmethod
    def output_type(self) -> ModelDataType:
        pass
    
    def needs_bounds_check(self) -> bool:
        return True
    
    def set_to_constant_model(self, constant: int) -> bool:
        return False

# =============================================================================
# From: rmi_lib/src/models/linear.rs
# Linear regression models
# =============================================================================

class LinearModel(Model):
    """From: rmi_lib/src/models/linear.rs"""
    def __init__(self, data: RMITrainingData):
        self.intercept, self.slope = self._slr(data)
    
    def _slr(self, data: RMITrainingData) -> Tuple[float, float]:
        """Simple linear regression using online algorithm"""
        mean_x = 0.0
        mean_y = 0.0
        c = 0.0
        n = 0
        m2 = 0.0
        data_size = 0
        
        for x, y in data.iter():
            n += 1
            dx = float(x) - mean_x
            mean_x += dx / n
            mean_y += (float(y) - mean_y) / n
            c += dx * (float(y) - mean_y)
            
            dx2 = float(x) - mean_x
            m2 += dx * dx2
            data_size += 1
        
        if data_size == 0:
            return (0.0, 0.0)
        if data_size == 1:
            return (mean_y, 0.0)
        
        cov = c / (n - 1)
        var = m2 / (n - 1)
        
        if var == 0.0:
            return (mean_y, 0.0)
        
        beta = cov / var
        alpha = mean_y - beta * mean_x
        
        return (alpha, beta)
    
    def predict_to_float(self, inp: ModelInput) -> float:
        return self.slope * inp.as_float() + self.intercept
    
    def input_type(self) -> ModelDataType:
        return ModelDataType.FLOAT
    
    def output_type(self) -> ModelDataType:
        return ModelDataType.FLOAT
    
    def set_to_constant_model(self, constant: int) -> bool:
        self.intercept = float(constant)
        self.slope = 0.0
        return True

class RobustLinearModel(Model):
    """From: rmi_lib/src/models/linear.rs"""
    def __init__(self, data: RMITrainingData):
        total_items = len(data)
        if total_items == 0:
            self.intercept, self.slope = (0.0, 0.0)
            return
        
        # Skip 0.01% of data from each end for robustness
        bnd = max(1, int(total_items * 0.0001))
        
        # Need at least bnd*2+1 items
        if bnd * 2 + 1 >= total_items:
            # Not enough data, use regular linear
            self.intercept, self.slope = LinearModel(data)._slr(data)
            return
        
        # Create iterator skipping first and last bnd items
        subset = []
        for i, item in enumerate(data.iter()):
            if i < bnd:
                continue
            if i >= total_items - bnd:
                break
            subset.append(item)
        
        subset_data = RMITrainingData(subset)
        self.intercept, self.slope = LinearModel(subset_data)._slr(subset_data)
    
    def predict_to_float(self, inp: ModelInput) -> float:
        return self.slope * inp.as_float() + self.intercept
    
    def input_type(self) -> ModelDataType:
        return ModelDataType.FLOAT
    
    def output_type(self) -> ModelDataType:
        return ModelDataType.FLOAT
    
    def set_to_constant_model(self, constant: int) -> bool:
        self.intercept = float(constant)
        self.slope = 0.0
        return True

# =============================================================================
# From: rmi_lib/src/models/linear_spline.rs
# Linear spline model
# =============================================================================

class LinearSplineModel(Model):
    """From: rmi_lib/src/models/linear_spline.rs"""
    def __init__(self, data: RMITrainingData):
        self.intercept, self.slope = self._linear_splines(data)
    
    def _linear_splines(self, data: RMITrainingData) -> Tuple[float, float]:
        if len(data) == 0:
            return (0.0, 0.0)
        if len(data) == 1:
            return (float(data.get(0)[1]), 0.0)
        
        first_pt = data.get(0)
        last_pt = data.get(len(data) - 1)
        
        if first_pt[0] == last_pt[0]:
            return (float(data.get(0)[1]), 0.0)
        
        slope = (float(first_pt[1]) - float(last_pt[1])) / (float(first_pt[0]) - float(last_pt[0]))
        intercept = float(first_pt[1]) - slope * float(first_pt[0])
        
        return (intercept, slope)
    
    def predict_to_float(self, inp: ModelInput) -> float:
        return self.slope * inp.as_float() + self.intercept
    
    def input_type(self) -> ModelDataType:
        return ModelDataType.FLOAT
    
    def output_type(self) -> ModelDataType:
        return ModelDataType.FLOAT
    
    def set_to_constant_model(self, constant: int) -> bool:
        self.intercept = float(constant)
        self.slope = 0.0
        return True

# =============================================================================
# From: rmi_lib/src/models/cubic_spline.rs
# Cubic spline model
# =============================================================================

class CubicSplineModel(Model):
    """From: rmi_lib/src/models/cubic_spline.rs"""
    def __init__(self, data: RMITrainingData):
        self.a, self.b, self.c, self.d = self._cubic(data)
        
        # Check against linear model - sometimes cubic doesn't work well
        linear = LinearSplineModel(data)
        
        our_error = 0.0
        lin_error = 0.0
        
        for x, y in data.iter_model_input():
            c_pred = self.predict_to_float(x)
            l_pred = linear.predict_to_float(x)
            
            our_error += abs(c_pred - float(y))
            lin_error += abs(l_pred - float(y))
        
        if lin_error < our_error:
            # Use linear instead
            self.a = 0.0
            self.b = 0.0
            self.c = linear.slope
            self.d = linear.intercept
    
    def _cubic(self, data: RMITrainingData) -> Tuple[float, float, float, float]:
        if len(data) == 0:
            return (0.0, 0.0, 1.0, 0.0)
        if len(data) == 1:
            return (0.0, 0.0, 0.0, float(data.get(0)[1]))
        
        # Check for unique values
        candidate = data.get(0)[0]
        uniq = any(x != candidate for x, _ in data.iter())
        if not uniq:
            return (0.0, 0.0, 0.0, float(data.get(0)[1]))
        
        first_pt = data.get(0)
        last_pt = data.get(len(data) - 1)
        xmin, ymin = float(first_pt[0]), float(first_pt[1])
        xmax, ymax = float(last_pt[0]), float(last_pt[1])
        
        x1, y1 = 0.0, 0.0
        x2, y2 = 1.0, 1.0
        
        # Find first point with scaled x > 0
        m1 = None
        for xn, yn in data.iter():
            sxn = (float(xn) - xmin) / (xmax - xmin)
            if sxn > 0.0:
                syn = (float(yn) - ymin) / (ymax - ymin)
                m1 = (syn - y1) / (sxn - x1)
                break
        
        if m1 is None:
            m1 = 0.0
        
        # Find last point with scaled x < 1
        m2 = None
        for i in range(len(data) - 1, -1, -1):
            xp, yp = data.get(i)
            sxp = (float(xp) - xmin) / (xmax - xmin)
            if sxp < 1.0:
                syp = (float(yp) - ymin) / (ymax - ymin)
                m2 = (y2 - syp) / (x2 - sxp)
                break
        
        if m2 is None:
            m2 = 0.0
        
        # Keep monotonic
        if m1**2 + m2**2 > 9.0:
            tau = 3.0 / math.sqrt(m1**2 + m2**2)
            m1 *= tau
            m2 *= tau
        
        # Compute coefficients
        a = (m1 + m2 - 2.0) / (xmax - xmin)**3
        b = -(xmax * (2.0*m1 + m2 - 3.0) + xmin * (m1 + 2.0*m2 - 3.0)) / (xmax - xmin)**3
        c = (m1*xmax**2 + m2*xmin**2 + xmax*xmin*(2.0*m1 + 2.0*m2 - 6.0)) / (xmax - xmin)**3
        d = -xmin * (m1*xmax**2 + xmax*xmin*(m2 - 3.0) + xmin**2) / (xmax - xmin)**3
        
        a *= ymax - ymin
        b *= ymax - ymin
        c *= ymax - ymin
        d *= ymax - ymin
        d += ymin
        
        return (a, b, c, d)
    
    def predict_to_float(self, inp: ModelInput) -> float:
        val = inp.as_float()
        # Use FMA-like computation
        v1 = self.a * val + self.b
        v2 = v1 * val + self.c
        v3 = v2 * val + self.d
        return v3
    
    def input_type(self) -> ModelDataType:
        return ModelDataType.FLOAT
    
    def output_type(self) -> ModelDataType:
        return ModelDataType.FLOAT
    
    def set_to_constant_model(self, constant: int) -> bool:
        self.a = 0.0
        self.b = 0.0
        self.c = 0.0
        self.d = float(constant)
        return True

# =============================================================================
# From: rmi_lib/src/models/radix.rs
# Radix models
# =============================================================================

def num_bits(largest_target: int) -> int:
    """From: rmi_lib/src/models/utils.rs"""
    nbits = 0
    while (1 << (nbits + 1)) - 1 <= largest_target:
        nbits += 1
    assert nbits >= 1
    return nbits

def common_prefix_size(data: RMITrainingData) -> int:
    """From: rmi_lib/src/models/utils.rs"""
    any_ones = 0
    no_ones = (1 << 64) - 1
    
    for x, _ in data.iter_model_input():
        val = x.as_int()
        any_ones |= val
        no_ones &= val
    
    any_zeros = ~no_ones & ((1 << 64) - 1)
    prefix_bits = any_zeros ^ any_ones
    
    # Count leading zeros
    prefix_bits &= ((1 << 64) - 1)
    if prefix_bits == 0:
        return 64
    
    leading_zeros = 0
    test_bit = 1 << 63
    while test_bit > 0 and (~prefix_bits & test_bit):
        leading_zeros += 1
        test_bit >>= 1
    
    return leading_zeros

class RadixModel(Model):
    """From: rmi_lib/src/models/radix.rs"""
    def __init__(self, data: RMITrainingData):
        if len(data) == 0:
            self.left_shift = 0
            self.num_bits = 0
            return
        
        largest_value = max(y for _, y in data.iter())
        bits = num_bits(largest_value)
        common_prefix = common_prefix_size(data)
        
        self.left_shift = common_prefix
        self.num_bits = bits
    
    def predict_to_int(self, inp: ModelInput) -> int:
        as_int = inp.as_int()
        res = (as_int << self.left_shift) >> (64 - self.num_bits)
        return res & ((1 << 64) - 1)
    
    def predict_to_float(self, inp: ModelInput) -> float:
        return float(self.predict_to_int(inp))
    
    def input_type(self) -> ModelDataType:
        return ModelDataType.INT
    
    def output_type(self) -> ModelDataType:
        return ModelDataType.INT
    
    def needs_bounds_check(self) -> bool:
        return False

class RadixTable(Model):
    """From: rmi_lib/src/models/radix.rs"""
    def __init__(self, data: RMITrainingData, bits: int):
        self.prefix_bits = common_prefix_size(data)
        self.table_bits = bits
        self.hint_table = [0] * (1 << bits)
        
        last_radix = 0
        for inp, y in data.iter_model_input():
            x = inp.as_int()
            num_bits = 0 if self.prefix_bits + bits > 64 else 64 - (self.prefix_bits + bits)
            current_radix = ((x << self.prefix_bits) >> self.prefix_bits) >> num_bits
            
            if current_radix == last_radix:
                continue
            
            self.hint_table[int(current_radix)] = y
            
            for i in range(int(last_radix) + 1, int(current_radix)):
                self.hint_table[i] = y
            
            last_radix = current_radix
        
        for i in range(int(last_radix) + 1, len(self.hint_table)):
            self.hint_table[i] = len(self.hint_table)
    
    def predict_to_int(self, inp: ModelInput) -> int:
        as_int = inp.as_int()
        num_bits = 0 if self.prefix_bits + self.table_bits > 64 else 64 - (self.prefix_bits + self.table_bits)
        res = ((as_int << self.prefix_bits) >> self.prefix_bits) >> num_bits
        idx = self.hint_table[int(res)]
        return idx
    
    def predict_to_float(self, inp: ModelInput) -> float:
        return float(self.predict_to_int(inp))
    
    def input_type(self) -> ModelDataType:
        return ModelDataType.INT
    
    def output_type(self) -> ModelDataType:
        return ModelDataType.INT
    
    def needs_bounds_check(self) -> bool:
        return False

# =============================================================================
# From: rmi_lib/src/train/lower_bound_correction.rs
# Lower bound correction for empty models
# =============================================================================

class LowerBoundCorrection:
    """From: rmi_lib/src/train/lower_bound_correction.rs"""
    def __init__(self, pred_func: Callable, num_leaf_models: int, data: RMITrainingData):
        self.first = [None] * num_leaf_models
        self.last = [None] * num_leaf_models
        self.next = [(0, 0)] * num_leaf_models
        self.prev = [(0, 0)] * num_leaf_models
        self.run_lengths = [0] * num_leaf_models
        
        last_target = 0
        current_run_length = 0
        current_run_key = data.get_key(0)
        
        for x, y in data.iter():
            leaf_idx = pred_func(x)
            target = min(num_leaf_models - 1, leaf_idx)
            
            if target == last_target and x == current_run_key:
                current_run_length += 1
            elif target != last_target or x != current_run_key:
                self.run_lengths[last_target] = max(
                    self.run_lengths[last_target], current_run_length
                )
                current_run_length = 1
                current_run_key = x
                last_target = target
            
            if self.first[target] is None:
                self.first[target] = (y, x)
            self.last[target] = (y, x)
        
        # Compute next_for_leaf
        idx = 0
        while idx < num_leaf_models:
            next_found = None
            for i in range(idx + 1, num_leaf_models):
                if self.first[i] is not None:
                    next_found = (i, self.first[i])
                    break
            
            if next_found:
                next_leaf_idx, val = next_found
                for i in range(idx, next_leaf_idx):
                    self.next[i] = val
                idx = next_leaf_idx
            else:
                for i in range(idx, num_leaf_models):
                    self.next[i] = (len(data), data.get_key(len(data) - 1) if len(data) > 0 else 0)
                break
        
        # Compute prev_for_leaf
        idx = num_leaf_models - 1
        while idx > 0:
            prev_found = None
            for i in range(idx - 1, -1, -1):
                if self.last[i] is not None:
                    prev_found = (i, self.last[i])
                    break
            
            if prev_found:
                prev_leaf_idx, val = prev_found
                for i in range(prev_leaf_idx + 1, idx + 1):
                    self.prev[i] = val
                idx = prev_leaf_idx
            else:
                break
    
    def first_key(self, leaf_idx: int) -> Optional[Union[int, float]]:
        return self.first[leaf_idx][1] if self.first[leaf_idx] else None
    
    def last_key(self, leaf_idx: int) -> Optional[Union[int, float]]:
        return self.last[leaf_idx][1] if self.last[leaf_idx] else None
    
    def next_index(self, leaf_idx: int) -> int:
        return self.next[leaf_idx][0]
    
    def prev_key(self, leaf_idx: int) -> Union[int, float]:
        return self.prev[leaf_idx][1]
    
    def longest_run(self, leaf_idx: int) -> int:
        return self.run_lengths[leaf_idx]

# =============================================================================
# From: rmi_lib/src/train/two_layer.rs
# Two-layer RMI training
# =============================================================================

def error_between(v1: int, v2: int, max_pred: int) -> int:
    """From: rmi_lib/src/train/two_layer.rs"""
    pred1 = min(v1, max_pred)
    pred2 = min(v2, max_pred)
    return max(pred1, pred2) - min(pred1, pred2)

def train_model(model_type: str, data: RMITrainingData) -> Model:
    """From: rmi_lib/src/train/mod.rs"""
    model_map = {
        'linear': LinearModel,
        'robust_linear': RobustLinearModel,
        'linear_spline': LinearSplineModel,
        'cubic': CubicSplineModel,
        'radix': RadixModel,
    }
    
    if model_type.startswith('radix') and model_type != 'radix':
        bits = int(model_type[5:])
        return RadixTable(data, bits)
    
    if model_type not in model_map:
        raise ValueError(f"Unknown model type: {model_type}")
    
    return model_map[model_type](data)

@dataclass
class TrainedRMI:
    """From: rmi_lib/src/train/mod.rs"""
    num_rmi_rows: int
    num_data_rows: int
    model_avg_error: float
    model_avg_l2_error: float
    model_avg_log2_error: float
    model_max_error: int
    model_max_error_idx: int
    model_max_log2_error: float
    last_layer_max_l1s: List[int]
    rmi: List[List[Model]]
    models: str
    branching_factor: int
    build_time: float
    
    def predict(self, key: Union[int, float]) -> Tuple[int, int]:
        """Predict position and error for a key"""
        inp = ModelInput(key)
        
        # First layer
        model_idx = self.rmi[0][0].predict_to_int(inp)
        
        # Second layer
        model_idx = min(len(self.rmi[1]) - 1, model_idx)
        prediction = self.rmi[1][model_idx].predict_to_int(inp)
        error = self.last_layer_max_l1s[model_idx]
        
        return (prediction, error)

def build_models_from(
    data: RMITrainingData,
    top_model: Model,
    model_type: str,
    start_idx: int,
    end_idx: int,
    first_model_idx: int,
    num_models: int
) -> List[Model]:
    """From: rmi_lib/src/train/two_layer.rs"""
    
    assert end_idx > start_idx
    assert end_idx <= len(data)
    assert start_idx <= len(data)
    
    leaf_models = []
    second_layer_data = []
    last_target = first_model_idx
    
    # Get bounded iterator
    for i, (x, y) in enumerate(data.iter()):
        if i < start_idx:
            continue
        if i >= end_idx:
            break
        
        model_pred = top_model.predict_to_int(ModelInput(x))
        target = min(first_model_idx + num_models - 1, model_pred)
        assert target >= last_target
        
        if target > last_target:
            # Train previous model
            last_item = second_layer_data[-1] if second_layer_data else None
            second_layer_data.append((x, y))
            
            container = RMITrainingData(second_layer_data)
            leaf_model = train_model(model_type, container)
            leaf_models.append(leaf_model)
            
            # Add empty models for skipped indices
            for _ in range(last_target + 1, target):
                empty_data = RMITrainingData([])
                leaf_models.append(train_model(model_type, empty_data))
            
            assert len(leaf_models) + first_model_idx == target
            
            second_layer_data = []
            if last_item:
                second_layer_data.append(last_item)
        
        second_layer_data.append((x, y))
        last_target = target
    
    # Train last model
    assert second_layer_data
    container = RMITrainingData(second_layer_data)
    leaf_model = train_model(model_type, container)
    leaf_models.append(leaf_model)
    
    # Add remaining empty models
    for _ in range(last_target + 1, first_model_idx + num_models):
        empty_data = RMITrainingData([])
        leaf_models.append(train_model(model_type, empty_data))
    
    assert len(leaf_models) == num_models
    return leaf_models

def train_two_layer(
    data: RMITrainingData,
    layer1_model: str,
    layer2_model: str,
    num_leaf_models: int
) -> TrainedRMI:
    """From: rmi_lib/src/train/two_layer.rs"""
    
    num_rows = len(data)
    
    print(f"Training top-level {layer1_model} model...")
    data.set_scale(num_leaf_models / num_rows)
    top_model = train_model(layer1_model, data)
    
    print(f"Training second-level {layer2_model} models (num models = {num_leaf_models})...")
    data.set_scale(1.0)
    
    # Find split point near middle
    midpoint_model = num_leaf_models // 2
    split_idx = data.lower_bound_by(
        lambda x: -1 if top_model.predict_to_int(ModelInput(x[0])) < midpoint_model 
                  else (1 if top_model.predict_to_int(ModelInput(x[0])) > midpoint_model else 0)
    )
    
    # Build leaf models
    if split_idx >= len(data):
        leaf_models = build_models_from(
            data, top_model, layer2_model, 0, len(data), 0, num_leaf_models
        )
    else:
        split_idx_target = min(
            num_leaf_models - 1,
            top_model.predict_to_int(ModelInput(data.get_key(split_idx)))
        )
        
        first_half_models = split_idx_target
        second_half_models = num_leaf_models - split_idx_target
        
        # Build first half
        hf1 = build_models_from(
            data, top_model, layer2_model, 0, split_idx, 0, first_half_models
        )
        
        # Build second half
        hf2 = build_models_from(
            data, top_model, layer2_model, split_idx + 1, len(data),
            split_idx_target, second_half_models
        )
        
        leaf_models = hf1 + hf2
    
    print("Computing lower bound corrections...")
    lb_corrections = LowerBoundCorrection(
        lambda x: top_model.predict_to_int(ModelInput(x)),
        num_leaf_models,
        data
    )
    
    print("Fixing empty models...")
    # Replace empty models with constants
    for idx in range(num_leaf_models - 1):
        if lb_corrections.first_key(idx) is None:
            upper_bound = lb_corrections.next_index(idx)
            leaf_models[idx].set_to_constant_model(upper_bound)
    
    print("Computing last level errors...")
    # Compute errors for each leaf model
    last_layer_max_l1s = [(0, 0)] * num_leaf_models
    
    for x, y in data.iter_model_input():
        leaf_idx = top_model.predict_to_int(x)
        target = min(num_leaf_models - 1, leaf_idx)
        
        pred = leaf_models[target].predict_to_int(x)
        err = error_between(pred, y, len(data))
        
        cur_count, cur_max = last_layer_max_l1s[target]
        last_layer_max_l1s[target] = (cur_count + 1, max(err, cur_max))
    
    # Adjust errors for lower bound correctness
    for leaf_idx in range(num_leaf_models):
        curr_err = last_layer_max_l1s[leaf_idx][1]
        
        # Upper error
        idx_of_next, key_of_next = lb_corrections.next[leaf_idx]
        key_minus_eps = key_of_next - (1 if isinstance(key_of_next, int) else 1e-10)
        pred = leaf_models[leaf_idx].predict_to_int(ModelInput(key_minus_eps))
        upper_error = error_between(pred, idx_of_next + 1, len(data))
        
        # Lower error
        first_key_before = lb_corrections.prev_key(leaf_idx)
        prev_idx = 0 if leaf_idx == 0 else leaf_idx - 1
        first_idx = lb_corrections.next_index(prev_idx)
        
        key_plus_eps = first_key_before + (1 if isinstance(first_key_before, int) else 1e-10)
        pred = leaf_models[leaf_idx].predict_to_int(ModelInput(key_plus_eps))
        lower_error = error_between(pred, first_idx, len(data))
        
        new_err = max(curr_err, upper_error, lower_error) + lb_corrections.longest_run(leaf_idx)
        
        num_items = last_layer_max_l1s[leaf_idx][0]
        last_layer_max_l1s[leaf_idx] = (num_items, new_err)
    
    print("Evaluating RMI...")
    # Compute statistics
    model_max_error_idx, (_, model_max_error) = max(
        enumerate(last_layer_max_l1s), key=lambda x: x[1][1]
    )
    
    total_items = sum(n for n, _ in last_layer_max_l1s)
    model_avg_error = sum(n * err for n, err in last_layer_max_l1s) / total_items
    model_avg_l2_error = sum((n * err) ** 2 for n, err in last_layer_max_l1s) / total_items
    model_avg_log2_error = sum(n * math.log2(2 * err + 2) for n, err in last_layer_max_l1s) / total_items
    model_max_log2_error = math.log2(model_max_error) if model_max_error > 0 else 0.0
    
    final_errors = [err for _, err in last_layer_max_l1s]
    
    return TrainedRMI(
        num_rmi_rows=len(data),
        num_data_rows=len(data),
        model_avg_error=model_avg_error,
        model_avg_l2_error=model_avg_l2_error,
        model_avg_log2_error=model_avg_log2_error,
        model_max_error=model_max_error,
        model_max_error_idx=model_max_error_idx,
        model_max_log2_error=model_max_log2_error,
        last_layer_max_l1s=final_errors,
        rmi=[[top_model], leaf_models],
        models=f"{layer1_model},{layer2_model}",
        branching_factor=num_leaf_models,
        build_time=0.0
    )

# =============================================================================
# From: rmi_lib/src/train/mod.rs
# Main training function
# =============================================================================

def train(data: RMITrainingData, model_spec: str, branch_factor: int) -> TrainedRMI:
    """From: rmi_lib/src/train/mod.rs"""
    
    start_time = time.time()
    
    model_list = model_spec.split(',')
    if len(model_list) != 2:
        raise ValueError("Only two-layer RMIs are currently supported")
    
    layer1_model, layer2_model = model_list
    
    result = train_two_layer(data, layer1_model, layer2_model, branch_factor)
    
    build_time = time.time() - start_time
    result.build_time = build_time
    
    print(f"\nRMI Training Complete!")
    print(f"Build time: {build_time:.3f}s")
    print(f"Average error: {result.model_avg_error:.2f}")
    print(f"Max error: {result.model_max_error}")
    print(f"Average log2 error: {result.model_avg_log2_error:.3f}")
    print(f"Max log2 error: {result.model_max_log2_error:.3f}")
    
    return result

# =============================================================================
# Example usage and testing
# =============================================================================

def load_binary_data(filename: str) -> RMITrainingData:
    """
    Load binary data file in the format expected by RMI:
    - First 8 bytes: number of items (uint64, little endian)
    - Remaining bytes: data items (uint64 or uint32, little endian)
    """
    import struct
    
    with open(filename, 'rb') as f:
        # Read number of items
        num_items_bytes = f.read(8)
        num_items = struct.unpack('<Q', num_items_bytes)[0]
        
        # Determine data type from filename
        if 'uint32' in filename:
            fmt = '<I'
            item_size = 4
        elif 'uint64' in filename:
            fmt = '<Q'
            item_size = 8
        else:
            raise ValueError("Filename must contain 'uint32' or 'uint64'")
        
        # Read all data
        data = []
        for i in range(num_items):
            item_bytes = f.read(item_size)
            if len(item_bytes) < item_size:
                break
            value = struct.unpack(fmt, item_bytes)[0]
            data.append((value, i))
    
    return RMITrainingData(data)

def create_synthetic_data(n: int, distribution: str = 'uniform') -> RMITrainingData:
    """Create synthetic sorted data for testing"""
    
    if distribution == 'uniform':
        keys = np.sort(np.random.randint(0, n * 10, size=n))
    elif distribution == 'normal':
        keys = np.sort(np.random.normal(n * 5, n, size=n).astype(int))
        keys = np.maximum(keys, 0)  # Ensure non-negative
    elif distribution == 'lognormal':
        keys = np.sort(np.random.lognormal(10, 2, size=n).astype(int))
    elif distribution == 'exponential':
        keys = np.sort(np.random.exponential(n / 10, size=n).astype(int))
    else:
        raise ValueError(f"Unknown distribution: {distribution}")
    
    # Create (key, position) pairs
    data = [(int(k), i) for i, k in enumerate(keys)]
    return RMITrainingData(data)

def benchmark_rmi(rmi: TrainedRMI, test_keys: List[int], actual_data: List[Tuple[int, int]]):
    """
    Benchmark RMI performance
    From: general benchmarking approach
    """
    print("\n=== Benchmarking RMI ===")
    
    # Create sorted array for binary search comparison
    sorted_keys = [k for k, _ in actual_data]
    
    # Test predictions
    total_search_range = 0
    correct_predictions = 0
    
    for key in test_keys:
        pred, err = rmi.predict(key)
        
        # Check if prediction is within error bound
        lower = max(0, pred - err)
        upper = min(len(actual_data) - 1, pred + err)
        total_search_range += (upper - lower + 1)
        
        # Find actual position
        actual_pos = bisect.bisect_left(sorted_keys, key)
        
        if lower <= actual_pos <= upper:
            correct_predictions += 1
    
    avg_search_range = total_search_range / len(test_keys)
    accuracy = 100 * correct_predictions / len(test_keys)
    
    print(f"Tested {len(test_keys)} keys")
    print(f"Average search range: {avg_search_range:.2f}")
    print(f"Prediction accuracy: {accuracy:.2f}%")
    print(f"Avg log2(search range): {math.log2(avg_search_range + 1):.3f}")
    
    return {
        'avg_search_range': avg_search_range,
        'accuracy': accuracy,
        'total_tested': len(test_keys)
    }

# =============================================================================
# Main demonstration
# =============================================================================

def main():
    """Demonstrate RMI training and usage"""
    
    print("=" * 70)
    print("RMI (Recursive Model Index) Implementation")
    print("=" * 70)
    
    # Create synthetic dataset
    print("\n1. Creating synthetic dataset...")
    n = 1_000_000
    data = create_synthetic_data(n, distribution='normal')
    print(f"Created {len(data)} data points")
    
    # Train RMI with different configurations
    configs = [
        ("linear,linear", 1000),
        ("robust_linear,linear", 1000),
        ("cubic,linear", 1000),
        ("linear_spline,linear", 1000),
    ]
    
    results = []
    
    for model_spec, branch_factor in configs:
        print(f"\n{'=' * 70}")
        print(f"2. Training RMI: {model_spec} with branching factor {branch_factor}")
        print(f"{'=' * 70}")
        
        rmi = train(data.soft_copy(), model_spec, branch_factor)
        
        # Test on random keys
        print("\n3. Testing predictions...")
        test_keys = [data.get_key(i) for i in np.random.randint(0, len(data), size=1000)]
        
        benchmark_results = benchmark_rmi(rmi, test_keys, data.data)
        
        results.append({
            'config': model_spec,
            'branching_factor': branch_factor,
            'build_time': rmi.build_time,
            'avg_error': rmi.model_avg_error,
            'max_error': rmi.model_max_error,
            'avg_log2_error': rmi.model_avg_log2_error,
            'avg_search_range': benchmark_results['avg_search_range']
        })
    
    # Print comparison table
    print("\n" + "=" * 70)
    print("RESULTS COMPARISON")
    print("=" * 70)
    print(f"{'Config':<25} {'Branch':<8} {'Build(s)':<10} {'AvgErr':<10} {'MaxErr':<10} {'AvgLog2':<10} {'SearchRng':<10}")
    print("-" * 70)
    
    for r in results:
        print(f"{r['config']:<25} {r['branching_factor']:<8} {r['build_time']:<10.3f} "
              f"{r['avg_error']:<10.2f} {r['max_error']:<10} {r['avg_log2_error']:<10.3f} "
              f"{r['avg_search_range']:<10.2f}")
    
    print("\n" + "=" * 70)
    print("Done!")
    print("=" * 70)

if __name__ == "__main__":
    main()

RMI (Recursive Model Index) Implementation
Translated from Rust implementation

1. Creating synthetic dataset...
Created 1000000 data points

2. Training RMI: linear,linear with branching factor 1000
Training top-level linear model...
Training second-level linear models (num models = 1000)...
Computing lower bound corrections...
Fixing empty models...
Computing last level errors...
Evaluating RMI...

RMI Training Complete!
Build time: 5.543s
Average error: 1841.99
Max error: 38888
Average log2 error: 6.352
Max log2 error: 15.247

3. Testing predictions...

=== Benchmarking RMI ===
Tested 1000 keys
Average search range: 3015.47
Prediction accuracy: 99.80%
Avg log2(search range): 11.559

2. Training RMI: robust_linear,linear with branching factor 1000
Training top-level robust_linear model...
Training second-level linear models (num models = 1000)...
Computing lower bound corrections...
Fixing empty models...
Computing last level errors...
Evaluating RMI...

RMI Training Complete!
Build 

In [15]:
"""
RMI Implementation with Real Dataset Support
Optimized for SOSD Benchmark and Real-World Data
"""

import numpy as np
import struct
import requests
import os
from pathlib import Path
from typing import List, Tuple, Optional, Union
import time
import json
from dataclasses import dataclass, asdict

# =============================================================================
# DATASET DOWNLOADING AND LOADING
# =============================================================================

class SOSDDatasetDownloader:
    """
    Download and manage SOSD benchmark datasets
    Source: https://github.com/learnedsystems/SOSD
    """
    
    DATASETS = {
        'books_200M_uint64': {
            'url': 'https://dataverse.harvard.edu/api/access/datafile/:persistentId?persistentId=doi:10.7910/DVN/JGVF9A/MZZUP2',
            'size': 1600000000,  # bytes
            'records': 200000000,
            'description': 'Amazon book popularity (sorted by popularity rank)'
        },
        'fb_200M_uint64': {
            'url': 'https://dataverse.harvard.edu/api/access/datafile/:persistentId?persistentId=doi:10.7910/DVN/JGVF9A/SVN8PI',
            'size': 1600000000,
            'records': 200000000,
            'description': 'Facebook user IDs'
        },
        'osm_cellids_200M_uint64': {
            'url': 'https://dataverse.harvard.edu/api/access/datafile/:persistentId?persistentId=doi:10.7910/DVN/JGVF9A/LMTZJA',
            'size': 1600000000,
            'records': 200000000,
            'description': 'OpenStreetMap cell IDs'
        },
        'wiki_ts_200M_uint64': {
            'url': 'https://dataverse.harvard.edu/api/access/datafile/:persistentId?persistentId=doi:10.7910/DVN/JGVF9A/HJPVBB',
            'size': 1600000000,
            'records': 200000000,
            'description': 'Wikipedia edit timestamps'
        }
    }
    
    def __init__(self, data_dir: str = './sosd_data'):
        self.data_dir = Path(data_dir)
        self.data_dir.mkdir(exist_ok=True)
    
    def download_dataset(self, dataset_name: str, force: bool = False):
        """Download a specific dataset"""
        if dataset_name not in self.DATASETS:
            raise ValueError(f"Unknown dataset: {dataset_name}. Available: {list(self.DATASETS.keys())}")
        
        dataset_info = self.DATASETS[dataset_name]
        filepath = self.data_dir / dataset_name
        
        if filepath.exists() and not force:
            print(f"Dataset {dataset_name} already exists. Use force=True to re-download.")
            return filepath
        
        print(f"Downloading {dataset_name}...")
        print(f"Description: {dataset_info['description']}")
        print(f"Size: {dataset_info['size'] / 1e9:.2f} GB")
        print(f"Records: {dataset_info['records']:,}")
        
        # Download with progress bar
        response = requests.get(dataset_info['url'], stream=True)
        total_size = int(response.headers.get('content-length', 0))
        
        with open(filepath, 'wb') as f:
            downloaded = 0
            for chunk in response.iter_content(chunk_size=8192):
                if chunk:
                    f.write(chunk)
                    downloaded += len(chunk)
                    if total_size > 0:
                        progress = (downloaded / total_size) * 100
                        print(f"\rProgress: {progress:.1f}%", end='', flush=True)
        
        print(f"\n✓ Downloaded {dataset_name} to {filepath}")
        return filepath
    
    def list_available(self):
        """List all available datasets"""
        print("\n=== Available SOSD Datasets ===")
        for name, info in self.DATASETS.items():
            exists = (self.data_dir / name).exists()
            status = "✓ Downloaded" if exists else "✗ Not downloaded"
            print(f"\n{name}")
            print(f"  Status: {status}")
            print(f"  Description: {info['description']}")
            print(f"  Records: {info['records']:,}")
            print(f"  Size: {info['size'] / 1e9:.2f} GB")


def load_binary_dataset(filepath: str, max_records: Optional[int] = None) -> List[Tuple[int, int]]:
    """
    Load binary dataset in SOSD format:
    - First 8 bytes: number of items (uint64, little endian)
    - Remaining bytes: data items (uint64 or uint32, little endian)
    
    Returns: List of (key, position) tuples
    """
    filepath = Path(filepath)
    
    if not filepath.exists():
        raise FileNotFoundError(f"Dataset not found: {filepath}")
    
    # Determine data type from filename
    if 'uint32' in filepath.name:
        dtype = np.uint32
        item_size = 4
    elif 'uint64' in filepath.name or 'f64' in filepath.name:
        dtype = np.uint64
        item_size = 8
    else:
        # Default to uint64
        dtype = np.uint64
        item_size = 8
    
    print(f"Loading dataset: {filepath.name}")
    
    with open(filepath, 'rb') as f:
        # Read number of items
        num_items_bytes = f.read(8)
        num_items = struct.unpack('<Q', num_items_bytes)[0]
        
        if max_records:
            num_items = min(num_items, max_records)
        
        print(f"Total records in file: {num_items:,}")
        
        # Read data items
        print("Reading data...")
        data = np.fromfile(f, dtype=dtype, count=num_items)
    
    print(f"Loaded {len(data):,} records")
    
    # Sort and create (key, position) pairs
    print("Sorting data...")
    sorted_indices = np.argsort(data)
    sorted_data = data[sorted_indices]
    
    # Create position mapping
    result = [(int(key), int(pos)) for pos, key in enumerate(sorted_data)]
    
    print(f"✓ Dataset ready: {len(result):,} records")
    return result


def load_csv_dataset(filepath: str, key_column: int = 0, max_records: Optional[int] = None) -> List[Tuple[int, int]]:
    """
    Load CSV dataset and convert to RMI format
    
    Args:
        filepath: Path to CSV file
        key_column: Which column to use as key (0-indexed)
        max_records: Maximum number of records to load
    """
    import pandas as pd
    
    print(f"Loading CSV: {filepath}")
    df = pd.read_csv(filepath, nrows=max_records)
    
    # Get key column
    keys = df.iloc[:, key_column].values
    
    # Sort and create positions
    sorted_indices = np.argsort(keys)
    sorted_keys = keys[sorted_indices]
    
    result = [(int(key), int(pos)) for pos, key in enumerate(sorted_keys)]
    print(f"✓ Loaded {len(result):,} records from CSV")
    return result


# =============================================================================
# OPTIMIZED RMI IMPLEMENTATION
# =============================================================================

from dataclasses import dataclass
from typing import List, Tuple
import numpy as np

class RMITrainingData:
    """Optimized training data with NumPy backend"""
    def __init__(self, data: List[Tuple[Union[int, float], int]]):
        # Convert to NumPy arrays for performance
        self.keys = np.array([k for k, _ in data])
        self.positions = np.array([p for _, p in data])
        self.scale = 1.0
        
        # Sort by keys
        sort_idx = np.argsort(self.keys)
        self.keys = self.keys[sort_idx]
        self.positions = self.positions[sort_idx]
    
    def __len__(self) -> int:
        return len(self.keys)
    
    def set_scale(self, scale: float):
        self.scale = scale
    
    def get(self, idx: int) -> Tuple[Union[int, float], int]:
        key = self.keys[idx]
        pos = self.positions[idx]
        if abs(self.scale - 1.0) > 1e-10:
            pos = int(pos * self.scale)
        return (key, pos)
    
    def get_key(self, idx: int):
        return self.keys[idx]
    
    def iter(self):
        """Efficient iterator using NumPy"""
        if abs(self.scale - 1.0) > 1e-10:
            scaled_pos = (self.positions * self.scale).astype(int)
            for k, p in zip(self.keys, scaled_pos):
                yield (k, p)
        else:
            for k, p in zip(self.keys, self.positions):
                yield (k, p)
    
    def soft_copy(self):
        """Create shallow copy"""
        new_data = RMITrainingData.__new__(RMITrainingData)
        new_data.keys = self.keys
        new_data.positions = self.positions
        new_data.scale = self.scale
        return new_data


class LinearModel:
    """Optimized linear regression using NumPy"""
    def __init__(self, data: RMITrainingData):
        if len(data) == 0:
            self.intercept, self.slope = 0.0, 0.0
            return
        
        if len(data) == 1:
            _, y = data.get(0)
            self.intercept, self.slope = float(y), 0.0
            return
        
        # Vectorized linear regression
        X = data.keys.astype(np.float64)
        y = np.array([p for _, p in data.iter()], dtype=np.float64)
        
        if len(np.unique(X)) == 1:
            self.intercept, self.slope = np.mean(y), 0.0
            return
        
        # Compute slope and intercept
        X_mean = np.mean(X)
        y_mean = np.mean(y)
        
        numerator = np.sum((X - X_mean) * (y - y_mean))
        denominator = np.sum((X - X_mean) ** 2)
        
        if denominator == 0:
            self.slope = 0.0
        else:
            self.slope = numerator / denominator
        
        self.intercept = y_mean - self.slope * X_mean
    
    def predict(self, key: Union[int, float]) -> int:
        return max(0, int(self.slope * float(key) + self.intercept))


# =============================================================================
# COMPREHENSIVE BENCHMARKING FRAMEWORK
# =============================================================================

@dataclass
class BenchmarkResult:
    dataset_name: str
    dataset_size: int
    model_config: str
    branch_factor: int
    
    # Build metrics
    build_time: float
    
    # Error metrics
    avg_error: float
    max_error: int
    avg_log2_error: float
    max_log2_error: float
    
    # Query performance
    avg_search_range: float
    prediction_accuracy: float
    
    # Memory
    model_size_bytes: int
    
    def to_dict(self):
        """Convert to dict with JSON-serializable types"""
        result = asdict(self)
        # Convert numpy types to native Python types
        for key, value in result.items():
            if hasattr(value, 'item'):  # NumPy scalar
                result[key] = value.item()
            elif isinstance(value, (np.integer, np.floating)):
                result[key] = value.item()
        return result
    
    def summary(self) -> str:
        return f"""=== Benchmark Result ===
Dataset: {self.dataset_name} ({self.dataset_size:,} records)
Model: {self.model_config} (branch={self.branch_factor})

Build Time: {self.build_time:.3f}s
Average Error: {self.avg_error:.2f}
Max Error: {self.max_error:,}
Avg Log2 Error: {self.avg_log2_error:.3f}

Average Search Range: {self.avg_search_range:.2f}
Prediction Accuracy: {self.prediction_accuracy:.1f}%
Log2(Search Range): {np.log2(self.avg_search_range + 1):.3f}
"""


class RMIBenchmark:
    """Comprehensive benchmarking framework"""
    
    def __init__(self, output_dir: str = './benchmark_results'):
        self.output_dir = Path(output_dir)
        self.output_dir.mkdir(exist_ok=True)
        self.results = []
    
    def benchmark_dataset(
        self,
        dataset_path: str,
        model_configs: List[Tuple[str, int]],
        max_records: Optional[int] = None,
        num_test_queries: int = 1000
    ):
        """
        Benchmark multiple RMI configurations on a dataset
        
        Args:
            dataset_path: Path to dataset file
            model_configs: List of (model_spec, branch_factor) tuples
            max_records: Limit number of records (for testing)
            num_test_queries: Number of queries for testing
        """
        print(f"\n{'='*70}")
        print(f"Benchmarking: {Path(dataset_path).name}")
        print(f"{'='*70}")
        
        # Load dataset
        if dataset_path.endswith('.csv'):
            data_list = load_csv_dataset(dataset_path, max_records=max_records)
        else:
            data_list = load_binary_dataset(dataset_path, max_records=max_records)
        
        data = RMITrainingData(data_list)
        dataset_name = Path(dataset_path).stem
        
        # Generate test queries
        test_indices = np.random.randint(0, len(data), size=num_test_queries)
        test_keys = [data.get_key(i) for i in test_indices]
        
        # Benchmark each configuration
        for model_spec, branch_factor in model_configs:
            print(f"\n--- Testing: {model_spec} (branch={branch_factor}) ---")
            
            try:
                # Train RMI
                start_time = time.time()
                rmi = self._train_simple_rmi(data.soft_copy(), model_spec, branch_factor)
                build_time = time.time() - start_time
                
                # Test queries
                total_search_range = 0
                correct = 0
                
                for key in test_keys:
                    pred, err = rmi['predict'](key)
                    
                    lower = max(0, pred - err)
                    upper = min(len(data) - 1, pred + err)
                    search_range = upper - lower + 1
                    total_search_range += search_range
                    
                    # Check correctness using binary search
                    actual_pos = np.searchsorted(data.keys, key)
                    if lower <= actual_pos <= upper:
                        correct += 1
                
                avg_search_range = total_search_range / len(test_keys)
                accuracy = 100 * correct / len(test_keys)
                
                # Create result
                result = BenchmarkResult(
                    dataset_name=dataset_name,
                    dataset_size=len(data),
                    model_config=model_spec,
                    branch_factor=branch_factor,
                    build_time=build_time,
                    avg_error=rmi['avg_error'],
                    max_error=rmi['max_error'],
                    avg_log2_error=rmi['avg_log2_error'],
                    max_log2_error=rmi['max_log2_error'],
                    avg_search_range=avg_search_range,
                    prediction_accuracy=accuracy,
                    model_size_bytes=0  # TODO: calculate actual size
                )
                
                print(result.summary())
                self.results.append(result)
                
            except Exception as e:
                print(f"✗ Failed: {e}")
                continue
    
    def _train_simple_rmi(self, data: RMITrainingData, model_spec: str, branch_factor: int):
        """Simplified 2-layer RMI training"""
        layer1, layer2 = model_spec.split(',')
        
        # Train top model
        data.set_scale(branch_factor / len(data))
        top_model = LinearModel(data)
        
        # Train leaf models
        data.set_scale(1.0)
        leaf_models = []
        leaf_errors = []
        
        # Partition data
        partitions = [[] for _ in range(branch_factor)]
        for key, pos in data.iter():
            pred = top_model.predict(key)
            target = min(branch_factor - 1, pred)
            partitions[target].append((key, pos))
        
        # Train each leaf
        for partition in partitions:
            if len(partition) == 0:
                leaf_models.append(LinearModel(RMITrainingData([(0, 0)])))
                leaf_errors.append(0)
            else:
                leaf_data = RMITrainingData(partition)
                leaf_model = LinearModel(leaf_data)
                leaf_models.append(leaf_model)
                
                # Calculate error
                max_err = 0
                for key, pos in partition:
                    pred = leaf_model.predict(key)
                    err = abs(pred - pos)
                    max_err = max(max_err, err)
                leaf_errors.append(max_err)
        
        # Create prediction function
        def predict(key):
            pred = top_model.predict(key)
            target = min(branch_factor - 1, pred)
            leaf_pred = leaf_models[target].predict(key)
            return (leaf_pred, leaf_errors[target])
        
        return {
            'predict': predict,
            'avg_error': np.mean(leaf_errors),
            'max_error': max(leaf_errors),
            'avg_log2_error': np.mean([np.log2(2*e + 2) for e in leaf_errors]),
            'max_log2_error': np.log2(2*max(leaf_errors) + 2) if max(leaf_errors) > 0 else 0
        }
    
    def save_results(self, filename: str = 'benchmark_results.json'):
        """Save all results to JSON"""
        filepath = self.output_dir / filename
        with open(filepath, 'w') as f:
            json.dump([r.to_dict() for r in self.results], f, indent=2)
        print(f"\n✓ Results saved to {filepath}")
    
    def print_comparison(self):
        """Print comparison table"""
        if not self.results:
            print("No results to compare")
            return
        
        print("\n" + "="*120)
        print("BENCHMARK COMPARISON")
        print("="*120)
        print(f"{'Dataset':<25} {'Model':<20} {'Branch':<8} {'Build(s)':<10} {'AvgErr':<12} "
              f"{'SearchRng':<12} {'Accuracy':<10}")
        print("-"*120)
        
        for r in self.results:
            print(f"{r.dataset_name:<25} {r.model_config:<20} {r.branch_factor:<8} "
                  f"{r.build_time:<10.3f} {r.avg_error:<12.2f} {r.avg_search_range:<12.2f} "
                  f"{r.prediction_accuracy:<10.1f}%")


# =============================================================================
# EXAMPLE USAGE
# =============================================================================

def run_comprehensive_benchmark():
    """Run comprehensive benchmark on multiple datasets"""
    
    print("="*70)
    print("RMI Comprehensive Benchmark Suite")
    print("="*70)
    
    # Initialize benchmark
    benchmark = RMIBenchmark()
    
    # Model configurations to test
    configs = [
        ("linear,linear", 1000),
        ("linear,linear", 5000),
        ("linear,linear", 10000),
    ]
    
    # Test on synthetic data first (fast)
    print("\n1. Creating synthetic datasets...")
    
    # Uniform distribution
    uniform_data = [(i, i) for i in range(1000000)]
    uniform_path = Path('./test_data/uniform_1M.dat')
    uniform_path.parent.mkdir(exist_ok=True)
    
    # Save as binary
    with open(uniform_path, 'wb') as f:
        f.write(struct.pack('<Q', len(uniform_data)))
        for key, _ in uniform_data:
            f.write(struct.pack('<Q', key))
    
    benchmark.benchmark_dataset(
        str(uniform_path),
        configs,
        max_records=100000,  # Use subset for speed
        num_test_queries=1000
    )
    
    # Test on real SOSD data if available
    sosd_dir = Path('./sosd_data')
    if sosd_dir.exists():
        for dataset_file in sosd_dir.glob('*_uint64'):
            print(f"\nFound real dataset: {dataset_file.name}")
            benchmark.benchmark_dataset(
                str(dataset_file),
                configs,
                max_records=1000000,  # Use 1M records for faster testing
                num_test_queries=1000
            )
    else:
        print("\nNo SOSD datasets found. Run downloader to get real datasets.")
    
    # Print results
    benchmark.print_comparison()
    benchmark.save_results()
    
    print("\n" + "="*70)
    print("Benchmark Complete!")
    print("="*70)


if __name__ == "__main__":
    # Example 1: Download SOSD datasets
    print("Example 1: Dataset Downloader")
    print("-" * 50)
    
    downloader = SOSDDatasetDownloader()
    downloader.list_available()
    
    # downloader.download_dataset('fb_200M_uint64')
    
    # Example 2: Load and inspect a dataset
    print("\n\nExample 2: Load and Inspect Dataset")
    print("-" * 50)
    print("# To load a dataset:")
    data = load_binary_dataset('./sosd_data/fb_200M_uint64', max_records=1000000)
    print(f'First 10 keys: {[k for k, _ in data[:10]]}')
    
    # Example 3: Run benchmark
    print("\n\nExample 3: Run Comprehensive Benchmark")
    print("-" * 50)
    print("Uncomment the line below to run:")
    run_comprehensive_benchmark()

Example 1: Dataset Downloader
--------------------------------------------------

=== Available SOSD Datasets ===

books_200M_uint64
  Status: ✓ Downloaded
  Description: Amazon book popularity (sorted by popularity rank)
  Records: 200,000,000
  Size: 1.60 GB

fb_200M_uint64
  Status: ✓ Downloaded
  Description: Facebook user IDs
  Records: 200,000,000
  Size: 1.60 GB

osm_cellids_200M_uint64
  Status: ✗ Not downloaded
  Description: OpenStreetMap cell IDs
  Records: 200,000,000
  Size: 1.60 GB

wiki_ts_200M_uint64
  Status: ✗ Not downloaded
  Description: Wikipedia edit timestamps
  Records: 200,000,000
  Size: 1.60 GB


Example 2: Load and Inspect Dataset
--------------------------------------------------
# To load a dataset:
Loading dataset: fb_200M_uint64
Total records in file: 1,000,000
Reading data...
Loaded 1,000,000 records
Sorting data...
✓ Dataset ready: 1,000,000 records
First 10 keys: [200000000, 979672113, 1000931008, 1013842213, 1023286855, 1027557666, 1029155329, 103992