# Utilities

## Imports

In [1]:
import torch
import pyro
import numpy as np



  from .autonotebook import tqdm as notebook_tqdm


## Misc

In [2]:

def entropy(dist, samples = False):
    """
        calculates entropy and joint entropy. If samples = True, returns torch array for each distribution sample.
    """
    dist = torch.Tensor(dist)

    # if distribution has 0 values, add small value to avoid log(0)
    dist += 1e-3

    if samples == False:

        if len(dist.shape) == 1:
        # expected value is just a dot product
            return -torch.dot(dist,torch.log2(dist))
        

        # if joint entropy is to be calculated
        else:
            n_variables = len(dist.shape)

            # dimensions of variables to reduce
            dims = tuple([i for i in range(n_variables)])

            return -(dist*torch.log2(dist)).sum(dim = dims)
    

    # assumes first dimension is samples
    else:
        n_variables = len(dist.shape)-1

        # dimensions of variables to reduce
        dims = tuple([i+1 for i in range(n_variables)])


        return -(dist*torch.log2(dist)).sum(dim = dims)



def renyi_entropy(dist, r = 1, samples = False):
    """calculates renyi entropy. Similar to DIT"""    
    dist = torch.Tensor(dist)
    
    
    if r < 0:
        msg = "`order` must be a non-negative real number"
        raise ValueError(msg)
    
    if r == 1:
        return entropy(dist, samples=samples)

    else:

        # if distribution has 0 values, add small value to avoid log(0)
        dist += 1e-3


        # if one dimensional distribution
        if len(dist.shape) == 1:
            return (1/(1-r))*torch.log2(torch.sum(dist**r))


        # assumes first dimension is for the samples
        elif samples:
            n_variables = len(dist.shape)

            # dimensions of variables to reduce
            dims = tuple([i+1 for i in range(n_variables)])
            
            return (1/(1-r))*torch.log2(torch.sum(dist**r, dim=dims))
        

        # joint renyi entropy
        else:
            n_variables = len(dist.shape)

            # dimensions of variables to reduce
            dims = tuple([i for i in range(n_variables)])
            
            return (1/(1-r))*torch.log2(torch.sum(dist**r, dim=dims))
        
            



def calculate__independent_joint(dist1, dist2):
    """ (n_samples, m), (n_samples, n) -> (n_samples, m,n)"""
    
    # squeeze distributions
    dist1 = dist1.squeeze()
    dist2 = dist2.squeeze()

    # adds a dimension for samples when only one sample is available
    if len(dist1.shape) == 1:
        dist1 = dist1.unsqueeze(0)

    if len(dist2.shape) == 1:
        dist2 = dist2.unsqueeze(0)

    # return joint probabilities
    return torch.bmm(dist1.unsqueeze(-1), dist2.unsqueeze(1))


# given joint distribution, calculate conditional
def calculate_conditional(dist, indices):
    
    # marginal distributions where certain dimensions are 1.
    marginals = torch.sum(dist, dim = indices, keepdim = True)

    # expand marginals to have similar shape as before
    marginals = marginals.expand(dist.shape)

    return dist/marginals
    




