# Programming Assignment 1 ARK - CS726

## Version 2 
- messes up calculation of Z, and for top-k assignments, is giving the same value of probability so some issue here as well
## Updates
- using tuples instead of lists, as they are not compatible with dictionaries (need to be immutable)
- completed framework of code and added snippet to check TestCases.json

In [21]:
import json
import heapq
import itertools

In [41]:


########################################################################
# Do not install any external packages. You can only use Python's default libraries such as:
# json, math, itertools, collections, functools, random, heapq, etc.
########################################################################

class Inference:
    def __init__(self, data):
        """
        Initialize the Inference class with the input data.
        
        Parameters:
        -----------
        data : dict
            The input data containing the graphical model details, such as variables, cliques, potentials, and k value.
        
        What to do here:
        ----------------
        - Parse the input data and store necessary attributes (e.g., variables, cliques, potentials, k value).
        - Initialize any data structures required for triangulation, junction tree creation, and message passing.
        
        Refer to the sample test case for the structure of the input data.
        """
        # Parse basic input data
        self.test_case_number = data.get("TestCaseNumber", None)
        self.variables_count = data.get("VariablesCount", None)
        self.potentials_count = data.get("Potentials_count", None)
        self.cliques_potentials = data.get("Cliques and Potentials", [])
        self.k = data.get("k value (in top k)", 1)
        
        # Build the initial undirected graph.
        # Assume nodes are labeled 0 to variables_count - 1.
        self.graph = {i: set() for i in range(self.variables_count)}     #initializing as a dictionary
        for cp in self.cliques_potentials:
            nodes = cp.get("cliques", [])
            # For each clique, add an edge between every pair of nodes.
            for i in range(len(nodes)):
                for j in range(i + 1, len(nodes)):
                    u = nodes[i]
                    v = nodes[j]
                    self.graph[u].add(v)
                    self.graph[v].add(u)
        
        # Store the potentials for later use
        self.potentials = self.cliques_potentials
        
        # Placeholders for triangulated graph and maximal cliques
        self.triangulated_graph = None
        self.maximal_cliques = []
        self.junction_tree = None
        self.assigned_clique_potentials = None

    def triangulate_and_get_cliques(self):
        """
        Triangulate the undirected graph and extract the maximal cliques.
        
        What to do here:
        ----------------
        - Implement the triangulation algorithm to make the graph chordal.
        - Extract the maximal cliques from the triangulated graph.
        - Store the cliques for later use in junction tree creation.

        Refer to the problem statement for details on triangulation and clique extraction.
        """
        # Make a copy of the original graph to update fill edges
        triangulated = {node: set(neighbors) for node, neighbors in self.graph.items()}
        # Create a working copy for elimination
        work_graph = {node: set(neighbors) for node, neighbors in triangulated.items()}
        candidate_cliques = []

        while work_graph:
            # Select the vertex with the minimum degree.
            v = min(work_graph, key=lambda x: len(work_graph[x]))
            neighbors = work_graph[v]
            
            # Form a clique: vertex v and all its neighbors.
            current_clique = set(neighbors)
            current_clique.add(v)
            candidate_cliques.append(current_clique)
            
            # Add fill edges in the triangulated graph: connect all neighbors of v.
            for u in neighbors:
                for w in neighbors:
                    if u != w:
                        triangulated[u].add(w)
                        triangulated[w].add(u)
            
            # Remove v from the working graph.
            del work_graph[v]
            for u in work_graph:
                work_graph[u].discard(v)

        # Deduplicate candidate cliques: only retain cliques that are not subsets of others.
        maximal = []
        for clique in candidate_cliques:
            if not any(clique < other for other in candidate_cliques if clique != other):
                maximal.append(sorted(list(clique)))  # sort for consistency

        self.triangulated_graph = triangulated
        self.maximal_cliques = maximal


    def get_junction_tree(self):
        """
        Construct the junction tree from the maximal cliques.
        
        What to do here:
        ----------------
        - Create a junction tree using the maximal cliques obtained from the triangulated graph.
        - For each pair of cliques, compute the common variables.
          Then define the directed separator sets:
              S_ij = clique_i - (clique_i ∩ clique_j)
              S_ji = clique_j - (clique_i ∩ clique_j)
        - Use the size of the common set as the weight to construct a maximum spanning tree.
        - Store the junction tree as a list of tuples:
              (clique_i, clique_j, S_ij, S_ji)
          where S_ij is the separator when a message is passed from clique i to clique j,
          and S_ji is for the reverse direction.
        """
        if not self.maximal_cliques:
            raise Exception("Triangulation must be done before constructing the junction tree.")
        
        cliques = self.maximal_cliques
        n = len(cliques)
        edges = []
        # For every pair of cliques, compute the common variables and then the directed separator sets.
        for i in range(n):
            for j in range(i+1, n):
                common = set(cliques[i]).intersection(set(cliques[j]))
                weight = len(common)
                if weight > 0:
                    # Separator for message from clique i to clique j:
                    S_ij = sorted(list(set(cliques[i]) - common))
                    # Separator for message from clique j to clique i:
                    S_ji = sorted(list(set(cliques[j]) - common))
                    edges.append((i, j, weight, S_ij, S_ji))
        # Sort edges in descending order by weight (using common set size)
        edges.sort(key=lambda x: x[2], reverse=True)
        
        # Use union-find to build the maximum spanning tree.
        parent = list(range(n))
        def find(x):
            while parent[x] != x:
                parent[x] = parent[parent[x]]
                x = parent[x]
            return x
        def union(x, y):
            rootx = find(x)
            rooty = find(y)
            if rootx != rooty:
                parent[rooty] = rootx
        
        junction_tree = []
        for i, j, weight, S_ij, S_ji in edges:
            if find(i) != find(j):
                union(i, j)
                # Store the edge along with both directed separator sets.
                junction_tree.append((cliques[i], cliques[j], S_ij, S_ji))
        
        self.junction_tree = junction_tree

    def print_junction_tree(self):
        """Print the structure of the junction tree with directed separator sets."""
        if self.junction_tree is None:
            print("Junction tree has not been constructed yet.")
            return
        print("Junction Tree Structure (Directed Separator Sets):")
        for idx, (clique_i, clique_j, S_ij, S_ji) in enumerate(self.junction_tree, 1):
            print(f"Edge {idx}:")
            print(f"   Message from {clique_i} to {clique_j}: Separator = {S_ij}")
            print(f"   Message from {clique_j} to {clique_i}: Separator = {S_ji}")



    def assign_potentials_to_cliques(self):
        """
        Assign potentials to the cliques in the junction tree.
        
        What to do here:
        ----------------
        - Map the given potentials (from the input data) to the corresponding cliques in the junction tree.
        - Ensure the potentials are correctly associated with the cliques for message passing.
        
        Refer to the sample test case for how potentials are associated with cliques.
        """
        # Create a mapping from a sorted tuple of clique nodes to its potentials from the input.
        mapping = {}
        for cp in self.potentials:
            clique_nodes = sorted(cp.get("cliques", []))
            mapping[tuple(clique_nodes)] = cp.get("potentials", [])
        
        self.assigned_clique_potentials = {}
        for clique in self.maximal_cliques:
            key = tuple(clique)
            if key in mapping:
                self.assigned_clique_potentials[key] = mapping[key]
            '''else: ## VERY SUSPICIOUS DEAFULT UNIFORM POTENTIAL ##
                # If the clique is not provided explicitly, assume a default uniform potential.
                size = len(clique)
                self.assigned_clique_potentials[key] = [1] * (2 ** size)'''
            

    def get_z_value(self):
        """
            Compute the partition function (Z value) of the graphical model.
            Implements sum-product message passing on the junction tree.
        """
        if not self.junction_tree:
            raise Exception("Junction tree must be constructed before computing Z.")

        root = tuple(self.junction_tree[0][0])  # Ensure root is a tuple
        messages = {}

        def send_message(clique_from, clique_to, visited=set()):
            """Computes message from clique_from to clique_to recursively."""
            clique_from = tuple(clique_from)  # Convert to tuple
            clique_to = tuple(clique_to) if clique_to is not None else None  # Convert to tuple if not None

            # If message was already computed, return cached value
            if (clique_from, clique_to) in messages:
                return messages[(clique_from, clique_to)]

            # Avoid infinite loops
            if clique_from in visited:
                return 1  # Identity for multiplication

            visited.add(clique_from)

            # Get clique potential, defaulting to 1 if unassigned
            potential = sum(self.assigned_clique_potentials.get(clique_from, [1]))

            # Recursively send messages to all other neighbors except the target clique
            for neighbor in self.junction_tree:
                # Handle cases with more than two elements in the neighbor
                clique_list = [tuple(clique) for clique in neighbor]  # Ensure all cliques are tuples
                if clique_from in clique_list and clique_to not in clique_list:
                    other_clique = next(clique for clique in clique_list if clique != clique_from)
                    potential *= send_message(other_clique, clique_from, visited.copy())  # Multiply incoming messages

            messages[(clique_from, clique_to)] = potential
            return potential

        # Compute messages to the root from its neighbors
        Z = send_message(root, None)  # Start message passing with the root clique

        print(f"Debug: Computed Partition Function Z = {Z}")
        return Z

    def compute_marginals(self):
        """
        Compute the marginal probabilities for all variables in the graphical model.
        
        What to do here:
        ----------------
        - Use the message passing algorithm to compute the marginal probabilities for each variable.
        - Return the marginals as a list of lists, where each inner list contains the probabilities for a variable.
        
        Refer to the sample test case for the expected format of the marginals.
        """
        """Compute marginals for each variable using sum-product."""
        if not self.junction_tree:
            raise Exception("Junction tree must be constructed before computing marginals.")
        
        marginals = {i: [0, 0] for i in range(self.variables_count)}
        Z = self.get_z_value()
        
        for clique in self.maximal_cliques:
            key = tuple(clique)
            if key in self.assigned_clique_potentials:
                for var in clique:
                    marginals[var][0] += sum(self.assigned_clique_potentials[key]) / Z
                    marginals[var][1] += (1 - sum(self.assigned_clique_potentials[key]) / Z)
        
        return [[marginals[i][0], marginals[i][1]] for i in range(self.variables_count)]
    


    def compute_top_k(self, k=3):
        """
        Compute the top-k most probable assignments in the graphical model.
        Uses max-product message passing and normalizes probabilities using Z.
        """
        if not self.junction_tree:
            raise Exception("Junction tree must be constructed before computing top-k.")

        Z = self.get_z_value()
        if Z is None or Z == 0:
            raise ValueError("Partition function Z is invalid, cannot normalize probabilities.")

        assignments = []
        heap = []

        for clique in self.maximal_cliques:
            key = tuple(clique)
            if key in self.assigned_clique_potentials:
                for assignment in itertools.product([0, 1], repeat=len(clique)):
                    # Compute joint probability
                    joint_prob = self.assigned_clique_potentials[key][0]  # Start with first potential
                    for i, val in enumerate(assignment[1:]):  # Iterate through remaining variables
                        joint_prob *= self.assigned_clique_potentials[key][i + 1]  # Multiply correctly

                    # Normalize the probability
                    normalized_prob = joint_prob / Z

                    # Store negative probability for max heap
                    heapq.heappush(heap, (-normalized_prob, assignment))

        # Retrieve top-k assignments
        for _ in range(k):
            if heap:
                prob, assignment = heapq.heappop(heap)
                assignments.append((assignment, -prob))  # Convert back to positive probability

        # Debugging
        print(f"Debug: Computed Z = {Z}")
        print(f"Debug: Top-{k} Assignments = {assignments}")

        return assignments



    # Helper methods for display
    def display_graph(self):
        """Neatly display the original undirected graph."""
        print("Undirected Graph:")
        for node in sorted(self.graph.keys()):
            print(f"{node}: {sorted(list(self.graph[node]))}")

    def display_triangulated_graph(self):
        """Neatly display the triangulated graph."""
        if self.triangulated_graph is None:
            print("Triangulated graph not computed yet.")
        else:
            print("Triangulated Graph:")
            for node in sorted(self.triangulated_graph.keys()):
                print(f"{node}: {sorted(list(self.triangulated_graph[node]))}")


