#FOPPL grammar rules adapted to Python
The idea is to use normal Python syntax for control structures, function calls, and assignments that mirror FOPPL's 8 grammar constructs.

1. Variables (v) and constants (c): \\
In Python, variables can be assigned directly, and constants are just literal values:

```python
v = 42    # variable
c = 3.14  # constant
```
2. Let-bindings (let [v e1] e2):

There is no special let keyword in Python, but a simple assignment followed by further code accomplishes the same structural effect. \\
```python
v = e1
#Now use v in e2
result = e2
```

3. Conditionals (if e1 e2 e3):

```python
if e1:
    result = e2
else:
    result = e3
```

4. Function calls (f e1 ... en) and (c e1 ... en):

If f is a function and c is a primitive operation (like + or *), in Python you just call them.
```python
result = f(e1, e2, ..., en)
result = e1 + e2  # if c is the '+' operation
```

5. Sampling (sample e) and Observing (observe e1 e2):

Note: Sample can be an Assignment or an Expr. We can therefore call a "sample" as argument (e1) in an "observe"
```python
x = sample(e)
observe(e1, e2)
```

6. Function Definitions (defn f [v1 ... vn] e) q:

```python
def f(v1, v2, ..., vn):
    return e
# Then q follows as normal Python statements
q
```

7. Foreach (foreach c [v1 e1 ... vn en] e1' ... e'k):

```python
for i in range(c):
    v1 = e1[i]
    ...
    vn = en[i]
    # Then execute e1', e2', ..., e'k inside the loop
```

8. Loop (loop c e f e1 ... en):

```python
# Assume c is an integer, e is some initial value, f is a function, and
# e1, e2, ..., en are expressions producing values.
a1 = e1
a2 = e2
...
an = en
# Repeat the application of f, incrementing the index each time
for i in range(c):
    v = f(i, v, a1, a2, ..., an)
# After c steps, v holds the final result that would appear at the end of the desugared let chain.
```

#Examples of models written in PyGM:
We will use these models to help us debug as we code our translator. (TODO: Add the functionality of using an input data file)

1. Example 1

```python
def sqrt(x):
    return x ** 0.5

z = sample(bernoulli(0.5))

d = normal(z, 1.0)

y0 = 0.25*2
observe(d, y0)

observe(z, 1)

y2 = sqrt(4)
observe(sample(normal(5, 1.0)), y2)

z

```

2. GMM

```python
# Given data.
data = [1.1, 2.1, 2.0, 1.9, 0.0, -0.1, -0.05]
likes = []
for _ in range(3):
    mu = sample(normal(0.0, 10.0))
    sigma = sample(gamma(1.0, 1.0))
    likes.append(normal(mu, sigma))
pi = sample(dirichlet([1.0, 1.0, 1.0]))
z_prior = discrete(pi)
for y in data:
    z = sample(z_prior)
    observe(likes[z], y)
```
3. HMM

```python
data = [0.9, 0.8, 0.7, 0.0, -0.025, -5.0, -2.0, -0.1, 0.0, 0.13, 0.45, 6, 0.2, 0.3, -1, -1]

trans_dists = [
    multinomial(1, [0.10, 0.50, 0.40]),
    multinomial(1, [0.20, 0.20, 0.60]),
    multinomial(1, [0.15, 0.15, 0.70])
]

likes = [
    norm(-1.0, 1.0),
    norm(1.0, 1.0),
    norm(0.0, 1.0)
]

states = [sample(multinomial(1, [0.33, 0.33, 0.34]))]

for i in range(16):
    z = sample(trans_dists[states[-1]])
    observe(likes[z], data[i])
    append(states, z)

```

Guide to use PyGM :

Our language PyGM acts as a Bayesian approximator to the joint of latent variables conditioned on the observed variables.

1. You can't name variables or functions "zi" or "yi" where i is an integer.
2. To use the append function, instead of writing "array.append(elem)" ---> use "append(array, elem)"
3. To use "for loops", there are two syntaxes to do it :    
- "for _ in range():"

  For this syntax, using range(), you can put 1, 2 or 3 arguments in the range function BUT they have to be constants or variable names referring to constants.

- "for _ in stored_var:"

  For this syntax, using a stored variable, stored_var has to have been declared as an array/list before the loop.

In both cases, the target "\_" can be named as you like. After the for loop, the variable "\_" will be stored in the environment with the last value it was assigned.

  4. To use distributions you must respect the following syntax for each one :    

- Normal : "norm(mu, sigma)"
- Gamma : "gamma(alpha, theta)"
- Dirichlet : "dirichlet([elem1, elem2, ..., elemk])"
- Beta : "beta(alpha, beta_arg)"
- Binomial : "binomial(n, p)"
- Discrete : "multinomial(1, [prob1, prob2, ..., probk])"
- Multinomial : "multinomial(n, [prob1, prob2, ..., probk])"
- Poisson : "poisson(lambda)"
- Bernoulli : "bernoulli(p)"

5. The follow porgramming tools have not yet been implemented :

- if statement
- function definition
- any function that is not part of the basic operations (+, -, \*, /, **, %)


#Import source codes from Drive

In [None]:
import ast, copy
import os
from scipy.stats import norm, gamma, poisson, dirichlet, beta, binom, multinomial
import numpy as np
from google.colab import drive
from dataclasses import dataclass, field
from typing import Set, List, Tuple, Dict, Any
import re
from copy import deepcopy
import math
import networkx as nx
import matplotlib.pyplot as plt

#drive.mount('/content/drive')
#shared_folder_path = '/content/drive/My Drive/Source_code_examples'

#Translator

