In [3]:
from pandas import DataFrame
import re
import random


class Node:
    def __init__(self,name,parents,children,table):
        self.name = name
        self.parents = parents
        self.children = children
        self.table = table

def parser(file):
    infile = open(file)
    #print (infile)
    
    # Regex patterns for parsing
    variable_pattern = re.compile(r"  type discrete \[ \d+ \] \{ (.+) \};\s*")
    prior_probability_pattern_1 = re.compile(
        r"probability \( ([^|]+) \) \{\s*")
    prior_probability_pattern_2 = re.compile(r"  table (.+);\s*")
    conditional_probability_pattern_1 = (
        re.compile(r"probability \( (.+) \| (.+) \) \{\s*"))
    conditional_probability_pattern_2 = re.compile(r"  \((.+)\) (.+);\s*")
    
    variables = {}  # domains
    functions = []  # function names (nodes/variables)
    
    nodes = []
    # For every line in the file
    
    while True:
        line = infile.readline()

        # End of file
        if not line:
            break
            
        # Variable declaration    
        if line.startswith("variable"):
            match = variable_pattern.match(infile.readline())
            # Extract domain and place into dictionary
            if match:
                variables[line[9:-3]] = match.group(1).split(", ")
                #print (variables)
                
        
        # Probability distribution
        elif line.startswith("probability"):
            
            match = prior_probability_pattern_1.match(line)
            if match:

                # Prior probabilities
                variable = match.group(1)
                #print (variable)
                function_name = variable
                functions.append(function_name)
                line = infile.readline()
                match = prior_probability_pattern_2.match(line)
                dictionary = dict(zip(variables[variable],map(float, match.group(1).split(", "))))
                
                
                rows = []
                for row in dictionary:
                    rows.append([row,dictionary[row]])
                table = DataFrame(rows,columns=[variable,'prob'])
                nodes.append(Node(variable,[],[],table))
                
                
            else:
                match = conditional_probability_pattern_1.match(line)
                if match:
                    
                    # Conditional probabilities
                    variable = match.group(1)
                    #print ("1",variable)
                    
                    function_name = variable
                    functions.append(function_name)
                    given = match.group(2).split(", ")
                    
                    
                    allCols = [x for x in given]
                    allCols.append(variable)
                    allCols.append('prob')
                    
                    #print (given)
                    dictionary = {}
                
                while True:
                        line = infile.readline()  # line of the CPT
                        if line == '}\n':
                            break
                        match = conditional_probability_pattern_2.match(line)
                        given_values = match.group(1).split(", ")
                        for value, prob in zip(
                                variables[variable],
                                map(float, match.group(2).split(", "))):
                            dictionary[tuple(given_values + [value])] = prob
                        #print (dictionary)
                        
                rows = []
                for row in dictionary:
                    x = []
                    for var in row:
                        x.append(var)
                    x.append(dictionary[row])
                    rows.append(x)
                #print (rows)
                
                table = DataFrame(rows,columns=allCols)
                
                
                tmp = Node(variable,[],[],table)
                
                for node in nodes:
                        if node.name in given:
                            node.children.append(tmp)
                            tmp.parents.append(node)
                
                nodes.append(tmp)
    return nodes

    

    
def eliminate_variable(nodes,variable,target,debug=False):    
    if debug: print ("eliminating",variable)
    node = next((x for x in nodes if x.name == variable), None)
    if target in node.table.columns.tolist(): return
    #print(node.table,'\n')
    for child in node.children:
        if debug: print('merging with',child.name,'\n',child.table,'\n')
        
        child.table = child.table.merge(right=node.table,on=variable)
        
        if debug: print('merged\n',child.table,'\n')
        
        cols = [col for col in child.table.columns if 'prob' in col]
        child.table['prob'] = child.table.loc[:, cols].prod(axis=1)
        child.table = child.table.drop(cols,axis=1)
        
        #child.table = merge(child.table,node.table,variable,debug)
        #child.table = sumOut(child.table,node.name)
        
        group = child.table.columns.values.tolist()
        group.remove(node.name)
        group.remove('prob')
        
        if debug: print('group by',group)
        
        #df.groupby('Id', as_index=False).agg(lambda x: set(x))
        
        child.table = child.table.groupby(group,as_index=False).agg('sum')
        child.parents.remove(node)
        child.parents= child.parents+node.parents 
        
        if debug: print('new child table\n',child.table,'\n\n')
        #print ('\n',child.table.dtypes,'\n')
    nodes.remove(node)
      
    
def elimination_algorithm(nodes,var,debug=False):
    for variable in variables:
        if variable != var:
            eliminate_variable(nodes,variable,var,debug=False)
    
    fNode = nodes.pop()
    for node in nodes:
        if debug: print (node.table)
        #fNode.table = merge(fNode.table,node.table,var)
        #fNode.table = sumOut(fNode.table,var)
        
        fNode.table = fNode.table.merge(right=node.table,on=var)
        cols = [col for col in fNode.table.columns if 'prob' in col]
        fNode.table['prob'] = fNode.table.loc[:, cols].prod(axis=1)
        fNode.table = fNode.table.drop(cols,axis=1)
        fNode.table = fNode.table.groupby(var,as_index=False).agg('sum')
        if debug: print (fNode.table) 
    return fNode.table    
    
nodes = parser('asia.bif')
variables = [x.name for x in nodes]