def common_entropy_oracle(x_y_distribution = torch.Tensor(), n_observed_states = (10,10), n_samples = 1, beta_values = np.linspace(0,0.1,50), iterations = 100):
    """find common entropy of distributions
        input is torch tensor of size n_samples*n*m 
    """

    latent_states = n_observed_states[0]

    # sample observed distributions. this is working properly.
    sampled_observed = x_y_distribution


    cmi_threshold = 0.001



    data = Data(n_observed_states=n_observed_states, 
                latent_entropy_threshold=1, model_type='latent')


    # create a null list of observed samples
    final_joints = torch.zeros(size = (n_samples, ) + n_observed_states + (latent_states,))


    z_entropies = 5*np.ones(n_samples)

    for beta in beta_values:
        latentsearch = LatentSearch(beta = beta, iterations = iterations)

        # initialize q(z|x,y)
        initializations = data.initialize_joint_conditional_distributions(n_samples=n_samples)


        joint = latentsearch.fit(observed_joint=sampled_observed, initialization=initializations)


        YZ = joint.sum(dim=1)
        XZ = joint.sum(dim=2)

        # calculate cmi
        cmi_ = cmi(XZ = XZ, YZ = YZ, XYZ = joint, Z = joint.sum(dim=(1,2)), samples = True)
        

        z = joint.sum(dim=(1,2))



        for i in range(n_samples):

            # check if cmi constraint is satisfied.
            if cmi_[i] <= cmi_threshold:
                
                z_entropy = entropy(z[i,:],samples=True)
                
                # if minimum, keep it.
                if z_entropy<z_entropies[i]:
                    final_joints[i,:,:,:] = joint[i,:,:,:]
                    z_entropies[i] = z_entropy





    #if some entropy never changed, its cmi was not found
    if z_entropies.max()>4:
        print("CMI condition was never satisfied")
        return ValueError



    # returns entropy of z
    return z_entropies



# returns wrong sized distribution. still using it for cmi because it's default
def calc_stats(data, var_size, weights=None):
    """
    Calculate the counts of instances in the data
    :param data: a dataset of categorical features
    :param var_size: a vector defining the cardinalities of the features
    :param weights: a vector of non-negative weights for data samples.
    :return:
    """
    sz_cum_prod = [1]
    for node_i in range(len(var_size) - 1):
        sz_cum_prod += [sz_cum_prod[-1] * var_size[node_i]]

    sz_prod = sz_cum_prod[-1] * var_size[-1]

    data_idx = np.dot(data, sz_cum_prod)



    try:
        # hist_count, _ = np.histogram(data_idx, np.arange(sz_prod + 1), weights=weights)
        hist_count = np.bincount(data_idx, minlength=sz_prod, weights=weights)
    except MemoryError as error:
        print('Out of memory')
        return None
    return hist_count



# calculate distribution from categorical data
def calc_joint_dist(data = np.array([]), states = (10,10)):
    """calculates empirical pdf of data with discrete states

            data = np.array(np.random.choice(3,(5,3)), dtype='int32')
            calc_joint_dist(data, states = (3,3,3))
            
    """
    joint_dist, _ = np.histogramdd(data, bins = states)
    joint_dist /= joint_dist.sum()

    return joint_dist



def cmi(XZ, YZ, XYZ, Z, samples = True):
    """Calculates the conditional mutual information (CMI) I(X;Y|Z)."""

    # Calculate joint and marginal entropies
    H_XYZ = entropy(XYZ, samples = samples)
    H_XZ = entropy(XZ, samples = samples)
    H_Z = entropy(Z, samples = samples)
    H_YZ = entropy(YZ, samples = samples)

    # Calculate CMI
    return H_XZ + H_YZ - H_XYZ - H_Z





## Data Generation


### Latent Search Sampling

