In [2]:
"""
Vertices will be stored in a list. Better a hashmap, with key 
as name and value as the vertex object?

Each vertex object will have:
 - List of it’s children.
 - List of it’s parents.
 - CPT of x|u where x is the node itself and u is the parent set.
   This is inline with the structure of single processor….
   It's stored via another object (CPT object)
 - Finally it’ll have two more lists to store lambda and pi?


It'll also get a new graph if we choose to merge two nodes.


TODO:
 - In CPT, also initialze by name of nodes.
   So need to pass a map which will have name to node mapping.
"""
import numpy as np




In [38]:
class Vertex:
    def __init__(self, name):
        self.name = name
        self.values = []
        self.vIdx = {} #maps index of each value!
        self.parents = []
        self.children = []
        #Each node/vertex needs to store incoming lmbda 
        # and outgoing pi. This is coz they will be same 
        # in dimension as values it takes.
        self.inLmbda = []
        self.outPi = []
        #self.is_init = False
        self.nodePI = None
        self.nodeLmbda = None
        self.CPT = None
        self.BEL = None
        self.is_init = False #Since a vertex would be created
        #              such that not all info is present during
        #            it's creation, so before using the vertex 
        #       one would need to initialize it, so that all is
        #       ok with it.
        #
        #Now code, when it is derived from some network:
        self.origG = None
        self.origNodes = []
    
    def initializeV(self):
        """
        
        """
        if self.is_init:
            return
        #    
        self.is_init = True
        for i in range(len(self.values)):
            #if type(self.values[i]) == int:
            #    #int can cause confusion as it is also index
            #    self.values[i] = str(self.values[i])
            self.vIdx[self.values[i]] = i
        
        self.inLmbda = []
        #self.initializePi()
        for p in self.children:
            self.inLmbda += [np.ones(len(self.values))]
            #self.outPi += [np.ones(len(self.values))]
    
    def getPiProdCross(self, exclNodes):
        """
        exclNodes is a list of nodes to be excluded while
        calculating pi-prod.
        This function is used to get cross prod of pi coming
        from parents. It is used to calc pi and also pi going
        out.
        """
        if exclNodes is None:
            exclNodes = []
        pPiProd = None
        for p in self.parents: #goin in order
            if p in exclNodes:
                continue
            if pPiProd is None:
                pPiProd = p.outPi[
                    p.children.index(self)
                ].reshape(-1, 1)
            else:
                pPiProd = np.multiply(
                    pPiProd,
                    p.outPi[
                        p.children.index(self)
                    ].reshape(1, -1)
                ).reshape(-1, 1)
        if pPiProd is None: #only one parent!
            pPiProd = np.array([1.0])
        return pPiProd
        
    
    def initializePi(self):
        """
        For the very first run, one needs to initialize all the pi
        for root it is simply p, so that is set to all children.
        for non-root, it'll look at parents outPi (assumed 
        populated (responsibility of caller to maintain the 
        order)), take cross-product in right order,
        then with the vector obtained will multiply the M (CPT)
        and will populate it's outPi.
        Since all lmbda is 1 at this point so we need not worry.
        """
        cpt = self.CPT.getCPT()
        if len(self.parents) == 0:
            self.nodePI = cpt[0].reshape(-1, 1)
        else:
            #take inner product of outPi for parents in order, 
            # then multiply that with cpt.
            pPiProd = None
            for p in self.parents: #goin in order
                if pPiProd is None:
                    pPiProd = p.outPi[
                        p.children.index(self)
                    ].reshape(-1, 1)
                else:
                    pPiProd = np.multiply(
                        pPiProd,
                        p.outPi[
                            p.children.index(self)
                        ].reshape(1, -1)
                    ).reshape(-1, 1)
            self.nodePI = np.dot(pPiProd.T, cpt)
        self.outPi = []
        for _ in self.children:
            self.outPi += [self.nodePI.copy()]
            
    
    def nCard(self):
        if not self.is_init:
            self.initializeV()
        return len(self.values)
    
    def updatePi(self):
        """
        This will update the outgoing pi, wrt incoming pi and 
        lmbda.
        Caller should ensure that the call is not redundant 
        i.e., call only when something has changed.
        """
    
    def updateBeliefs(self):
        """
        It assumes that incoming pi and lambda has already been
        updated (often only one will).
        
        Depending on that it updates outgoing lmbda and pi.
        Outgoing lmbda needs to be updated into it's parents node.
        """
        #Now, need to populate it first,
        if len(self.children) != 0:
            #Else, either it'll be all 1 or 
            # 1-hot, if the node is initialized
            #
            self.nodeLmbda = None
            #It's a simple multiplication
            for i in range(len(self.inLmbda)):
                if self.nodeLmbda is None:
                    self.nodeLmbda = self.inLmbda[i]
                else:
                    self.nodeLmbda = np.multiply(
                        self.nodeLmbda, self.inLmbda[i])
            self.nodeLmbda = self.nodeLmbda.reshape([1, -1])
        
        if self.nodeLmbda is None:
            self.nodeLmbda = np.ones(len(self.values))
        if len(self.parents) != 0:
            #Else: lmbda won't need updating
            for i in range(len(self.parents)):
                p = self.parents[i]
                M = self.CPT.getCPTS(p)
                pPiProd = self.getPiProdCross(
                    [p]).reshape([1, -1])
                oLmbda = []
                for m in M:
                    temp = np.dot(pPiProd, m)
                    oLmbda += [np.dot(temp.reshape(-1),
                                     self.nodeLmbda.reshape(-1))]
                    #oLmbda += np.dot(
                    #    np.dot(pPiProd, m), self.nodeLmbda)
                p.inLmbda[p.children.index(
                    self)] = np.array(oLmbda).reshape([1, -1])
        
        if len(self.parents) != 0:
            #This one is to update pi
            pPiProd = self.getPiProdCross([]).reshape([1, -1])
            M = self.CPT.getCPT()
            self.nodePI = np.dot(pPiProd, M)
        #else: self.nodePI won't change :-)
        
        if len(self.children) != 0:
            #Else: pi won't need updating
            self.outPi = []
            for i in range(len(self.inLmbda)):
                tLmbda = np.ones(len(self.values)).reshape([-1,1])
                for j in range(len(self.inLmbda)):
                    if i==j:
                        continue
                    tLmbda = np.multiply(
                        tLmbda, 
                        self.inLmbda[j].reshape([-1,1]))
                self.outPi += [np.multiply(
                    tLmbda, 
                    self.nodePI.reshape([-1,1]))]
                #Normalizing the pi
                self.outPi[-1] = self.outPi[-1]/sum(
                    self.outPi[-1])
        """
        Now if multiple nodes are merged, we can identify (save) the
        order in which their values was merged. A bit cumbersome, but 
        could be done.
        """
        self.BEL = np.multiply(
            self.nodeLmbda.reshape(1, -1),
            self.nodePI.reshape(1, -1))
        self.BEL = self.BEL/np.sum(self.BEL)
        

