In [32]:
import pandas as pd
import numpy as np
from pgmpy.estimators import K2Score
from pgmpy.models import BayesianModel
from pgmpy.estimators import HillClimbSearch, BayesianEstimator
import random

def load_data(DATA_CSV):
    D = pd.read_csv(DATA_CSV)
    V = D.columns
    N = len(D.index)
    V_CARD = {v: len(D[v].unique()) for v in V}
    print(f'ARQUIVO: {D}')
    print(f'VARIÁVEIS: {V}')
    #print(f'NÚMERO DE AMOSTRAS: {N}')
    #print(f'MAPEAMENTO DAS VARIÁVEIS COM NÚMEROS DE VALORES ÚNICOS: {V_CARD}')
    return D, V, N, V_CARD

def calcular_k2(D):
    k2score = K2Score(D)
    return k2score

def estimar_modelo(D, scoring_method):
    estimator_k2 = HillClimbSearch(D)
    max_possible_edges = len(V) * (len(V) - 1) / 2
    max_iter = min(max_possible_edges, 1000)
    best_model = estimator_k2.estimate(scoring_method='k2score', tabu_length=50, max_indegree=4, max_iter=max_iter)
    k2score = calcular_k2(D)
    k2_score = k2score.score(best_model)
    print(f'Valor do score K2: {k2_score}')
    return best_model

def tabular_cpd(best_model, D):
    bayesian_network = BayesianModel(best_model)
    estimator = BayesianEstimator(bayesian_network, D)
    cpds = []
    for node in bayesian_network.nodes():
        cpd = estimator.estimate_cpd(node)
        cpds.append(cpd)
    return cpds, bayesian_network

def generate_shuffled_csv(DATA_CSV):
    D = pd.read_csv(DATA_CSV)
    D_shuffled_columns = D.sample(frac=1, axis=1)
    D_shuffled_columns.to_csv('data_shuffled.csv', index=False)

# Caminho do arquivo CSV original
DATA_CSV = 'alarm.csv'

# Gerar um novo arquivo CSV com os dados embaralhados
generate_shuffled_csv(DATA_CSV)

# Carregar dados
D, V, N, V_CARD = load_data('data_shuffled.csv')

# Estima a estrutura do modelo com o K2
best_model = estimar_modelo(D, 'k2score')
print(f'Ordem das variáveis: {V}')
print(f'Melhor modelo: {best_model}')

# Exibe a estrutura do modelo
structure = (best_model.edges)
print(f'Estrutura da rede: {structure}')

# Estima as CPDs e passa a bayesian_network
cpds, bayesian_network = tabular_cpd(best_model, D)
print(f'CPDs: {cpds}')
print(f'Bayesian Network: {bayesian_network}')


ARQUIVO:          PAP  HISTORY VENTMACH  INSUFFANESTH MINVOL  ERRLOWOUTPUT ARTCO2  \
0     NORMAL    False   NORMAL         False   ZERO         False   HIGH   
1     NORMAL    False   NORMAL         False   ZERO         False   HIGH   
2     NORMAL    False     ZERO         False   ZERO         False   HIGH   
3     NORMAL    False   NORMAL          True   ZERO         False   HIGH   
4     NORMAL    False   NORMAL         False   ZERO         False   HIGH   
...      ...      ...      ...           ...    ...           ...    ...   
9995    HIGH    False     ZERO         False   HIGH         False    LOW   
9996  NORMAL    False     HIGH         False   HIGH         False   HIGH   
9997  NORMAL    False   NORMAL          True   ZERO         False   HIGH   
9998  NORMAL    False   NORMAL          True   ZERO         False   HIGH   
9999  NORMAL    False   NORMAL         False   ZERO          True   HIGH   

       HRSAT  EXPCO2 VENTTUBE  ...   SHUNT    CO    FIO2  SAO2 LVEDVOLUME  \
0

 10%|▉         | 65/666 [00:18<02:46,  3.61it/s]


Valor do score K2: -106412.82447895246
Ordem das variáveis: Index(['PAP', 'HISTORY', 'VENTMACH', 'INSUFFANESTH', 'MINVOL', 'ERRLOWOUTPUT',
       'ARTCO2', 'HRSAT', 'EXPCO2', 'VENTTUBE', 'TPR', 'HREKG', 'ERRCAUTER',
       'PRESS', 'HRBP', 'PVSAT', 'VENTALV', 'HR', 'ANAPHYLAXIS', 'INTUBATION',
       'DISCONNECT', 'CVP', 'VENTLUNG', 'CATECHOL', 'PULMEMBOLUS', 'LVFAILURE',
       'MINVOLSET', 'SHUNT', 'CO', 'FIO2', 'SAO2', 'LVEDVOLUME', 'BP',
       'KINKEDTUBE', 'HYPOVOLEMIA', 'PCWP', 'STROKEVOLUME'],
      dtype='object')
Melhor modelo: DAG with 37 nodes and 61 edges
Estrutura da rede: [('VENTMACH', 'MINVOLSET'), ('MINVOL', 'VENTALV'), ('MINVOL', 'VENTLUNG'), ('MINVOL', 'INTUBATION'), ('ERRLOWOUTPUT', 'HRBP'), ('ARTCO2', 'EXPCO2'), ('ARTCO2', 'HR'), ('ARTCO2', 'CATECHOL'), ('ARTCO2', 'TPR'), ('VENTTUBE', 'VENTMACH'), ('VENTTUBE', 'MINVOL'), ('VENTTUBE', 'PRESS'), ('VENTTUBE', 'VENTALV'), ('VENTTUBE', 'VENTLUNG'), ('TPR', 'ANAPHYLAXIS'), ('TPR', 'BP'), ('ERRCAUTER', 'HRSAT'), ('ERRCAUT

-106412.82447895245
-106412.82447895246
-106412.82447895246