In [None]:
class Graph:
    def __init__(self, v, a, p, y):
        """
        v is the set of vertices : {x, y}
        a is the set of arcs (directed, from x to y here) : {(x, y)}
        p is a dictionary from vertex to mass/density function : {x : distribution_1, y : distribution_2}
        y is a dictionary from vertex to its observed value : {y : observed_value}
        """
        self.v = v
        self.a = a
        self.p = p
        self.y = y

    def __str__(self):
        return f"Graph(\nv={self.v},\na={self.a},\np={self.p_to_string()},\ny={self.y})"

    def p_to_string(self):
        string = "{"
        for key, value in self.p.items():
            string += (f"{key}: {str(value)}\n")
        return string[:-1] + "}"

    def print_p(self):
      for key, value in self.p.items():
        print(key)
        print(ast.dump(value, indent=4))

    def print_y(self):
      for key, value in self.y.items():
        print(key)
        print(ast.dump(value, indent=4))


def create_empty_graph():
    graph = Graph(set(), set(), {}, {})
    return graph

def combine_disjoint_graphs(g1, g2):
    """
    Combines two graphs into one, ensuring node and attribute names are unique
    to avoid conflicts when the graphs share nodes with the same names.
    """
    # If g1 or g2 is a list of graph
    if isinstance(g1, list):
        new_g1 = create_empty_graph()
        for graph in g1:
            new_g1 = combine_disjoint_graphs(new_g1, graph)
        g1 = new_g1
    if isinstance(g2, list):
        new_g2 = create_empty_graph()
        for graph in g2:
            new_g2 = combine_disjoint_graphs(new_g2, graph)
        g2 = new_g2

    combined_graph = Graph(g1.v.union(g2.v), g1.a.union(g2.a), {**g1.p, **g2.p}, {**g1.y, **g2.y})
    return combined_graph

def visualize_graph_variables(graph):
    print(str(graph))
    return

def copy_graph(graph):
    # For list of graphs
    if isinstance(graph, list):
        new_graph_list = []
        for g in graph:
            new_graph_list.append(copy_graph(g))
        return new_graph_list

    new_graph = Graph(graph.v.copy(), graph.a.copy(), graph.p.copy(), graph.y.copy())
    return new_graph

# Use this function to see the vertices and the edges of a graph
def show_graph(graph):
    # Your vertex and edge lists
    v=graph.v
    a=graph.a

    # Create a graph
    G = nx.DiGraph()
    G.add_nodes_from(v)
    G.add_edges_from(a)

    # Assign colors based on node prefix
    color_map = []
    for node in G.nodes:
        if node.startswith('z'):
            color_map.append('lightblue')  # Nodes starting with 'z' get light blue
        elif node.startswith('y'):
            color_map.append('lightgreen')  # Nodes starting with 'y' get light green

    # Draw the graph
    plt.figure(figsize=(12, 8))
    pos = nx.spring_layout(G)  # Position nodes for visualization
    nx.draw(
        G, pos, with_labels=True, node_size=1500, node_color=color_map,
        font_size=25, font_weight="bold", arrowsize=20
    )

    # Display the plot
    plt.title("Graph Visualization of the source code")
    plt.show()

# Function to evaluate ast.node that is deterministic :
def partial_evaluate(node):
    ast_to_eval = ast.Assign(targets=[ast.Name(id='result', ctx=ast.Store())], value=node)
    ast.fix_missing_locations(ast_to_eval)
    module = ast.Module(body=[ast_to_eval], type_ignores=[])
    ast.fix_missing_locations(module)
    compiled_code = compile(module, filename="<ast>", mode="exec")

    # Prepare a namespace to store the execution result
    namespace = {}
    exec(compiled_code, namespace)

    # Access the assigned value of "result"
    result_value = namespace.get("result")
    return result_value

