In [4]:
import torch
from abc import ABC, abstractmethod
import numpy as np
from scipy.special import digamma
import multiprocessing as mp
from concurrent.futures import ThreadPoolExecutor
from sklearn.neighbors import KDTree, NearestNeighbors

def std_matrix(X):
    """Standardize matrix by centering and scaling by standard deviation"""
    x_means = torch.mean(X, dim=0)
    x_stds = torch.std(X, dim=0)
    return (X - x_means) / x_stds

class GPUMIComputer:
    def __init__(self, n_neighbors=3, device='cuda'):
        self.k = n_neighbors
        if torch.cuda.is_available():
            self.device = 'cuda'
        elif torch.backends.mps.is_available():
            self.device = 'mps'
        else:
            self.device = 'cpu'
            print("You won't be able to train the RNN decoder on a CPU, unfortunately.")
        print(self.device)

    def compute_radius(self, X, Y):
        """Compute k-nearest neighbor radius efficiently on GPU"""
        n, m = X.shape
        points = torch.stack((X, Y), dim=2).to(self.device)
        
        # Compute pairwise distances using batched operations
        batch_size = 1000  # Adjust based on GPU memory
        distances = torch.zeros(n, n, device=self.device)
        
        for i in range(0, n, batch_size):
            batch_end = min(i + batch_size, n)
            batch = points[i:batch_end]
            
            # Compute distances for current batch
            diffs = batch.unsqueeze(1) - points.unsqueeze(0)
            batch_distances = torch.max(torch.abs(diffs), dim=2)[0]
            distances[i:batch_end] = batch_distances

        # Get kth smallest distance for each point
        k_distances, _ = torch.topk(distances, k=self.k+1, dim=1, largest=False)
        return k_distances[:, -1]

    def compute_counts(self, X, Y, radius):
        """Count points within radius using optimized GPU operations"""
        n, m = X.shape
        radius = radius.unsqueeze(1)

        # Compute counts in batches
        batch_size = 1000
        nx = torch.zeros(n, device=self.device)
        ny = torch.zeros(n, device=self.device)

        for i in range(0, n, batch_size):
            batch_end = min(i + batch_size, n)
            batch_radius = radius[i:batch_end]

            # X counts
            X_diffs = torch.abs(X[i:batch_end].unsqueeze(1) - X.unsqueeze(0))
            nx[i:batch_end] = (X_diffs < batch_radius).sum(dim=1) - 1

            # Y counts
            Y_diffs = torch.abs(Y[i:batch_end].unsqueeze(1) - Y.unsqueeze(0))
            ny[i:batch_end] = (Y_diffs < batch_radius).sum(dim=1) - 1

        return nx, ny

    def compute_MI(self, X, Y):
        """Compute mutual information with GPU acceleration"""
        # Move data to GPU if needed
        if not isinstance(X, torch.Tensor):
            X = torch.tensor(X, device=self.device)
        if not isinstance(Y, torch.Tensor):
            Y = torch.tensor(Y, device=self.device)

        # Standardize inputs
        X = std_matrix(X)
        Y = std_matrix(Y)

        # Compute radius and counts
        radius = self.compute_radius(X, Y)
        nx, ny = self.compute_counts(X, Y, radius)

        # Calculate MI using digamma function
        mi = (
            torch.digamma(torch.tensor(X.shape[0], device=self.device))
            + torch.digamma(torch.tensor(self.k, device=self.device))
            - torch.mean(torch.digamma(nx + 1))
            - torch.mean(torch.digamma(ny + 1))
        )
        return torch.clamp(mi, min=0)

class CPUMIComputer:
    def __init__(self, n_neighbors=3):
        self.k = n_neighbors
        self.threads = mp.cpu_count()

    def process_column(self, X, Y, n_neighbors, col_idx):
        """Process a single column for MI computation"""
        col_X = X[:, col_idx:col_idx+1]
        col_Y = Y[:, col_idx:col_idx+1]

        # Use sklearn's NearestNeighbors for efficient radius computation
        xy = np.hstack((col_X, col_Y))
        nn = NearestNeighbors(metric="chebyshev", n_neighbors=self.k)
        nn.fit(xy)
        radius = np.nextafter(nn.kneighbors()[0][:, -1], 0)

        # Use KDTree for efficient counting
        kdx = KDTree(col_X, metric="chebyshev")
        kdy = KDTree(col_Y, metric="chebyshev")
        nx = kdx.query_radius(col_X, radius, count_only=True, return_distance=False) - 1
        ny = kdy.query_radius(col_Y, radius, count_only=True, return_distance=False) - 1

        mi = (
            digamma(len(col_X))
            + digamma(n_neighbors)
            - np.mean(digamma(nx + 1))
            - np.mean(digamma(ny + 1))
        )
        return col_idx, max(0, mi)

    def compute_MI(self, X, Y):
        """Compute MI using parallel CPU processing"""
        n_cols = X.shape[1]
        mi_results = [None] * n_cols

        # Use ThreadPoolExecutor for parallel processing
        with ThreadPoolExecutor(max_workers=self.threads) as executor:
            futures = []
            for i in range(n_cols):
                future = executor.submit(self.process_column, X, Y, self.k, i)
                futures.append(future)

            for future in futures:
                col_idx, mi = future.result()
                mi_results[col_idx] = mi

        return mi_results

In [7]:
# Example usage
X = torch.randn(1000)
Y = torch.randn(1000)

mi_computer = OptimizedMIComputer(n_neighbors=3, batch_size=1000)
mi = mi_computer.compute_MI(X, Y)
print(f"Mutual Information: {mi.item()}")

mps
Mutual Information: 8.984471321105957
