In [1]:
import numpy as np
from pomegranate import *
import itertools

In [4]:
class MarkovBlanket():    
    def __init__(self, ind_h):
        self.hidden = ind_h
        self.parents = []
        self.children = []
        self.coparents = []
        self.prob_table = {}
        
    def populate(self, model):
        """populate the parents, children, and coparents nodes
        """
        state_indices = {state.name : i for i, state in enumerate(model.states)}
        
        edges_list = [(parent.name, child.name) for parent, child in model.edges]
        edges_list = [(state_indices[parent],state_indices[child]) 
                  for parent, child in edges_list]
        
        self.children = list(set([child for parent, child in edges_list if parent==self.hidden]))
        self.parents = list(set([parent for parent, child in edges_list if child==self.hidden]))
        self.coparents = list(set([parent for parent, child in edges_list if child in self.children]))
        try:
            self.coparents.remove(self.hidden)
        except ValueError:
            pass
            
    def calculate_prob(self, model):
        """Create the probability table from nodes
        """
        for ind_state in [self.hidden]+self.children:
            distribution = model.states[ind_state].distribution
            
            if isinstance(distribution, ConditionalProbabilityTable):
                table = list(distribution.parameters[0]) # make a copy
                self.prob_table[ind_state] = {
                    tuple(row[:-1]) : row[-1] for row in table}
            else:
                self.prob_table[ind_state] = dict(distribution.parameters[0]) # make a copy
                
    def update_prob(self, model, expected_counts, ct):
        """Update the probability table using expected counts
        """
        ind = {x : i for i, x in enumerate([self.hidden] + self.parents + self.children + self.coparents)}
        mb_keys = expected_counts.counts.keys()
        
        for ind_state in [self.hidden] + self.children:
            distribution = model.states[ind_state].distribution
            
            if isinstance(distribution, ConditionalProbabilityTable):
                idxs = distribution.column_idxs
                table = self.prob_table[ind_state] # dict
                
                # calculate the new parameter for this key
                for key in table.keys():
                    num = 0
                    denom = 0
                    
                    # marginal counts
                    for mb_key in mb_keys:
                        # marginal counts of node + parents
                        if tuple([mb_key[ind[x]] for x in idxs]) == key:
                            num += ct.table[mb_key[1:]]*expected_counts.counts[mb_key] 
                            
                        # marginal counts of parents
                        if tuple([mb_key[ind[x]] for x in idxs[:-1]]) == key[:-1]:
                            denom += ct.table[mb_key[1:]]*expected_counts.counts[mb_key]
                            
                    try:
                        prob = num/denom
                    except ZeroDivisionError:
                        prob = 0
                        
                    # update the parameter
                    table[key] = prob
                    
            else: # DiscreteProb
                table = self.prob_table[ind_state] # dict 
                
                # calculate the new parameter for this key
                for key in table.keys():
                    prob = 0
                    for mb_key in mb_keys:
                        if mb_key[ind[ind_state]] == key:
                            prob += ct.table[mb_key[1:]]*expected_counts.counts[mb_key]
                    
                    # update the parameter
                    table[key] = prob/ct.size

In [5]:
class ExpectedCounts():
    """Calculate the expected counts using the model parameters
    
    Parameters
    ----------
    model : a BayesianNetwork object
    
    mb : a MarkovBlanket object
    
    Attributes
    ----------
    counts : dict
        a dict of expected counts for nodes in the Markov blanket
    """
    
    def __init__(self, model, mb):
        self.counts = {}
        
        self.populate(model, mb)
        
    def populate(self, model, mb):
        #create combinations of keys
        keys_list = [model.states[mb.hidden].distribution.keys()]
        for ind in mb.parents + mb.children + mb.coparents:
            keys_list.append(model.states[ind].distribution.keys())
        
        self.counts = {p:0 for p in itertools.product(*keys_list)}
        
    def update(self, model, mb):
        ind = {x : i for i, x in enumerate([mb.hidden] + mb.parents + mb.children + mb.coparents)}
    
        marginal_prob = {}
    
        # calculate joint probability and marginal probability
        for i, key in enumerate(self.counts.keys()):
            prob = 1
        
            for j, ind_state in enumerate([mb.hidden] + mb.children):
                distribution = model.states[ind_state].distribution
            
                if isinstance(distribution, ConditionalProbabilityTable):
                    idxs = distribution.column_idxs
                    state_key = tuple([key[ind[x]] for x in idxs])
                else:
                    state_key = key[ind[ind_state]]
                
                prob = prob*mb.prob_table[ind_state][state_key]         
                self.counts[key] = prob
            try:
                marginal_prob[key[1:]] += prob
            except KeyError:
                marginal_prob[key[1:]] = prob
                 
        # divide the joint prob by the marginal prob to get the conditional
        for i, key in enumerate(self.counts.keys()):
            try:
                self.counts[key] = self.counts[key]/marginal_prob[key[1:]]
            except ZeroDivisionError:
                self.counts[key] = 0