In [None]:
# Prototype 3 ; recursive translation
class PPL_translator():
    def __init__(self):
        self.phi = True
        self.rho = {}  # map from function names to variable values
        self.env = {}  # map from variable names to variable values
        self.load_rho()
        self.lv_counter = 0  # to keep track of latent variables indexing
        self.ov_counter = 0  # to keep track of observed variables indexing
        self.free_vars = []

    def load_rho(self):
        #Defined functions
        self.rho["norm"] = "logpdf"
        self.rho["gamma"] = "logpdf"
        self.rho["dirichlet"] = "logpdf"
        self.rho["beta"] = "logpdf"
        self.rho["binomial"] = "logpmf"
        self.rho["multinomial"] = "logpmf"
        self.rho["poisson"] = "logpmf"
        self.rho["bernoulli"] = "logpmf"


    def score(self, node, vertex_added):
        if isinstance(node, ast.Call) and node.func.id in self.rho:
            p_zi = "p" + vertex_added
            dist_type = self.rho[node.func.id]
            return ast.Assign(targets=[ast.Name(id=p_zi)], value=ast.Attribute(value=node, attr=dist_type, ctx=ast.Load()))
        elif isinstance(node, ast.Name):
            if bool(re.fullmatch(r'z\d+', node.id)):  # Check if the name refers to a vertex
                p_zi = "p" + vertex_added
                return ast.Assign(targets=[ast.Name(id=p_zi)], value=node)

            name_graph, name_value = self.process(node)  # After that, name_value is supposed to be an ast.Call
            p_zi = "p" + vertex_added
            dist_type = self.rho[name_value.func.id]
            return ast.Assign(targets=[ast.Name(id=p_zi)], value=ast.Attribute(value=name_value, attr=dist_type, ctx=ast.Load()))

        elif isinstance(node, ast.Subscript):
            p_zi = "p" + vertex_added
            distr_list = node.value
            new_distr_elts = []
            for distr in distr_list.elts:
                assert isinstance(distr, ast.Call), "Element in subscript is not a Call"
                dist_type = self.rho[distr.func.id]
                new_distr_elts.append(ast.Attribute(value=distr, attr=dist_type, ctx=ast.Load()))
            return ast.Assign(targets=[ast.Name(id=p_zi)], value=ast.Subscript(value=ast.List(elts=new_distr_elts), slice=node.slice, ctx=ast.Load()))
        else:
            raise ValueError(f"invalid argument for sample: {type(node).__name__}")

    def search_free_vars(self, node):
        if isinstance(node, ast.Name):
            if bool(re.fullmatch(r'z\d+', node.id)):
                  self.free_vars.append(node.id)

        for child in ast.iter_child_nodes(node):
            self.search_free_vars(child)



    def process(self, node):
        """
        Recursively process an AST node and return the graph fragment corresponding to it.
        """
        if isinstance(node, ast.Module):
            # This is the root
            length = len(node.body)
            assert length > 0, "By convention, program to translate has to have at least 1 line of code"
            # assert isinstance(node.body[length-1], ast.Expr), "By convention, last line of the program has to be an ast.Expr node indicading the return value"
            final_graph = create_empty_graph()

            for i in range(length):
                statement = node.body[i]
                graph, value = self.process(statement)
                final_graph = combine_disjoint_graphs(final_graph, graph)
            # return the graph and value of the last statement aka the ast.Expr
            # node that indicates the return value
            return final_graph, value

        elif isinstance(node, ast.Assign):
            # Handle assignment statements
            target = node.targets[0].id  # Assuming single variable on the left
            graph, expression = self.process(node.value)  # Process the right-hand side
            self.env[target] = (graph, expression)  # Save variable for future use
            return graph, ast.Assign(targets=[ast.Name(id=target)], value=expression)

        elif isinstance(node, ast.Call):
            # Handle function calls
            func_name = node.func.id
            if func_name == "sample":
                # Special handling for `sample`
                dist_graph, dist_value = self.process(node.args[0])  # Process the argument
                num_v = self.lv_counter
                self.lv_counter += 1

                vertex_added = "z" + str(num_v)
                dist_graph.v.add(vertex_added)

                self.free_vars = []
                self.search_free_vars(dist_value)

                for var in self.free_vars:
                    dist_graph.a.add(tuple([var, vertex_added]))

                dist_graph.p[vertex_added] = self.score(dist_value, vertex_added)
                return dist_graph, ast.Name(id=vertex_added)

            elif func_name == "observe":
                # Special handling for `observe`
                dist_graph, dist_value = self.process(node.args[0])  # Process the distribution
                obs_graph, obs_value = self.process(node.args[1])  # Process the observed value

                if isinstance(obs_graph, list):
                    for graph in obs_graph:
                        if graph.v != set():
                            raise("observed value is not deterministic")
                else:
                    if obs_graph.v != set():
                        raise("observed value is not deterministic")

                combined_graph = combine_disjoint_graphs(dist_graph, obs_graph)
                num_v = self.ov_counter
                self.ov_counter += 1
                vertex_added = "y" + str(num_v)

                score_func = self.score(dist_value, vertex_added)
                self.free_vars = []
                self.search_free_vars(score_func)

                for var in self.free_vars:
                    combined_graph.a.add(tuple([var, vertex_added]))

                combined_graph.v.add(vertex_added)
                combined_graph.p[vertex_added] = score_func
                combined_graph.y[vertex_added] = obs_value

                return combined_graph, obs_value

            elif func_name == "append":
                element = node.args[1]
                element_graph, element_value = self.process(element)
                array = self.env[node.args[0].id][1]
                graph_array = self.env[node.args[0].id][0]
                array.elts.append(element_value)
                graph_array.append(element_graph)
                return element_graph, None

            elif func_name == "range":
                # Assume deterministic range for now
                # Also assume range args are ast.Constant nodes (can change later if needed)
                args = []
                for arg in node.args:
                    arg_graph, arg_value = self.process(arg)  # arg_graph should be empty, and arg_value should be ast.Constant even if arg is ast.Name
                    args.append(arg_value)
                if len(node.args) == 1:
                    range_return = list(range(args[0].value))
                    return create_empty_graph(), ast.List(elts=[ast.Constant(value=x) for x in range_return], ctx=ast.Load())
                elif len(node.args) == 2:
                    range_return = list(range(args[0].value, args[1].value))
                    return create_empty_graph(), ast.List(elts=[ast.Constant(value=x) for x in range_return], ctx=ast.Load())
                elif len(node.args) == 3:
                    range_return = list(range(args[0].value, args[1].value, args[2].value))
                    return create_empty_graph(), ast.List(elts=[ast.Constant(value=x) for x in range_return], ctx=ast.Load())
                else:
                    raise ValueError("Invalid number of arguments for range")

            else:
                # Generic function call handling
                graph = create_empty_graph()
                args_list = []
                if node.func.id in self.rho:
                    for arg in node.args:
                        arg_graph, arg_value = self.process(arg)
                        graph = combine_disjoint_graphs(graph, arg_graph)
                        args_list.append(arg_value)
                    name = node.func.id
                    value = ast.Call(func=ast.Name(id=name, ctx=ast.Load()), args=args_list, keywords=[])
                else:
                    raise ValueError(f"Unknown function: {node.func.id}")

                return graph, value

        elif isinstance(node, ast.Expr):
            expr_graph, expr_value = self.process(node.value)
            return expr_graph, ast.Expr(value=expr_value)

        elif isinstance(node, ast.Subscript):
            index_graph, index_value = self.process(node.slice)
            value_graph, value_value = self.process(node.value)
            # Evaluate when index is deterministic :
            if index_graph.v == set():
                if isinstance(index_value, ast.Constant):
                    index = partial_evaluate(index_value)
                    elem_value = value_value.elts[index]  # Assuming value_value is ast.List
                    elem_graph = value_graph[index]
                    return elem_graph, elem_value

            final_graph = combine_disjoint_graphs(value_graph, index_graph)
            return final_graph, ast.Subscript(value=value_value, slice=index_value, ctx=ast.Load())

        elif isinstance(node, ast.List):
            l = []
            # lgraph = create_empty_graph()
            lgraph = []
            for elem in node.elts:
                elem_graph, elem_value = self.process(elem)
                # lgraph = combine_disjoint_graphs(lgraph, elem_graph)
                lgraph.append(elem_graph)
                l.append(elem_value)
            return lgraph, ast.List(elts=l, ctx=ast.Load())  # ast.List returns a list of graphs

        elif isinstance(node, ast.Name):
            # Variable reference (return a minimal graph with this variable)
            if node.id in self.env:
                if isinstance(self.env[node.id][1], ast.List):  # This is to make it so when we refer an array in a loop, it computes the current array and not the final array (after next append functions)
                    graph, value = self.env[node.id]
                    return copy_graph(graph), copy.deepcopy(value)
                else:
                    graph, value = self.env[node.id]
            else:
                raise ValueError(f"Unknown variable: {node.id}")
            return graph, value


        elif isinstance(node, ast.For):
            for_var = node.target.id
            iter_graph, iter_value = self.process(node.iter)  # We will assume a deterministic iter for now
            for_graph = create_empty_graph()
            for elem in iter_value.elts:
                self.env[for_var] = (iter_graph, elem)
                for statement in node.body:
                    statement_graph, statement_value = self.process(statement)
                    for_graph = combine_disjoint_graphs(for_graph, statement_graph)
            return for_graph, ast.For(target=node.target, iter=node.iter, body=node.body)  # Not sure about the expression to return here


        elif isinstance(node, ast.Constant):
            # Literal value (return an empty graph)
            return create_empty_graph(), ast.Constant(value=node.value)

        elif isinstance(node, ast.BinOp):
            left_graph, left_value = self.process(node.left)
            right_graph, right_value = self.process(node.right)
            op_graph, op_value = self.process(node.op)

            # Partial evaluation
            if left_graph.v == set() and right_graph.v == set():
                result = partial_evaluate(ast.BinOp(left_value, op_value, right_value))
                return create_empty_graph(), ast.Constant(value=result)

            return combine_disjoint_graphs(left_graph, right_graph), ast.BinOp(left_value, op_value, right_value)

        elif isinstance(node, ast.UnaryOp):
            operand_graph, operand_value = self.process(node.operand)

            # Partial evaluation
            if operand_graph.v == set():
                result = partial_evaluate(ast.UnaryOp(node.op, operand_value))
                return create_empty_graph(), ast.Constant(value=result)
            op_graph, op_value = self.process(node.op)

            return combine_disjoint_graphs(operand_graph, op_graph), ast.UnaryOp(op_value, operand_value)

        elif isinstance(node, ast.Add):
            return create_empty_graph(), ast.Add()

        elif isinstance(node, ast.Sub):
            return create_empty_graph(), ast.Sub()

        elif isinstance(node, ast.Mult):
            return create_empty_graph(), ast.Mult()

        elif isinstance(node, ast.Div):
            return create_empty_graph(), ast.Div()

        elif isinstance(node, ast.Pow):
            return create_empty_graph(), ast.Pow()

        elif isinstance(node, ast.Mod):
            return create_empty_graph(), ast.Mod()

        elif isinstance(node, ast.USub):
            return create_empty_graph(), ast.USub()

        else:
            # Unhandled nodes (for simplicity, treat them as no-op)
            raise ValueError(f"Unhandled node type: {type(node).__name__}")
            return create_empty_graph(), None