In [3]:
# Class for distribution generation
class Data():
    """Samples distributions from a given latent graph"""

    def __init__(self, n_observed_states = (10, 10), latent_entropy_threshold = 1, model_type = 'latent'):

        self.n_latent_states = max(n_observed_states)
        self.n_observed_states = n_observed_states
        self.model_type = model_type
        self.latent_entropy_treshold = latent_entropy_threshold


    # sample from the latent distributions
    def sample_latent(self, n_samples):
        
        # sampling low entropy latent distributions
        iterations = 10

        N_samples_factor = 10

        low_entropy_samples = None

        for i in range(iterations):

            # latent distribution
            alpha = torch.Tensor(self.n_latent_states*[1/(2**i)])
            latent_distribution = torch.distributions.Dirichlet(alpha)


            # sample distribution 10N times
            latent_samples = latent_distribution.sample(sample_shape = (N_samples_factor*n_samples,))

            # calculate entropy for all
            latent_samples_entropy = renyi_entropy(latent_samples, r=1, samples=True)


            # find low entropy samples
            low_entropy_samples = latent_samples[latent_samples_entropy<=self.latent_entropy_treshold, :]

            # if count >= N, break
            if low_entropy_samples.shape[0] >= n_samples:
                return low_entropy_samples[0:n_samples,:]


        # if not found, run error
        raise Exception("did not find enough low entropy samples")



    # sample conditional using method in paper. returns in shape (n,x,y,z)
    # assumes x, y are independent
    def sample_conditional(self, n_samples = 10):

        if self.model_type == 'latent':
            
            # sample conditionals independently
            alpha_x = torch.Tensor(self.n_observed_states[0]*[1])
            alpha_y = torch.Tensor(self.n_observed_states[1]*[1])
            x_z_distribution = torch.distributions.Dirichlet(alpha_x)
            y_z_distribution = torch.distributions.Dirichlet(alpha_y)

            # P(X|Z). has shape n_samples, observed_states, latent_states
            x_z_samples = x_z_distribution.sample(sample_shape = (n_samples, self.n_latent_states,))
            x_z_samples = torch.transpose(x_z_samples, 1,2)

            # P(Y|Z)
            y_z_samples = y_z_distribution.sample(sample_shape = (n_samples, self.n_latent_states,))
            y_z_samples = torch.transpose(y_z_samples, 1,2)

            # P(X,Y|Z). 
            x_z_samples = x_z_samples.unsqueeze(2)
            y_z_samples = y_z_samples.unsqueeze(1)
            P_x_y_given_z = x_z_samples * y_z_samples 

            return P_x_y_given_z



    # sample P(x,y)
    def sample_observed(self, n_samples = 10):

        # P(X,Y,Z) = P(X,Y|Z)*P(Z)
        # has shape (n,X,Y,Z)
        x_y_condition_z = self.sample_conditional(n_samples=n_samples)

        # has shape (n,Z)
        z = self.sample_latent(n_samples = n_samples)

        # dimensions: n,x,y,z
        x_y_z = x_y_condition_z*z.unsqueeze(1).unsqueeze(2)

        # marginalize over z dimension
        return torch.sum(x_y_z, dim=3), z
    



    def initialize_joint_conditional_distributions(self, n_samples = 10):
        """ initialization for latentsearch. returns shape (n, x, y, z)"""
        alpha = torch.Tensor(self.n_latent_states*[1])

        distribution = torch.distributions.Dirichlet(alpha)

        # note that this has shape n, x, y, z still.
        latent_conditional_samples = distribution.sample(sample_shape = (n_samples, self.n_observed_states[0], self.n_observed_states[1]))

        return latent_conditional_samples
    


### Causality-Lab Data Sampling

In [4]:


# sample from the latent distributions
def sample_latent_distributions(n_latent_states, latent_entropy_threshold, n_samples):
    """returns n_samples low_entropy samples from a Dirichlet distribution"""
    # sampling low entropy latent distributions
    iterations = 10

    N_samples_factor = 10

    low_entropy_samples = None

    for i in range(iterations):

        # latent distribution
        alpha = torch.Tensor(n_latent_states*[1/(2**i)])
        latent_distribution = torch.distributions.Dirichlet(alpha)


        # sample distribution 10N times
        latent_samples = latent_distribution.sample(sample_shape = (N_samples_factor*n_samples,))

        # calculate entropy for all
        latent_samples_entropy = renyi_entropy(latent_samples, r=1, samples=True)

        # find low entropy samples
        low_entropy_samples = latent_samples[latent_samples_entropy<=latent_entropy_threshold, :]


        # if count >= N, break
        if low_entropy_samples.shape[0] >= n_samples:
            return low_entropy_samples[0:n_samples,:]


    # if not found, run error
    raise Exception("did not find low entropy samples")



