In [None]:
# RMI (Recursive Model Index) Implementation in Python
# Translated from the Rust implementation at https://github.com/learnedsystems/RMI

import numpy as np
from typing import List, Tuple, Optional, Dict, Any, Callable, Union
from dataclasses import dataclass
from abc import ABC, abstractmethod
import time
from enum import Enum
import bisect

# ================================================
# From: rmi_lib/src/models/mod.rs
# ================================================

class KeyType(Enum):
    U32 = "uint32"
    U64 = "uint64"
    F64 = "float64"

class ModelDataType(Enum):
    INT = "int"
    FLOAT = "float"

class ModelRestriction(Enum):
    NONE = 0
    MUST_BE_TOP = 1
    MUST_BE_BOTTOM = 2

@dataclass
class ModelInput:
    """Represents input to a model - can be int or float"""
    value: Union[int, float]
    
    def as_float(self) -> float:
        return float(self.value)
    
    def as_int(self) -> int:
        return int(self.value)
    
    def minus_epsilon(self):
        if isinstance(self.value, int):
            return ModelInput(max(0, self.value - 1))
        return ModelInput(self.value - np.finfo(float).eps)
    
    def plus_epsilon(self):
        if isinstance(self.value, int):
            return ModelInput(self.value + 1)
        return ModelInput(self.value + np.finfo(float).eps)

class RMITrainingData:
    """Training data container for RMI
    From: rmi_lib/src/models/mod.rs
    """
    def __init__(self, data: List[Tuple[Union[int, float], int]]):
        self.data = np.array(data, dtype=[('key', 'f8'), ('pos', 'i8')])
        self.scale = 1.0
        
    def __len__(self):
        return len(self.data)
    
    def get(self, idx: int) -> Tuple[Union[int, float], int]:
        item = self.data[idx]
        pos = int(item['pos'] * self.scale) if self.scale != 1.0 else int(item['pos'])
        return (item['key'], pos)
    
    def get_key(self, idx: int):
        return self.data[idx]['key']
    
    def set_scale(self, scale: float):
        self.scale = scale
    
    def iter(self):
        """Iterator over (key, position) pairs"""
        for item in self.data:
            pos = int(item['pos'] * self.scale) if self.scale != 1.0 else int(item['pos'])
            yield (item['key'], pos)
    
    def soft_copy(self):
        """Create a shallow copy with shared data"""
        new_data = RMITrainingData([])
        new_data.data = self.data
        new_data.scale = self.scale
        return new_data

# ================================================
# From: rmi_lib/src/models/mod.rs - Model trait
# ================================================

class Model(ABC):
    """Base class for all RMI models
    From: rmi_lib/src/models/mod.rs
    """
    
    @abstractmethod
    def predict_to_float(self, inp: ModelInput) -> float:
        pass
    
    def predict_to_int(self, inp: ModelInput) -> int:
        return max(0, int(np.floor(self.predict_to_float(inp))))
    
    @abstractmethod
    def input_type(self) -> ModelDataType:
        pass
    
    @abstractmethod
    def output_type(self) -> ModelDataType:
        pass
    
    def needs_bounds_check(self) -> bool:
        return True
    
    def restriction(self) -> ModelRestriction:
        return ModelRestriction.NONE
    
    def set_to_constant_model(self, constant: int) -> bool:
        return False

# ================================================
# From: rmi_lib/src/models/linear.rs
# ================================================

