Skip to content

Commit

Permalink
refactored ks for speed, added appropriate unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
austin-marcus committed May 17, 2023
1 parent 66ced49 commit add64a2
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 0 deletions.
19 changes: 19 additions & 0 deletions cana/boolean_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
flip_bit
)
from cana.utils import ncr, input_monotone
import functools

class BooleanNode(object):
"""
Expand Down Expand Up @@ -314,6 +315,24 @@ def input_symmetry(self, aggOp="mean", kernel="numDots", sameSymbol=False):
kernFunc = lambda x: strToKern[kernel](x, sameSymbol=sameSymbol)
return self._input_symmetry(strToOp[aggOp], kernFunc)

# refactor ks for speed, avg op only
def input_symmetry_mean(self):
"""compute the input symmetry (k_s) of the boolean node.
Specifically, computes it using the avg operator for the summand.
Refactoring of input_symmetry for speed.
Returns:
(float)
"""
summand = 0
# fTheta = a list of TS
for fTheta in self._ts_coverage.values():
inner = 0
for ts in fTheta:
inner += sum(len(i) for i in ts[1]) # assumes that indicies will ever only be in at most 1 group
summand += inner / len(fTheta)
return summand / 2**self.k

def look_up_table(self):
""" Returns the Look Up Table (LUT)
Expand Down
18 changes: 18 additions & 0 deletions tests/test_boolean_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,8 @@ def test_input_symmetry_AND():
assert (k_s == true_k_s), f"Input symmetry: AND (mean): returned {k_s}, true value is {true_k_s}"
k_s, true_k_s = n.input_symmetry(aggOp="max", kernel="numDots"), 3.0/2
assert (k_s == true_k_s), f"Input symmetry: AND (max): returned {k_s}, true value is {true_k_s}"
k_s, true_k_s = n.input_symmetry_mean(), 3.0/2
assert (k_s == true_k_s), f"Input symmetry simp: AND (mean): returned {k_s}, true value is {true_k_s}"

# k_s, true_k_s = n.input_symmetry(aggOp="mean", kernel="numDots", sameSymbol=True), 2.0
# assert (k_s == true_k_s), f"Input symmetry: AND (mean, sameSymbol): returned {k_s}, true value is {true_k_s}"
Expand All @@ -407,46 +409,62 @@ def test_input_symmetry_XOR():
n = XOR()
k_s, true_k_s = n.input_symmetry(aggOp="mean", kernel="numDots"), 1.0
assert (k_s == true_k_s), f"Input symmetry: XOR (mean): returned {k_s}, true value is {true_k_s}"
k_s, true_k_s = n.input_symmetry_mean(), 1.0
assert (k_s == true_k_s), f"Input symmetry simp: XOR (mean): returned {k_s}, true value is {true_k_s}"
k_s, true_k_s = n.input_symmetry(aggOp="max", kernel="numDots"), 1.0
assert (k_s == true_k_s), f"Input symmetry: XOR (max): returned {k_s}, true value is {true_k_s}"

# k_s, true_k_s = n.input_symmetry(aggOp="mean", kernel="numDots", sameSymbol=True), 2.0
# assert (k_s == true_k_s), f"Input symmetry: XOR (mean, sameSymbol): returned {k_s}, true value is {true_k_s}"
# k_s, true_k_s = n.input_symmetry_mean(), 2.0
# assert (k_s == true_k_s), f"Input symmetry simp: XOR (mean, sameSymbol): returned {k_s}, true value is {true_k_s}"
# k_s, true_k_s = n.input_symmetry(aggOp="max", kernel="numDots", sameSymbol=True), 2.0
# assert (k_s == true_k_s), f"Input symmetry: XOR (max, sameSymbol): returned {k_s}, true value is {true_k_s}"

def test_input_symmetry_COPYx1():
n = COPYx1()
k_s, true_k_s = n.input_symmetry(aggOp="mean", kernel="numDots"), 0
assert (k_s == true_k_s), f"Input symmetry: COPYx1 (mean): returned {k_s}, true value is {true_k_s}"
k_s, true_k_s = n.input_symmetry_mean(), 0
assert (k_s == true_k_s), f"Input symmetry simp: COPYx1 (mean): returned {k_s}, true value is {true_k_s}"
k_s, true_k_s = n.input_symmetry(aggOp="max", kernel="numDots"), 0
assert (k_s == true_k_s), f"Input symmetry: COPYx1 (max): returned {k_s}, true value is {true_k_s}"