# generate graph
# similar to causality-lab
def sample_data_from_dag_dirichlet(in_dag, num_samples, states):
    """
    :param num_samples: number of samples (dataset records)
    :param states: number of possible states for each variable in the dag
    :return: Sampled dataset in the form of a 2D NumPy array
    """


    # start with empty node. since we're running through the topological order, no problem should ocurr.
    data = np.empty((num_samples, len(in_dag.nodes_set)))


    topological_order = in_dag.find_topological_order()


    
    for node in topological_order:


        # find parents
        parents_set = in_dag.parents(node)


        # if no parent, use low-entropy sampling
        if len(parents_set) == 0:

            # find a distribution. cast np array because it returns torch
            distributions = np.repeat(np.array(sample_latent_distributions(states[node], latent_entropy_threshold=1, n_samples=1)), num_samples, axis =0)

        # use Chickering and Meek Method for sampling with dependencies
        else:

            # find identifiers for each sample
            current_identifiers = find_indentifiers(data, states, parents_set, num_samples)

            # sample dirichlet respecting parent set
            # construct mean vector: 
            mean_vector = np.array([1/(i+1) for i in range(states[node])])
            
            # cyclic shift
            # sample dirichlet with mean vector
            mean_vectors = [np.roll(mean_vector, current_identifiers[i]) for i in range(num_samples)]
            
            # sample dirichlet for mean_vectors[i]
            distributions = [np.random.dirichlet(mean_vectors[i]) for i in range(num_samples)]
            distributions = np.array(distributions).squeeze()



        
        
        # sample from distributions
        data[:,node] = vectorized_sampling(distributions)



    return data.astype(int)






def find_indentifiers(data, states, parents_set, num_samples):
    """Finds identifiers of the distribution conditioned on the parent set. 
    return shape: (n_samples,)"""

    # Generate possible identifiers based on the states of the parents
    total_configurations = np.prod([states[parent] for parent in parents_set])

    # shape (n, m, n, k) for parents with possible states m, n, and k. For better access to value of identifier.
    identifiers = np.arange(total_configurations).reshape((1,) + tuple(states[parent] for parent in parents_set))
    identifiers = np.repeat(identifiers, repeats=num_samples, axis=0)


    parent_values = np.stack([data[:, parent] for parent in parents_set], axis=1)


    samples = np.arange(num_samples)


    # shape: (num_samples, 1 + len(parents_set))
    indices = np.column_stack((samples, parent_values)).astype(int)


    current_identifiers = identifiers[tuple(indices.T)]
    
    return current_identifiers







def vectorized_sampling(distributions):
    """returns samples from given categorical distributions in form (n_samples, n_states)
    source: https://stackoverflow.com/questions/47722005/vectorizing-numpy-random-choice-for-given-2d-array-of-probabilities-along-an-a"""
    
    n_samples = distributions.shape[0]

    # returns n samples from uniform distribution between 0 and 1
    r = np.random.rand(n_samples)
    
    # cumsum(1): cdf
    # argmax: first place when we're above the cdf
    return (distributions.cumsum(1) > r[:,None]).argmax(1)


## Latent Search

In [5]:
# calculates joint (renyi_{1}-common entropy)

