In [None]:
"""
RMI (Recursive Model Index) Implementation in Python
Translated from the Go implementation at github.com/BenJoyenConseil/rmi
"""

# ============================================================
# Part 1: Linear Regression Estimator
# Translated from: estimate/linear/estimator.go
# ============================================================

import numpy as np
from scipy import stats
import pandas as pd
from typing import Tuple, List, Optional
import time

class RegressionModel:
    """Linear Regression model for CDF prediction"""
    
    def __init__(self, intercept: float, slope: float):
        self.intercept = intercept
        self.slope = slope
    
    def predict(self, x: float) -> float:
        """Predict the CDF result of a given x"""
        return self.intercept + self.slope * x


def fit(x: np.ndarray, y: np.ndarray) -> RegressionModel:
    """
    Fit a linear regression model on x and y
    Returns a RegressionModel with alpha (intercept) and beta (slope)
    """
    # Using scipy's linregress for linear regression
    slope, intercept, _, _, _ = stats.linregress(x, y)
    
    # Handle NaN cases (when all x values are the same)
    if np.isnan(intercept) or np.isnan(slope):
        intercept = np.mean(y)
        slope = 0.0
    
    return RegressionModel(intercept, slope)


def cdf(x: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
    """
    Return the x array and the y array containing empirical CDF value for each x's value
    len(x) = len(y)
    """
    y = np.zeros(len(x))
    for i, val in enumerate(x):
        # Calculate empirical CDF for each value
        y[i] = np.sum(x <= val) / len(x)
    return x, y


# ============================================================
# Part 2: Sorted Table
# Translated from: search/sort.go
# ============================================================

class SortedTable:
    """
    A Sorted Table represents a collection of key:offset pairs 
    that is sorted by key, keeping offsets following their corresponding key
    """
    
    def __init__(self, keys: np.ndarray, offsets: np.ndarray):
        self.keys = keys
        self.offsets = offsets


def new_sorted_table(x: np.ndarray) -> SortedTable:
    """
    Return a Sorted Table structure sorted by key in ascending order
    """
    keys = x.copy()
    offsets = np.arange(len(x))
    
    # Sort both arrays by keys
    sort_indices = np.argsort(keys)
    keys = keys[sort_indices]
    offsets = offsets[sort_indices]
    
    return SortedTable(keys, offsets)


# ============================================================
# Part 3: Index Utilities
# Translated from: index/bound.go
# ============================================================

def residual(guess: int, y: int) -> int:
    """Calculate residual between guess and actual position"""
    return y - guess


def scale(cdf_val: float, dataset_len: int) -> int:
    """
    Scale returns the CDF value * datasetLen - 1 to get back 
    the position in a sortedTable
    """
    return int(np.round(cdf_val * dataset_len - 1))


# ============================================================
# Part 4: Learned Index
# Translated from: index/learn.go
# ============================================================

class LearnedIndex:
    """
    LearnedIndex is an index structure that uses inference to locate keys
    """
    
    def __init__(self, m: RegressionModel, st: SortedTable, length: int, 
                 min_err_bound: int, max_err_bound: int):
        self.m = m
        self.st = st
        self.len = length
        self.min_err_bound = min_err_bound
        self.max_err_bound = max_err_bound
    
    def guess_index(self, key: float) -> Tuple[int, int, int]:
        """
        Return the predicted position of the key in the index
        and upper/lower positions' search interval. 
        Guess, lower and upper always have values between 0 and len(keys)-1
        """
        guess = scale(self.m.predict(key), self.len)
        lower = self.min_err_bound + guess
        
        if lower < 0:
            lower = 0
        elif lower > self.len - 1:
            lower = self.len - 1
        
        upper = guess + self.max_err_bound
        if upper > self.len - 1:
            upper = self.len - 1
        elif upper < 0:
            upper = 0
        
        if guess < 0:
            guess = 0
        elif guess > self.len - 1:
            guess = self.len - 1
        
        return guess, lower, upper
    
    def lookup(self, key: float) -> Tuple[Optional[List[int]], Optional[str]]:
        """
        Lookup returns the offsets of the key or error if the key is not found
        """
        guess, lower, upper = self.guess_index(key)
        offsets = []
        
        # Binary search in the bounded range
        if key > self.st.keys[guess]:
            sub_keys = self.st.keys[guess + 1:upper + 1]
            i = np.searchsorted(sub_keys, key, side='left') + guess + 1
        elif key <= self.st.keys[guess]:
            sub_keys = self.st.keys[lower:guess + 1]
            i = np.searchsorted(sub_keys, key, side='left') + lower
        
        # Iterate to get all equal keys
        while i <= upper:
            if i < len(self.st.keys) and self.st.keys[i] == key:
                offsets.append(int(self.st.offsets[i]))
                i += 1
            else:
                break
        
        if len(offsets) == 0:
            return None, f"The following key <{key}> is not found in the index"
        
        return offsets, None


def new_learned_index(dataset: np.ndarray) -> LearnedIndex:
    """
    Create a new LearnedIndex fitted over the dataset with linear regression
    """
    st = new_sorted_table(dataset)
    
    x, y = cdf(st.keys)
    len_ = len(dataset)
    m = fit(x, y)
    
    guesses = np.zeros(len_, dtype=int)
    scaled_y = np.zeros(len_, dtype=int)
    max_err, min_err = 0, 0
    
    for i, k in enumerate(x):
        guesses[i] = scale(m.predict(k), len_)
        scaled_y[i] = scale(y[i], len_)
        res = residual(guesses[i], scaled_y[i])
        
        if res > max_err:
            max_err = res
        elif res < min_err:
            min_err = res
    
    return LearnedIndex(m, st, len_, min_err, max_err)


# ============================================================
# Part 5: Search Algorithms for Comparison
# Translated from: search/fullscan.go and search/binary.go
# ============================================================

def full_scan_lookup(key: float, st: SortedTable) -> Tuple[Optional[List[int]], Optional[str]]:
    """Full scan lookup for comparison"""
    offsets = []
    
    for i in range(len(st.keys)):
        if st.keys[i] == key:
            offsets.append(int(st.offsets[i]))
    
    if len(offsets) > 0:
        return offsets, None
    return None, f"The following key <{key}> is not found in the index"


def binary_search_lookup(key: float, st: SortedTable) -> Tuple[Optional[List[int]], Optional[str]]:
    """Binary search lookup for comparison"""
    i = np.searchsorted(st.keys, key, side='left')
    offsets = []
    
    while i < len(st.keys):
        if st.keys[i] > key:
            break
        elif st.keys[i] == key:
            offsets.append(int(st.offsets[i]))
        i += 1
    
    if len(offsets) > 0:
        return offsets, None
    return None, f"The following key <{key}> is not found in the index"


# ============================================================
# Part 6: Benchmarking Functions
# Translated from: functional_test.go (benchmark sections)
# ============================================================

def benchmark_learned_index(dataset: np.ndarray, n_queries: int = 10000, 
                           key_range: Tuple[float, float] = (0., 100.)) -> dict:
    """Benchmark the learned index"""
    idx = new_learned_index(dataset)
    
    min_val, max_val = key_range
    keys_found = {}
    keys_not_found = {}
    
    start_time = time.time()
    
    for _ in range(n_queries):
        k = np.round(min_val + np.random.random() * (max_val - min_val))
        offsets, err = idx.lookup(k)
        
        if err is not None:
            if k not in keys_not_found:
                keys_not_found[k] = []
            keys_not_found[k].append(err)
        else:
            keys_found[k] = offsets
    
    elapsed_time = time.time() - start_time
    
    return {
        'time': elapsed_time,
        'queries': n_queries,
        'keys_found': len(keys_found),
        'keys_not_found': len(keys_not_found),
        'time_per_query': elapsed_time / n_queries * 1e6  # microseconds
    }


def benchmark_binary_search(dataset: np.ndarray, n_queries: int = 10000,
                           key_range: Tuple[float, float] = (0., 100.)) -> dict:
    """Benchmark binary search"""
    st = new_sorted_table(dataset)
    
    min_val, max_val = key_range
    keys_found = {}
    keys_not_found = {}
    
    start_time = time.time()
    
    for _ in range(n_queries):
        k = np.round(min_val + np.random.random() * (max_val - min_val))
        offsets, err = binary_search_lookup(k, st)
        
        if err is not None:
            if k not in keys_not_found:
                keys_not_found[k] = []
            keys_not_found[k].append(err)
        else:
            keys_found[k] = offsets
    
    elapsed_time = time.time() - start_time
    
    return {
        'time': elapsed_time,
        'queries': n_queries,
        'keys_found': len(keys_found),
        'keys_not_found': len(keys_not_found),
        'time_per_query': elapsed_time / n_queries * 1e6  # microseconds
    }


# ============================================================
# Part 7: Example Usage and Testing
# Translated from: main.go and functional_test.go
# ============================================================

def extract_column(file_path: str, col_name: str) -> np.ndarray:
    """Extract a column from CSV file"""
    df = pd.read_csv(file_path)
    return df[col_name].values


def run_example():
    """Run the example from main.go"""
    print("=" * 60)
    print("RMI Example - People Dataset")
    print("=" * 60)
    
    # Create sample data (from data/people.csv)
    people_data = {
        'name': ['jeanne', 'jean', 'Carlos', 'Carlotta', 'Miguel', 'Martine', 'Georgette'],
        'age': [90, 23, 3, 45, 1, 1.5, 23],
        'sex': ['F', 'M', 'M', 'F', 'M', 'F', 'F']
    }
    
    age_column = np.array(people_data['age'])
    
    # Create index
    index = new_learned_index(age_column)
    
    print(f"Index created with {index.len} elements")
    print(f"Max Error Bound: {index.max_err_bound}")
    print(f"Min Error Bound: {index.min_err_bound}")
    print()
    
    # Search for age 23
    search_age = 23.0
    lines, err = index.lookup(search_age)
    
    if err:
        print(f"Error: {err}")
    else:
        print(f"People who are {search_age} years old are located at {lines}")
    
    return index, age_column


def run_tests(dataset: np.ndarray):
    """Run functional tests to verify correctness"""
    print("\n" + "=" * 60)
    print("Functional Tests - Verifying Correctness")
    print("=" * 60)
    
    li = new_learned_index(dataset)
    st = new_sorted_table(dataset)
    
    # Test for all unique keys
    unique_keys = np.unique(dataset)
    all_passed = True
    
    for key in unique_keys:
        result_fs, _ = full_scan_lookup(key, st)
        result_li, _ = li.lookup(key)
        
        if set(result_fs) != set(result_li):
            print(f"FAIL: Key {key} - FS: {result_fs}, LI: {result_li}")
            all_passed = False
    
    if all_passed:
        print(f"✓ All {len(unique_keys)} keys passed functional tests")
    else:
        print("✗ Some tests failed")
    
    return all_passed


def run_benchmarks(dataset: np.ndarray, n_queries: int = 10000):
    """Run benchmarks comparing learned index vs binary search"""
    print("\n" + "=" * 60)
    print(f"Benchmarks - {n_queries} queries")
    print("=" * 60)
    
    key_min, key_max = dataset.min(), dataset.max()
    
    # Benchmark Learned Index
    print("\nRunning Learned Index benchmark...")
    li_results = benchmark_learned_index(dataset, n_queries, (key_min, key_max))
    
    # Benchmark Binary Search
    print("Running Binary Search benchmark...")
    bs_results = benchmark_binary_search(dataset, n_queries, (key_min, key_max))
    
    # Print results
    print("\n" + "-" * 60)
    print("LEARNED INDEX:")
    print(f"  Total time: {li_results['time']:.4f} seconds")
    print(f"  Time per query: {li_results['time_per_query']:.2f} μs")
    print(f"  Keys found: {li_results['keys_found']}")
    print(f"  Keys not found: {li_results['keys_not_found']}")
    
    print("\nBINARY SEARCH:")
    print(f"  Total time: {bs_results['time']:.4f} seconds")
    print(f"  Time per query: {bs_results['time_per_query']:.2f} μs")
    print(f"  Keys found: {bs_results['keys_found']}")
    print(f"  Keys not found: {bs_results['keys_not_found']}")
    
    speedup = bs_results['time'] / li_results['time']
    print(f"\nSpeedup: {speedup:.2f}x")
    print("-" * 60)
    
    return li_results, bs_results


# ============================================================
# Main execution
# ============================================================

if __name__ == "__main__":
    # Run the basic example
    run_example()
    
    # For more comprehensive testing, load a larger dataset
    print("\n\nFor comprehensive benchmarks, load your dataset:")
    print("  dataset = extract_column('path/to/titanic.csv', 'age')")
    print("  run_tests(dataset)")
    print("  run_benchmarks(dataset, n_queries=10000)")

RMI Example - People Dataset
Index created with 7 elements
Max Error Bound: 1
Min Error Bound: -2

People who are 23.0 years old are located at [1, 6]


For comprehensive benchmarks, load your dataset:
  dataset = extract_column('path/to/titanic.csv', 'age')
  run_tests(dataset)
  run_benchmarks(dataset, n_queries=10000)


: 