class LinearModel(Model):
    """Simple linear regression model
    From: rmi_lib/src/models/linear.rs
    """
    
    def __init__(self, data: RMITrainingData):
        self.alpha, self.beta = self._slr(data)
    
    def _slr(self, data: RMITrainingData) -> Tuple[float, float]:
        """Simple linear regression
        From: rmi_lib/src/models/linear.rs - slr function
        """
        mean_x = 0.0
        mean_y = 0.0
        c = 0.0
        n = 0
        m2 = 0.0
        
        for x, y in data.iter():
            n += 1
            dx = x - mean_x
            mean_x += dx / n
            mean_y += (y - mean_y) / n
            c += dx * (y - mean_y)
            dx2 = x - mean_x
            m2 += dx * dx2
        
        if n == 0:
            return (0.0, 0.0)
        if n == 1:
            return (mean_y, 0.0)
        
        cov = c / (n - 1)
        var = m2 / (n - 1)
        
        if var == 0.0:
            return (mean_y, 0.0)
        
        beta = cov / var
        alpha = mean_y - beta * mean_x
        
        return (alpha, beta)
    
    def predict_to_float(self, inp: ModelInput) -> float:
        return np.fma(self.beta, inp.as_float(), self.alpha)
    
    def input_type(self) -> ModelDataType:
        return ModelDataType.FLOAT
    
    def output_type(self) -> ModelDataType:
        return ModelDataType.FLOAT
    
    def set_to_constant_model(self, constant: int) -> bool:
        self.alpha = float(constant)
        self.beta = 0.0
        return True

# ================================================
# From: rmi_lib/src/models/linear.rs - RobustLinearModel
# ================================================

class RobustLinearModel(Model):
    """Robust linear regression that excludes outliers
    From: rmi_lib/src/models/linear.rs - RobustLinearModel
    """
    
    def __init__(self, data: RMITrainingData):
        if len(data) == 0:
            self.alpha, self.beta = 0.0, 0.0
            return
        
        # Skip 0.01% of data from each end to avoid outliers
        total_items = len(data)
        bnd = max(1, int(total_items * 0.0001))
        
        if bnd * 2 + 1 >= len(data):
            # Not enough data, fall back to regular linear
            self.alpha, self.beta = LinearModel(data)._slr(data)
            return
        
        # Create a view of the middle portion of data
        robust_data = []
        for i, (x, y) in enumerate(data.iter()):
            if bnd <= i < total_items - bnd:
                robust_data.append((x, y))
        
        # Fit linear regression on robust subset
        robust_container = RMITrainingData(robust_data)
        self.alpha, self.beta = LinearModel._slr(None, robust_container)
    
    def predict_to_float(self, inp: ModelInput) -> float:
        return np.fma(self.beta, inp.as_float(), self.alpha)
    
    def input_type(self) -> ModelDataType:
        return ModelDataType.FLOAT
    
    def output_type(self) -> ModelDataType:
        return ModelDataType.FLOAT
    
    def set_to_constant_model(self, constant: int) -> bool:
        self.alpha = float(constant)
        self.beta = 0.0
        return True

# ================================================
# From: rmi_lib/src/models/linear_spline.rs
# ================================================

class LinearSplineModel(Model):
    """Linear spline connecting first and last points
    From: rmi_lib/src/models/linear_spline.rs
    """
    
    def __init__(self, data: RMITrainingData):
        self.alpha, self.beta = self._linear_splines(data)
    
    def _linear_splines(self, data: RMITrainingData) -> Tuple[float, float]:
        """From: rmi_lib/src/models/linear_spline.rs - linear_splines function"""
        if len(data) == 0:
            return (0.0, 0.0)
        
        if len(data) == 1:
            return (float(data.get(0)[1]), 0.0)
        
        first_x, first_y = data.get(0)
        last_x, last_y = data.get(len(data) - 1)
        
        if first_x == last_x:
            return (float(first_y), 0.0)
        
        slope = (first_y - last_y) / (first_x - last_x)
        intercept = first_y - slope * first_x
        
        return (intercept, slope)
    
    def predict_to_float(self, inp: ModelInput) -> float:
        return np.fma(self.beta, inp.as_float(), self.alpha)
    
    def input_type(self) -> ModelDataType:
        return ModelDataType.FLOAT
    
    def output_type(self) -> ModelDataType:
        return ModelDataType.FLOAT
    
    def set_to_constant_model(self, constant: int) -> bool:
        self.alpha = float(constant)
        self.beta = 0.0
        return True