class LatentSearch():
    
    def __init__(self, beta = 1, iterations = 100):
        self.beta = beta
        self.data_over_time = []
        self.iterations = iterations


    def fit(self, observed_joint = torch.Tensor([]), initialization = torch.Tensor([])):
        
        """    
        calculates q(z|x,y) from data. Best possible distribution of z given x and y, while keeping x and y's mutual information minimum

        observed joint has shape (n, x, y). Initialization has shape (n, x, y, z)"""

        # if not one sample, add one.
        n_samples = observed_joint.shape[0]
        latent_states = initialization.shape[3]
        x_states = observed_joint.shape[1]
        y_states = observed_joint.shape[2]


        # q(z|x,y)
        z_conditioned_x_y = initialization

        # add dimension to observed. this is for expanding the last dimension
        observed = observed_joint.unsqueeze(-1)



        for i in range(self.iterations):

        # form joint.

            # joint has shape (n,x,y,z)
            joint = z_conditioned_x_y*observed.expand(-1, -1, -1, latent_states)

            # calculates q(z|x), q(z|y), and q(z)
            q_z_x, q_z_y, q_z = self.calculate(joint)

            # numerical convergence
            small_value = 1e-2
            q_z_x += (q_z_x<small_value).float()*small_value
            q_z_y += (q_z_y<small_value).float()*small_value
            q_z += (q_z<small_value).float()*small_value

            # updates q(z|x,y)
            z_conditioned_x_y = self.update(q_z_x, q_z_y, q_z)

            # save update?

        small_value = 1e-4
        z_conditioned_x_y += (z_conditioned_x_y).float()*small_value


        # joint has shape (n,x,y,z)
        joint = z_conditioned_x_y*observed.expand(-1, -1, -1, latent_states) 

    
        return joint

        
    # joint has shape (n,x,y,z)
    def calculate(self, joint):

        # q(z|x). has shape (n, x, 1, z)
        q_z_x = torch.sum(joint, dim=2, keepdim=True)/torch.sum(joint, dim = (2,3), keepdim=True)

        # q(z|y). has shape (n, 1, y, z)
        q_z_y = torch.sum(joint, dim=1, keepdim=True)/torch.sum(joint, dim = (1,3), keepdim=True)

        # q(z). has shape (n, 1, 1, z)
        q_z = torch.sum(joint, dim=(1,2), keepdim=True)


        return q_z_x, q_z_y, q_z



    # update joint distribution
    def update(self, q_z_x, q_z_y, q_z):


        x_states = q_z_x.shape[1]
        y_states = q_z_y.shape[2]
        z_states = q_z.shape[3]

        # q(z|x)q(z|y). has shape (n, x, y, z)
        numerator = q_z_x * q_z_y
        
        # # add term to denominator to avoid division by zero. has shape (n, 1, 1, z)
        denominator = (q_z**(1 - self.beta))
        # denominator = (q_z**(1 - self.beta))

        # has shape (n, x, y, z)
        update = numerator/denominator
    

        # has shape (n, x, y, 1)
        normalization = torch.sum(update, dim = 3, keepdim=True)


        # returns q(z|x,y)
        return (1/normalization) * update


    
        

## EntropicPC

In [6]:
from itertools import combinations, chain, tee
from causality_lab.causal_discovery_utils.constraint_based import LearnStructBase, unique_element_iterator
from causality_lab.graphical_models import PDAG

import numpy as np


