Skip to content

Commit

Permalink
Improved NetPramsLner by adding exception mechanism to deal with trai…
Browse files Browse the repository at this point in the history
…ning data that yields un-normalizable potentials

Signed-off-by: rrtucci <tucci@ar-tiste.com>
  • Loading branch information
rrtucci committed Sep 7, 2016
1 parent 0bbadb1 commit 9891186
Show file tree
Hide file tree
Showing 13 changed files with 8,004 additions and 7,824 deletions.
41 changes: 40 additions & 1 deletion MyExceptions.py
Expand Up @@ -5,7 +5,7 @@
class BadGraphStructure(Exception):
"""
An exception class raised when a graph's structure is detected to be
illegal. Thrown when an alleged DAG is detected to contain cycles.
illegal; for instance, when an alleged DAG is detected to contain cycles.
Attributes
----------
Expand Down Expand Up @@ -37,5 +37,44 @@ def __repr__(self):
"""
return self.txt


class UnNormalizablePot(Exception):
"""
An exception class raised when an attempt to normalize a DiscreteCondPot
fails because it leads to division by zero.
Attributes
----------
pa_indices : tuple[int]
"""

def __init__(self, pa_indices):
"""
Constructor
Parameters
----------
pa_indices : tuple[int]
the indices of the parent state pa(C)=y such that
pot(C=x|pa(C)=y) = 0 for all x.
Returns
-------
"""
self.pa_indices = pa_indices

def __repr__(self):
"""
Returns
-------
tuple[int]
"""
return self.pa_indices

if __name__ == "__main__":
print(5)
3 changes: 3 additions & 0 deletions learning/AracneLner.py
Expand Up @@ -43,6 +43,8 @@ class AracneLner(ChowLiuTreeLner):
states_df : pandas.DataFrame
a Pandas DataFrame with training data. column = node and row =
sample. Each row/sample gives the state of the col/node.
ord_nodes : list[DirectedNode]
a list of DirectedNode's named and in the same order as the column
labels of self.states_df.
Expand All @@ -56,6 +58,7 @@ def __init__(self, states_df, vtx_to_states=None):
Parameters
----------
states_df : pandas.DataFrame
vtx_to_states : dict[str, list[str]]
A dictionary mapping each node name to a list of its state names.
This information will be stored in self.bnet. If
Expand Down
3 changes: 2 additions & 1 deletion learning/ChowLiuTreeLner.py
Expand Up @@ -61,7 +61,8 @@ def __init__(self, states_df, vtx_to_states=None):
None
"""
NetStrucLner.__init__(self, False, states_df, vtx_to_states)
NetStrucLner.__init__(self, False,
states_df, vtx_to_states)
self.learn_net_struc()

def learn_net_struc(self):
Expand Down
13 changes: 7 additions & 6 deletions learning/HillClimbingLner.py
Expand Up @@ -32,7 +32,7 @@ class HillClimbingLner(NetStrucLner):
a BayesNet in which we store what is learned
states_df : pandas.DataFrame
a Pandas DataFrame with training data. column = node and row =
sample. Each row/sample gives the state of the col/node.
sample. Each row/sample gives the state of the col/node
ord_nodes : list[DirectedNode]
a list of DirectedNode's named and in the same order as the column
labels of self.states_df.
Expand Down Expand Up @@ -91,11 +91,12 @@ def __init__(self, states_df, score_type, max_num_mtries,

# get vtx_to_states info from self.bnet
vtx_to_states1 = {nd.name: nd.state_names for nd in self.bnet.nodes}
self.scorer = NetStrucScorer(self.states_df,
self.vtx_to_parents,
vtx_to_states1,
score_type,
ess)
self.scorer = NetStrucScorer(
self.states_df,
self.vtx_to_parents,
vtx_to_states1,
score_type,
ess)

self.nx_graph = nx.DiGraph()

Expand Down

0 comments on commit 9891186

Please sign in to comment.