# ================================================
# From: rmi_lib/src/models/cubic_spline.rs
# ================================================

class CubicSplineModel(Model):
    """Cubic spline model
    From: rmi_lib/src/models/cubic_spline.rs
    """
    
    def __init__(self, data: RMITrainingData):
        self.a, self.b, self.c, self.d = self._cubic(data)
    
    def _cubic(self, data: RMITrainingData) -> Tuple[float, float, float, float]:
        """From: rmi_lib/src/models/cubic_spline.rs - cubic function"""
        if len(data) == 0:
            return (0.0, 0.0, 1.0, 0.0)
        
        if len(data) == 1:
            return (0.0, 0.0, 0.0, float(data.get(0)[1]))
        
        # Check for unique values
        first_key = data.get(0)[0]
        has_unique = any(x != first_key for x, _ in data.iter())
        
        if not has_unique:
            return (0.0, 0.0, 0.0, float(data.get(0)[1]))
        
        first_x, first_y = data.get(0)
        last_x, last_y = data.get(len(data) - 1)
        xmin, ymin = float(first_x), float(first_y)
        xmax, ymax = float(last_x), float(last_y)
        
        x1, y1 = 0.0, 0.0
        x2, y2 = 1.0, 1.0
        
        # Find m1
        for x, y in data.iter():
            sx = (x - xmin) / (xmax - xmin)
            if sx > 0.0:
                sy = (y - ymin) / (ymax - ymin)
                m1 = (sy - y1) / (sx - x1)
                break
        
        # Find m2
        for i in range(len(data) - 1, -1, -1):
            x, y = data.get(i)
            sx = (x - xmin) / (xmax - xmin)
            if sx < 1.0:
                sy = (y - ymin) / (ymax - ymin)
                m2 = (y2 - sy) / (x2 - sx)
                break
        
        # Keep it monotonic
        if m1**2 + m2**2 > 9.0:
            tau = 3.0 / np.sqrt(m1**2 + m2**2)
            m1 *= tau
            m2 *= tau
        
        a = (m1 + m2 - 2.0) / (xmax - xmin)**3
        b = -(xmax * (2.0 * m1 + m2 - 3.0) + xmin * (m1 + 2.0 * m2 - 3.0)) / (xmax - xmin)**3
        c = (m1 * xmax**2 + m2 * xmin**2 + xmax * xmin * (2.0 * m1 + 2.0 * m2 - 6.0)) / (xmax - xmin)**3
        d = -xmin * (m1 * xmax**2 + xmax * xmin * (m2 - 3.0) + xmin**2) / (xmax - xmin)**3
        
        a *= (ymax - ymin)
        b *= (ymax - ymin)
        c *= (ymax - ymin)
        d *= (ymax - ymin)
        d += ymin
        
        return (a, b, c, d)
    
    def predict_to_float(self, inp: ModelInput) -> float:
        val = inp.as_float()
        v1 = np.fma(self.a, val, self.b)
        v2 = np.fma(v1, val, self.c)
        v3 = np.fma(v2, val, self.d)
        return v3
    
    def input_type(self) -> ModelDataType:
        return ModelDataType.FLOAT
    
    def output_type(self) -> ModelDataType:
        return ModelDataType.FLOAT
    
    def set_to_constant_model(self, constant: int) -> bool:
        self.a = 0.0
        self.b = 0.0
        self.c = 0.0
        self.d = float(constant)
        return True

# ================================================
# From: rmi_lib/src/models/utils.rs and radix.rs
# ================================================

def num_bits(largest_target: int) -> int:
    """From: rmi_lib/src/models/utils.rs"""
    nbits = 0
    while (1 << (nbits + 1)) - 1 <= largest_target:
        nbits += 1
    return max(nbits, 1)