In [None]:
# Example function to parse and translate Python source code
def parse_and_translate(source_code: str) -> Tuple[Graph, Any]:
    """
    Parse Python source code and translate it into a Graph and deterministic expression.

    Parameters:
        source_code (str): The Python source code to translate.

    Returns:
        Tuple[Graph, Any]: The combined graphical model and the final deterministic expression.
    """
    tree = ast.parse(source_code)


    ppl_t = PPL_translator()
    combined_graph, expr = ppl_t.process(tree)
    return combined_graph, expr, ppl_t.env

#Evaluator:

In [None]:
def is_one_hot_list(elts):
    # elts are a list of ast.Constant nodes with either 0 or 1 values
    values = [e.value for e in elts if isinstance(e, ast.Constant)]
    return len(values) == len(elts) and sum(values) == 1 and all(v in [0,1] for v in values)

def one_hot_index_list(elts):
    values = [e.value for e in elts if isinstance(e, ast.Constant)]
    return values.index(1)

class ZYReplacer(ast.NodeTransformer):
    def __init__(self, z, y):
        super().__init__()
        self.z = z
        self.y = y

    def visit_Name(self, node):
        # For z_i
        if node.id.startswith('z') and node.id[1:].isdigit():
            idx = int(node.id[1:])
            val = self.z[idx]
            # If val is a scalar, just return a Constant
            if isinstance(val, (int, float)):
                return ast.Constant(value=val)
            # If val is a list, return an ast.List of ast.Constant
            elif isinstance(val, list):
                return ast.List(
                    elts=[ast.Constant(value=v) for v in val],
                    ctx=ast.Load()
                )
            else:
                # Handle other types as needed
                return ast.Constant(value=val)

        # For y_i
        elif node.id.startswith('y') and node.id[1:].isdigit():
            return self.y[node.id]  # already an ast.Constant

        return node

class OneHotIndexTransformer(ast.NodeTransformer):
    def visit_Subscript(self, node):
        self.generic_visit(node)
        # Now check if the slice is a List node
        if isinstance(node.slice, ast.List):
            # Check if it's a one-hot list
            if is_one_hot_list(node.slice.elts):
                idx = one_hot_index_list(node.slice.elts)
                node.slice = ast.Constant(value=idx)
            else:
                raise ValueError("Non one-hot vector used as index.")
            if isinstance(node.value, ast.List) and not hasattr(node.value, 'ctx'):
                node.value.ctx = ast.Load()
        return node


