In [103]:
from itertools import combinations
import numpy as np
from torch.utils.data import IterableDataset

from npeet import entropy_estimators as ee
import gcmi

from tqdm.notebook import tqdm

class systemPartsDataset(IterableDataset):
    def __init__(self, X, order):
        """
            X = T samples x N features
        """
        
        self.X = X
        self.order = order
        self.N = self.X.shape[1]
        self.linparts_generator = combinations(range(self.N), self.order)

    def __len__(self):
        n_fact = np.math.factorial(self.N)
        r_fact = np.math.factorial(self.order)
        n_r_fact = np.math.factorial(self.N - self.order)
        return int(n_fact / (r_fact * n_r_fact))

    def __iter__(self):
        """
            # T x N
        """
        for part in self.linparts_generator:
            yield part, len(part), self.X[:,part]


def o_information(X: np.ndarray, allMin1: np.ndarray, entropy_func):
    """
        X = multivariate feature of T samples x N variables
        allMin1 = matrix of mask to mask one variable at a time: N x N with Falses in the diagonal.
        Dis matrix could be computed but is as parameter for reutilization purposes
    """

    N = X.shape[1]
    joint_entropy = entropy_func(X)
    individual_entropies = sum(entropy_func(X[:,[idx]]) for idx in range(N))
    conditional_entropies = sum(entropy_func(X[:,idxs]) for idxs in allMin1)

    return (N - 2) * joint_entropy + individual_entropies - conditional_entropies


def multi_order_meas(data, entropy_func, min_n=2, max_n=None):
    """    
    data = T samples x N variables matrix
    """
    
    N = np.shape(data)[1]
    n = N if max_n is None else max_n

    assert n <= N, f"max_n must be lower than len(elids). {n} >= {N}"
    assert min_n <= n, f"min_n must be lower or equal than max_n. {min_n} > {n}"

    # To compute using pytorch, we need to compute each order separately
    
    for order in range(min_n, n+1):

        allmin1 = (np.ones((order, order)) - np.eye(order)).astype(bool)
        dataset = systemPartsDataset(data, order)
        pbar = tqdm(enumerate(dataset), total=len(dataset))

        for i, (indexes, size, X) in pbar:

            # X = T samples x order features

            pbar.set_description(f'Processing order {order} - {i}: computing nplets')
            o_info_res = o_information(X, allmin1, entropy_func)
            #print(indexes, o_info_res)

In [None]:
entropy_npeet = ee.entropy

def entropy_gcmi(X: np.ndarray):
    """
        X = multivariate feature of T samples x N variables
    """

    # X.T = N variables x T samples
    return gcmi.ent_g(X.T)

In [None]:
X = np.random.rand(1000, 34)
multi_order_meas(X, min_n=6, max_n=6)