In [52]:
import yaml
from pomegranate import bayesian_network
from pomegranate.distributions import Normal, Exponential

# Define the function to parse the YAML and build the network
def build_pomegranate_network(yaml_data):
    # Create a dictionary to store the states (variables) by their names
    states = {}

    # Initialize the Bayesian Network
    network = bayesian_network.BayesianNetwork()

    # First pass: Create distributions and add states to the network
    for var in yaml_data['variables']:
        var_name = var['name']
        var_type = var['type']
        distribution_type = var['distribution']
        parameters = var['parameters']

        # Create the distribution based on the type
        if distribution_type == 'Gaussian':
            dist = Normal([parameters['mean']],
                          [parameters['std_dev']],
                          covariance_type='diag')
        elif distribution_type == 'Exponential':
            dist = Exponential([parameters['mean']])
            raise ValueError(f"Unsupported distribution type: {distribution_type}")

        # Create the state for the variable and add it to the network
        state = State(dist, name=var_name)
        states[var_name] = state
        network.add_state(state)

    # Second pass: Set up relationships (edges) based on parent-child dependencies
    for var in yaml_data['variables']:
        var_name = var['name']
        parents = var.get('parents', None)

        if parents:
        # For each parent-child relationship
            for parent in parents:
                parent_name = parent['variable']
                parent_relationship_type = parent['relationship_type']
                parent_coefficients = parent['coefficients']
    
                if parent_relationship_type == 'linear':
                    # Linear relationship: Linear Gaussian Distribution
                    intercept = parent_coefficients.get('intercept', 0.0)
                    slope = parent_coefficients.get('slope', 0.0)
                    dist = LinearGaussianDistribution(
                        intercept=intercept,
                        slope=slope,
                        mean=0.0,  # These can be adjusted based on your specific model
                        std=1.0
                    )
                    state = State(dist, name=f"{parent_name}_to_{var_name}")
                    network.add_state(state)
                    network.add_edge(states[parent_name], state)
                    network.add_edge(state, states[var_name])
    
                elif parent_relationship_type == 'exponential':
                    # Exponential relationship: Use ExponentialDistribution
                    a = parent_coefficients.get('a', 1.0)
                    b = parent_coefficients.get('b', 0.02)
                    dist = ExponentialDistribution(a)  # Exponential distribution, parameterized by a
                    state = State(dist, name=f"{parent_name}_to_{var_name}")
                    network.add_state(state)
                    network.add_edge(states[parent_name], state)
                    network.add_edge(state, states[var_name])
    
                else:
                    raise ValueError(f"Unsupported relationship type: {parent_relationship_type}")

    # Bake the network to finalize its structure
    network.bake()

    return network

In [53]:
# Load your YAML file (Assume it's loaded as a dictionary)
yaml_file = 'config.yaml'
with open(yaml_file, 'r') as file:
    yaml_data = yaml.safe_load(file)

In [54]:
# Build the network from the YAML data
network = build_pomegranate_network(yaml_data)

NameError: name 'State' is not defined

In [22]:
from pomegranate.

In [48]:
x = Normal([yaml_data['variables'][0]['parameters']['mean']],
           [yaml_data['variables'][0]['parameters']['std_dev']],
           covariance_type='diag')

In [51]:
x.covs

Parameter containing:
tensor([3])

In [55]:
from pomegranate import *

In [61]:
net = bayesian_network.BayesianNetwork()
net.add_distribution(x)

ValueError: Must be Categorical or ConditionalCategorical