In [1]:
import pandas as pd
import numpy as np

# Hela ML libraries 
from hela import hmm
import hela.generation.hmm as hmm_gen

# Viz libraries
import altair as alt
import hela.visualization.hmm as hmmplot 
import matplotlib.pyplot as plt
from hela.visualization.hmm import TU_COLORS
%matplotlib inline

# Utility Libraries
from datetime import datetime
from dask.distributed import Client
from scipy.special import logsumexp
from scipy import stats
import itertools
from IPython.display import Image

# PGMPy
from hela.hmm.graphical_models import DynamicBayesianNetwork as dbn
from hela.hmm.graphical_models.ContinuousFactor import ContinuousFactor
# from pgmpy.factors.discrete import TabularCPD
import networkx as nx


  return torch._C._cuda_getDeviceCount() > 0


In [2]:
n = 500
generative_model = hmm_gen.DiscreteHMMGenerativeModel(
                                     n_hidden_states = 3,
                                     n_gaussian_features=0,
                                    n_categorical_features = 2,
                                     n_gmm_components = None)

hidden_states = generative_model.generate_hidden_state_sequence(n_observations = n)
hmm_data = generative_model.generate_observations(hidden_states)
hmm_training_spec = generative_model.generative_model_to_discrete_hmm_training_spec()
model_config = hmm.DiscreteHMMConfiguration.from_spec(hmm_training_spec)
model = model_config.to_model()

In [3]:
graph = dbn.hmm_model_to_graph(model)
graph.initialize_initial_state()

In [4]:
categorical_dict = {
    str(list(model.categorical_model.finite_values.values[i])): i
    for i in range(len(model.categorical_model.finite_values))
}
categorical_dict

{'[0, 0]': 0, '[0, 1]': 1, '[0, 2]': 2, '[1, 0]': 3, '[1, 1]': 4, '[1, 2]': 5}

In [5]:
hmm_flattened_data = pd.Series(
            [categorical_dict[str(list(v))] for v in np.array(hmm_data)],
            index=hmm_data.index)

In [6]:
ev_keys = [('cat_obs', i) for i in range(n)]

In [7]:
ev_dict = dict(zip(ev_keys, hmm_flattened_data.values[:n]))

In [8]:
variables = [('hs', i) for i in range(1,n)]

In [9]:
from collections import defaultdict
from itertools import tee, chain, combinations

from pgmpy.factors.discrete import DiscreteFactor
from pgmpy.models import BayesianModel
from pgmpy.factors import factor_product
from pgmpy.inference import Inference, BeliefPropagation
from hela.hmm.graphical_models import structured_inference as dbn_inf




In [10]:
inference = dbn_inf.DBNInference(graph)
inference.interface_nodes_0 = graph.get_interface_nodes(time_slice=0)
inference.interface_nodes_1 = graph.get_interface_nodes(time_slice=1)
inference.start_bayesian_model = BayesianModel(graph.get_intra_edges(0))
flattened_factors_0 = [cpd[0] for cpd in graph.get_factors(time_slice=0)]
flattened_factors_1 = [cpd[0] for cpd in graph.get_factors(time_slice=1)]
inference.start_bayesian_model.add_cpds(*flattened_factors_0)
cpd_inter = [graph.get_factors(node)[0] for node in graph.get_interface_nodes(1)]
inference.interface_nodes = graph.get_interface_nodes(0)
inference.one_and_half_model = BayesianModel(
    graph.get_inter_edges() + graph.get_intra_edges(1)
)

inference.one_and_half_model.add_cpds(
    *(flattened_factors_1 + cpd_inter)
)
start_markov_model = inference.start_bayesian_model.to_markov_model()
one_and_half_markov_model = inference.one_and_half_model.to_markov_model()
combinations_slice_0 = tee(combinations(inference.interface_nodes_0, 2), 2)
combinations_slice_1 = combinations(inference.interface_nodes_1, 2)
start_markov_model.add_edges_from(combinations_slice_0[0])
one_and_half_markov_model.add_edges_from(
    chain(combinations_slice_0[1], combinations_slice_1)
)

inference.one_and_half_junction_tree = one_and_half_markov_model.to_junction_tree()
inference.start_junction_tree = start_markov_model.to_junction_tree()

inference.start_interface_clique = inference._get_clique(
    inference.start_junction_tree, inference.interface_nodes_0
)
inference.in_clique = inference._get_clique(
    inference.one_and_half_junction_tree, inference.interface_nodes_0
)
inference.out_clique = inference._get_clique(
    inference.one_and_half_junction_tree, inference.interface_nodes_1
)



In [11]:
one_and_half_markov_model.edges()

EdgeView([(('hs', 0), ('hs', 1)), (('hs', 1), ('cat_obs', 1))])

In [12]:
one_and_half_markov_model.nodes()

NodeView((('hs', 0), ('hs', 1), ('cat_obs', 1)))

In [13]:
inference.one_and_half_junction_tree.edges()

EdgeView([((('hs', 1), ('hs', 0)), (('hs', 1), ('cat_obs', 1)))])

In [14]:
inference.one_and_half_junction_tree.factors[0].values

array([[0.98442657, 0.01557343, 0.        ],
       [0.        , 0.98579456, 0.01420544],
       [0.01527483, 0.        , 0.98472517]])