In [35]:
a = np.array([1, 2])
b = np.array([3, 4])
a, b, a.shape, np.dot(a, b)

(array([1, 2]), array([3, 4]), (2,), 11)

In [53]:
np.ones(2).reshape([1,-1]), np.ones(2)

(array([[1., 1.]]), array([1., 1.]))

In [27]:
np.multiply(a.reshape(-1, 1),  b.reshape(1, -1)).reshape(-1, 1)

array([[3],
       [4],
       [6],
       [8]])

In [49]:
L = [4, 5, 7, 3, 4]
L.pop(1)


5

In [33]:
1 == 2

False

In [103]:
class CPT:
    """
    Since, CPTs would change if say we:
     - decide to merge two nodes (it's v common in loopy n/w).
     - We try to condition on one of the node.
     
    """
    def __init__(self, nodes, vals):
        """
        nodes is a list of nodes (objects),
        vals is list of floating points which are values in 
        same order as that contained in table field of BIF
        """
        self.nodes = nodes #Q: So at 0, we'll always have the self??
        #                        ha ha this is CPT so no self
        self.vals = vals
        nodes[0].parents = nodes[1:]
        
    
    def getCPT(self):
        """
        simply returns an np array where,
            col represent: prob for a fixed val, of node (aka first index). 
                given all combination of parents.
        So, #col = node card
        and #rows = Prod (parents card)
        """
        nRow = int(np.prod([n.nCard() for n in self.nodes[1:]]))
        nCol = self.nodes[0].nCard()
        M = np.zeros([nRow, nCol])
        #This will be just rearranging probs
        i = 0
        for c in range(nCol):
            for r in range(nRow):
                M[r][c] = self.vals[i]
                i += 1
        return M
    
    def getCPTS(self, U):
        """
        It'll return an array of size |U| (aka value taken by 
        the node U), where U is one of the parents and it has
        been excluded.
        """
        M = self.getCPT()
        #Now we need to split M in |U| parts.
        # If nodes following U take in totla K vale: K = prod (ui)
        # then k rows will go to first matrix
        # next k to another
        # till |U| matrices and after that:
        # |U| + 1 will again get concatenated to the first matrix
        # |U| + 2 to second and so on...#
        #
        #Ok, we'll do it as per:
        #   https://stackoverflow.com/a/42817678/1953366
        #We get n-dimension array.
        dims = [n.nCard() for n in self.nodes[1:]] + [-1]
        R = M.reshape(dims)
        UIdx = self.nodes.index(U) - 1
        dims.pop(UIdx)
        if len(dims) == 1:
            dims = [1] + dims
        return [R.take(i, axis = UIdx).reshape(dims) 
                for i in range(len(U.values))]
    
    def getP4mIdx(self, args):
        """
        given a list of values of nodes in args,
        we'll find the correct value and will return it.
        """
        if len(args) != len(self.nodes):
            print("args and nodes length", len(args),
                  "and", len(nodes), "mismatched")
        index = 0
        nCard = 0
        for idx in range(len(args)-1, -1, -1):
            valIdx = args[idx]
            if nCard == 0:
                nCard = self.nodes[idx].nCard()
                index += valIdx
                continue
            
            index += nCard*valIdx
            nCard *= self.nodes[idx].nCard() #For future values
        return self.vals[index]
    
    def getProb(self, args):
        """
        given a list of values of nodes in args,
        we'll find the correct value and will return it.
        """
        if len(args) != len(self.nodes):
            print("args and nodes length", len(args),
                  "and", len(nodes), "mismatched")
        index = 0
        nCard = 0
        for idx in range(len(args)-1, -1, -1):
            if type(args[idx]) == int and False:
                #Not going with this as of now
                valIdx = args[idx]
            else:
                valIdx = self.nodes[idx].vIdx[args[idx]]
            
            if nCard == 0:
                nCard = self.nodes[idx].nCard()
                index += valIdx
                continue
            
            index += nCard*valIdx
            nCard *= self.nodes[idx].nCard() #For future values
        return self.vals[index]
    
    def printCPT(self):
        """
        print CPT in nice format like:
        Node | P1, P2 | Prob
        ========================
        True | t , t  | 0.1 ... 
        """
        for n in self.nodes:
            print(">>",n.name)
        print("")
        print(self.nodes[0].name[:4], "\t|", 
              ",\t ".join(
                  list(map(lambda x: x.name[:4], self.nodes[1:]))
              ),
              "\t| Prob")
        print("====================")
        card = np.prod(list(map(lambda x: x.nCard(), self.nodes)))
        #Above lambda function invokation seems redundant. Can't 
        #it simply be done as: [n.nCard for n in self.nodes]??
        #
        indices = [0]*len(self.nodes)
        for i in range(card):
            if i!=0:
                #update indices
                for idx in range(len(indices)-1, -1, -1):
                    indices[idx] += 1
                    if indices[idx] < self.nodes[idx].nCard():
                        break
                    indices[idx] = 0
            print(self.nodes[0].values[indices[0]], "\t|", 
              ",\t ".join(
                  list(map(lambda x, idx
                           : 
                           #x.name + str(idx),
                        str(x.values[idx])[:4],
                        self.nodes[1:],
                        indices[1:]
                          ))
                          
              ),
              "\t| "+str(self.vals[i]))

