In [1]:
import numpy as np
from scipy.special import digamma
from collections import defaultdict


class JoinAwareMIEstimator:
    def __init__(self, k=3):
        """
        Initialize MI estimator with k nearest neighbors
        
        Args:
            k: Number of nearest neighbors to use
        """
        self.k = k
        self._valid_keys = None  # Cache for valid join keys
        self._x_values = None  # Cache for valid X values
        self._y_values = None  # Cache for valid Y values

    def _compute_distances(self, X, Y, join_key):
        """
        Compute k-nearest neighbor distances without materializing full join
        
        Args:
            X: Dictionary mapping join_key -> x_value 
            Y: Dictionary mapping join_key -> y_value
            join_key: Array of join keys that exist in both X and Y
        
        Returns:
            epsilon: Array of k-th nearest neighbor distances
        """
        distances = []
        for key in join_key:
            if key not in X or key not in Y:
                continue

            x, y = X[key], Y[key]
            # Find k-th nearest neighbor distance using L-infinity norm
            x_dists = [abs(x - X[k2]) for k2 in join_key if k2 in X]
            y_dists = [abs(y - Y[k2]) for k2 in join_key if k2 in Y]
            x_dists.sort()
            y_dists.sort()

            if len(x_dists) < self.k or len(y_dists) < self.k:
                continue

            dist = max(x_dists[self.k-1], y_dists[self.k-1])
            distances.append(dist)

        return np.array(distances)

    def _count_neighbors(self, X, Y, join_key, epsilon):
        """
        Count points within epsilon radius for each dimension
        
        Args:
            X, Y: Dictionaries mapping join_key -> value
            join_key: Array of valid join keys
            epsilon: Array of distances for each point
            
        Returns:
            nx, ny: Arrays containing number of neighbors in each dimension
        """
        nx, ny = [], []

        for i, key in enumerate(join_key):
            if key not in X or key not in Y:
                continue

            x, y = X[key], Y[key]
            eps = epsilon[i]

            # Count neighbors within epsilon using L-infinity norm
            x_neighbors = sum(1 for k2 in join_key if k2 in X and abs(X[k2] - x) <= eps)
            y_neighbors = sum(1 for k2 in join_key if k2 in Y and abs(Y[k2] - y) <= eps)

            nx.append(x_neighbors)
            ny.append(y_neighbors)

        return np.array(nx), np.array(ny)

    def estimate(self, X, Y, join_key=None):
        """
        Estimate mutual information between X and Y using k-nearest neighbors
        
        Args:
            X: Dictionary mapping join_key -> x_value
            Y: Dictionary mapping join_key -> y_value 
            join_key: Optional array of join keys present in both X and Y. If None, will compute intersection of X and Y keys.
            
        Returns:
            mi: Estimated mutual information value
        """
        # Compute valid keys if not provided
        if join_key is None:
            join_key = list(set(X.keys()) & set(Y.keys()))

        # Cache valid keys and values
        self._valid_keys = set(join_key)  # Convert to set for O(1) lookup
        self._x_values = np.array([X[k] for k in join_key])
        self._y_values = np.array([Y[k] for k in join_key])
        # Get k-nearest distances
        epsilon = self._compute_distances(X, Y, join_key)
        if len(epsilon) == 0:
            return 0.0

        # Count neighbors within epsilon radius
        nx, ny = self._count_neighbors(X, Y, join_key, epsilon)

        # Compute MI estimate using digamma function
        n = len(epsilon)
        mi = digamma(self.k) - np.mean(digamma(nx) + digamma(ny)) + digamma(n)

        return max(0, mi)  # MI should be non-negative


def generate_test_data(n_samples=1000, correlation=0.5, noise=0.1):
    """Generate correlated normal distributions with known MI"""
    # Generate correlated normal data
    mean = [0, 0]
    cov = [[1, correlation], [correlation, 1]]
    x, y = np.random.multivariate_normal(mean, cov, n_samples).T

    # Add some noise
    x += np.random.normal(0, noise, n_samples)
    y += np.random.normal(0, noise, n_samples)

    # Convert to dictionary format with keys
    X = {f'key{i}': val for i, val in enumerate(x)}
    Y = {f'key{i}': val for i, val in enumerate(y)}

    # Theoretical MI for bivariate normal is -0.5 * log(1 - correlation^2)
    theoretical_mi = -0.5 * np.log(1 - correlation**2)

    return X, Y, theoretical_mi


def test_mi_estimators(n_samples=1000, correlations=[0.0, 0.3, 0.6, 0.9], k=3):
    print("\nTesting MI Estimators")
    print("-" * 80)
    print(f"{'Correlation':^12} | {'Theoretical':^12} | {
          'Sklearn':^12} | {'JoinAware':^12} | {'Time Ratio':^12}")
    print("-" * 80)

    for correlation in correlations:
        # Generate test data
        X_dict, Y_dict, theoretical_mi = generate_test_data(
            n_samples, correlation)
        join_keys = list(X_dict.keys())

        # Convert to array format for sklearn
        X_arr = np.array([X_dict[k] for k in join_keys]).reshape(-1, 1)
        y_arr = np.array([Y_dict[k] for k in join_keys])

        # Time sklearn implementation
        t0 = time.time()
        sklearn_mi = mutual_info_regression(X_arr, y_arr, n_neighbors=k)[0]
        sklearn_time = time.time() - t0

        # Time our implementation
        t0 = time.time()
        estimator = JoinAwareMIEstimator(k=k)
        our_mi = estimator.estimate(X_dict, Y_dict, join_keys)
        our_time = time.time() - t0

        # Print results
        print(f"{correlation:^12.3f} | {theoretical_mi:^12.3f} | {
              sklearn_mi:^12.3f} | {our_mi:^12.3f} | {our_time/sklearn_time:^12.3f}")


if __name__ == "__main__":
    np.random.seed(42)

    # Test with different sample sizes
    for n_samples in [100, 1000, 10000]:
        print(f"\nTesting with {n_samples} samples:")
        test_mi_estimators(n_samples=n_samples)

    # Test with different k values
    print("\nTesting with different k values (1000 samples):")
    for k in [1, 3, 5, 10]:
        print(f"\nk = {k}:")
        test_mi_estimators(n_samples=1000, k=k)

In [None]:
# Method 1: Let it compute valid keys
estimator = JoinAwareMIEstimator(k=3)
mi = estimator.estimate(X_dict, Y_dict)

# Method 2: Provide pre-computed join keys
valid_keys = list(set(X_dict.keys()) & set(Y_dict.keys()))
mi = estimator.estimate(X_dict, Y_dict, join_key=valid_keys)

# Method 3: Provide specific subset of keys
important_keys = ['id1', 'id2']  # Only compute MI for these keys
mi = estimator.estimate(X_dict, Y_dict, join_key=important_keys)