Skip to content

Commit

Permalink
fixes #498
Browse files Browse the repository at this point in the history
  • Loading branch information
amaiya committed Jun 14, 2023
1 parent 6b05db1 commit 64ed862
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 8 deletions.
13 changes: 13 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,19 @@ Most recent releases are shown at the top. Each release shows:
- **Changed**: Additional parameters, changes to inputs or outputs, etc
- **Fixed**: Bug fixes that don't change documented behaviour

## 0.37.2 (TBD)

### new:
- N/A

### changed
- N/A

### fixed:
- fix `validate` to support multilabel classification problems (#498)
- add a warning to `TransformerPreprocessor.get_classifier` to use `binary_accuracy` for multilabel problems (#498)


## 0.37.1 (2023-06-05)

### new:
Expand Down
21 changes: 16 additions & 5 deletions ktrain/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,10 +161,7 @@ class integer IDs.
#'to manually validate.')
# return
pass

if U.is_multilabel(val) or multilabel:
warnings.warn("multilabel confusion matrices not yet supported")
return
is_multilabel = U.is_multilabel(val) or multilabel
y_pred = self.predict(val_data=val)
y_true = self.ground_truth(val_data=val)
y_pred = np.squeeze(y_pred)
Expand All @@ -191,9 +188,14 @@ class integer IDs.
if len(y_pred.shape) == 1:
y_pred = np.where(y_pred > 0.5, 1, 0)
y_true = np.where(y_true > 0.5, 1, 0)
elif is_multilabel:
from sklearn.preprocessing import binarize

y_pred = binarize(y_pred, threshold=0.5)
else:
y_pred = np.argmax(y_pred, axis=1)
y_true = np.argmax(y_true, axis=1)

if print_report or save_path is not None:
if class_names:
try:
Expand All @@ -208,7 +210,10 @@ class integer IDs.
)
else:
report = classification_report(
y_true, y_pred, output_dict=not print_report
y_true,
y_pred,
output_dict=not print_report,
zero_division=0,
)
if print_report:
print(report)
Expand All @@ -217,6 +222,12 @@ class integer IDs.
df.to_csv(save_path)
print("classification report saved to: %s" % (save_path))
cm_func = confusion_matrix
if is_multilabel:
warnings.warn(
"Confusion matrices do not currently support multilabel classification, so returning None"
)
return

cm = confusion_matrix(y_true, y_pred)
return cm

Expand Down
5 changes: 5 additions & 0 deletions ktrain/text/preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1139,6 +1139,11 @@ def get_classifier(self, fpath=None, multilabel=None, metrics=["accuracy"]):
+ "this is a multilabel problem (labels are not mutually-exclusive). Using multilabel=False anyways."
)

if multilabel and metrics == ["accuracy"]:
warnings.warn(
'For multilabel problems, we recommend you supply the following argument to this method: metrics=["binary_accuracy"]'
)

# setup model
num_labels = len(self.get_classes())
mname = fpath if fpath is not None else self.model_name
Expand Down
2 changes: 1 addition & 1 deletion ktrain/version.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
__all__ = ["__version__"]
__version__ = "0.37.1"
__version__ = "0.37.2"
3 changes: 1 addition & 2 deletions tests/test_multilabel.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@


def synthetic_multilabel():

# data
X = [
[1, 0, 0, 0, 0, 0, 0],
Expand Down Expand Up @@ -93,7 +92,7 @@ def test_multilabel(self):
# use loss instead of accuracy due to: https://github.com/tensorflow/tensorflow/issues/41114
hist = learner.fit(0.001, 200)
learner.view_top_losses(n=5)
learner.validate()
print(learner.validate())
# final_acc = hist.history[VAL_ACC_NAME][-1]
# print('final_accuracy:%s' % (final_acc))
# self.assertGreater(final_acc, 0.97)
Expand Down

0 comments on commit 64ed862

Please sign in to comment.