class LearnStructPC(LearnStructBase):
    """Entropic Version of the Causality-lab implementation of the PC algorithm"""

    def __init__(self, nodes_set, ci_test, entropic = False, F = False, data = None, states = None):
        super().__init__(PDAG, nodes_set=nodes_set, ci_test=ci_test)
        self.graph.create_complete_graph(nodes_set)  # Create a fully connected graph
        self.overwrite_starting_graph = True  # if True, the sequence at which the CIs are tested affects the result
        
        # maysep
        self.maysep = np.ones((len(nodes_set), len(nodes_set)))
        
        self.entropic = entropic
        self.F = F
        
        
        self.data = data

        # find values, states
        self.states = states

        # find empirical distribution of data
        self.distribution = calc_joint_dist(data, states=self.states)
        

        
    def learn_structure(self):
        """
        Learn a CPDAG (completed partially directed graph) using the PC algorithm
        :return:
        """

        self.learn_skeleton()

        self.orient_v_structures()
        self.graph.convert_bidirected_to_undirected()  # treat bi-directed (spurious) as undirected

        # meek rules
        self.graph.maximally_orient_pattern([1, 2, 3])


    def _exit_cond(self, order):
        """
        Check if the max fan-in is lower or equal to the order (exit-cond. is met)
        :param order: condition set size of the CI-test
        :return: True if exit condition is met
        """
        for node in self.graph.nodes_set:
            if self.graph.fan_in(node) > order:  # if a node have a large enough number of parents, exit cond. is false
                return False
        else:
            return True  # didn't find a node with a large enough number of parents for CI test, so exit




    def learn_skeleton(self):
        cond_indep = self.ci_test.cond_indep

        if self.overwrite_starting_graph:
            source_cpdag = self.graph  # Not a copy!!! thus, edge deletions affect consequent CI queries
        else:
            source_cpdag = self.graph.copy()  # slower, but removes the dependence on the sequence of CI testing

        cond_set_size = 0
        while not self._exit_cond(cond_set_size):
            for node_i, node_j in combinations(source_cpdag.nodes_set, 2):
                if not source_cpdag.is_connected(node_i, node_j):
                    continue

                # maysep condition
                if not self.maysep[node_i][node_j]: 
                    continue

                pot_parents_i = source_cpdag.undirected_neighbors(node_i) - {node_j}
                pot_parents_j = source_cpdag.undirected_neighbors(node_j) - {node_i}
                cond_sets_i = combinations(pot_parents_i, cond_set_size)
                cond_sets_j = combinations(pot_parents_j, cond_set_size)
                cond_sets = unique_element_iterator(  # unique of
                    chain(cond_sets_i, cond_sets_j)  # neighbors of node_i OR neighbors of node_j
                )



                for cond_set in cond_sets:
                    if cond_indep(node_i, node_j, cond_set):

                        # if using entropicPC algorithm
                        if self.entropic:
                            # put cond_sets into a set for index selection
                            all_indices = list(np.arange(len(self.distribution.shape)))                        

                            # calculate distributions
                            cond_set_dist = np.sum(self.distribution, axis = tuple(set(all_indices) - set(cond_set)))
                            nodes_dist = np.sum(self.distribution, axis = tuple(set(all_indices) - {node_i, node_j}))
                            
                            # joint entropy of the conditioning set
                            if len(cond_set) == 0:
                                conditional_set_joint_entropy = 0.1*min(entropy(np.sum(nodes_dist, axis=1)), entropy(np.sum(nodes_dist, axis = 0)))
                                # print("conditional len is 0", conditional_set_joint_entropy)
                            else:
                                conditional_set_joint_entropy = entropy(cond_set_dist)
                                # print("conditional len is not 0", conditional_set_joint_entropy)

  
                            
                            
                            # common entropy of nodes i and j
                            # this won't work if you're dealing with more than 1 sample!!!!
                            common_entropy = common_entropy_oracle(torch.Tensor(nodes_dist).unsqueeze(0), 
                                                                   n_observed_states = (self.states[node_i], self.states[node_j]), 
                                                                   n_samples = 1, beta_values = np.linspace(0,0.1,50), iterations = 500)
                            common_entropy = np.array(common_entropy)[0]

                            # if common entropy condition is satisfied
                            if conditional_set_joint_entropy >= common_entropy:
                                self.graph.delete_edge(node_i, node_j)  # remove directed/undirected edge
                                self.sepset.set_sepset(node_i, node_j, cond_set)
                                break  # stop searching for independence as we found one and updated the graph accordingly

                            elif self.F == True: 
                                # set mayseps of both node directions to False

                                self.maysep[node_j,node_i] = False
                                self.maysep[node_i, node_j] = False


                            # else if common_entropy(x,y) >= 0.8*min(entropy(x), entropy(y))
                            elif common_entropy >= 0.8*min(entropy(np.sum(nodes_dist, axis=1)), entropy(np.sum(nodes_dist, axis = 0))):
                                self.maysep[node_i, node_j] = False
                                self.maysep[node_j, node_i] = False

                        
                        else:                                                        
                            self.graph.delete_edge(node_i, node_j)  # remove directed/undirected edge
                            self.sepset.set_sepset(node_i, node_j, cond_set)
                            break  # stop searching for independence as we found one and updated the graph accordingly



            cond_set_size += 1  # now go again over all the edges and try to remove using a condition set size +1





    def orient_v_structures(self):
        # ToDo: Move this function to the PDAG class
        # create a copy of edges
        pre_neighbors = dict()
        for node in self.graph.nodes_set:
            pre_neighbors[node] = self.graph.undirected_neighbors(node).copy()  # undirected neighbors pre graph changes

        # check each node if it can serve as new collider for a disjoint neighbors
        for node_z in self.graph.nodes_set:
            # check undirected neighbors
            xy_nodes = pre_neighbors[node_z]  # undirected neighbors
            for node_x, node_y in combinations(xy_nodes, 2):
                if self.graph.is_connected(node_x, node_y):
                    continue  # skip this pair as they are connected
                if node_z not in self.sepset.get_sepset(node_x, node_y):
                    self.graph.orient_edge(source_node=node_x, target_node=node_z)  # orient X --> Z
                    self.graph.orient_edge(source_node=node_y, target_node=node_z)  # orient Y --> Z