def common_prefix_size(data: RMITrainingData) -> int:
    """From: rmi_lib/src/models/utils.rs"""
    any_ones = 0
    no_ones = (1 << 64) - 1
    
    for x, _ in data.iter():
        x_int = int(x)
        any_ones |= x_int
        no_ones &= x_int
    
    any_zeros = ~no_ones & ((1 << 64) - 1)
    prefix_bits = any_zeros ^ any_ones
    
    # Count leading zeros
    if prefix_bits == 0:
        return 64
    return (prefix_bits ^ ((1 << 64) - 1)).bit_length()

class RadixModel(Model):
    """Radix-based model
    From: rmi_lib/src/models/radix.rs
    """
    
    def __init__(self, data: RMITrainingData):
        if len(data) == 0:
            self.left_shift = 0
            self.num_bits = 0
            return
        
        largest_value = max(y for _, y in data.iter())
        bits = num_bits(largest_value)
        self.num_bits = bits
        self.left_shift = common_prefix_size(data)
    
    def predict_to_int(self, inp: ModelInput) -> int:
        as_int = inp.as_int()
        res = (as_int << self.left_shift) >> (64 - self.num_bits)
        return res
    
    def predict_to_float(self, inp: ModelInput) -> float:
        return float(self.predict_to_int(inp))
    
    def input_type(self) -> ModelDataType:
        return ModelDataType.INT
    
    def output_type(self) -> ModelDataType:
        return ModelDataType.INT
    
    def needs_bounds_check(self) -> bool:
        return False
    
    def restriction(self) -> ModelRestriction:
        return ModelRestriction.MUST_BE_TOP

# ================================================
# From: rmi_lib/src/train/two_layer.rs
# ================================================

@dataclass
class TrainedRMI:
    """Trained RMI structure
    From: rmi_lib/src/train/mod.rs
    """
    num_rmi_rows: int
    num_data_rows: int
    model_avg_error: float
    model_avg_l2_error: float
    model_avg_log2_error: float
    model_max_error: int
    model_max_error_idx: int
    model_max_log2_error: float
    last_layer_max_l1s: List[int]
    rmi: List[List[Model]]
    models: str
    branching_factor: int
    build_time: float

def train_model(model_type: str, data: RMITrainingData) -> Model:
    """From: rmi_lib/src/train/mod.rs - train_model function"""
    if model_type == "linear":
        return LinearModel(data)
    elif model_type == "robust_linear":  # ADD THIS
        return RobustLinearModel(data)
    elif model_type == "linear_spline":
        return LinearSplineModel(data)
    elif model_type == "cubic":
        return CubicSplineModel(data)
    elif model_type == "radix":
        return RadixModel(data)
    else:
        raise ValueError(f"Unknown model type: {model_type}")

# ================================================
# From: rmi_lib/src/train/lower_bound_correction.rs
# ================================================