########################################################################
# Do not change anything below this line
########################################################################

class Get_Input_and_Check_Output:
    def __init__(self, file_name):
        with open(file_name, 'r') as file:
            self.data = json.load(file)
    
    def get_output(self):
        n = len(self.data)
        output = []
        for i in range(n):
            inference = Inference(self.data[i]['Input'])
            inference.triangulate_and_get_cliques()
            inference.get_junction_tree()
            inference.assign_potentials_to_cliques()
            z_value = inference.get_z_value()
            marginals = inference.compute_marginals()
            top_k_assignments = inference.compute_top_k()
            output.append({
                'Marginals': marginals,
                'Top_k_assignments': top_k_assignments,
                'Z_value' : z_value
            })
        self.output = output

    def write_output(self, file_name):
        with open(file_name, 'w') as file:
            json.dump(self.output, file, indent=4)




In [46]:
if __name__ == '__main__':
    with open('Sample_Testcase.json', 'r') as f:
        data = json.load(f)

    test_inference = Inference(data[0]['Input'])

    print("\n=== Original Undirected Graph ===")
    test_inference.display_graph()

    test_inference.triangulate_and_get_cliques()
    print("\n=== Triangulated Graph ===")
    test_inference.display_triangulated_graph()

    test_inference.get_junction_tree()
    print("\n=== Junction Tree (Directed Separator Sets) ===")
    test_inference.print_junction_tree()

    test_inference.assign_potentials_to_cliques()
    print("\n=== Assigned Potentials to Cliques ===")
    for clique, potentials in test_inference.assigned_clique_potentials.items():
        print(f"Clique {list(clique) if isinstance(clique, tuple) else clique}: Potentials {potentials}")



