In [92]:
import json
import heapq
import collections
import itertools
from typing import List, Dict, Tuple, Any, Union

In [93]:
# Read in the data from the json file 
path = 'Sample_Testcase.json'
with open(path, 'r') as file:
    data = json.load(file)

In [94]:
# Define helper functions for message passing
def multiplyFactors(factorList: List[Dict], domain: list) -> dict:
    if not factorList:
        return oneFactor(domain) # Return a factor that is one for all assignments
    
    result = factorList[0].copy()
    # print(type(result))
    for factor in factorList[1:]:
        result = multiplyTwoFactors(result, factor)
    return result

def multiplyTwoFactors(factor1 : Dict, factor2 : Dict) -> Dict:
    result = {}
    for assignment1, value1 in factor1.items():
        for assignment2, value2 in factor2.items():
            merged = unifyAssignments(assignment1, assignment2)
            if merged is not None:
                result[merged] = result.get(merged, 0.0) + (value1 * value2)   
    return result

def unifyAssignments(assignment1 : tuple, assignment2 : tuple) -> Union[tuple, None]:
    mergedDict = dict(assignment1)
    for variable, value in assignment2:
        if variable in mergedDict and mergedDict[variable] != value:
            return None
        mergedDict[variable] = value
    return tuple(sorted(mergedDict.items()))

def oneFactor(domain : list) -> dict:
    # Make all possible assignments
    domainList = sorted(domain)
    allAssignments = {}

    for assignmentTuple in itertools.product([0, 1], repeat = len(domainList)):
        assignment = tuple(zip(domainList, assignmentTuple))
        allAssignments[assignment] = 1.0
    return allAssignments
    

def sumOut(factor: Dict[Tuple[Tuple[int, int], ...], float], variablesToSumOut: List[int], domain: List[int]) -> Dict[Tuple[Tuple[int, int], ...], float]:
    # Summation over a specified variables in the factor
    result = {}
    print(factor)
    for assignment, value in factor:
        newAssignment = tuple((var, val) for var, val in assignment if var not in variablesToSumOut)
        result[newAssignment] = result.get(newAssignment, 0) + value
    return result

In [95]:
domain = [0, 1, 2]
print(oneFactor(domain))

{((0, 0), (1, 0), (2, 0)): 1.0, ((0, 0), (1, 0), (2, 1)): 1.0, ((0, 0), (1, 1), (2, 0)): 1.0, ((0, 0), (1, 1), (2, 1)): 1.0, ((0, 1), (1, 0), (2, 0)): 1.0, ((0, 1), (1, 0), (2, 1)): 1.0, ((0, 1), (1, 1), (2, 0)): 1.0, ((0, 1), (1, 1), (2, 1)): 1.0}


In [96]:
# Testing helper functions
# a1 = ((1, 0), (0, 1))
# a2 = ((1, 1), (3, 0))
# a3 = unifyAssignments(a1, a2)
# print(a3)

factor1 = {
  ((0,0),(1,0)): 2.0,
  ((2,2),(1,0)): 4.0
}
factor2 = {
  ((2,0),(3,1)): 3.0,
  ((2,1),(3,0)): 5.0
}

# Expected Multiplication : {((0, 0),): 0.1, ((0, 1),): 0.4}
result = multiplyTwoFactors(factor1, factor2)
print(result)

{((0, 0), (1, 0), (2, 0), (3, 1)): 6.0, ((0, 0), (1, 0), (2, 1), (3, 0)): 10.0}


