In [1]:
import numpy as np
from scipy.special import psi
from math import log, exp

def discrete_continuous_info(d, c, k=3, base=exp(1)):
    """
    Estimates mutual information between a discrete vector 'd' and continuous vectors 'c'.
    
    Args:
        d (np.ndarray): Discrete array of shape (1, N) or (N,)
        c (np.ndarray): Continuous data of shape (M, N) where M is feature dim
        k (int): Number of nearest neighbors
        base (float): Logarithm base (default: natural log)
    
    Returns:
        f (float): Estimated mutual information
        V (np.ndarray): Volume estimates for each sample
    """
    d = np.asarray(d)
    c = np.asarray(c)

    if d.ndim == 1:
        d = d.reshape(1, -1)
    if c.ndim == 1:
        c = c.reshape(1, -1)

    N = c.shape[1]
    symbol_IDs = np.zeros(N, dtype=int)
    first_symbol = []
    c_split = []
    cs_indices = []
    num_d_symbols = 0

    # Bin continuous data by discrete symbols
    for i in range(N):
        found = False
        for j in range(num_d_symbols):
            if d[:, i] == d[:, first_symbol[j]]:
                symbol_IDs[i] = j
                found = True
                break
        if not found:
            symbol_IDs[i] = num_d_symbols
            first_symbol.append(i)
            c_split.append([])
            cs_indices.append([])
            num_d_symbols += 1

        c_split[symbol_IDs[i]].append(c[:, i])
        cs_indices[symbol_IDs[i]].append(i)

    # Convert to numpy arrays
    for i in range(num_d_symbols):
        c_split[i] = np.column_stack(c_split[i])

    m_tot = 0
    av_psi_Nd = 0
    V = np.zeros(N)
    all_c_distances = np.zeros(N)
    psi_ks = 0

    for c_bin in range(num_d_symbols):
        n_bin = c_split[c_bin].shape[1]
        one_k = min(k, n_bin - 1)

        if one_k > 0:
            for pivot in range(n_bin):
                # Compute distances within bin
                c_pivot = c_split[c_bin][:, pivot]
                c_distances = np.linalg.norm(c_split[c_bin] - c_pivot[:, None], axis=0)
                sorted_distances = np.sort(c_distances)
                eps_over_2 = sorted_distances[one_k + 1]  # skip pivot

                # Count total points within volume
                all_distances = np.linalg.norm(c - c_pivot[:, None], axis=0)
                m = max(np.sum(all_distances <= eps_over_2) - 1, 0)

                m_tot += psi(m)
                V[cs_indices[c_bin][pivot]] = (2 * eps_over_2) ** d.shape[0]
        else:
            m_tot += psi(num_d_symbols * 2)

        p_d = n_bin / N
        av_psi_Nd += p_d * psi(p_d * N)
        psi_ks += p_d * psi(max(one_k, 1))

    f = (psi(N) - av_psi_Nd + psi_ks - m_tot / N) / log(base)
    return f, V


In [8]:
import numpy as np
from scipy.spatial import distance
from scipy.special import digamma
from math import log, exp

def discrete_continuous_info_2(d, c, k=3, base=np.e):
    """
    Estimates the mutual information between a discrete vector `d`
    and a continuous matrix `c` using a nearest-neighbor approach.
    
    Parameters:
    d (np.ndarray): 1D array of discrete labels (shape: [n_samples])
    c (np.ndarray): 2D array of continuous data (shape: [n_features, n_samples])
    k (int): number of nearest neighbors (default: 3)
    base (float): log base for the mutual information (default: natural log)
    
    Returns:
    f (float): estimated mutual information
    V (np.ndarray): estimated local volume per point
    """
    n_samples = c.shape[1]
    symbol_IDs = np.zeros(n_samples, dtype=int)
    c_split = {}
    cs_indices = {}
    unique_labels = []
    num_d_symbols = 0

    # Bin the data by discrete labels
    for i in range(n_samples):
        label = d[i]
        if label not in unique_labels:
            unique_labels.append(label)
            num_d_symbols += 1
            c_split[label] = []
            cs_indices[label] = []
        c_split[label].append(c[:, i])
        cs_indices[label].append(i)

    for label in c_split:
        c_split[label] = np.column_stack(c_split[label])  # Convert to numpy array

    m_tot = 0
    av_psi_Nd = 0
    V = np.zeros(n_samples)
    psi_ks = 0

    for label in c_split:
        group_c = c_split[label]
        group_indices = cs_indices[label]
        n_group = group_c.shape[1]
        one_k = min(k, n_group - 1)

        if one_k > 0:
            for pivot in range(n_group):
                pivot_vec = group_c[:, pivot].reshape(-1, 1)
                diffs = group_c - pivot_vec
                distances = np.linalg.norm(diffs, axis=0)
                sorted_distances = np.sort(distances)
                eps_over_2 = sorted_distances[one_k + 1]  # k-th neighbor (skip self)

                # Count neighbors in all data within this radius
                full_diffs = c - pivot_vec
                all_distances = np.linalg.norm(full_diffs, axis=0)
                m = max(np.sum(all_distances <= eps_over_2) - 1, 0)
                m_tot += digamma(m + 1)  # +1 for numerical stability
                V[group_indices[pivot]] = (2 * eps_over_2) ** c.shape[0]
        else:
            m_tot += digamma(num_d_symbols * 2)

        p_d = n_group / n_samples
        av_psi_Nd += p_d * digamma(n_group)
        psi_ks += p_d * digamma(max(one_k, 1))

    f = (digamma(n_samples) - av_psi_Nd + psi_ks - m_tot / n_samples) / log(base)
    return f, V


