Skip to content

Commit

Permalink
Merge pull request #36 from jussiviinikka/feature/param-refactoring-a…
Browse files Browse the repository at this point in the history
…nd-adjustment

Feature/param refactoring and adjustment
  • Loading branch information
jussiviinikka committed Sep 21, 2023
2 parents 835f555 + 4d7c818 commit c9c8fda
Show file tree
Hide file tree
Showing 23 changed files with 1,586 additions and 1,202 deletions.
3 changes: 3 additions & 0 deletions .git-blame-ignore-revs
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,6 @@ dc29f6da14c3a11e9f523001ef7c79be8832ad9c

# PR 11: Use isort (and a bit of black)
8fe07a70bd5dd104bee44405cf9a35ec388420a5

# Use tabs for c++ indentation consistently
c46d5467b3109c93a103795080a20416f7fdb0a9
19 changes: 13 additions & 6 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,16 @@ version = attr: sumu.__version__
max-line-length = 79
target-version = ['py37']
ignore=
E731, # do not assign a lambda expression, use a def
E203, # space before : (needed for how black formats slicing)
W605, # invalid escape sequence '\m' (m is any letter) :
# (needed for TeX)
W503, # line break before binary operator
E741, # do not use variables named ‘l’, ‘O’, or ‘I’
# do not assign a lambda expression, use a def
E731,
# space before : (needed for how black formats slicing)
E203,
# invalid escape sequence '\m' (m is any letter) :
# (needed for TeX)
W605,
# line break before binary operator
W503,
# line break after binary operator
W504,
# do not use variables named ‘l’, ‘O’, or ‘I’
E741,
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,10 +187,10 @@ def run(self):
author_email="jussi.viinikka@helsinki.fi",
license="BSD",
packages=["sumu", "sumu.utils", "sumu.scores"],
install_requires=["numpy", "scipy>=1.6"],
install_requires=["numpy", "scipy>=1.6", "psutil"],
package_data={"sumu": ["sources.bib"]},
include_package_data=True,
extras_require={"plotext": ["plotext==4.1.3"], "test": ["psutil"]},
extras_require={"plotext": ["plotext==4.1.3"]},
cmdclass=cmdclass,
ext_modules=cythonize(
exts, language_level="3", compiler_directives=COMPILER_DIRECTIVES
Expand Down
2 changes: 1 addition & 1 deletion sumu/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.2.dev5"
__version__ = "0.2.dev6"

from . import bnet, gadget
from .aps import aps
Expand Down
6 changes: 3 additions & 3 deletions sumu/_mcmc_moves.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ from .bnet import transitive_closure
from .utils.math_utils import comb, subsets


def DAG_edgerev(**kwargs):
def DAG_edge_reversal(**kwargs):

DAG = kwargs["DAG"]
R = kwargs["R"]
Expand Down Expand Up @@ -77,7 +77,7 @@ def DAG_edgerev(**kwargs):
return DAG, scoreratio, edge


def R_basic_move(**kwargs):
def R_split_merge(**kwargs):
"""Splits or merges a root-partition :footcite:`kuipers:2017`.
Args:
Expand Down Expand Up @@ -138,7 +138,7 @@ def R_basic_move(**kwargs):
return R_prime, q, q_prime, rescore


def R_swap_any(**kwargs):
def R_swap_node_pair(**kwargs):

R = kwargs["R"]
m = len(R)
Expand Down
16 changes: 8 additions & 8 deletions sumu/candidates.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,18 +53,18 @@ def greedy(
K,
*,
scores,
params={"k": 6, "t_budget": None, "criterion": "score"},
params={"K_f": 6, "t_budget": None, "association_measure": "score"},
**kwargs,
):
k = params.get("k")
k = params.get("K_f")
if k is not None:
k = min(k, K)
t_budget = params.get("t_budget")
criterion = params.get("criterion")
assert criterion in ("score", "gain")
if criterion == "score":
association_measure = params.get("association_measure")
assert association_measure in ("score", "gain")
if association_measure == "score":
goodness = lambda v, S, u: scores._local(v, np.array(S + (u,)))
elif criterion == "gain":
elif association_measure == "gain":
goodness = lambda v, S, u: scores._local(
v, np.array(S + (u,))
) - scores._local(v, np.array(S))
Expand Down Expand Up @@ -136,7 +136,7 @@ def get_k(t_budget):
C[v] = tuple(C[v])
scores.clear_cache()

return C, {"k": k}
return C, {"K_f": k}


candidate_parent_algorithm = {"opt": opt, "rnd": rnd, "greedy": greedy}
candidate_parent_algorithm = {"optimal": opt, "random": rnd, "greedy": greedy}
4 changes: 4 additions & 0 deletions sumu/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ def n(self):
def N(self):
return self.data.shape[0]

@property
def shape(self):
return self.data.shape

@property
def arities(self):
return np.count_nonzero(np.diff(np.sort(self.data.T)), axis=1) + 1
Expand Down
Loading

0 comments on commit c9c8fda

Please sign in to comment.