In [97]:
class Inference:
    def __init__(self, data : dict):
        """
        Initialize the Inference class with the input data.
        
        Parameters:
        -----------
        data : dict
            The input data containing the graphical model details, such as variables, cliques, potentials, and k value.
        
        What to do here:
        ----------------
        - Parse the input data and store necessary attributes (e.g., variables, cliques, potentials, k value).
        - Initialize any data structures required for triangulation, junction tree creation, and message passing.
        
        Refer to the sample test case for the structure of the input data.
        """
        self.testcaseNumber = data.get("TestCaseNumber", 0)
        self.variableCount = data.get("VariablesCount", 0)
        self.potentialCount = data.get("Potentials_count", 0)
        # List of cliques (each clique is a dictionary)
        self.cliquesAndPotentials = data.get("Cliques and Potentials", []) 
        self.k = data.get("k value (in top k)", 0)

        # Build the graph from the cliques. Nodes are labelled from 0 to variableCount - 1
        self.undirectedGraph = {i: set() for i in range(self.variableCount)}
        for clique in self.cliquesAndPotentials:
            nodes = clique.get("cliques", [])
            # For each clique, add edges between each pair of nodes
            for i in range(len(nodes)):
                for j in range(i + 1, len(nodes)):
                    self.undirectedGraph[nodes[i]].add(nodes[j])
                    self.undirectedGraph[nodes[j]].add(nodes[i])
        
        # Define the variables needed to construct the triangulated graph
        self.triangulatedGraph = {}
        self.maximalCliques = []

        # Define the variables needed to construct the junction tree
        self.junctionTree = []

        # Make a dictionary to store the potentials for each clique
        self.cliquePotentials = {}


    def triangulate_and_get_cliques(self):
        """
        Triangulate the undirected graph and extract the maximal cliques.
        
        What to do here:
        ----------------
        - Implement the triangulation algorithm to make the graph chordal.
        - Extract the maximal cliques from the triangulated graph.
        - Store the cliques for later use in junction tree creation.

        Refer to the problem statement for details on triangulation and clique extraction.
        """
        temporaryGraph = {n : set(self.undirectedGraph[n]) for n in self.undirectedGraph}
        chordalGraph = {n : set(self.undirectedGraph[n]) for n in self.undirectedGraph}

        # Initialize the degree info into a min heap
        degreeInfo = {node : len(temporaryGraph[node]) for node in temporaryGraph}
        # Initialize the heap
        degreeHeap = [(degreeInfo[node], node) for node in degreeInfo]
        heapq.heapify(degreeHeap)
        # Maintain a set for the nodes that have been eliminated
        eliminatedNodes = set()
        # Store the cliques from each elimination step
        candidateCliques = []

        # Eliminate the nodes one by one to triangulate the graph
        while degreeHeap:
            currDegree, currNode = heapq.heappop(degreeHeap)
            if currNode in eliminatedNodes or degreeInfo[currNode] != currDegree:
                continue

            # Mark the node as eliminated
            eliminatedNodes.add(currNode)
            # Get the neighbours of the node that are not eliminated
            neighbours = temporaryGraph[currNode] - eliminatedNodes
            # Record the clique as the current node and the neighbours
            currClique = set(neighbours)
            currClique.add(currNode)
            candidateCliques.append(currClique)

            # Add fill edges among the neighours in the chordal graph
            for i in neighbours:
                for j in neighbours:
                    # Add edge in the chordal graph if it is not already present
                    if i == j:
                        continue
                    if j not in chordalGraph[i]:
                        chordalGraph[i].add(j)
                        chordalGraph[j].add(i)
                    # Add edge in the temporary graph if it is not already present
                    if j not in temporaryGraph[i]:
                        temporaryGraph[i].add(j)
                        temporaryGraph[j].add(i)
                        degreeInfo[i] += 1
                        degreeInfo[j] += 1
            
            # Remove the current node from the neighbours
            for neighbour in neighbours:
                temporaryGraph[neighbour].discard(currNode)
            temporaryGraph[currNode].clear()
            # Update the degree info
            degreeInfo[currNode] = 0

            # Update the heap with the new degrees
            for neighbour in neighbours:
                heapq.heappush(degreeHeap, (degreeInfo[neighbour], neighbour))
        
        # After exiting the while loop the graph is now triangulated
        self.triangulatedGraph = chordalGraph
        # Extract the maximal cliques from the candidate cliques by removing the duplicates and subsets
        candidateCliques = list(set(frozenset(clique) for clique in candidateCliques))
        finalCliques = []
        for clique in candidateCliques:
            if not any(clique < others for others in candidateCliques if clique != others):
                finalCliques.append(clique)
        self.maximalCliques = [set(clique) for clique in finalCliques]

    
    def get_junction_tree(self):
        """
        Construct the junction tree from the maximal cliques.
        
        What to do here:
        ----------------
        - Create a junction tree using the maximal cliques obtained from the triangulated graph.
        - For each pair of cliques, compute the common variables.
          Then define the directed separator sets:
              S_ij = clique_i - (clique_i ∩ clique_j)
              S_ji = clique_j - (clique_i ∩ clique_j)
        - Use the size of the common set as the weight to construct a maximum spanning tree.
        - Store the junction tree as a list of tuples:
              (clique_i, clique_j, S_ij, S_ji)
          where S_ij is the separator when a message is passed from clique i to clique j,
          and S_ji is for the reverse direction.
        """
        # Build the weighted graph among the cliques
        cliques = self.maximalCliques
        cliqueCount = len(cliques)

        if cliqueCount <= 1:
            self.junctionTree = []
            return
        
        # Collect the edges of the graph as (weight, i, j) where i and j are the indices of the cliques
        weightedEdges = []
        for i in range(cliqueCount):
            for j in range(i + 1, cliqueCount):
                intersection = cliques[i].intersection(cliques[j])
                weight = len(intersection)
                # Add the edge only if the intersection is non-empty
                if weight > 0:
                    # Store the negative weight to use the min heap as a max heap
                    weightedEdges.append((-weight, i, j))
        
        # If there are no edges, then the junction tree is empty
        if not weightedEdges and cliqueCount > 0:
            # Implies that original graph has no intersecting cliques
            self.junctionTree = []
            return
        
        # Initialize the heap with the edges
        heapq.heapify(weightedEdges)

        # Initialize the union-find set to keep track of the connected components
        parent = list(range(cliqueCount))
        rank = [0] * cliqueCount

        # Define the find and union functions for the union-find set
        def find(x):
            if parent[x] != x:
                parent[x] = find(parent[x])
            return parent[x]

        def union(x, y):
            rootX, rootY = find(x), find(y)
            if rootX == rootY:
                return False
            # Union by Rank
            if rank[rootX] > rank[rootY]:
                parent[rootY] = rootX
            elif rank[rootX] < rank[rootY]:
                parent[rootX] = rootY
            else:
                parent[rootY] = rootX
                rank[rootX] += 1
            return True
        
        # Initialize the junction tree as a list of tuples
        spanningTreeEdges = []
        components = cliqueCount
        while weightedEdges and components > 1:
            weight, i, j = heapq.heappop(weightedEdges)
            if union(i, j):
                components -= 1
                spanningTreeEdges.append((i, j, -weight))
        
        # Define the directed separator sets for each edge in the junction tree
        junctionTree = []
        for i, j, weight in spanningTreeEdges:
            intersection = cliques[i].intersection(cliques[j])
            separatorIJ = cliques[i] - intersection
            separatorJI = cliques[j] - intersection
            junctionTree.append((cliques[i], cliques[j], separatorIJ, separatorJI))
            junctionTree.append((cliques[j], cliques[i], separatorJI, separatorIJ))
        
        self.junctionTree = junctionTree

    
    def assign_potentials_to_cliques(self):
        """
        Assign potentials to the cliques in the junction tree.
        
        What to do here:
        ----------------
        - Map the given potentials (from the input data) to the corresponding cliques in the junction tree.
        - Ensure the potentials are correctly associated with the cliques for message passing.
        
        Refer to the sample test case for how potentials are associated with cliques.
        """
        # Convert the maximal cliques into a set of frozensets for easy comparison
        frozenMaximalCliques = [frozenset(clique) for clique in self.maximalCliques]
        self.cliquePotentials = {frozenClique: [] for frozenClique in frozenMaximalCliques}

        # Iterate over each potential from the input data and assign it to the corresponding clique
        for potentialDict in self.cliquesAndPotentials:
            variableList = potentialDict.get("cliques", [])
            potentialValues = potentialDict.get("potentials", [])
            frozenVariableList = frozenset(variableList)
            # Find all the maximal cliques containing the variables in the potential
            for frozenClique in frozenMaximalCliques:
                if frozenVariableList.issubset(frozenClique):
                    self.cliquePotentials[frozenClique].append(potentialValues)
    

    def get_z_value(self):
        """
            Compute the partition function (Z value) of the graphical model.
            Implements sum-product message passing on the junction tree.

            Broad Algorithm:
            1. Choose a clique as the root of the junction tree.
            2. Build a tree adjacency and find the post-order traversal of the junction tree from the leaves upto the root
            3. For each directed edge (child -> parent), compute the message by:
                - Taking child's factor by * all incoming messages from child's children
                - Summing out the child's variables NOT in child's separator set
            4. At the root, multiply the root's factors by all the factors from its children
            5. Sum out the variables in the root factor
        """
        # Edge case to consider when there are no cliques
        if not self.maximalCliques:
            return 0.0
        
        