def evaluate_distribution(expr, val):
    code_ast = ast.fix_missing_locations(ast.Expression(expr))
    code = compile(code_ast, "<ast>", "eval")
    dist_func = eval(code)
    return dist_func(val)

def process_and_evaluate(p, y, z):
    p_copy = deepcopy(p)
    replacer = ZYReplacer(z, y)
    for k in p_copy:
        p_copy[k] = replacer.visit(p_copy[k])
    indexer = OneHotIndexTransformer()
    for k in p_copy:
        p_copy[k] = indexer.visit(p_copy[k])

    results = []
    for k, assign_node in p_copy.items():
        var_name = assign_node.targets[0].id
        dist_expr = assign_node.value
        def func(val):
            return evaluate_distribution(dist_expr, val)

        if var_name.startswith('pz'):
            i = int(var_name[2:])
            results.append(func(z[i]))
        elif var_name.startswith('py'):
            i = int(var_name[2:])
            results.append(func(y[f"y{i}"].value))

    logprob = sum(results)
    return logprob, p_copy

#Examples

In [None]:
gmm_source_code = """
data = [1.1, 2.1, 2.0, 1.9, 0.0, -0.1, -0.05]
likes = []
for _ in range(3):
    mu = sample(norm(0.0, 10.0))
    sigma = sample(gamma(1.0, 1.0))
    append(likes, norm(mu, sigma))
pi = sample(dirichlet([1.0, 1.0, 1.0]))
z_prior = multinomial(1, pi)
for y in data:
    z = sample(z_prior)
    observe(likes[z], y)
"""

gmm_source_code_2 = """
data = [1.1, 2.1]
likes = []
for _ in range(2):
    mu = sample(norm(0.0, 10.0))
    sigma = sample(gamma(1.0, 1.0))
    append(likes, norm(mu, sigma))
pi = sample(dirichlet([1.0, 1.0]))
z_prior = multinomial(1, pi)
for y in data:
    z = sample(z_prior)
    observe(likes[z], y)
"""

hmm_source_code = """
data = [0.9, 0.8, 0.7, 0.0, -0.025, -5.0, -2.0, -0.1, 0.0, 0.13, 0.45, 6, 0.2, 0.3, -1, -1]

trans_dists = [
    multinomial(1, [0.10, 0.50, 0.40]),
    multinomial(1, [0.20, 0.20, 0.60]),
    multinomial(1, [0.15, 0.15, 0.70])
]

likes = [
    norm(-1.0, 1.0),
    norm(1.0, 1.0),
    norm(0.0, 1.0)
]

states = [sample(multinomial(1, [0.33, 0.33, 0.34]))]

for i in range(16):
    z = sample(trans_dists[states[-1]])
    observe(likes[z], data[i])
    append(states, z)
"""

hmm_source_code_2 = """
data = [0.9, 0.8, 0.7, 0.0]

trans_dists = [
    multinomial(1, [0.10, 0.50, 0.40]),
    multinomial(1, [0.20, 0.20, 0.60]),
    multinomial(1, [0.15, 0.15, 0.70])
]

likes = [
    norm(-1.0, 1.0),
    norm(1.0, 1.0),
    norm(0.0, 1.0)
]

states = [sample(multinomial(1, [0.33, 0.33, 0.34]))]

for y in data:
    z = sample(trans_dists[states[-1]])
    observe(likes[z], y)
    append(states, z)
"""

GMM1

In [None]:
latent_variables = [
    1,  # mu_1
    2,  # mu_2
    3,  # mu_3
    4,  # sigma_1
    5,  # sigma_2
    6,  # sigma_3
    [0.2, 0.3, 0.5],  # pi (mixing weights as a single vector)
    [1, 0, 0],  # z_1 (one-hot encoding for cluster assignment)
    [0, 1, 0],  # z_2
    [0, 0, 1],  # z_3
    [1, 0, 0],  # z_4
    [0, 1, 0],  # z_5
    [0, 0, 1],  # z_6
    [1, 0, 0]   # z_7
]
source_code = gmm_source_code

graph, expr, vars = parse_and_translate(source_code)
# print(graph)
# print("\n")
# graph.print_p()
# print("\n")
# graph.print_y()
# print("\n")
# print("Variables:\n", vars)

logprob, _ = process_and_evaluate(graph.p, graph.y, latent_variables)
print("Sum of all logprob:", logprob)
print("Product of all prob", math.exp(logprob))

Sum of all logprob: -42.67022133991871
Product of all prob 2.941428733159413e-19


Product of all evaluations: 2.9414287331594206e-19

GMM2

In [None]:
latent_variables = [
    1,  # mu_1
    2,  # mu_2
    3,  # sigma_1
    4,  # sigma_2
    [0.2, 0.8],  # pi (mixing weights as a single vector)
    [1, 0],  # z_1 (one-hot encoding for cluster assignment)
    [0, 1],  # z_2
]
source_code = gmm_source_code_2

graph, expr, vars = parse_and_translate(source_code)

# print(graph)
# print("\n")
# graph.print_p()
# print("\n")
# graph.print_y()
# print("\n")
# print("Variables:\n", vars)
logprob, _ = process_and_evaluate(graph.p, graph.y, latent_variables)
print("Sum of all logprob:", logprob)
print("Product of all prob", math.exp(logprob))

Sum of all logprob: -16.269509824234927
Product of all prob 8.594916464592117e-08


Product of all evaluations: 8.594916464592108e-08

HMM1