=== Original Undirected Graph ===
Undirected Graph:
0: [1, 3]
1: [0, 2]
2: [1]
3: [0]

=== Triangulated Graph ===
Triangulated Graph:
0: [1, 3]
1: [0, 2]
2: [1]
3: [0]

=== Junction Tree (Directed Separator Sets) ===
Junction Tree Structure (Directed Separator Sets):
Edge 1:
   Message from [1, 2] to [0, 1]: Separator = [2]
   Message from [0, 1] to [1, 2]: Separator = [0]
Edge 2:
   Message from [0, 1] to [0, 3]: Separator = [1]
   Message from [0, 3] to [0, 1]: Separator = [3]

=== Assigned Potentials to Cliques ===
Clique [1, 2]: Potentials [2, 7, 1, 3]
Clique [0, 1]: Potentials [3, 4, 5, 6]
Clique [0, 3]: Potentials [5, 8, 2, 7]


In [45]:
if __name__ == '__main__':
    with open('TestCases.json', 'r') as f:
        data = json.load(f)

    for i, test_case in enumerate(data):
        print(f"\n\n===================== Test Case {i + 1} =====================\n")

        test_inference = Inference(test_case['Input'])

        print("\n=== Original Undirected Graph ===")
        test_inference.display_graph()

        test_inference.triangulate_and_get_cliques()
        print("\n=== Triangulated Graph ===")
        test_inference.display_triangulated_graph()

        test_inference.get_junction_tree()
        print("\n=== Junction Tree (Directed Separator Sets) ===")
        test_inference.print_junction_tree()

        test_inference.assign_potentials_to_cliques()
        print("\n=== Assigned Potentials to Cliques ===")
        for clique, potentials in test_inference.assigned_clique_potentials.items():
            print(f"Clique {list(clique)}: Potentials {potentials}")

        # Add placeholders for additional computations
        print("\n=== Partition Function (Z Value) ===")
        try:
            z_value = test_inference.get_z_value()
            print(f"Z = {z_value}")
        except TypeError as e:
            print(f"Error in computing Z value: {e}")

        print("\n=== Marginal Probabilities ===")
        try:
            marginals = test_inference.compute_marginals()
            for var_idx, probs in enumerate(marginals):
                print(f"Variable {var_idx}: {probs}")
        except Exception as e:
            print(f"Error in computing marginals: {e}")

        print("\n=== Top-k Most Probable Assignments ===")
        try:
            top_k_results = test_inference.compute_top_k()
            for rank, (assignment, prob) in enumerate(top_k_results, 1):
                print(f"Rank {rank}: Assignment {assignment}, Probability {prob}")
        except Exception as e:
            print(f"Error in computing top-k: {e}")

        print("\n============================================================\n")