In [None]:
import numpy as np
from scipy.special import digamma
from sklearn.neighbors import NearestNeighbors
import sys

np.set_printoptions(threshold=sys.maxsize)

def compute_kmu(x, y, per_filter=True, avarage=True, n_neighbors=3):
    """Compute mutual information between continious and discrete variables
    :parameter
    x : ndarray, shape (batch_size, n_filters, height, width)
         4d continious variable,
    y : ndarray,  shape (batch_size, )
        1d discrete variable
    per_filter : bool,
        Whether to calculate mu between each 3d filter and discrete variable, or full 4d tensor and discrete variable
    avarage : bool,
        In case of per_filter=True, average the result or not
    n_neighbors: int,
        Number of nearest neighbors to search for each point
     :returns
     kmu : float, or list of floats (depends on per_filter parameter),
        Estimated mutual information
     """
    if per_filter:
        filters_count = x.shape[1]
        x = x.reshape(x.shape[0], x.shape[1], -1)
        kmu = [mu_approximate(x[:, i, :], y, n_neighbors=n_neighbors) for i in range(filters_count)]
    else:
        x = x.reshape(x.shape[0], -1)
        kmu = mu_approximate(x, y, n_neighbors=n_neighbors)

    if avarage:
        kmu = np.mean(kmu)

    return kmu


def nn_sklearn(x, k):
    """Compute nearest neighbors for each point in the given set using sklearn's NearestNeighbors
    :parameter
    x : ndarray, shape (n_samples, )
        Set of points
    k : int,
        Number of nearest neighbors to search for each point
    :returns
    d : ndarray, shape (n_samples, n_neighbors)
        Distances between the point and each neighbor for each point
    """
    nn = NearestNeighbors(n_neighbors=k+1)
    nn.fit(x)
    d, i = nn.kneighbors(x)
    return d


def binary_search(arr, value):
    """Binary search to find the insertion index in a sorted array"""
    low, high = 0, len(arr)
    while low < high:
        mid = (low + high) // 2
        if arr[mid] < value:
            low = mid + 1
        else:
            high = mid
    return low


def mu_approximate(c, d, n_neighbors):
    """Mutual information calculation based on approximate nearest neighbors
    :parameter
    c : ndarray, shape (n_samples,)
        Samples of a continuous random variable.
    d : ndarray, shape (n_samples,)
        Samples of a discrete random variable.
    n_neighbors : int
        Number of nearest neighbors to search for each point
    :returns
    mi : float
        Estimated mutual information. If it turned out to be negative it is
        replaced by 0.
    Notes
    -----
    True mutual information can't be negative. If its estimate by a numerical
    method is negative, it means (providing the method is adequate) that the
    mutual information is close to 0 and replacing it by 0 is a reasonable
    strategy.
    """
    n_samples = c.shape[0]
    radius = np.empty(n_samples)
    label_counts = np.empty(n_samples)
    k_all = np.empty(n_samples)

    for label in np.unique(d):
        mask = (d == label).reshape(-1)
        count = np.sum(mask)
        if count > n_neighbors + 1:
            k = min(n_neighbors, count - 1)
            dist = nn_sklearn(c[mask, :], k=k)
            radius[mask] = np.nextafter(dist[:, -1], 0)

            k_all[mask] = k
        label_counts[mask] = count

    # Ignore points with unique labels.
    mask = label_counts > 1
    n_samples = np.sum(mask)
    label_counts = label_counts[mask]
    k_all = k_all[mask]
    c = c[mask]
    radius_sklearn = radius[mask]

    # Find nearest neighbors (at max 100) using sklearn
    D = nn_sklearn(c, k=100)
    idc_counts = np.array([max(0, binary_search(D[i], radius_sklearn[i])) for i in range(c.shape[0])])

    mi = (digamma(n_samples) + np.mean(digamma(k_all)) -
          np.mean(digamma(label_counts)) -
          np.mean(digamma(idc_counts + 1)))

    # Mutual information cannot be too high. It means that approximation gave bad results.
    if mi > 100:
        mi = -1
    return max(0, mi)


