In [87]:
import re
from collections import defaultdict
from scipy.stats import norm
import itertools
import numpy as np
import io
from contextlib import redirect_stdout
from more_itertools import random_product

import pyagrum as gum
import pyagrum.lib.notebook as gnb

In [88]:
# Create CN
bn=gum.fastBN("A[2]->B[2]<-C[3]")
bn_min=gum.BayesNet(bn)
bn_max=gum.BayesNet(bn)
for n in bn.nodes():
  x=0.4*min(bn.cpt(n).min(),1-bn.cpt(n).max())
  bn_min.cpt(n).translate(-x)
  bn_max.cpt(n).translate(x)

cn=gum.CredalNet(bn_min,bn_max)
cn.intervalToCredal()

In [89]:
#TODO: improve code
def parse_credal_net(cn_str: str):
    """
    Parsa l'output testuale di una credal net da PyAgrum.
    Restituisce un dizionario del tipo:
        {
            'A': {'<>': [ [p1], [p2], ... ]},
            'B': {'<A:0|C:0>': [...], ...},
            ...
        }
    """
    credal_dict = defaultdict(lambda: defaultdict(list))
    current_var = None

    lines = cn_str.strip().split('\n')

    for line in lines:
        line = line.strip()

        # Identificazione della variabile
        var_match = re.match(r'^([A-Za-z0-9_]+):Range\(\[.*\]\)', line)
        if var_match:
            current_var = var_match.group(1)
            continue

        if current_var is None or not line:
            continue

        # Identificazione di una CPT con intestazione <condizioni>
        cpt_match = re.match(r'^<([^>]*)>\s*:\s*(.*)', line)
        if cpt_match:
            condition = f"<{cpt_match.group(1).strip()}>"
            raw_cpt = cpt_match.group(2)

            # Estrarre tutte le liste interne: [[x,x,x], [x,x,x], ...]
            vectors = re.findall(r'\[\s*([^\[\]]+?)\s*\]', raw_cpt)
            for vec in vectors:
                prob_list = [float(x.strip()) for x in vec.split(',')]
                credal_dict[current_var][condition].append(prob_list)

    return credal_dict


In [90]:
def get_simplex(cn, n: int = None) -> list:

    # Store the CN in form of string
    buffer = io.StringIO()
    with redirect_stdout(buffer):
        print(cn)

    cn_text = buffer.getvalue()

    # Parse CN
    parsed = parse_credal_net(cn_text)

    # Get baseline DAG and init simplex
    dag = gum.BayesNet(cn.current_bn())
    bns = []

    # Compute slots and store variables indexes
    slots = []
    for var in parsed:
        for cond, vectors in parsed[var].items():
            slots.append((var, cond, vectors))
    var_idx = {var:[idx for idx, elem in enumerate(slots) if elem[0] == var] for var in bn.names()}

    # If 'n' is provided...
    if bool(n):
        # Get 'n' random combinations of CPTs
        combinations = [random_product(*[vecs for _, _, vecs in slots]) for _ in range(n)]
        n_combs = len(combinations)
        assert(n_combs == n); print(f"Generating {n_combs} random BNs...")
    else:
        # Get all combinations of CPTs
        combinations = list(itertools.product(*[vecs for _, _, vecs in slots]))
        n_combs = len(combinations); print(f"Generating {n_combs} BNs...")

    # For each combination...
    for combo in combinations:

        # Init BN and ...
        bn_tmp = gum.BayesNet(dag)

        # Fill its CPTs
        for var in bn.names():
            array = np.array([(combo[idx]) for idx in var_idx.get(var)]).flatten()
            bn_tmp.cpt(var).fillWith(array)

        bns.append(bn_tmp)
    
    # Check
    assert(n_combs == len(bns))

    return bns


In [91]:
# TODO: Check code
def are_all_bn_different(bn_list):

    def serialize_bn(bn):
        cpt_data = []
        for var in bn.names():
            cpt = bn.cpt(var)
            flat = [f"{v:.8f}" for v in cpt.toarray().flatten()]
            cpt_data.append(f"{var}:" + ",".join(flat))
        return "|".join(cpt_data)

    signatures = set()
    for bn in bn_list:
        sig = serialize_bn(bn)
        signatures.add(sig)

    if len(signatures) == len(bn_list):
        print(f"✅ All {len(signatures)} BNs are different.")
        return
    else:
        unique_sigs = set()
        unique_bns = []

        for bn in bn_list:
            sig = serialize_bn(bn)
            if sig not in unique_sigs:
                unique_sigs.add(sig)
                unique_bns.append(bn)
        print(f"⚠️ Obtained {len(unique_bns)}/{len(bn_list)} different BNs.")
        return

In [92]:
# Get full simplex
bns = get_simplex(cn)

# Ensure they are all different
are_all_bn_different(bns)

Generating 768 BNs...
✅ All 768 BNs are different.


In [93]:
# Get random subset of simplex
bns = get_simplex(cn, 50)

# Ensure they are all different
are_all_bn_different(bns)

Generating 50 random BNs...
⚠️ Obtained 47/50 different BNs.