class LowerBoundCorrection:
    """Handles boundary cases for lower bound searches
    From: rmi_lib/src/train/lower_bound_correction.rs
    """
    
    def __init__(self, pred_func: Callable, num_leaf_models: int, data: RMITrainingData):
        """
        pred_func: function that predicts which leaf model a key maps to
        num_leaf_models: number of leaf models in the RMI
        data: training data
        """
        self.num_leaf_models = num_leaf_models
        
        # Track first and last key in each leaf
        self.first_key_for_leaf = [None] * num_leaf_models
        self.last_key_for_leaf = [None] * num_leaf_models
        
        # Track next key after each leaf
        self.next_for_leaf = [(0, 0)] * num_leaf_models
        
        # Track previous key before each leaf
        self.prev_for_leaf = [(0, 0)] * num_leaf_models
        
        # Track maximum run length of duplicate keys
        self.max_run_length = [0] * num_leaf_models
        
        # Build the correction data
        self._build(pred_func, data)
    
    def _build(self, pred_func: Callable, data: RMITrainingData):
        """Build correction tables"""
        last_target = 0
        current_run_length = 0
        current_run_key = data.get_key(0)
        
        # First pass: find first/last keys and run lengths
        for x, y in data.iter():
            leaf_idx = pred_func(x)
            target = min(self.num_leaf_models - 1, leaf_idx)
            
            # Track run lengths
            if target == last_target and x == current_run_key:
                current_run_length += 1
            else:
                self.max_run_length[last_target] = max(
                    self.max_run_length[last_target],
                    current_run_length
                )
                current_run_length = 1
                current_run_key = x
                last_target = target
            
            # Track first and last keys
            if self.first_key_for_leaf[target] is None:
                self.first_key_for_leaf[target] = (y, x)
            self.last_key_for_leaf[target] = (y, x)
        
        # Update final run length
        self.max_run_length[last_target] = max(
            self.max_run_length[last_target],
            current_run_length
        )
        
        # Build next_for_leaf: what's the first key after this leaf?
        for idx in range(self.num_leaf_models):
            # Find next non-empty leaf
            found = False
            for next_idx in range(idx + 1, self.num_leaf_models):
                if self.first_key_for_leaf[next_idx] is not None:
                    self.next_for_leaf[idx] = self.first_key_for_leaf[next_idx]
                    found = True
                    break
            
            if not found:
                # No more keys, use end of dataset
                self.next_for_leaf[idx] = (len(data), float('inf'))
        
        # Build prev_for_leaf: what's the last key before this leaf?
        for idx in range(self.num_leaf_models - 1, -1, -1):
            # Find previous non-empty leaf
            for prev_idx in range(idx - 1, -1, -1):
                if self.last_key_for_leaf[prev_idx] is not None:
                    self.prev_for_leaf[idx] = self.last_key_for_leaf[prev_idx]
                    break
    
    def first_key(self, leaf_idx: int):
        """Get first key in this leaf (or None if empty)"""
        if self.first_key_for_leaf[leaf_idx] is None:
            return None
        return self.first_key_for_leaf[leaf_idx][1]
    
    def last_key(self, leaf_idx: int):
        """Get last key in this leaf (or None if empty)"""
        if self.last_key_for_leaf[leaf_idx] is None:
            return None
        return self.last_key_for_leaf[leaf_idx][1]
    
    def next(self, leaf_idx: int):
        """Get (index, key) of first key after this leaf"""
        return self.next_for_leaf[leaf_idx]
    
    def next_index(self, leaf_idx: int):
        """Get index of first key after this leaf"""
        return self.next_for_leaf[leaf_idx][0]
    
    def prev_key(self, leaf_idx: int):
        """Get last key before this leaf"""
        return self.prev_for_leaf[leaf_idx][1]
    
    def longest_run(self, leaf_idx: int):
        """Get maximum run length of duplicate keys in this leaf"""
        return self.max_run_length[leaf_idx]


