In [104]:
from scipy.sparse.csgraph import minimum_spanning_tree as mst
from scipy.sparse.csgraph import breadth_first_order as bfo
from scipy.special import logsumexp
import numpy as np
import csv

# Calcs mutual information

In [105]:
def compute_mi(cols):
    cols[:, 0] *= 2
    names, occurs = np.unique(np.sum(cols, axis=1), return_counts=True)
    namelen = len(names)
    # if only 2 marginals, variables must be independent so mutual information = 0
    if namelen == 2:
        return 0
    
    # Use small value alpha to substitute the 0-values
    if namelen == 3:
        occurs = np.insert(occurs.astype('float'), np.argmin(np.isin([0,1,2,3], names)),0.01)

        
    probs = np.log(occurs / sum(occurs))

    
    px0 = logsumexp([probs[0], probs[1]])
    px1 = logsumexp([probs[2], probs[3]])
    py0 = logsumexp([probs[0], probs[2]])
    py1 = logsumexp([probs[1], probs[3]])

    mi = np.exp(probs[0]) * (probs[0] - (px0+ py0)) + \
         np.exp(probs[1]) * (probs[1] - (px0+ py1)) + \
         np.exp(probs[2]) * (probs[2] - (px1+ py0)) + \
         np.exp(probs[3]) * (probs[3] - (px1+ py1))
    
    return np.log(mi)

# Calcs the probabilities

In [106]:
def prob_col(col, val, alpha):
    return np.log(2*alpha*len(col[col == val]))-np.log(4*alpha*len(col))


def joint_prob_col(both_cols, val1, val2, alpha):
    return np.log(alpha*len(both_cols[np.sum(both_cols == (val1, val2), axis=1) == 2]))-np.log(4*alpha*len(both_cols))


def cond_prob_col(both_cols, vals, alpha):
    return joint_prob_col(both_cols, vals[0], vals[1], alpha)-prob_col(both_cols[:, 0], vals[0], alpha)

# The actual class

In [107]:
class BinaryCLT:
    def __init__(self, data, root: int = None, alpha: float = 0.01):
        self.cols = data.shape[1]
        self.data = data
        self.root = root
        self.alpha = alpha
        self.tree = self.gettree()
        self.pmfs = self.getlogparams()

    def gettree(self):
        # create the mutual information matrix
        mi_matrix = np.array(
            [[compute_mi(self.data[:, [i, j]]) if j > i else 0 for j in range(self.cols)] for i in range(self.cols)])
        # invert the mutual information
        # to get the maximum spanning tree by calculating the minimum spanning tree of the inverse
        mi_matrix_inv = -mi_matrix
        tree = mst(mi_matrix_inv)
        # add connections to the tree in both directions
        tree = tree.toarray().astype(float)
        tree = tree.T + tree
        if not self.root:
            self.root = np.random.choice(range(0, self.cols))
        clt = bfo(tree, self.root)
        clt = clt[1]
        clt[clt == -9999] = -1
        return clt

    def getlogparams(self):
        pmfs = []
        for i in range(self.cols):

            if self.tree[i] == -1:
                vals = [prob_col(self.data[:, i], j, self.alpha) for j in range(2)]
                pmfs.append([[vals[0], vals[1]], [vals[0], vals[1]]])
            else:
                vals = [cond_prob_col(self.data[:, [i, self.tree[i]]], j, self.alpha) for j in
                        [(0, 0), (0, 1), (1, 0), (1, 1)]]
                pmfs.append([[vals[0], vals[1]], [vals[2], vals[3]]])
        # multiply by 2 so every column adds up to 1
        return pmfs + np.log(2)

    # def logprob(self, x, exhaustive: bool = False):

    # def sample(self, nsamples: int)

# Runs the thing

In [108]:
with open('nltcs.train.data', 'r') as file:
    reader = csv.reader(file, delimiter=',')
    dataset = np.array(list(reader)).astype(int)
mytree = BinaryCLT(dataset, 2)
print(mytree.pmfs)
print(mytree.tree)

[[[-0.15248733 -1.95594879]
  [-1.43703145 -0.27132608]]

 [[-0.14665299 -1.99211664]
  [-1.25232511 -0.33664743]]

 [[-0.26420763 -1.46021695]
  [-0.26420763 -1.46021695]]

 [[-0.20355576 -1.69186732]
  [-1.59337381 -0.22720022]]

 [[-0.06372132 -2.78492763]
  [-1.11394298 -0.39788701]]

 [[-0.0860395  -2.49566009]
  [-1.03002662 -0.4416066 ]]

 [[-0.10763256 -2.28236571]
  [-0.92915381 -0.50234125]]

 [[-0.04676177 -3.08597915]
  [-1.03893093 -0.43669691]]

 [[-0.10082742 -2.34433505]
  [-1.86895965 -0.16757178]]

 [[-0.02595521 -3.66433251]
  [-0.71373993 -0.67296995]]

 [[-0.13953996 -2.03836304]
  [-1.22551868 -0.34758124]]

 [[-0.05226251 -2.97749336]
  [-0.69399166 -0.69230342]]

 [[-0.10390755 -2.31575762]
  [-1.11428943 -0.39771775]]

 [[-0.03786339 -3.29264247]
  [-0.983091   -0.46864903]]

 [[-0.06691135 -2.73765578]
  [-0.8765093  -0.53825391]]

 [[-0.13362015 -2.07882049]
  [-2.34109767 -0.10117148]]]
[ 2  6 -1  5 13  7  2  6  6  7 14 10  8 14 12 12]
