## IDM Learning of a Credal Network

In [61]:
import random
import numpy as np
import math
from statsmodels.distributions.empirical_distribution import ECDF
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
from scipy.stats import norm

import pyagrum as gum
import pyagrum.lib.notebook as gnb

In [62]:
# Set seeds
random.seed(42)
gum.initRandom(seed=42)

In [63]:
# Init BN
bn=gum.fastBN("A[2]->B[3]")

# Generate data from BN
gpop_ss = 1000
g = gum.BNDatabaseGenerator(bn)
g.drawSamples(gpop_ss)
g.setDiscretizedLabelModeRandom()
gpop = g.to_pandas()

assert(gpop.shape[0]==gpop_ss)

In [64]:
# Add counts of events to BN
for node in bn.names():
    var = bn.variable(node)
    parents = bn.parents(node)
    parent_names = [bn.variable(p).name() for p in parents]

    shape = [bn.variable(p).domainSize() for p in parents] + [var.domainSize()]
    counts_array = np.zeros(shape, dtype=float)  # float, non int!

    for _, row in gpop.iterrows():
        try:
            key = tuple([int(row[p]) for p in parent_names] + [int(row[node])])
            counts_array[key] += 1.0
        except KeyError:
            continue

    bn.cpt(node).fillWith(counts_array.flatten().tolist())

In [65]:
# Plot BN
gnb.flow.row(bn, bn.cpt("A"), bn.cpt("B"), captions=["CN", "CPT (A)", "CPT (B | A)"])

A,A
0,1
98.0,902.0

Unnamed: 0_level_0,B,B,B
A,0,1,2
0,37.0,17.0,44.0
1,488.0,349.0,65.0


In [66]:
# Quick check
c = gpop[gpop["A"]=="0"]
print(f"Counts of A=0: {len(c)}")
assert(len(c) == bn.cpt("A")[0])

c = c[c["B"] == '2']
print(f"Counts of B=2 | A=0: {len(c)}")
assert(len(c) == bn.cpt("B")[0,2])

Counts of A=0: 98
Counts of B=2 | A=0: 44


In [67]:
# Convert BN to CN
cn = gum.CredalNet(bn)

# IDM Learning ('s' must be integer)
cn.idmLearning(s=2)

In [68]:
# Plot CN
ie_mc=gum.CNMonteCarloSampling(cn)
gnb.sideBySide(cn, gnb.getInference(cn,engine=ie_mc))

0,1
G A A B B A->B,"structs Inference in 12.55ms A  2025-04-15T11:40:08.223563  image/svg+xml  Matplotlib v3.10.1, https://matplotlib.org/  B  2025-04-15T11:40:08.259435  image/svg+xml  Matplotlib v3.10.1, https://matplotlib.org/  A->B"
