In [1]:
import itertools
import logging
logging.getLogger('pgmpy').setLevel(logging.WARNING)
import warnings
warnings.filterwarnings("ignore", category=RuntimeWarning, module="networkx.utils.backends")
import logging

import pysnooper
import pandas as pd
import pyagrum as gum
import concurrent.futures
import os
import random
from time import *
from Decom_Tree import *
from functools import reduce
import numpy as np
import networkx as nx
from copy import deepcopy
from pgmpy.utils import get_example_model
from pgmpy.sampling import BayesianModelSampling
from pgmpy.factors.discrete import DiscreteFactor, TabularCPD
from pgmpy.factors import factor_product,factor_divide
from pgmpy.estimators import MaximumLikelihoodEstimator
from pgmpy.inference import VariableElimination, BeliefPropagation
from pgmpy.models import DiscreteBayesianNetwork,JunctionTree
os.environ["NUMEXPR_MAX_THREADS"] = "32"


In [2]:
def generate_state_names(g):
    return {node: ['A', 'B'] for node in list(g.nodes)}

def get_random_cpds_with_labels(model, state_names, inplace=False, seed=None):
    cpds = []
    for node in model.nodes():
        parents = list(model.predecessors(node))
        cpds.append(
            TabularCPD.get_random(
                variable=node,
                evidence=parents,
                cardinality={var: 2 for var in model.nodes()},  
                state_names=state_names, 
                seed=seed
            )
        )
    if inplace:
        model.add_cpds(*cpds)
    else:
        return cpds

In [15]:
columns = pd.MultiIndex.from_product(
    [[ "asia", "child", "alarm", "insurance","survey", "pathfinder","hailfinder","mildew","hepar2", "win95pts" ],  
     [500, 1000, 2500, 5000, 7500, 10000],  
     ["klPQ", "hellinger", "bhattacharya", "jensen-shannon"]]  
)

result = pd.DataFrame(columns=columns)

for file in [ "asia", "child", "alarm", "insurance","survey", "pathfinder","hailfinder","mildew","hepar2", "win95pts"  ]:
    print(file)
    save_file = f'{file}.bif' 
    learn_file = f'{file}1.bif' 
    model = get_example_model(file)
    G = nx.DiGraph()
    G.add_nodes_from(model.nodes)
    G.add_edges_from(model.edges)
    state_names = generate_state_names(G)
    decom = Graph_Decom(G)
    atoms = decom.Decom()
    for sample_size in [500, 1000, 2500, 5000, 7500, 10000]:
        print(sample_size)
        succ_num = 100
        row = 0
        while succ_num:
            bn = DiscreteBayesianNetwork()
            bn.add_nodes_from(list(G.nodes))
            bn.add_edges_from(list(G.edges))
            
            get_random_cpds_with_labels(bn, state_names, inplace=True)       
    
            df = BayesianModelSampling(bn).forward_sample(size=sample_size, show_progress=False)
            
            learn_bn = DiscreteBayesianNetwork()
            learn_bn.add_nodes_from(list(G.nodes))
            learn_bn.add_edges_from(list(G.edges))
            
            learn_bn.cpds = []
            Traverse = []
            for i in atoms:
                sub_model = DiscreteBayesianNetwork(list(G.subgraph(list(i)).edges))
                sub_model.add_nodes_from(i)
                sub_model.cpds = []
                estimator = MaximumLikelihoodEstimator(sub_model, df[list(i)])
                for node in i:
                    if node not in Traverse:
                        if len(model.get_parents(node)) == len(sub_model.get_parents(node)):
                            Traverse.append(node)
                            learn_bn.add_cpds(estimator.estimate_cpd(node))
    
            bn.save(save_file, filetype='bif')
            learn_bn.save(learn_file, filetype='bif')
            try:
                bn_g = gum.loadBN(save_file)
                bn_g1 = gum.loadBN(learn_file)
        
                g1=gum.GibbsBNdistance(bn_g,bn_g1)
                value = g1.compute()

                if value["bhattacharya"]>0:
                    for metric_name in ["klPQ", "hellinger", "bhattacharya", "jensen-shannon"]:
                        result.loc[row + 1, (file, sample_size, metric_name)] = round(value[metric_name], 4)                
                    succ_num -= 1
                    row += 1
        
            except Exception as e:
                continue
                
result.to_csv("discrete.csv")
print(result.head())