In [100]:
allV = []
allV += [Vertex("A"), Vertex("B"), Vertex("C")]
allV[0].values = [1, 2, 3]
allV[1].values = [True, False]
allV[2].values = ["Hakun", "Matata"]

In [101]:
cpt = CPT(allV, [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2])
cpt.printCPT()


A 	| B,	 C 	| Prob
1 	| True,	 Haku 	| 0.1
1 	| True,	 Mata 	| 0.2
1 	| Fals,	 Haku 	| 0.3
1 	| Fals,	 Mata 	| 0.4
2 	| True,	 Haku 	| 0.5
2 	| True,	 Mata 	| 0.6
2 	| Fals,	 Haku 	| 0.7
2 	| Fals,	 Mata 	| 0.8
3 	| True,	 Haku 	| 0.9
3 	| True,	 Mata 	| 1.0
3 	| Fals,	 Haku 	| 1.1
3 	| Fals,	 Mata 	| 1.2


In [3]:
import numpy as np

In [8]:
M = np.array([[i*10 + j for j in range(3)] for i in range(8)])
M

array([[ 0,  1,  2],
       [10, 11, 12],
       [20, 21, 22],
       [30, 31, 32],
       [40, 41, 42],
       [50, 51, 52],
       [60, 61, 62],
       [70, 71, 72]])

In [21]:
R = M.reshape([2, 2, 2, -1])
R

array([[[[ 0,  1,  2],
         [10, 11, 12]],

        [[20, 21, 22],
         [30, 31, 32]]],


       [[[40, 41, 42],
         [50, 51, 52]],

        [[60, 61, 62],
         [70, 71, 72]]]])

In [31]:
R.take(0, axis=2)

array([[[ 0,  1,  2],
        [20, 21, 22]],

       [[40, 41, 42],
        [60, 61, 62]]])

In [26]:
M1 = R.take(0, axis=2)

In [27]:
M1.shape

(2, 2, 3)

In [28]:
M1[0,0,0] = 100

In [29]:
M1

array([[[100,   1,   2],
        [ 20,  21,  22]],

       [[ 40,  41,  42],
        [ 60,  61,  62]]])

In [30]:
R

array([[[[ 0,  1,  2],
         [10, 11, 12]],

        [[20, 21, 22],
         [30, 31, 32]]],


       [[[40, 41, 42],
         [50, 51, 52]],

        [[60, 61, 62],
         [70, 71, 72]]]])