In [62]:
import numpy as np
import matplotlib.pyplot as plt

# --- Step 1: Generate synthetic data ---
np.random.seed(42)
N = 300  # number of samples
dim = 1  # dimensionality of continuous variable

# Create 3 discrete classes
n_classes = 3
samples_per_class = N // n_classes

c = []
d = []

for i in range(n_classes):
    # Continuous values clustered around different means
    mean = np.random.randn(dim) * 5
    cov = np.eye(dim) * 0.5
    cluster = np.random.multivariate_normal(mean, cov, size=samples_per_class).T
    c.append(cluster)
    d += [i] * samples_per_class

# Stack all continuous samples into a (dim, N) array
c = np.hstack(c)
d = np.array(d).reshape(1, -1)

# --- Step 2: Estimate mutual information ---
mi, V = discrete_continuous_info(d, c, k=3, base=np.e)
print(f"Estimated Mutual Information I(D; C): {mi:.4f}")
mi, V = discrete_continuous_info_2(d.reshape(-1), c, k=3, base=np.e)
print(f"Estimated Mutual Information I(D; C): {mi:.4f}")
mi, V = mu_approximate(c.reshape(-1), d.reshape(-1), 3)
print(f"Estimated Mutual Information I(D; C): {mi:.4f}")

# --- Step 3: Optional - Visualize data ---
plt.figure(figsize=(6, 5))
colors = ['red', 'green', 'blue']
for i in range(n_classes):
    class_data = c[:, d.flatten() == i]
    plt.scatter(class_data[0], class_data[1], label=f'Class {i}', alpha=0.7, color=colors[i])
plt.title("Synthetic 2D Continuous Data by Class")
plt.xlabel("x1")
plt.ylabel("x2")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()


Estimated Mutual Information I(D; C): 0.7523
Estimated Mutual Information I(D; C): 0.5042


ValueError: Expected 2D array, got 1D array instead:
array=[2.38580314 2.94155572 3.5605155  2.31799933 2.31801094 3.60024286
 3.02622907 2.15160224 2.86721865 2.15588497 2.1542501  2.65466393
 1.13067733 1.26386967 2.08597344 1.76739101 2.70577718 1.84150078
 1.48492124 3.51994095 2.32392281 2.53132042 1.47612166 2.09863405
 2.56200488 1.6696954  2.74922938 2.05885507 2.27731214 2.05809994
 3.79332923 2.47402679 1.73565619 3.06519785 1.62030394 2.63125963
 1.09787473 1.5444014  2.62277268 3.00574549 2.60474644 2.40179508
 2.2706583  1.43809784 1.97456404 2.15784997 3.23106906 2.72654559
 1.23691312 2.71273274 2.21127647 2.00491463 2.91609122 3.21259752
 3.14208525 1.89015436 2.2649246  2.71780938 3.17338534 2.14474341
 2.35229004 1.7012738  1.63772495 3.05811328 3.44257729 2.43265192
 3.19317568 2.73928605 2.02740221 2.73911605 3.57112685 2.45823793
 3.5899409  0.63113124 3.0647436  2.54512234 2.27214064 2.54845543
 1.07814731 2.32823928 2.73608749 3.52859967 2.11709838 1.91187946
 2.12877496 3.13085781 2.7160329  2.10897373 2.84650565 2.55221496
 3.16850621 1.98714426 2.25187864 2.20630843 1.44870942 2.69295942
 2.66816472 2.48718652 2.31769261 1.48275252].
