Skip to content

Commit

Permalink
feat: Refactor step function to handle limit and return values
Browse files Browse the repository at this point in the history
The step function in `fuzzylogic/functions.py` has been refactored to handle the limit and return values more accurately. It now returns the left argument when coming from the left, the average of the left and right arguments at the limit, and the right argument when coming from the right. This change improves the functionality and clarity of the step function.

Note: The `njit` import has been commented out due to unresolved issues.

Co-authored-by: dependabot[bot] <support@github.com>
  • Loading branch information
amogorkon and dependabot[bot] committed May 23, 2024
1 parent 7ce45ed commit db2c45c
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 51 deletions.
91 changes: 47 additions & 44 deletions src/fuzzylogic/classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
adding logical operaitons for easier handling.
"""

from logging import warn
from typing import Callable

import matplotlib.pyplot as plt
Expand All @@ -17,7 +16,6 @@

class FuzzyWarning(UserWarning):
"""Extra Exception so that user code can filter exceptions specific to this lib."""

pass


Expand Down Expand Up @@ -71,19 +69,11 @@ def __init__(
self._res = res
self._sets = {} if sets is None else sets # Name: Set(Function())

def __call__(self, X):
def __call__(self, x):
"""Pass a value to all sets of the domain and return a dict with results."""
if isinstance(X, np.ndarray):
if any(not (self._low <= x <= self._high) for x in X):
raise FuzzyWarning("Value in array is outside of defined range!")
res = {}
for s in self._sets.values():
vector = np.vectorize(s.func, otypes=[float])
res[s] = vector(X)
return res
if not (self._low <= X <= self._high):
warn(f"{X} is outside of domain!")
return {s: s.func(X) for name, s in self._sets.items()}
if not (self._low <= x <= self._high):
raise FuzzyWarning(f"{x} is outside of domain!")
return {name: s.func(x) for name, s in self._sets.items()}

def __str__(self):
"""Return a string to print()."""
Expand Down Expand Up @@ -195,7 +185,7 @@ class Set:
name = None # these are set on assignment to the domain! DO NOT MODIFY
domain = None

def __init__(self, func: Callable, *, name=None, domain=None):
def __init__(self, func: Callable, *, name:str|None=None, domain:Domain|None=None):
self.func = func
self.domain = domain
self.name = name
Expand Down Expand Up @@ -334,7 +324,7 @@ def dilated(self):

def multiplied(self, n):
"""Multiply with a constant factor, changing all membership values."""
return Set(lambda x: self.func(x) * n, domain=self)
return Set(lambda x: self.func(x) * n, domain=self.domain)

def plot(self):
"""Graph the set in the given domain."""
Expand All @@ -350,19 +340,18 @@ def array(self):
raise FuzzyWarning("No domain assigned.")
return np.fromiter((self.func(x) for x in self.domain.range), float)

@property
def center_of_gravity(self):
"""Return the center of gravity for this distribution, within the given domain."""
if self.__center_of_gravity is not None:
return self.__center_of_gravity

assert self.domain is not None, "No center of gravity with no domain."
weights = self.array()
if sum(weights) == 0:
return 0
cog = np.average(np.arange(len(weights)), weights=weights)
cog = np.average(self.domain.range, weights=weights)
self.__center_of_gravity = cog
return cog


def __repr__(self):
"""
Return a string representation of the Set that reconstructs the set with eval().
Expand Down Expand Up @@ -410,6 +399,7 @@ class Rule:
"""

def __init__(self, conditions, func=None):
print("ohalala")
self.conditions = {frozenset(C): oth for C, oth, in conditions.items()}
self.func = func

Expand Down Expand Up @@ -444,30 +434,43 @@ def __call__(self, args: "dict[Domain, float]", method="cog"):
assert isinstance(
args, dict
), "Please make sure to pass in the values as a dictionary."
if method == "cog":
assert (
len({C.domain for C in self.conditions.values()}) == 1
), "For CoG, all conditions must have the same target domain."
actual_values = {
f: f(args[f.domain]) for S in self.conditions.keys() for f in S
}

weights = []
for K, v in self.conditions.items():
x = min((actual_values[k] for k in K if k in actual_values), default=0)
if x > 0:
weights.append((v, x))

if not weights:
return None
target_domain = list(self.conditions.values())[0].domain
index = sum(v.center_of_gravity * x for v, x in weights) / sum(
x for v, x in weights
)
return (target_domain._high - target_domain._low) / len(
target_domain.range
) * index + target_domain._low

match method:
case "cog":
assert (
len({C.domain for C in self.conditions.values()}) == 1
), "For CoG, all conditions must have the same target domain."
actual_values = {
f: f(args[f.domain]) for S in self.conditions.keys() for f in S
}

weights = []
for K, v in self.conditions.items():
x = min((actual_values[k] for k in K if k in actual_values), default=0)
if x > 0:
weights.append((v, x))

if not weights:
return None
target_domain = list(self.conditions.values())[0].domain
index = sum(v.center_of_gravity * x for v, x in weights) / sum(
x for v, x in weights
)
return (target_domain._high - target_domain._low) / len(
target_domain.range
) * index + target_domain._low

case "centroid":
raise NotImplementedError("Centroid method not implemented yet.")
case "bisector":
raise NotImplementedError("Bisector method not implemented yet.")
case "mom":
raise NotImplementedError("Middle of max method not implemented yet.")
case "som":
raise NotImplementedError("Smallest of max method not implemented yet.")
case "lom":
raise NotImplementedError("Largest of max method not implemented yet.")
case _:
raise ValueError("Invalid method.")

def rule_from_table(table: str, references: dict):
"""Turn a (2D) string table into a Rule of fuzzy sets.
Expand Down
4 changes: 3 additions & 1 deletion src/fuzzylogic/combinators.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,12 @@ def F(z):
from collections.abc import Callable
from functools import reduce

from fuzzylogic.functions import noop # noqa
from numpy import multiply

try:
from numba import njit
raise ImportError
# from numba import njit # still not ready for prime time :(
except ImportError:

def njit(func):
Expand Down
22 changes: 16 additions & 6 deletions src/fuzzylogic/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@
from math import exp, isinf, isnan, log
from typing import Optional



try:
from numba import njit
raise ImportError
# from numba import njit # still not ready for prime time :(

except ImportError:

def njit(func):
Expand Down Expand Up @@ -206,19 +206,29 @@ def f(x) -> float:
return f


def step(limit, /, *, left=0, right=1):
def step(limit:float|int, /, *, left:float|int=0, right:float|int=1, at_lmt:None|float|int=None) -> Callable:
"""A step function.
Coming from left, the function returns the *left* argument.
At the limit, it returns *at_lmt* or the average of left and right.
After the limit, it returns the *right* argument.
>>> f = step(2)
>>> f(1)
0
>>> f(2)
0.5
>>> f(3)
1
"""
assert 0 <= left <= 1 and 0 <= right <= 1

def f(x):
return left if x < limit else right
def f(x:float|int) -> float|int:
if x < limit:
return left
elif x > limit:
return right
else:
return at_lmt if at_lmt is not None else (left + right)/2

return f

Expand Down

0 comments on commit db2c45c

Please sign in to comment.