# Experiments

## Latent Search

In [8]:
n_samples = 200
n_observed_states = (5, 5)
latent_states = n_observed_states[0]
latent_entropy_threshold = 1
model_type = 'latent'

data = Data(n_observed_states=n_observed_states, 
            latent_entropy_threshold=1, model_type='latent')


# sample observed distributions. this is working properly.
sampled_observed, z = data.sample_observed(n_samples=n_samples)



# feed into latentsearch
beta_values = np.linspace(0, 0.1, 50)

joint_distributions = []
cmi_threshold = 0.001




cmi_values = []

# create a null list of observed samples
final_joints = torch.zeros(size = (n_samples, ) + n_observed_states + (latent_states,))

# upper bound for z entropies. This is used to find the minimum z entropy for each sample
z_entropies = 5*np.ones(n_samples)


for beta in beta_values:
    latentsearch = LatentSearch(beta = beta, iterations = 500)

    # initialize q(z|x,y)
    initializations = data.initialize_joint_conditional_distributions(n_samples=n_samples)


    joint = latentsearch.fit(observed_joint=sampled_observed, initialization=initializations)


    YZ = joint.sum(dim=1)
    XZ = joint.sum(dim=2)

    # calculate cmi
    cmi_ = cmi(XZ = XZ, YZ = YZ, XYZ = joint, Z = joint.sum(dim=(1,2)), samples = True)
    

    z = joint.sum(dim=(1,2))


    # for each sample
    for i in range(n_samples):

        # check if cmi constraint is satisfied.
        if cmi_[i] <= cmi_threshold:
            
            # z_entropy = entropy(z[i,:],samples=True)
            z_entropy = entropy(z[i,:], samples=True)
            
            # if minimum, keep it.
            if z_entropy<z_entropies[i]:
                final_joints[i,:,:,:] = joint[i,:,:,:]
                z_entropies[i] = z_entropy





#if some entropy never changed, its cmi was not found
if z_entropies.max()>4:
    print("CMI condition was never satisfied")
    print(z_entropies)




In [9]:
print((z_entropies<1).sum())


188


## EntropicPC

In [None]:
# create dag

from causality_lab.graphical_models import PDAG
from causality_lab.causal_discovery_utils.cond_indep_tests import CondIndepCMI
from causality_lab.causal_discovery_utils.performance_measures import structural_hamming_distance_cpdag as SHD

nodes_set = set(range(3))
dag = DAG(nodes_set)

pdag_copy = PDAG(nodes_set)

# latent graph
dag.add_edges({0}, 1)
dag.add_edges({0}, 2)
# dag.add_edges({1}, 2)

# give it to sample data from dirichlet
states = [5,5,5]
data = sample_data_from_dag_dirichlet(dag, 50000, states=states)


print(data.shape)

# run it.
ci_test = CondIndepCMI(data, threshold=0.001)


entropicpc = LearnStructPC(nodes_set = nodes_set, ci_test = ci_test, entropic = True, F = False, data = data, states = states)



entropicpc.learn_structure()

# to compute SHD, both graphs need to be CPDAG
dag.convert_to_cpdag(pdag_copy)