def train_two_layer(data: RMITrainingData, layer1_model: str, 
                   layer2_model: str, num_leaf_models: int) -> TrainedRMI:
    """Train a two-layer RMI with proper lower bound correction
    From: rmi_lib/src/train/two_layer.rs - train_two_layer function
    """
    start_time = time.time()
    num_rows = len(data)
    
    # Train top model
    print(f"Training top-level {layer1_model} model...")
    data.set_scale(num_leaf_models / num_rows)
    top_model = train_model(layer1_model, data)
    data.set_scale(1.0)
    
    # Train leaf models
    print(f"Training {num_leaf_models} {layer2_model} leaf models...")
    leaf_models = []
    second_layer_data = []
    last_target = 0
    
    for x, y in data.iter():
        model_pred = top_model.predict_to_int(ModelInput(x))
        target = min(num_leaf_models - 1, model_pred)
        
        if target > last_target:
            # Train previous model
            if second_layer_data:
                last_item = second_layer_data[-1]
                second_layer_data.append((x, y))
                
                container = RMITrainingData(second_layer_data)
                leaf_model = train_model(layer2_model, container)
                leaf_models.append(leaf_model)
                
                # Fill in skipped models
                for _ in range(last_target + 1, target):
                    leaf_models.append(train_model(layer2_model, RMITrainingData([])))
                
                second_layer_data = [last_item] if last_item else []
        
        second_layer_data.append((x, y))
        last_target = target
    
    # Train last model
    if second_layer_data:
        container = RMITrainingData(second_layer_data)
        leaf_model = train_model(layer2_model, container)
        leaf_models.append(leaf_model)
    
    # Fill remaining empty models
    for _ in range(len(leaf_models), num_leaf_models):
        leaf_models.append(train_model(layer2_model, RMITrainingData([])))
    
    print("Computing lower bound corrections...")
    # Build lower bound correction
    lb_corrections = LowerBoundCorrection(
        lambda x: top_model.predict_to_int(ModelInput(x)),
        num_leaf_models,
        data
    )
    
    # Fix empty models to return correct constants
    print("Fixing empty models...")
    for idx in range(num_leaf_models - 1):
        if lb_corrections.first_key(idx) is None:
            # Model is empty - set to constant
            upper_bound = lb_corrections.next_index(idx)
            leaf_models[idx].set_to_constant_model(upper_bound)
    
    # Compute errors with lower bound correction
    print("Computing errors...")
    last_layer_max_l1s = [(0, 0)] * num_leaf_models
    
    for x, y in data.iter():
        leaf_idx = top_model.predict_to_int(ModelInput(x))
        target = min(num_leaf_models - 1, leaf_idx)
        
        pred = leaf_models[target].predict_to_int(ModelInput(x))
        err = abs(pred - y)
        
        count, max_err = last_layer_max_l1s[target]
        last_layer_max_l1s[target] = (count + 1, max(err, max_err))
    
    # Apply lower bound corrections to errors
    for leaf_idx in range(num_leaf_models):
        curr_err = last_layer_max_l1s[leaf_idx][1]
        
        # Upper error: can we handle queries for (next_key - epsilon)?
        idx_of_next, key_of_next = lb_corrections.next(leaf_idx)
        if key_of_next != float('inf'):
            pred = leaf_models[leaf_idx].predict_to_int(
                ModelInput(max(0, key_of_next - 1))
            )
            upper_error = abs(pred - min(idx_of_next + 1, num_rows))
        else:
            upper_error = curr_err
        
        # Lower error: can we handle queries for (prev_key + epsilon)?
        first_key_before = lb_corrections.prev_key(leaf_idx)
        prev_idx = max(0, leaf_idx - 1)
        first_idx = lb_corrections.next_index(prev_idx)
        
        if first_key_before != 0:
            pred = leaf_models[leaf_idx].predict_to_int(
                ModelInput(first_key_before + 1)
            )
            lower_error = abs(pred - first_idx)
        else:
            lower_error = curr_err
        
        # Take maximum error and add run length
        new_err = max(curr_err, upper_error, lower_error) + lb_corrections.longest_run(leaf_idx)
        
        num_items_in_leaf = last_layer_max_l1s[leaf_idx][0]
        last_layer_max_l1s[leaf_idx] = (num_items_in_leaf, new_err)
    
    # Calculate statistics
    final_errors = [err for _, err in last_layer_max_l1s]
    model_max_error_idx = np.argmax(final_errors)
    model_max_error = final_errors[model_max_error_idx]
    
    model_avg_error = sum(n * err for n, err in last_layer_max_l1s) / num_rows
    model_avg_l2_error = sum((n * err)**2 for n, err in last_layer_max_l1s) / num_rows
    model_avg_log2_error = sum(n * np.log2(2 * err + 2) for n, err in last_layer_max_l1s) / num_rows
    model_max_log2_error = np.log2(model_max_error) if model_max_error > 0 else 0.0
    
    build_time = time.time() - start_time
    
    print(f"Training complete in {build_time:.2f}s")
    
    return TrainedRMI(
        num_rmi_rows=num_rows,
        num_data_rows=num_rows,
        model_avg_error=model_avg_error,
        model_avg_l2_error=model_avg_l2_error,
        model_avg_log2_error=model_avg_log2_error,
        model_max_error=model_max_error,
        model_max_error_idx=model_max_error_idx,
        model_max_log2_error=model_max_log2_error,
        last_layer_max_l1s=final_errors,
        rmi=[[top_model], leaf_models],
        models=f"{layer1_model},{layer2_model}",
        branching_factor=num_leaf_models,
        build_time=build_time
    )

