Skip to content

Commit

Permalink
Merge pull request #42 from arminwitte/adaboost
Browse files Browse the repository at this point in the history
Adaboost
  • Loading branch information
arminwitte committed Jul 5, 2023
2 parents 1415c5b + 360bd6d commit 1885766
Show file tree
Hide file tree
Showing 24 changed files with 935 additions and 966 deletions.
75 changes: 52 additions & 23 deletions binarybeech/attributehandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,18 @@ def split(self, df):
]
N = len(df.index)
n = [len(df_.index) for df_ in split_df]
val = [self.metrics.node_value(df_[self.y_name]) for df_ in split_df]
loss = n[0] / N * self.metrics.loss(split_df[0][self.y_name], val[0]) + n[
1
] / N * self.metrics.loss(split_df[1][self.y_name], val[1])

if "__weights__" in df:
w = [df_["__weights__"].values for df_ in split_df]
else:
w = [None for df_ in split_df]
val = [
self.metrics.node_value(df_[self.y_name], w[i])
for i, df_ in enumerate(split_df)
]
loss = n[0] / N * self.metrics.loss(
split_df[0][self.y_name], val[0], w[0]
) + n[1] / N * self.metrics.loss(split_df[1][self.y_name], val[1], w[1])
if loss < self.loss:
success = True
self.loss = loss
Expand Down Expand Up @@ -153,10 +161,18 @@ def fun(x):
n = [len(df_.index) for df_ in split_df]
if min(n) == 0:
return np.Inf
val = [self.metrics.node_value(df_[self.y_name]) for df_ in split_df]
return n[0] / N * self.metrics.loss(split_df[0][self.y_name], val[0]) + n[
1
] / N * self.metrics.loss(split_df[1][self.y_name], val[1])

if "__weights__" in df:
w = [df_["__weights__"].values for df_ in split_df]
else:
w = [None for df_ in split_df]
val = [
self.metrics.node_value(df_[self.y_name], w[i])
for i, df_ in enumerate(split_df)
]
return n[0] / N * self.metrics.loss(
split_df[0][self.y_name], val[0], w[0]
) + n[1] / N * self.metrics.loss(split_df[1][self.y_name], val[1], w[1])

return fun

Expand Down Expand Up @@ -197,10 +213,18 @@ def split(self, df):
]
N = len(df.index)
n = [len(df_.index) for df_ in self.split_df]
val = [self.metrics.node_value(df_[self.y_name]) for df_ in self.split_df]

if "__weights__" in df:
w = [df_["__weights__"].values for df_ in self.split_df]
else:
w = [None for df_ in self.split_df]
val = [
self.metrics.node_value(df_[self.y_name], w[i])
for i, df_ in enumerate(self.split_df)
]
self.loss = n[0] / N * self.metrics.loss(
self.split_df[0][self.y_name], val[0]
) + n[1] / N * self.metrics.loss(self.split_df[1][self.y_name], val[1])
self.split_df[0][self.y_name], val[0], w[0]
) + n[1] / N * self.metrics.loss(self.split_df[1][self.y_name], val[1], w[1])

return success

Expand Down Expand Up @@ -270,10 +294,18 @@ def _opt_fun(self, df):
def fun(x):
split_df = [df[df[split_name] < x], df[df[split_name] >= x]]
n = [len(df_.index) for df_ in split_df]
val = [self.metrics.node_value(df_[self.y_name]) for df_ in split_df]
return n[0] / N * self.metrics.loss(split_df[0][self.y_name], val[0]) + n[
1
] / N * self.metrics.loss(split_df[1][self.y_name], val[1])

if "__weights__" in df:
w = [df_["__weights__"].values for df_ in split_df]
else:
w = [None for df_ in split_df]
val = [
self.metrics.node_value(df_[self.y_name], w[i])
for i, df_ in enumerate(split_df)
]
return n[0] / N * self.metrics.loss(
split_df[0][self.y_name], val[0], w[0]
) + n[1] / N * self.metrics.loss(split_df[1][self.y_name], val[1], w[1])

return fun

Expand Down Expand Up @@ -409,9 +441,6 @@ def split(self, df):
df[df[name].isin(threshold)],
df[~df[name].isin(threshold)],
]
# N = len(df.index)
# n = [len(df_.index) for df_ in split_df]
# val = [self.metrics.node_value(None) for df_ in split_df]
loss = math.shannon_entropy(df[c])
if loss < self.loss:
success = True
Expand Down Expand Up @@ -442,8 +471,8 @@ def register_method_group(self, method_group):

def register_handler(self, attribute_handler_class, method_group="default"):
self.attribute_handlers[method_group].append(attribute_handler_class)
def __getitem__(self,name):

def __getitem__(self, name):
ahc = None
for val in self.attribute_handlers.values():
for a in val:
Expand Down Expand Up @@ -471,9 +500,9 @@ def create_attribute_handlers(

if method_group not in self.attribute_handlers.keys():
# raise ValueError(f"{method} is not a registered method_group")
print(
f"WARNING: '{method_group}' is not a registered method group. Chosing 'default'."
)
# print(
# f"WARNING: '{method_group}' is not a registered method group. Chosing 'default'."
# )
method_group = "default"

ahc = self.get_attribute_handler_class(df[y_name], method_group=method_group)
Expand Down

0 comments on commit 1885766

Please sign in to comment.