Skip to content

Commit

Permalink
Merge pull request #19 from arminwitte/algorithm_kwargs
Browse files Browse the repository at this point in the history
entropy tolerance
  • Loading branch information
arminwitte committed May 7, 2023
2 parents 7d6d8bb + 2bbe21b commit 2ec1f46
Show file tree
Hide file tree
Showing 9 changed files with 260 additions and 68 deletions.
79 changes: 53 additions & 26 deletions binarybeech/attributehandler.py
Expand Up @@ -12,10 +12,11 @@


class AttributeHandlerBase(ABC):
def __init__(self, y_name, attribute, metrics):
def __init__(self, y_name, attribute, metrics, algorithm_kwargs):
self.y_name = y_name
self.attribute = attribute
self.metrics = metrics
self.algorithm_kwargs = algorithm_kwargs

self.loss = None
self.split_df = []
Expand Down Expand Up @@ -44,8 +45,8 @@ def check(x):


class NominalAttributeHandler(AttributeHandlerBase):
def __init__(self, y_name, attribute, metrics):
super().__init__(y_name, attribute, metrics)
def __init__(self, y_name, attribute, metrics, algorithm_kwargs):
super().__init__(y_name, attribute, metrics, algorithm_kwargs)

def split(self, df):
self.loss = np.Inf
Expand Down Expand Up @@ -105,8 +106,8 @@ def check(x):


class DichotomousAttributeHandler(AttributeHandlerBase):
def __init__(self, y_name, attribute, metrics):
super().__init__(y_name, attribute, metrics)
def __init__(self, y_name, attribute, metrics, algorithm_kwargs):
super().__init__(y_name, attribute, metrics, algorithm_kwargs)

def split(self, df):
self.loss = np.Inf
Expand Down Expand Up @@ -155,8 +156,8 @@ def check(x):


class IntervalAttributeHandler(AttributeHandlerBase):
def __init__(self, y_name, attribute, metrics):
super().__init__(y_name, attribute, metrics)
def __init__(self, y_name, attribute, metrics, algorithm_kwargs):
super().__init__(y_name, attribute, metrics, algorithm_kwargs)

def split(self, df):
self.loss = np.Inf
Expand Down Expand Up @@ -212,8 +213,8 @@ def check(x):


class NullAttributeHandler(AttributeHandlerBase):
def __init__(self, y_name, attribute, metrics):
super().__init__(y_name, attribute, metrics)
def __init__(self, y_name, attribute, metrics, algorithm_kwargs):
super().__init__(y_name, attribute, metrics, algorithm_kwargs)

def split(self, df):
self.loss = np.Inf
Expand All @@ -240,8 +241,8 @@ def check(x):


class UnsupervisedIntervalAttributeHandler(AttributeHandlerBase):
def __init__(self, y_name, attribute, metrics):
super().__init__(y_name, attribute, metrics)
def __init__(self, y_name, attribute, metrics, algorithm_kwargs):
super().__init__(y_name, attribute, metrics, algorithm_kwargs)

def split(self, df):
self.loss = np.Inf
Expand All @@ -258,15 +259,41 @@ def split(self, df):
valleys = math.valley(df[name])
if not valleys:
return success
else:
success = True

self.threshold = valleys[0]
self.split_df = [
df[df[self.attribute] < self.threshold],
df[df[self.attribute] >= self.threshold],
]
self.loss = math.shannon_entropy_histogram(df[name])
loss = np.Inf
for v in valleys:
threshold_candidate = v
split_df_candidate = [
df[df[self.attribute] < threshold_candidate],
df[df[self.attribute] >= threshold_candidate],
]
H = math.shannon_entropy_histogram(df[name], normalized=False)
H_ = [
math.shannon_entropy_histogram(df_[name], normalized=False)
for df_ in split_df_candidate
]
loss_candidate = (-np.sum(H_) + H) / np.abs(H)
if loss_candidate < loss:
loss = loss_candidate
split_df = split_df_candidate
threshold = threshold_candidate

# loss = math.shannon_entropy_histogram(df[name], normalized=True)

print(f"{self.attribute} loss: {loss}")

tol = self.algorithm_kwargs.get(
"unsupervised_minimum_relative_entropy_improvement"
)

if tol is not None and loss > tol:
return success

success = True

self.threshold = threshold
self.split_df = split_df
self.loss = loss
return success

def handle_missings(self, df):
Expand All @@ -284,8 +311,8 @@ def check(x):


class UnsupervisedNominalAttributeHandler(AttributeHandlerBase):
def __init__(self, y_name, attribute, metrics):
super().__init__(y_name, attribute, metrics)
def __init__(self, y_name, attribute, metrics, algorithm_kwargs):
super().__init__(y_name, attribute, metrics, algorithm_kwargs)