In [None]:
latent_variables = [
    [1, 0, 0],  # states[0] (initial state, one-hot encoded)
    [0, 1, 0],
    [0, 0, 1],
    [1, 0, 0],
    [0, 1, 0],
    [0, 1, 0],
    [0, 0, 1],
    [1, 0, 0],
    [0, 1, 0],
    [0, 1, 0],
    [0, 0, 1],
    [1, 0, 0],
    [0, 1, 0],
    [0, 1, 0],
    [0, 0, 1],
    [1, 0, 0],
    [0, 1, 0]
]
source_code = hmm_source_code
graph, expr, vars = parse_and_translate(source_code)
[0.0, 1.0, 0.0, 1.0, [0.5, 0.5], [0,1], [1,0]]
logprob, _ = process_and_evaluate(graph.p, graph.y, latent_variables)
print("Sum of all logprob:", logprob)
print("Product of all prob", math.exp(logprob))

Sum of all logprob: -66.5625237305059
Product of all prob 1.2366968149519487e-29


Product of all evaluations: 1.2366968149519645e-29

HMM2

In [None]:
latent_variables = [
    [1, 0, 0],  # states[0] (initial state, one-hot encoded)
    [0, 1, 0],
    [0, 0, 1],
    [1, 0, 0],
    [0, 1, 0]
]

source_code = hmm_source_code_2

graph, expr, vars = parse_and_translate(source_code)
[0.0, 1.0, 0.0, 1.0, [0.5, 0.5], [0,1], [1,0]]

logprob, _ = process_and_evaluate(graph.p, graph.y, latent_variables)
print("Sum of all logprob:", logprob)
print("Product of all prob", math.exp(logprob))

Sum of all logprob: -10.848656727112065
Product of all prob 1.943069108809343e-05


Product of all evaluations: 1.9430691088093445e-05

#Inference Engine:


In [None]:
class DualTreeTransformer:
    def __init__(self, tree1, tree2):
        self.tree1 = tree1  # Fully evaluated
        self.tree2 = tree2  # Partially evaluated

    def transform(self):
        # Start transforming tree2 using tree1 as a reference
        return self._traverse_and_transform(self.tree1, self.tree2)

    def _traverse_and_transform(self, node1, node2):
        if isinstance(node2, ast.Name):
            # Check if the criteria for modification are met
            if bool(re.fullmatch(r'z\d+', node2.id)) and isinstance(node1, ast.Constant):
                # Replace node2 with the value from node1
                return ast.Constant(value=node1.value, kind=None)

        # Ensure both nodes have children and traverse them
        for field, value in ast.iter_fields(node2):
            if isinstance(value, list):
                new_values = []
                for i, item in enumerate(value):
                    if isinstance(item, ast.AST):
                        # Recursively transform child nodes in both trees
                        new_item = self._traverse_and_transform(
                            getattr(node1, field)[i] if hasattr(node1, field) else None,
                            item
                        )
                        new_values.append(new_item)
                    else:
                        new_values.append(item)
                setattr(node2, field, new_values)
            elif isinstance(value, ast.AST):
                # Recursively transform single child nodes
                new_node = self._traverse_and_transform(
                    getattr(node1, field, None),
                    value
                )
                setattr(node2, field, new_node)
        return node2


def combine_ps(p_all_eval, p_partially_eval):
    # Assume same p structure for both
    for key_eval, value_eval in p_all_eval.items():
        key_p_eval = key_eval
        value_p_eval = p_partially_eval[key_p_eval]
        start_node_eval = value_eval.value
        start_node_p_eval = value_p_eval.value

        transformer = DualTreeTransformer(start_node_eval, start_node_p_eval)
        node_now_eval = transformer.transform()

        p_partially_eval[key_p_eval] = node_now_eval
    return p_partially_eval

In [None]:
def ast_to_python_value(node):
    """
    Convert an AST node (like Constant or List of Constants) to a Python value.
    Adjust this function depending on the node types you expect.
    """
    if isinstance(node, ast.Constant):
        return node.value
    elif isinstance(node, ast.List):
        return [ast_to_python_value(e) for e in node.elts]
    # If you have other node types to handle, add them here
    # For now, assume we only deal with Constants or Lists of Constants.
    raise ValueError("Unsupported node type for conversion: {}".format(type(node)))

def find_z_values_in_ast(unfilled_node, filled_node, z_values):
    """
    Recursively compare unfilled_node and filled_node.
    Whenever we encounter a Name node referencing 'zX' in unfilled_node
    and a non-Name node in filled_node, record that mapping in z_values.
    """
    # If unfilled_node is a Name referencing zX and filled_node is not a Name,
    # we have found the value of zX.
    if isinstance(unfilled_node, ast.Name) and unfilled_node.id.startswith('z'):
        if not isinstance(filled_node, ast.Name):
            # Extract the integer index from 'zX'
            index = int(unfilled_node.id[1:])
            z_values[index] = filled_node
        return

    # If both are lists, compare element by element
    if isinstance(unfilled_node, list) and isinstance(filled_node, list):
        for u_node, f_node in zip(unfilled_node, filled_node):
            find_z_values_in_ast(u_node, f_node, z_values)
        return

    # If both are AST nodes, recurse into fields
    if isinstance(unfilled_node, ast.AST) and isinstance(filled_node, ast.AST):
        # Recurse into matching fields
        for field in unfilled_node._fields:
            val_unfilled = getattr(unfilled_node, field, None)
            val_filled = getattr(filled_node, field, None)
            if isinstance(val_unfilled, (ast.AST, list)) and isinstance(val_filled, (ast.AST, list)):
                find_z_values_in_ast(val_unfilled, val_filled, z_values)

def extract_z_values(p_filled, p_unfilled):
    """
    Given p_filled and p_unfilled dictionaries:
    - p_unfilled and p_filled have keys like 'z5', 'y0', etc.
    - Each value is an ast.Assign node.
    We want to find all occurrences where p_unfilled references a z_i as a Name,
    and p_filled has the actual value. Then return a list z where z[i] is the value of z_i.
    """
    z_values = {}

    # Compare each key's nodes to find where z_i appear as placeholders in unfilled
    # and actual values in filled.
    for key in p_unfilled.keys():
        unfilled_node = p_unfilled[key]
        filled_node = p_filled[key]
        # unfilled_node and filled_node should be Assign nodes
        # Compare their value fields
        find_z_values_in_ast(unfilled_node.value, filled_node.value, z_values)

    # Now we have a dict {index: ast_node_value} for z_i
    # Convert them to Python objects
    max_index = max(z_values.keys()) if z_values else -1
    z_list = [None]*(max_index+1)
    for i, node in z_values.items():
        z_list[i] = ast_to_python_value(node)

    return z_list