Reshape your data either using array.reshape(-1, 1) if your data has a single feature or array.reshape(1, -1) if it contains a single sample.

In [58]:
batch_size = 100
c = np.repeat(np.random.rand(int(batch_size / 10), 10, 64, 64).astype(np.float32), 10, axis=0)
d = np.random.randint(0, 5, batch_size)

print(c.shape, d.shape)
print(mu_approximate(c, d, 3))


(100, 10, 64, 64) (100,)
[False False False False False False False False False False False False
  True  True False False False False  True False False False  True False
 False False False False False False False  True False False False False
 False  True  True False  True False False False False False False False
 False False False False False False False False False False False False
 False  True False False  True False False  True False False  True False
 False  True False False False False False False False  True False False
 False  True False  True False False False False  True False False False
  True False  True False]


ValueError: Found array with dim 4. NearestNeighbors expected <= 2.

In [4]:
import numpy as np
from scipy.special import psi


def findpt(c_sorted, target):
    """
    Find the "index" of where `target` would go in sorted array `c_sorted`.
    This mimics MATLAB's 0.5 behavior between indices.
    """
    if target < c_sorted[0]:
        return 0.5
    elif target > c_sorted[-1]:
        return len(c_sorted) + 0.5
    else:
        return np.searchsorted(c_sorted, target, side='left')


def discrete_continuous_info_fast(d, c, k=3, base=np.e):
    d = np.asarray(d).flatten()
    c = np.asarray(c).flatten()
    
    assert len(d) == len(c), "d and c must be the same length"

    # Sort c and reorder d accordingly
    idx = np.argsort(c)
    c = c[idx]
    d = d[idx]

    # Bin the continuous data by discrete labels
    unique_labels, symbol_IDs = np.unique(d, return_inverse=True)
    num_d_symbols = len(unique_labels)

    c_split = [[] for _ in range(num_d_symbols)]
    cs_indices = [[] for _ in range(num_d_symbols)]
    
    for i in range(len(d)):
        label_id = symbol_IDs[i]
        c_split[label_id].append(c[i])
        cs_indices[label_id].append(i)

    c_split = [np.array(cs) for cs in c_split]
    V = np.zeros(len(d))
    m_tot = 0
    av_psi_Nd = 0
    psi_ks = 0

    for c_bin in range(num_d_symbols):
        bin_c = c_split[c_bin]
        indices = cs_indices[c_bin]
        bin_len = len(bin_c)
        one_k = min(k, bin_len - 1)

        if one_k > 0:
            for pivot in range(bin_len):
                one_c = bin_c[pivot]
                left = pivot
                right = pivot

                # Find the k-th nearest neighbor (1D)
                for _ in range(one_k):
                    if left == 0:
                        right += 1
                    elif right == bin_len - 1:
                        left -= 1
                    else:
                        if abs(bin_c[left - 1] - one_c) < abs(bin_c[right + 1] - one_c):
                            left -= 1
                        else:
                            right += 1

                # Radius of neighborhood
                distance_to_neighbor = abs(bin_c[right] - one_c) if right > pivot else abs(bin_c[left] - one_c)

                # Count number of total samples in full c within the same radius
                if right > pivot:
                    m = int(np.floor(findpt(c, bin_c[right]) - findpt(c, one_c - distance_to_neighbor)))
                else:
                    m = int(np.floor(findpt(c, one_c + distance_to_neighbor) - findpt(c, bin_c[left])))

                if m < one_k:
                    m = one_k

                m_tot += psi(m)
                V[indices[pivot]] = 2 * distance_to_neighbor
        else:
            m_tot += psi(num_d_symbols * 2)
            V[indices[0]] = 2 * (c[-1] - c[0])

        p_d = bin_len / len(d)
        av_psi_Nd += p_d * psi(p_d * len(d))
        psi_ks += p_d * psi(max(one_k, 1))

    f = (psi(len(d)) - av_psi_Nd + psi_ks - m_tot / len(d)) / np.log(base)
    return f, V


In [7]:
# Example with synthetic data
np.random.seed(1)
N = 300
d = np.random.choice([0, 1, 2], size=N)
c = np.random.randn(N) + d  # continuous variable depends on discrete label

mi, V = discrete_continuous_info_fast(d, c)
print(f"Estimated MI (fast version): {mi:.4f}")

Estimated MI (fast version): 0.4344