=== Original Undirected Graph ===
Undirected Graph:
0: [1, 2, 3]
1: [0, 4]
2: [0, 3, 4]
3: [0, 2]
4: [1, 2]

=== Triangulated Graph ===
Triangulated Graph:
0: [1, 2, 3, 4]
1: [0, 4]
2: [0, 3, 4]
3: [0, 2]
4: [0, 1, 2]

=== Junction Tree (Directed Separator Sets) ===
Junction Tree Structure (Directed Separator Sets):
Edge 1:
   Message from [0, 1, 4] to [2, 4]: Separator = [0, 1]
   Message from [2, 4] to [0, 1, 4]: Separator = [2]
Edge 2:
   Message from [0, 1, 4] to [0, 2, 3]: Separator = [1, 4]
   Message from [0, 2, 3] to [0, 1, 4]: Separator = [2, 3]

=== Assigned Potentials to Cliques ===
Clique [2, 4]: Potentials [15, 18, 12, 14]

=== Partition Function (Z Value) ===
Debug: Computed Partition Function Z = 59
Z = 59

=== Marginal Probabilities ===
Debug: Computed Partition Function Z = 59
Variable 0: [0, 0]
Variable 1: [0, 0]
Variable 2: [1.0, 0.0]
Variable 3: [0, 0]
Variable 4: [1.0, 0.0]

=== Top-k Most Probable Assignments ===
Debug: Computed Partition Function Z = 59
Debug

In [44]:
if __name__ == '__main__':
    evaluator = Get_Input_and_Check_Output('Sample_Testcase.json')
    evaluator.get_output()
    evaluator.write_output('Sample_Testcase_Output.json')

Debug: Computed Partition Function Z = 5148
Debug: Computed Partition Function Z = 5148
Debug: Computed Partition Function Z = 5148
Debug: Computed Z = 5148
Debug: Top-3 Assignments = [((0, 0), 0.00777000777000777), ((0, 1), 0.00777000777000777), ((1, 0), 0.00777000777000777)]
