In [25]:
"""
Vertices will be stored in a list. Better a hashmap, with key 
as name and value as the vertex object?

Each vertex object will have:
 - List of it’s children.
 - List of it’s parents.
 - CPT of x|u where x is the node itself and u is the parent set.
   This is inline with the structure of single processor….
   It's stored via another object (CPT object)
 - Finally it’ll have two more lists to store lambda and pi?


It'll also get a new graph if we choose to merge two nodes.


TODO:
 - In CPT, also initialze by name of nodes.
   So need to pass a map which will have name to node mapping.
"""
import numpy as np




In [104]:
class Vertex:
    def __init__(self, name):
        self.name = name
        self.values = []
        self.vIdx = {} #maps index of each value!
        self.parents = []
        self.children = []
        self.lmbda = []
        self.pi = []
        self.CPT = None
        self.is_init = False #Since a vertex would be created
        #              such that not all info is present during
        #            it's creation, so before using the vertex 
        #       one would need to initialize it, so that all is
        #       ok with it.
        #
        #Now code, when it is derived from some network:
        self.origG = None
        self.origNodes = []
    
    def initializeV(self):
        """
        
        """
        for i in range(len(self.values)):
            #if type(self.values[i]) == int:
            #    #int can cause confusion as it is also index
            #    self.values[i] = str(self.values[i])
            self.vIdx[self.values[i]] = i
        self.lmbda = [1]*len(self.values)
    
    def nCard(self):
        if not self.is_init:
            self.initializeV()
        return len(self.values)

In [103]:
class CPT:
    """
    Since, CPTs would change if say we:
     - decide to merge two nodes (it's v common in loopy n/w).
     - We try to condition on one of the node.
     
    """
    def __init__(self, nodes, vals):
        """
        nodes is a list of nodes (objects),
        vals is list of floating points which are values in 
        same order as that contained in table field of BIF
        """
        self.nodes = nodes #Q: So at 0, we'll always have the self??
        #                        ha ha this is CPT so no self
        self.vals = vals
        nodes[0].parents = nodes[1:]
    
    def getP4mIdx(self, args):
        """
        given a list of values of nodes in args,
        we'll find the correct value and will return it.
        """
        if len(args) != len(self.nodes):
            print("args and nodes length", len(args),
                  "and", len(nodes), "mismatched")
        index = 0
        nCard = 0
        for idx in range(len(args)-1, -1, -1):
            valIdx = args[idx]
            if nCard == 0:
                nCard = self.nodes[idx].nCard()
                index += valIdx
                continue
            
            index += nCard*valIdx
            nCard *= self.nodes[idx].nCard() #For future values
        return self.vals[index]
    
    def getProb(self, args):
        """
        given a list of values of nodes in args,
        we'll find the correct value and will return it.
        """
        if len(args) != len(self.nodes):
            print("args and nodes length", len(args),
                  "and", len(nodes), "mismatched")
        index = 0
        nCard = 0
        for idx in range(len(args)-1, -1, -1):
            if type(args[idx]) == int and False:
                #Not going with this as of now
                valIdx = args[idx]
            else:
                valIdx = self.nodes[idx].vIdx[args[idx]]
            
            if nCard == 0:
                nCard = self.nodes[idx].nCard()
                index += valIdx
                continue
            
            index += nCard*valIdx
            nCard *= self.nodes[idx].nCard() #For future values
        return self.vals[index]
    
    def printCPT(self):
        """
        print CPT in nice format like:
        Node | P1, P2 | Prob
        ========================
        True | t , t  | 0.1 ... 
        """
        for n in self.nodes:
            print(">>",n.name)
        print("")
        print(self.nodes[0].name[:4], "\t|", 
              ",\t ".join(
                  list(map(lambda x: x.name[:4], self.nodes[1:]))
              ),
              "\t| Prob")
        print("====================")
        card = np.prod(list(map(lambda x: x.nCard(), self.nodes)))
        #Above lambda function invokation seems redundant. Can't 
        #it simply be done as: [n.nCard for n in self.nodes]??
        #
        indices = [0]*len(self.nodes)
        for i in range(card):
            if i!=0:
                #update indices
                for idx in range(len(indices)-1, -1, -1):
                    indices[idx] += 1
                    if indices[idx] < self.nodes[idx].nCard():
                        break
                    indices[idx] = 0
            print(self.nodes[0].values[indices[0]], "\t|", 
              ",\t ".join(
                  list(map(lambda x, idx
                           : 
                           #x.name + str(idx),
                        str(x.values[idx])[:4],
                        self.nodes[1:],
                        indices[1:]
                          ))
                          
              ),
              "\t| "+str(self.vals[i]))

In [100]:
allV = []
allV += [Vertex("A"), Vertex("B"), Vertex("C")]
allV[0].values = [1, 2, 3]
allV[1].values = [True, False]
allV[2].values = ["Hakun", "Matata"]

In [101]:
cpt = CPT(allV, [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2])
cpt.printCPT()


A 	| B,	 C 	| Prob
1 	| True,	 Haku 	| 0.1
1 	| True,	 Mata 	| 0.2
1 	| Fals,	 Haku 	| 0.3
1 	| Fals,	 Mata 	| 0.4
2 	| True,	 Haku 	| 0.5
2 	| True,	 Mata 	| 0.6
2 	| Fals,	 Haku 	| 0.7
2 	| Fals,	 Mata 	| 0.8
3 	| True,	 Haku 	| 0.9
3 	| True,	 Mata 	| 1.0
3 	| Fals,	 Haku 	| 1.1
3 	| Fals,	 Mata 	| 1.2


In [102]:
cpt.getProb([1, False, "Matata"])

0.4

In [6]:
str(2)

'2'