In [1]:
from pgmpy.readwrite import BIFReader
from pathlib import Path
from src.utils import adj_df_from_BIF, get_terminal_connection_nodes, encode_data, perturb_adj_matrix
from src.constants import alarm_target
from src.data import AlarmDataset
from src.models.BNNet import BNNet

import pandas as pd
from scipy.stats import bernoulli

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import yaml

from torch_geometric.nn import GCNConv
from torch_geometric.utils import dense_to_sparse, to_torch_coo_tensor

In [2]:
fpath_bif = Path("/home/gaurang/bayesian_network/data/alarm/alarm.bif")
fpath_data = Path("/home/gaurang/bayesian_network/data/alarm/ALARM10k.csv")
fpath_config = Path("/home/gaurang/bayesian_network/code/src/config.yaml")

with open(fpath_config, 'r') as f:
    config = yaml.safe_load(f)
config

{'embedding_dim': 64,
 'gnn_hidden_dim': 64,
 'gnn_out_dim': 32,
 'fc1_out_dim': 16,
 'num_epoch': 100,
 'patience': 20,
 'min_delta': '1e-5'}

In [3]:
adj_df = adj_df_from_BIF(BIFReader(fpath_bif))
adj_df

Unnamed: 0,HISTORY,CVP,PCWP,HYPOVOLEMIA,LVEDVOLUME,LVFAILURE,STROKEVOLUME,ERRLOWOUTPUT,HRBP,HREKG,...,MINVOLSET,VENTMACH,VENTTUBE,VENTLUNG,VENTALV,ARTCO2,CATECHOL,HR,CO,BP
HISTORY,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
CVP,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
PCWP,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
HYPOVOLEMIA,0.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
LVEDVOLUME,0.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
LVFAILURE,1.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
STROKEVOLUME,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0
ERRLOWOUTPUT,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
HRBP,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
HREKG,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [4]:
new_adj_df = perturb_adj_matrix(adj_df, 0.)
new_adj_df

Unnamed: 0,HISTORY,CVP,PCWP,HYPOVOLEMIA,LVEDVOLUME,LVFAILURE,STROKEVOLUME,ERRLOWOUTPUT,HRBP,HREKG,...,MINVOLSET,VENTMACH,VENTTUBE,VENTLUNG,VENTALV,ARTCO2,CATECHOL,HR,CO,BP
HISTORY,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
CVP,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
PCWP,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
HYPOVOLEMIA,0.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
LVEDVOLUME,0.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
LVFAILURE,1.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
STROKEVOLUME,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0
ERRLOWOUTPUT,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
HRBP,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
HREKG,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [11]:
float(False)

0.0

In [3]:
df_data = pd.read_csv(fpath_data, dtype=str)
reader = BIFReader(fpath_bif)

In [4]:
data, oe = encode_data(df_data, reader)

In [5]:
dataset = AlarmDataset(data, reader)

In [6]:
dataset.terminal_node_ids

[35, 14]

In [7]:
edge_index, edge_weights = dense_to_sparse(torch.as_tensor(dataset.adj_mat))
edge_index.shape

torch.Size([2, 44])

In [11]:
model = BNNet(
        config=config,
        num_nodes= len(dataset.input_nodes),
        node_states=dataset.input_states,
        edge_index=edge_index,
        terminal_node_ids=dataset.terminal_node_ids,
        target_node_states=dataset.target_states
        )

In [12]:
model

BNNet(
  (gnn): GNN(
    (layer1): GCNConv(64, 64)
    (layer2): GCNConv(64, 32)
  )
  (MLP): Sequential(
    (0): Linear(in_features=64, out_features=16, bias=True)
    (1): LeakyReLU(negative_slope=0.01)
    (2): Linear(in_features=16, out_features=3, bias=True)
  )
)

In [13]:
loader = DataLoader(dataset, batch_size=4, shuffle=True)

In [14]:
it = iter(loader)
batch = next(it)

In [15]:
X, y = batch
X = X.long()
X.shape

torch.Size([4, 36])

In [17]:
model(X)

tensor([[ 0.0684, -0.0261, -0.1909],
        [ 0.4400, -0.3564, -0.3023],
        [ 0.3328, -0.1367, -0.1090],
        [ 0.2361, -0.1292, -0.0636]], grad_fn=<AddmmBackward0>)

In [11]:
num_embeddings_list = [len(state) for state in dataset.input_states]
node_embedding_layers = [
    nn.Embedding(num_emdeddings, 7)
    for num_emdeddings in num_embeddings_list
]

In [13]:
gnn_input = []

for i, node_embedding_layer in enumerate(node_embedding_layers):
    gnn_input.append(node_embedding_layer(X[:, i]))

len(gnn_input)


36

In [30]:
gnn_input = torch.stack(gnn_input, dim=1)
gnn_input.shape

torch.Size([4, 36, 7])

In [24]:
gnn = GCNConv(7, 10)

In [33]:
x = gnn(gnn_input, edge_index)
x.shape

torch.Size([4, 36, 10])

In [37]:
x = x.view(4, -1)
x.shape

torch.Size([4, 360])

In [48]:
pmf = bernoulli(0.0)
pmf.rvs(size=1)[0]

0

In [51]:
int(not 1)

0