class Factor:
    def __init__(self,name,table):
        self.name=name
        self.table=table
        

def getFactors(nodes,debug=False):
    factors = []
    for node in nodes:
        name = [node.name] + [x.name for x in node.parents]
        if debug: print (name)
        factors.append(Factor(name,node.table))
        if debug: print (node.table)
    return factors
    
factors = getFactors(nodes,debug=False)

def getNVars(factor):
    return len(factor.name)

def getTableSize(factor):
    return factor.table.size


#returns a ordered factor list
def sort(factors,orderBy=1):
    if orderBy==1: #order by number of vars
        return sorted(factors,key=getNVars)
    
    if orderBy==2: #order by table size
        return sorted(factors,key=getTableSize)
    
    if orderBy==3:#shuffle factor list (random) 
        random.shuffle(factors)
        return factors
    


def elimination_algorithm(factors,var,value=[],orderBy=1,debug=False):
    
    factors = sort(factors,orderBy)
    elimVars = [x.name[0] for x in factors if x.name[0] not in var]
    
    if debug: print('variables to remove\n',elimVars,'\n')
    for variable in elimVars:
        
        if debug: print('eliminating',variable)
        
        joinFactors = [factor for factor in factors if variable in factor.name]
        joinFactors = sort(joinFactors,orderBy)
        
        if debug:
            for f in joinFactors:
                print(f.name)
        
        
        if debug:
            print('factors to join:') 
            for factor in joinFactors:
                print(factor.name)
            print()
        
        newFactor = eliminate(variable,joinFactors,debug=False)
        
        factors = list(set(factors)-set(joinFactors))
        factors.append(newFactor)
            
    if debug: 
        print('\nmostly done') 
        for factor in factors:
            if debug: print(factor.table)
    
    factors = sort(factors,orderBy)
    newF = eliminate('',factors,last=True)
    
    if debug: print(newF.table)
    
    #if there is evidence
    if(len(var)>1):
        for i in range(1,len(var)):
            
            #filter the table
            if debug: print('filter var :',var[i],'\tvalue :',value[i-1])
            newF.table = newF.table.loc[newF.table[var[i]] == value[i-1]]
            
        #normalize the prod column    
        newF.table['prob'] = newF.table['prob'].div(newF.table['prob'].sum())
    
    #round the probabilities 
    newF.table['prob'] = newF.table['prob'].round(2)
    
    return newF.table
        

def merge(left,right,joinOn,debug=False):
    print (joinOn)
    print (left.table,'\n',right.table,'\n\n')
    #merge dataframes
    left.table = left.table.merge(right=right.table,on=joinOn)
    
    #variables necessary to make new probability column
    cols = [col for col in left.table.columns if 'prob' in col]
    
    #create new probability column
    #and drop the ones to make it
    left.table['prob'] = left.table.loc[:, cols].prod(axis=1)
    left.table = left.table.drop(cols,axis=1)
    if debug: print('merged:\n',newF.table,'\n')
    
    #update factor name
    left.name = left.table.columns.tolist()
    left.name.remove('prob')
    
    return left

def eliminate(variable,factors,last=False,debug=False):
    
    newF = factors[0] 
    for factor in factors[1:]:
        
        #variables in common
        joinOn = list(set(newF.name) & set(factor.name))
        
        #merge dataframes on commom variables
        newF = merge(newF,factor,joinOn,debug)
        
    #columns to do groupby operation
    if last: #last var to group by is the target(s) 
        cols = [col for col in newF.table.columns if col!='prob']
        
    else: #if not the last elim, dont group by the variable
        cols = [col for col in newF.table.columns if col!='prob' and col!=variable]
    if debug: print ('group by',cols)

    #sum values on groupby
    newF.table = newF.table.groupby(cols,as_index=False).agg('sum')
    
    return newF
    

answer = elimination_algorithm(factors,['dysp'],[],orderBy=1,debug=False)
answer

['asia']
  asia  prob
0  yes  0.01
1   no  0.99 
   asia  tub  prob
0  yes  yes  0.05
1  yes   no  0.95
2   no  yes  0.01
3   no   no  0.99 


['smoke']
  smoke  prob
0   yes   0.5
1    no   0.5 
   smoke lung  prob
0   yes  yes  0.10
1   yes   no  0.90
2    no  yes  0.01
3    no   no  0.99 


['smoke']
  smoke lung   prob
0   yes  yes  0.050
1   yes   no  0.450
2    no  yes  0.005
3    no   no  0.495 
   smoke bronc  prob
0   yes   yes   0.6
1   yes    no   0.4
2    no   yes   0.3
3    no    no   0.7 


['tub']
   tub    prob
0   no  0.9896
1  yes  0.0104 
   lung  tub either  prob
0  yes  yes    yes   1.0
1  yes  yes     no   0.0
2   no  yes    yes   1.0
3   no  yes     no   0.0
4  yes   no    yes   1.0
5  yes   no     no   0.0
6   no   no    yes   0.0
7   no   no     no   1.0 


['lung']
  lung bronc    prob
0   no    no  0.5265
1   no   yes  0.4185
2  yes    no  0.0235
3  yes   yes  0.0315 
   lung either    prob
0   no     no  0.9896
1   no    yes  0.0104
2  yes     no  0.0000
3  

Unnamed: 0,dysp,prob
0,no,0.56
1,yes,0.44