# k_s, true_k_s = n.input_symmetry(aggOp="mean", kernel="numDots", sameSymbol=True), 0.0
# assert (k_s == true_k_s), f"Input symmetry: COPYx1 (mean, sameSymbol): returned {k_s}, true value is {true_k_s}"
# k_s, true_k_s = n.input_symmetry_mean(), 0.0
# assert (k_s == true_k_s), f"Input symmetry simp: COPYx1 (mean, sameSymbol): returned {k_s}, true value is {true_k_s}"
# k_s, true_k_s = n.input_symmetry(aggOp="max", kernel="numDots", sameSymbol=True), 0.0
# assert (k_s == true_k_s), f"Input symmetry: COPYx1 (max, sameSymbol): returned {k_s}, true value is {true_k_s}"

def test_input_symmetry_RULE90():
n = RULE90()
k_s, true_k_s = n.input_symmetry(aggOp="mean", kernel="numDots"), 1.0
assert (k_s == true_k_s), f"Input symmetry: RULE90 (mean): returned {k_s}, true value is {true_k_s}"
k_s, true_k_s = n.input_symmetry_mean(), 1.0
assert (k_s == true_k_s), f"Input symmetry simp: RULE90 (mean): returned {k_s}, true value is {true_k_s}"
# k_s, true_k_s = n.input_symmetry(aggOp="max", kernel="numDots", sameSymbol=True), 1.0
assert (k_s == true_k_s), f"Input symmetry: RULE90 (max): returned {k_s}, true value is {true_k_s}"

# k_s, true_k_s = n.input_symmetry(aggOp="mean", kernel="numDots", sameSymbol=True), 2.0
# assert (k_s == true_k_s), f"Input symmetry: RULE90 (mean, sameSymbol): returned {k_s}, true value is {true_k_s}"
# k_s, true_k_s = n.input_symmetry_mean(), 2.0
# assert (k_s == true_k_s), f"Input symmetry simp: RULE90 (mean, sameSymbol): returned {k_s}, true value is {true_k_s}"
# k_s, true_k_s = n.input_symmetry(aggOp="max", kernel="numDots", sameSymbol=True), 2.0
# assert (k_s == true_k_s), f"Input symmetry: RULE90 (max, sameSymbol): returned {k_s}, true value is {true_k_s}"

def test_input_symmetry_SBF():
n = BooleanNode(outputs=list("0111" + "0"*12), k=4)
k_s, true_k_s = n.input_symmetry(aggOp="mean", kernel="numDots"), 1.6875
assert (k_s == true_k_s), f"Input symmetry: SBF (mean): returned {k_s}, true value is {true_k_s}"
k_s, true_k_s = n.input_symmetry_mean(), 1.6875
assert (k_s == true_k_s), f"Input symmetry simp: SBF (mean): returned {k_s}, true value is {true_k_s}"
k_s, true_k_s = n.input_symmetry(aggOp="max", kernel="numDots"), 1.875
assert (k_s == true_k_s), f"Input symmetry: SBF (max): returned {k_s}, true value is {true_k_s}"

# k_s, true_k_s = n.input_symmetry(aggOp="mean", kernel="numDots", sameSymbol=True), 4.0
# assert (k_s == true_k_s), f"Input symmetry: SBF (mean, sameSymbol): returned {k_s}, true value is {true_k_s}"
# k_s, true_k_s = n.input_symmetry_mean(), 4.0
# assert (k_s == true_k_s), f"Input symmetry simp: SBF (mean, sameSymbol): returned {k_s}, true value is {true_k_s}"
# k_s, true_k_s = n.input_symmetry(aggOp="max", kernel="numDots", sameSymbol=True), 4.0
# assert (k_s == true_k_s), f"Input symmetry: SBF (max, sameSymbol): returned {k_s}, true value is {true_k_s}"

0 comments on commit add64a2

Please sign in to comment.