In [98]:
# Helper Functions to test the implementation
def printUndirectedGraph(graph):
    print("Undirected Graph:")
    for node, neighbors in graph.items():
        print(f"{node}: {sorted(neighbors)}")

def printTriangulatedGraph(graph):
    print("Triangulated Graph:")
    for node, neighbors in graph.items():
        print(f"{node}: {sorted(neighbors)}")

In [99]:
sampleInference = Inference(data[0]['Input'])
print(sampleInference.testcaseNumber)
print(sampleInference.variableCount)
print(sampleInference.potentialCount)
print(sampleInference.cliquesAndPotentials)

1
4
3
[{'clique_size': 2, 'cliques': [0, 1], 'potentials': [3, 4, 5, 6]}, {'clique_size': 2, 'cliques': [1, 2], 'potentials': [2, 7, 1, 3]}, {'clique_size': 2, 'cliques': [0, 3], 'potentials': [5, 8, 2, 7]}]


In [100]:
printUndirectedGraph(sampleInference.undirectedGraph)

Undirected Graph:
0: [1, 3]
1: [0, 2]
2: [1]
3: [0]


In [101]:
sampleInference.triangulate_and_get_cliques()
printTriangulatedGraph(sampleInference.triangulatedGraph)

Triangulated Graph:
0: [1, 3]
1: [0, 2]
2: [1]
3: [0]


In [102]:
sampleInference.get_junction_tree()
print("Junction Tree:")
for clique in sampleInference.junctionTree:
    print(clique)

Junction Tree:
({0, 1}, {1, 2}, {0}, {2})
({1, 2}, {0, 1}, {2}, {0})
({0, 1}, {0, 3}, {1}, {3})
({0, 3}, {0, 1}, {3}, {1})


In [103]:
sampleInference.assign_potentials_to_cliques()
print("Clique Potentials:")
for clique, potentials in sampleInference.cliquePotentials.items():
    print(f"Clique : {clique} : Potential {potentials}")

Clique Potentials:
Clique : frozenset({0, 1}) : Potential [[3, 4, 5, 6]]
Clique : frozenset({1, 2}) : Potential [[2, 7, 1, 3]]
Clique : frozenset({0, 3}) : Potential [[5, 8, 2, 7]]


In [104]:
# Test the get_z_value function
zValue = sampleInference.get_z_value()
print(f"Z Value: {zValue}")

[2, 7, 1, 3]


TypeError: cannot unpack non-iterable int object