In [3]:
from scipy.sparse.csgraph import minimum_spanning_tree as mst
from scipy.sparse.csgraph import breadth_first_order as bfo
from scipy.special import logsumexp
import numpy as np
import csv

# Calcs mutual information

In [4]:
def compute_mi(cols):
    cols[:, 0] *= 2
    names, occurs = np.unique(np.sum(cols, axis=1), return_counts=True)
    namelen = len(names)
    # if only 2 marginals, variables must be independent so mutual information = 0
    if namelen == 2:
        return 0
    
    # Use small value alpha to substitute the 0-values
    if namelen == 3:
        occurs = np.insert(occurs.astype('float'), np.argmin(np.isin([0,1,2,3], names)),0.01)

        
    probs = np.log(occurs / sum(occurs))

    
    px0 = logsumexp([probs[0], probs[1]])
    px1 = logsumexp([probs[2], probs[3]])
    py0 = logsumexp([probs[0], probs[2]])
    py1 = logsumexp([probs[1], probs[3]])

    mi = np.exp(probs[0]) * (probs[0] - (px0+ py0)) + \
         np.exp(probs[1]) * (probs[1] - (px0+ py1)) + \
         np.exp(probs[2]) * (probs[2] - (px1+ py0)) + \
         np.exp(probs[3]) * (probs[3] - (px1+ py1))
    
    return np.log(mi)

# Calcs the probabilities

In [5]:
def prob_col(col, val, alpha):
    return np.log(2*alpha*len(col[col == val]))-np.log(4*alpha*len(col))


def joint_prob_col(both_cols, val1, val2, alpha):
    return np.log(alpha*len(both_cols[np.sum(both_cols == (val1, val2), axis=1) == 2]))-np.log(4*alpha*len(both_cols))


def cond_prob_col(both_cols, vals, alpha):
    return joint_prob_col(both_cols, vals[0], vals[1], alpha)-prob_col(both_cols[:, 0], vals[0], alpha)

# The actual class

In [6]:
class BinaryCLT:
    def __init__(self, data, root: int = None, alpha: float = 0.01):
        self.cols = data.shape[1]
        self.data = data
        self.root = root
        self.alpha = alpha
        self.tree = self.gettree()
        self.pmfs = self.getlogparams()

    def gettree(self):
        # create the mutual information matrix
        mi_matrix = np.array(
            [[compute_mi(self.data[:, [i, j]]) if j > i else 0 for j in range(self.cols)] for i in range(self.cols)])
        # invert the mutual information
        # to get the maximum spanning tree by calculating the minimum spanning tree of the inverse

        tree = mst(-mi_matrix)
        # add connections to the tree in both directions
        tree = tree.toarray().astype(float)
        tree = tree.T + tree
        if not self.root:
            self.root = np.random.choice(range(0, self.cols))
        clt = bfo(tree, self.root)
        clt = clt[1]
        clt[clt == -9999] = -1
        return clt

    def getlogparams(self):
        pmfs = []
        for i in range(self.cols):

            if self.tree[i] == -1:
                vals = [prob_col(self.data[:, i], j, self.alpha) for j in range(2)]
                pmfs.append([[vals[0], vals[1]], [vals[0], vals[1]]])
            else:
                vals = [cond_prob_col(self.data[:, [i, self.tree[i]]], j, self.alpha) for j in
                        [(0, 0), (0, 1), (1, 0), (1, 1)]]
                pmfs.append([[vals[0], vals[1]], [vals[2], vals[3]]])
        # multiply by 2 so every column adds up to 1
        return pmfs + np.log(2)

    # def logprob(self, x, exhaustive: bool = False):

    # def sample(self, nsamples: int)

# Runs the thing

In [7]:
with open('nltcs.train.data', 'r') as file:
    reader = csv.reader(file, delimiter=',')
    dataset = np.array(list(reader)).astype(int)
mytree = BinaryCLT(dataset, 2)
print(mytree.pmfs)
print(mytree.tree)

[[[-0.49047817 -0.94760991]
  [-1.35016836 -0.30002003]]

 [[-0.18194533 -1.79364271]
  [-0.80714178 -0.59082778]]

 [[-0.26420763 -1.46021695]
  [-0.26420763 -1.46021695]]

 [[-0.12551216 -2.13745242]
  [-0.48268729 -0.96004089]]

 [[-0.0732722  -2.6499864 ]
  [-0.49430542 -0.9415943 ]]

 [[-0.10188142 -2.33445398]
  [-0.52556476 -0.89458153]]

 [[-0.1470204  -1.98979365]
  [-0.84007761 -0.56506539]]

 [[-0.12490777 -2.14198353]
  [-0.66667151 -0.72034292]]

 [[-0.15454545 -1.94354478]
  [-0.98797112 -0.46574286]]

 [[-0.05952958 -2.85089916]
  [-0.41309193 -1.08353095]]

 [[-0.17299176 -1.8397606 ]
  [-0.60530675 -0.78945313]]

 [[-0.08868607 -2.46666773]
  [-0.46764288 -0.98477626]]

 [[-0.07772157 -2.59323152]
  [-0.54521209 -0.8668313 ]]

 [[-0.03786339 -3.29264247]
  [-0.983091   -0.46864903]]

 [[-0.14607542 -1.99578101]
  [-0.87977372 -0.53593282]]

 [[-0.22328162 -1.60888585]
  [-2.29904944 -0.10575414]]]
[11 10 -1 10 10 10 10 10 10 10  2  1  0 14  6 14]