# ================================================
# From: rmi_lib/src/train/mod.rs - train function
# ================================================

def train(data: RMITrainingData, model_spec: str, branch_factor: int) -> TrainedRMI:
    """Train an RMI with given model specification
    From: rmi_lib/src/train/mod.rs - train function
    """
    models = model_spec.split(',')
    
    if len(models) == 2:
        return train_two_layer(data, models[0], models[1], branch_factor)
    else:
        raise NotImplementedError("Only two-layer RMIs are currently supported")

# ================================================
# Lookup function
# ================================================

def rmi_lookup(rmi: TrainedRMI, key: Union[int, float]) -> Tuple[int, int]:
    """Lookup a key in the trained RMI
    Returns: (predicted_position, error_bound)
    """
    inp = ModelInput(key)
    
    # Top layer
    model_idx = rmi.rmi[0][0].predict_to_int(inp)
    model_idx = min(model_idx, len(rmi.rmi[1]) - 1)
    
    # Leaf layer
    pred = rmi.rmi[1][model_idx].predict_to_int(inp)
    pred = min(pred, rmi.num_rmi_rows - 1)
    
    err = rmi.last_layer_max_l1s[model_idx]
    
    return (pred, err)

# ================================================
# Example Usage and Testing
# ================================================

def generate_test_data(n: int, distribution: str = "linear") -> List[Tuple[int, int]]:
    """Generate test data for RMI"""
    if distribution == "linear":
        keys = np.arange(n)
    elif distribution == "random":
        keys = np.sort(np.random.randint(0, n * 10, n))
    elif distribution == "exponential":
        keys = np.sort(np.random.exponential(scale=1000, size=n).astype(int))
    else:
        raise ValueError(f"Unknown distribution: {distribution}")
    
    return [(int(k), i) for i, k in enumerate(keys)]

# Example: Train and test RMI
print("=" * 60)
print("RMI Implementation Example")
print("=" * 60)

# Generate test data
n = 10000
test_data = generate_test_data(n, "linear")
data = RMITrainingData(test_data)

print(f"\nDataset size: {len(data)} keys")
print(f"Key range: [{data.get_key(0)}, {data.get_key(len(data)-1)}]")

# Train RMI
print("\nTraining RMI with cubic,linear model...")
rmi = train(data, "cubic,linear", branch_factor=100)

print(f"\nTraining completed in {rmi.build_time:.4f} seconds")
print(f"Model configuration: {rmi.models}")
print(f"Branching factor: {rmi.branching_factor}")
print(f"Average error: {rmi.model_avg_error:.2f}")
print(f"Average log2 error: {rmi.model_avg_log2_error:.2f}")
print(f"Max error: {rmi.model_max_error}")
print(f"Max log2 error: {rmi.model_max_log2_error:.2f}")

# Test lookups
print("\n" + "=" * 60)
print("Testing Lookups")
print("=" * 60)

test_keys = [0, n//4, n//2, 3*n//4, n-1]
for key in test_keys:
    pred, err = rmi_lookup(rmi, key)
    actual = key  # For linear data, key == position
    print(f"Key: {key:6d} | Predicted: {pred:6d} ± {err:4d} | Actual: {actual:6d} | Error: {abs(pred-actual):4d}")