def split(self, df):
self.loss = np.Inf
Expand Down Expand Up @@ -362,21 +389,21 @@ def get_attribute_handler_class(self, arr, group_name="default"):

raise ValueError("no data handler class for this type of data")

def create_attribute_handlers(self, training_data, metrics):
def create_attribute_handlers(self, training_data, metrics, algorithm_kwargs):
df = training_data.df
y_name = training_data.y_name
X_names = training_data.X_names
dhc = self.get_attribute_handler_class(
ahc = self.get_attribute_handler_class(
df[y_name], group_name=metrics.attribute_handler_group()
)

d = {y_name: dhc(y_name, y_name, metrics)}
d = {y_name: ahc(y_name, y_name, metrics, algorithm_kwargs)}

for name in X_names:
dhc = self.get_attribute_handler_class(
ahc = self.get_attribute_handler_class(
df[name], group_name=metrics.attribute_handler_group()
)
d[name] = dhc(y_name, name, metrics)
d[name] = ahc(y_name, name, metrics, algorithm_kwargs)

return d

Expand Down
6 changes: 3 additions & 3 deletions binarybeech/datamanager.py
Expand Up @@ -12,17 +12,17 @@ def __init__(self, training_data, method, attribute_handlers, algorithm_kwargs):

if method is None:
metrics_type, metrics = metrics_factory.from_data(
training_data.df[training_data.y_name]
training_data.df[training_data.y_name], self.algorithm_kwargs
)
else:
metrics = metrics_factory.create_metrics(method)
metrics = metrics_factory.create_metrics(method, self.algorithm_kwargs)
metrics_type = method
self.metrics = metrics
self.metrics_type = metrics_type

if attribute_handlers is None:
attribute_handlers = attribute_handler_factory.create_attribute_handlers(
training_data, self.metrics
training_data, self.metrics, self.algorithm_kwargs
)
self.attribute_handlers = attribute_handlers
self.items = self.attribute_handlers.items
Expand Down
22 changes: 17 additions & 5 deletions binarybeech/math.py
Expand Up @@ -122,16 +122,28 @@ def ambiguity(X):


def valley(x):
hist, bin_edges = np.histogram(x, bins="auto")
hist, bin_edges = np.histogram(x, bins="sturges", density=False)
valley_ind, _ = scipy.signal.find_peaks(-hist)
# if len(valley_ind) < 1:
# return []
# prom = scipy.signal.peak_prominences(-hist,valley_ind)[0]
# ind_max = np.argmax(prom)
v = [(bin_edges[i] + bin_edges[i + 1]) * 0.5 for i in valley_ind]
return v
return v # [ind_max]


def shannon_entropy_histogram(x):
hist, bin_edges = np.histogram(x, bins="auto")
def shannon_entropy_histogram(x: np.ndarray, normalized=False):
hist, bin_edges = np.histogram(x, bins="sturges", density=False)
hist = np.maximum(hist, 1e-12)
return -np.sum(hist * np.log2(hist))
s = -np.sum(hist * np.log2(hist))

if normalized:
n_bins = bin_edges.size - 1
n_samples = x.size
s_ref = n_samples * np.log2(n_samples / n_bins)
s /= s_ref

return s


# =====================================
Expand Down
4 changes: 2 additions & 2 deletions binarybeech/metrics.py
Expand Up @@ -287,13 +287,13 @@ def __init__(self):
def register(self, metrics_type, metrics_class):
self.metrics[metrics_type] = metrics_class

def create_metrics(self, metrics_type):
def create_metrics(self, metrics_type, algorithm_kwargs):
if metrics_type in self.metrics:
return self.metrics[metrics_type]()
else:
raise ValueError("Invalid metrics type")

def from_data(self, y):
def from_data(self, y, algorithm_kwargs):
for name, cls in self.metrics.items():
if cls.check_data_type(y):
return cls(), name
Expand Down
14 changes: 14 additions & 0 deletions binarybeech/utils.py
Expand Up @@ -82,3 +82,17 @@ def model_missings(df, y_name, X_names=None, cart_settings={}):
df_.loc[df[x_name].isnull(), x_name] = mod.predict(df[df[x_name].isnull()])

return df_


# def plot_areas(df):
# x, y = np.meshgrid(np.linspace(1,7,101),np.linspace(0,2.5,101))
# col = []
# for i in range(len(x.ravel())):
# d = df_iris.iloc[120].copy()
# d["petal_length"] = x.ravel()[i]
# d["petal_width"] = y.ravel()[i]
# col.append(c_iris.tree.traverse(d).value)
# unique = [u for u in np.unique(col)]
# for i, c in enumerate(col):
# col[i] = unique.index(c)
# z = np.array(col).reshape(x.shape)

0 comments on commit 2ec1f46

Please sign in to comment.