In [15]:
inference.one_and_half_junction_tree.factors

[<DiscreteFactor representing phi(('hs', 1):3, ('hs', 0):3) at 0x7febb34856a0>,
 <DiscreteFactor representing phi(('hs', 1):3, ('cat_obs', 1):6) at 0x7febb3485550>]

In [16]:
graph.factors[-1].values.T

array([[5.92394114e-02, 7.67759015e-01, 4.86698534e-02, 3.25128510e-02,
        3.49195038e-02, 5.68993659e-02],
       [9.49789641e-01, 2.91747775e-04, 1.20145520e-02, 1.12286595e-02,
        1.25541161e-02, 1.41212836e-02],
       [2.57349237e-02, 4.35271010e-02, 6.59570844e-03, 3.56859266e-02,
        7.99425980e-03, 8.80462080e-01]])

EdgeView([((('hs', 1), ('cat_obs', 1)), (('hs', 1), ('hs', 0)))])

In [63]:
junction_tree = inference.one_and_half_junction_tree
nodes = inference.interface_nodes_1
print(nodes)
[clique for clique in junction_tree.nodes() if set(nodes).issubset(clique) and clique[0][1] == clique[1][1]][0]

[('hs', 1)]


(('hs', 1), ('cat_obs', 1))

In [62]:
cliques = [clique for clique in junction_tree.nodes() if set(nodes).issubset(clique)]
print([clique for clique in cliques if clique[0][1] == clique[1][1]][0])
# [[(nodes[1], nodes[1]) for nodes in clique if nodes[0] == nodes[0]] for clique in cliques]
# for clique in cliques:
#     print(clique[0][1] == clique[1][1])

(('hs', 1), ('cat_obs', 1))


In [29]:
junction_tree = inference.one_and_half_junction_tree
nodes = inference.interface_nodes_0
[clique for clique in junction_tree.nodes() if set(nodes).issubset(clique)]

[(('hs', 1), ('hs', 0))]

In [17]:
inference.out_clique

(('hs', 1), ('hs', 0))

In [18]:
one_and_half_markov_model.check_model()

# Triangulate the graph to make it chordal
triangulated_graph = one_and_half_markov_model.triangulate()
cliques = list(map(tuple, nx.find_cliques(triangulated_graph)))
cliques

[(('hs', 1), ('hs', 0)), (('hs', 1), ('cat_obs', 1))]

In [19]:
len(cliques)

2

In [20]:
list(nx.find_cliques(triangulated_graph))

[[('hs', 1), ('hs', 0)], [('hs', 1), ('cat_obs', 1)]]

In [24]:
edges = list(itertools.combinations(cliques, 2))
edges
for edge in edges:
    print(edge)

((('hs', 1), ('hs', 0)), (('hs', 1), ('cat_obs', 1)))


In [21]:
# Find maximal cliques in the chordal graph
# cliques = list(map(tuple, nx.find_cliques(triangulated_graph)))

# # If there is only 1 clique, then the junction tree formed is just a
# # clique tree with that single clique as the node
# if len(cliques) == 1:
#     clique_trees = JunctionTree()
#     clique_trees.add_node(cliques[0])

# # Else if the number of cliques is more than 1 then create a complete
# # graph with all the cliques as nodes and weight of the edges being
# # the length of sepset between two cliques
# elif len(cliques) >= 2:
#     complete_graph = UndirectedGraph()
#     edges = list(itertools.combinations(cliques, 2))
#     weights = list(map(lambda x: len(set(x[0]).intersection(set(x[1]))), edges))
#     for edge, weight in zip(edges, weights):
#         complete_graph.add_edge(*edge, weight=-weight)

#     # Create clique trees by minimum (or maximum) spanning tree method
#     clique_trees = JunctionTree(
#         nx.minimum_spanning_tree(complete_graph).edges()
#     )

# # Check whether the factors are defined for all the random variables or not
# all_vars = itertools.chain(*[factor.scope() for factor in self.factors])
# if set(all_vars) != set(self.nodes()):
#     ValueError("DiscreteFactor for all the random variables not specified")

# # Dictionary stating whether the factor is used to create clique
# # potential or not
# # If false, then it is not used to create any clique potential
# is_used = {factor: False for factor in self.factors}

# for node in clique_trees.nodes():
#     clique_factors = []
#     for factor in self.factors:
#         # If the factor is not used in creating any clique potential as
#         # well as has any variable of the given clique in its scope,
#         # then use it in creating clique potential
#         if not is_used[factor] and set(factor.scope()).issubset(node):
#             clique_factors.append(factor)
#             is_used[factor] = True

#     # To compute clique potential, initially set it as unity factor
#     var_card = [self.get_cardinality()[x] for x in node]
#     clique_potential = DiscreteFactor(
#         node, var_card, np.ones(np.product(var_card))
#     )
#     # multiply it with the factors associated with the variables present
#     # in the clique (or node)
#     # Checking if there's clique_factors, to handle the case when clique_factors
#     # is empty, otherwise factor_product with throw an error [ref #889]
#     if clique_factors:
#         clique_potential *= factor_product(*clique_factors)
#     clique_trees.add_factors(clique_potential)

# if not all(is_used.values()):
#     raise ValueError(
#         "All the factors were not used to create Junction Tree."
#         "Extra factors are defined."
#     )

# return clique_trees