-
Notifications
You must be signed in to change notification settings - Fork 46
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added reporting and metrics to the decoupled and easy to execute via …
…cli approach.
- Loading branch information
1 parent
b3d8a13
commit a8e17e0
Showing
10 changed files
with
122 additions
and
49 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,6 @@ | ||
__all__ = [ | ||
"utils", | ||
"logclass", | ||
] | ||
__all__ = ["utils", "logclass"] | ||
|
||
from .preprocess import * | ||
from .feature_engineering import * | ||
from .models import * | ||
from .reporting import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
__all__ = ["accuracy", "confusion_matrix", "multi_class_acc", "top_k_svm"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from .bb_registry import register | ||
from sklearn.metrics import f1_score | ||
|
||
|
||
@register('acc') | ||
def model_accuracy(y, pred): | ||
return f1_score(y, pred) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
"""Registry for black box reports or metrics.""" | ||
|
||
_BB_REPORTS = dict() | ||
|
||
|
||
def register(name): | ||
"""Registers a new black box report or metric function.""" | ||
|
||
def add_to_dict(func): | ||
_BB_REPORTS[name] = func | ||
return func | ||
|
||
return add_to_dict | ||
|
||
|
||
def get_bb_report(model): | ||
"""Fetches the black box report or metric function.""" | ||
return _BB_REPORTS[model] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from .bb_registry import register | ||
from sklearn.metrics import confusion_matrix | ||
|
||
|
||
@register('confusion_matrix') | ||
def report(y, pred): | ||
return confusion_matrix(y, pred) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from .bb_registry import register | ||
from sklearn.metrics import accuracy_score | ||
|
||
|
||
@register('multi_acc') | ||
def model_accuracy(y, pred): | ||
return accuracy_score(y, pred) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
from .wb_registry import register | ||
import numpy as np | ||
|
||
|
||
def get_feature_names(params, vocabulary, add_length=True): | ||
feature_names = zip(vocabulary.keys(), vocabulary.values()) | ||
feature_names = sorted(feature_names, key=lambda x: x[1]) | ||
feature_names = [x[0] for x in feature_names] | ||
if 'length' in params['features']: | ||
feature_names.append('LENGTH') | ||
return np.array(feature_names) | ||
|
||
|
||
@register('top_k_svm') | ||
def get_top_k_SVM_features(params, model, vocabulary, **kwargs): | ||
hparms = { | ||
'target_names': [], | ||
'top_features': 5, | ||
} | ||
hparms.update(kwargs) | ||
|
||
top_k_label = {} | ||
feature_names = get_feature_names(params, vocabulary) | ||
for i, label in enumerate(hparms['target_names']): | ||
if len(hparms['target_names']) < 3 and i == 1: | ||
break # coef is unidemensional when there's only two labels | ||
coef = model.coef_[i] | ||
top_coefficients = np.argsort(coef)[-hparms['top_features']:] | ||
top_k_features = feature_names[top_coefficients] | ||
top_k_label[label] = list(reversed(top_k_features)) | ||
return top_k_label |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
"""Registry for white box reports or metrics.""" | ||
|
||
_WB_REPORTS = dict() | ||
|
||
|
||
def register(name): | ||
"""Registers a new white box report or metric function.""" | ||
|
||
def add_to_dict(func): | ||
_WB_REPORTS[name] = func | ||
return func | ||
|
||
return add_to_dict | ||
|
||
|
||
def get_wb_report(model): | ||
"""Fetches the white box report or metric function.""" | ||
return _WB_REPORTS[model] |