def is_one_hot_list(elts):
    # elts are a list of ast.Constant nodes with either 0 or 1 values
    values = [e.value for e in elts if isinstance(e, ast.Constant)]
    return len(values) == len(elts) and sum(values) == 1 and all(v in [0,1] for v in values)

def one_hot_index_list(elts):
    values = [e.value for e in elts if isinstance(e, ast.Constant)]
    return values.index(1)

class OneHotIndexTransformer(ast.NodeTransformer):
    def visit_Subscript(self, node):
        self.generic_visit(node)
        # Now check if the slice is a List node
        if isinstance(node.slice, ast.List):
            # Check if it's a one-hot list
            if is_one_hot_list(node.slice.elts):
                idx = one_hot_index_list(node.slice.elts)
                node.slice = ast.Constant(value=idx)
            else:
                raise ValueError("Non one-hot vector used as index.")
            if isinstance(node.value, ast.List) and not hasattr(node.value, 'ctx'):
                node.value.ctx = ast.Load()
        return node

def evaluate_distribution(expr, val):
    code_ast = ast.fix_missing_locations(ast.Expression(expr))
    code = compile(code_ast, "<ast>", "eval")
    dist_func = eval(code)
    return dist_func(val)

def eval_joint(p, y, z):

    # print("After Initialization Values:")
    # for k, v in p.items():
    #     print(k, ast.dump(v, indent=4))
    # print("__________________________________________\n")

    results = []
    for k, assign_node in p.items():
        var_name = assign_node.targets[0].id
        dist_expr = assign_node.value
        def func(val):
            return evaluate_distribution(dist_expr, val)

        if var_name.startswith('pz'):
            i = int(var_name[2:])
            results.append(func(z[i]))
        elif var_name.startswith('py'):
            i = int(var_name[2:])
            results.append(func(y[f"y{i}"].value))

    logprob = sum(results)
    return logprob

class YReplacer(ast.NodeTransformer):
  #Returns a copy of the replaced node
    def __init__(self, y):
        self.y = y

    def visit_Name(self, node):
        # For y_i
        if node.id.startswith('y') and node.id[1:].isdigit():
            idx = int(node.id[1:])
            val = self.y[idx]
            # If val is a scalar, just return a Constant
            if isinstance(val, (int, float)):
                return ast.Constant(value=val)
            # If val is a list, return an ast.List of ast.Constant
            elif isinstance(val, list):
                return ast.List(
                    elts=[ast.Constant(value=v) for v in val],
                    ctx=ast.Load()
                )
            else:
                print("Suspicious type in Y map")
                return ast.Constant(value=val)

        return node

class ZReplacer(ast.NodeTransformer):
  #Returns a copy of the replaced node
      def __init__(self, zi_sample):
          self.id = zi_sample[0]
          self.val = zi_sample[1]

      def visit_Name(self, node):
          # For z_i
          if node.id == self.id:
              # If val is a scalar, just return a Constant
              if isinstance(self.val, (int, float)):
                  return ast.Constant(value=self.val)
              # If val is a list, return an ast.List of ast.Constant
              elif isinstance(self.val, list):
                  return ast.List(
                      elts=[ast.Constant(value=v) for v in self.val],
                      ctx=ast.Load()
                  )
              else:
                  print("Suspicious type for a z_i sample")
                  return ast.Constant(value=self.val)

          return node

class ZSampler(ast.NodeTransformer):
    def get_distribution_object(self, expr):
      """
      Given an AST expression (like Attribute(Call(norm(...)), 'logpdf')),
      extract the base distribution Call node and evaluate it to get the distribution object.
      """
      base_expr = expr
      if isinstance(expr, ast.Attribute):
          # If attr is 'logpdf', 'logpmf', etc., take expr.value which is the Call node.
          if expr.attr in ('logpdf', 'logpmf', 'pdf', 'pmf'):
              base_expr = expr.value

      for node in ast.walk(base_expr):
          if isinstance(node, ast.Name) and not hasattr(node, 'ctx'):
              node.ctx = ast.Load()


      code_ast = ast.fix_missing_locations(ast.Expression(base_expr))
      code = compile(code_ast, "<ast>", "eval")
      dist_obj = eval(code)
      return dist_obj

    def visit_Assign(self, assign_node):
        var_name = assign_node.targets[0].id

        # print(var_name[1:], "being sampled")

        dist_expr = assign_node.value
        dist_obj = self.get_distribution_object(dist_expr)
        sampled_val = dist_obj.rvs()

        if isinstance(sampled_val, np.ndarray):
            if sampled_val.ndim == 2:
                sampled_val = [float(x) for x in sampled_val.flatten().tolist()]
            elif sampled_val.ndim == 1:
                sampled_val = [float(x) for x in sampled_val.tolist()]
            else:
                raise NotImplementedError("Need to implement logic for general tensors as variables")
        else:
            # Convert scalar numpy types to Python float
            sampled_val = float(sampled_val)

        return (var_name[1:], sampled_val)