In [6]:
class CountTable():
    """Counting the data"""
    
    def __init__(self, model, mb, items):
        """
        Parameters
        ----------
        model : BayesianNetwork object
        
        mb : MarkovBlanket object
        
        items : ndarray
            columns are data for parents, children, coparents
        
        """
        self.table ={}
        self.ind = {}
        self.size = items.shape[0]
        
        self.populate(model, mb, items)
        
    def populate(self, model, mb, items):
        keys_list = []
        for ind in mb.parents + mb.children + mb.coparents:
            keys_list.append(model.states[ind].distribution.keys())
        
        # init
        self.table = {p:0 for p in itertools.product(*keys_list)}
        self.ind = {p:[] for p in itertools.product(*keys_list)}
        
        # count
        for i, row in enumerate(items):
            try:
                self.table[tuple(row)] += 1
                self.ind[tuple(row)].append(i)
            except KeyError:
                print ('Items in row', i, 'does not match the set of keys.')
                raise KeyError

In [7]:
def em_bayesnet(model, data, ind_h, max_iter = 50, criteria = 0.005):
    """Returns the data array with the hidden node filled in.
    (model is not modified.)
    
    Parameters
    ----------
    model : a BayesianNetwork object
        an already baked BayesianNetwork object with initialized parameters
        
    data : an ndarray
        each column is the data for the node in the same order as the nodes in the model
        the hidden node should be a column of NaNs
        
    ind_h : int
        index of the hidden node
        
    max_iter : int
        maximum number of iterations
        
    criteria : float between 0 and 1
        the change in probability in consecutive iterations, below this value counts as convergence 
        
    Returns
    -------
    data : an ndarray
        the same data arary with the hidden node column filled in
    """
    
    # create the Markov blanket object for the hidden node
    mb = MarkovBlanket(ind_h)
    mb.populate(model)
    mb.calculate_prob(model)
    
    # create the count table from data
    items = data[:, mb.parents + mb.children + mb.coparents]
    ct = CountTable(model, mb, items)
    
    # create expected counts
    expected_counts = ExpectedCounts(model, mb)
    expected_counts.update(model, mb)
    
    # ---- iterate over the E-M steps
    i = 0
    previous_params = list(mb.prob_table[mb.hidden].values())
    convergence = False
    
    while (not convergence) and (i < max_iter):
        mb.update_prob(model, expected_counts, ct)
        expected_counts.update(model, mb)
        # print 'Iteration',i,mb.prob_table
        
        # convergence criteria
        hidden_params = list(mb.prob_table[mb.hidden].values())
        change = np.abs([hidden_params[0] - previous_params[0],hidden_params[1] - previous_params[1]])
        convergence = max(change) < criteria
        previous_params = list(mb.prob_table[mb.hidden].values())
        
        i += 1
        
    if i == max_iter:
        print ('Maximum iterations reached.')
    
    # ---- fill in the hidden node data by sampling the distribution
    labels = {}
    for key, prob in expected_counts.counts.items():
        try:
            labels[key[1:]].append((key[0], prob))
        except:
            labels[key[1:]] = [(key[0], prob)]
            
    for key, counts in ct.table.items():
        label, prob = zip(*labels[key])
        prob = tuple(round(p,5) for p in prob)
        if not all(p == 0 for p in prob):
            samples = np.random.choice(label, size=counts, p=prob)
            data[ct.ind[key], ind_h] = samples
        
    return data

In [8]:
data = np.array([[np.nan, 'yellow', 'sweet', 'long'],
                [np.nan, 'green', 'sour', 'round'],
                [np.nan, 'green', 'sour', 'round'],
                [np.nan, 'yellow', 'sweet', 'long'],
                [np.nan, 'yellow', 'sweet', 'long'],
                [np.nan, 'green', 'sour', 'round'],
                [np.nan, 'green', 'sweet', 'long'],
                [np.nan, 'green', 'sweet', 'round']])

In [9]:
data

array([['nan', 'yellow', 'sweet', 'long'],
       ['nan', 'green', 'sour', 'round'],
       ['nan', 'green', 'sour', 'round'],
       ['nan', 'yellow', 'sweet', 'long'],
       ['nan', 'yellow', 'sweet', 'long'],
       ['nan', 'green', 'sour', 'round'],
       ['nan', 'green', 'sweet', 'long'],
       ['nan', 'green', 'sweet', 'round']], dtype='<U32')