class BlockMetropolisWithinGibbsSampler:
    def __init__(self, v, p, y, edges):
        """
        p: dictionary of AST nodes defining pz_i and py_i
        y: dictionary of constants for y_i
        edges: list of (a,b) tuples indicating a->b in the graph (a is parent of b)
        distributions_map: dict of name->distribution classes (e.g. {"norm": norm, ...})
        """
        self.p = p
        self.y = y
        self.edges = edges
        self.v = v
        self.blocks = []
        self.p_current = {}
        self.initialize() # Ancestral sampling
        self.build_blocks() # Build blocks according to the rule: {z_i} U children(z_i) U parents(children(z_i))

    def initialize(self):
        # print("Before Initialization Values:")
        # for k, v in self.p.items():
        #     print(k, ast.dump(v, indent=4))
        # print("__________________________________________\n")

      #Replace y_i's
        replacer = YReplacer(self.y)
        for k in self.p:
            self.p[k] = replacer.visit(self.p[k])
        indexer = OneHotIndexTransformer()
        for k in self.p:
            self.p[k] = indexer.visit(self.p[k])
      #Sample z_i's

        self.p_current = deepcopy(self.p)
        z = []
        Zsampler = ZSampler()
        indexer = OneHotIndexTransformer()
        for k1,v1 in self.p_current.items():
            if k1.startswith('z'):
                sample = Zsampler.visit(self.p_current[k1])
                z.append(sample[1])
                # print(sample)
                Zreplacer = ZReplacer(sample)
                for k2,v2 in self.p_current.items():
                    self.p_current[k2] = Zreplacer.visit(self.p_current[k2])
                for k2,v2 in self.p_current.items():
                    self.p_current[k2] = indexer.visit(self.p_current[k2])

        # logprob = eval_joint(self.p_current, self.y, z)
        # print("Sum of all logprob:", logprob)
        # print("Product of all prob", math.exp(logprob))


    def build_blocks(self):
        """
        Partition v into blocks according to the rule: {z_i} U children(z_i) U parents(children(z_i))
        """

        remaining_v = deepcopy(self.v)
        remaining_edges = deepcopy(self.edges)
        remaining_z = [node for node in remaining_v if node.startswith("z")]
        remaining_z.sort(key=lambda x: int(x[1:]))

        count = 1
        while remaining_z != []:
            # print("bloc", count )
            count += 1
            # Build adjacency structures for convenience
            children_map = {}
            parents_map = {}
            all_nodes = set()
            for (a,b) in remaining_edges:
                if a not in children_map:
                    children_map[a] = []
                children_map[a].append(b)
                if b not in parents_map:
                    parents_map[b] = []
                parents_map[b].append(a)
            # Ensure all nodes appear even if no edges
            for z in remaining_z:
                if z not in children_map:
                    children_map[z] = []
                if z not in parents_map:
                    parents_map[z] = []

            block = set()
            z = remaining_z.pop(0)
            # print(z)
            block.add(z)

            # print("children of ", z, " ", children_map[z])
            for c in children_map[z]:
                if c in remaining_v:
                    remaining_v.remove(c)
                # print(c, " removed from v")
                if c.startswith("z"):
                    block.add(c)
                    print(c, " removed from z")
                    if c in remaining_z:
                        remaining_z.remove(c)
                # print("parents of ", c, " ", parents_map[c])
                for p in parents_map[c]:
                    if p in remaining_v:
                        remaining_v.remove(p)
                    # print(p, " removed from v")
                    if p.startswith("z"):
                        block.add(p)
                        # print(p, " removed from z")
                        if p in remaining_z:
                            remaining_z.remove(p)
            remaining_edges = [
                edge
                for edge in remaining_edges
                if edge[0] in remaining_v and edge[1] in remaining_v
            ]

            self.blocks.append(list(block))  # store as a list

            print(self.blocks)

    def propose_bloc(self, block):
        z_new = {}
        p_prop = deepcopy(self.p)

        # Replace y_i's in p_prop
        replacer = YReplacer(self.y)
        for k in p_prop:
            p_prop[k] = replacer.visit(p_prop[k])
        indexer = OneHotIndexTransformer()
        for k in p_prop:
            p_prop[k] = indexer.visit(p_prop[k])

        #Sample z_i's in block
        Zsampler = ZSampler()
        indexer = OneHotIndexTransformer()
        for k1,v1 in p_prop.items():
            if k1 in block:
                sample = Zsampler.visit(p_prop[k1])
                z_new[sample[0]] = sample[1]
                Zreplacer = ZReplacer(sample)
                for k2,v2 in p_prop.items():
                    p_prop[k2] = Zreplacer.visit(p_prop[k2])
                for k2,v2 in self.p.items():
                    p_prop[k2] = indexer.visit(p_prop[k2])

        p_prop = combine_ps(p_prop, self.p_current)
        z_prop = extract_z_values(p_prop, self.p)
        z_current = extract_z_values(self.p_current, self.p)
        return p_prop, z_new, z_current


    def run(self):
        for block in self.blocks:
            p_proposal, z_prop, z_current = self.propose_bloc(block)

            logjoint_prop = eval_joint(p_proposal, self.y, z_prop)
            log_reverse_q = eval_joint(p_proposal, z_current)

            logjoint_curr = eval_joint(self.p.current, self.y, z_current)
            log_forward_q = eval_joint(self.p.current, z_prop)

            print(logjoint_prop)


In [None]:
source_code = gmm_source_code_2

graph, expr, vars = parse_and_translate(source_code)

# print(graph)
# print("\n")
# graph.print_p()
# print("\n")
# graph.print_y()
# print("\n")
# print("Variables:\n", vars)
BlockMetropolisWithinGibbsSampler(graph.v, graph.p, graph.y, graph.a)
sampler = BlockMetropolisWithinGibbsSampler(graph.v, graph.p, graph.y, graph.a)
sampler.run()

[['z1', 'z5', 'z0', 'z2', 'z3', 'z6']]
[['z1', 'z5', 'z0', 'z2', 'z3', 'z6'], ['z4']]
[['z1', 'z5', 'z0', 'z2', 'z3', 'z6']]
[['z1', 'z5', 'z0', 'z2', 'z3', 'z6'], ['z4']]


NameError: name 'z4' is not defined