-
Notifications
You must be signed in to change notification settings - Fork 270
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- scan is now in its own submodule/folder - reduction files are now following the new convention - major refactor of reduction procedures
- Loading branch information
1 parent
86ed1bd
commit da26fe5
Showing
11 changed files
with
107 additions
and
97 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
import pandas as pd | ||
|
||
|
||
def correlation(self, | ||
correlation='spearman', | ||
corr_to_drop='neg'): | ||
|
||
'''Correlation Reducers | ||
Note that this set of reducers works only for the continuous | ||
and stepped (e.g. batch size) hyperparameters. | ||
''' | ||
|
||
out = self.param_table.corr(correlation)[self.reduction_metric] | ||
out = out.dropna() | ||
|
||
if len(out) == 0: | ||
self._reduce_keys = None | ||
return self | ||
|
||
out = out[1:].sort_values(ascending=False) | ||
out = out.index[-1], out[-1] | ||
|
||
if abs(out[1]) >= self.reduction_threshold: | ||
dummy_cols = pd.get_dummies(self.param_table[out[0]]) | ||
dummy_cols.insert(0, | ||
self.reduction_metric, | ||
self.param_table[self.reduction_metric]) | ||
|
||
# case where threshold is not met | ||
else: | ||
self._reduce_keys = None | ||
return self | ||
|
||
# all other cases continue | ||
to_drop_temp = dummy_cols.corr(correlation)[self.reduction_metric] | ||
|
||
# pick the drop method based on paramaters | ||
if corr_to_drop == 'neg': | ||
self._reduce_keys = to_drop_temp.sort_values().index[0], out[0] | ||
elif corr_to_drop == 'pos': | ||
self._reduce_keys = to_drop_temp.sort_values().index[-2], out[0] | ||
|
||
return self |
10 changes: 5 additions & 5 deletions
10
talos/reducers/reduce_drop.py → talos/reducers/reduce_finish.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,18 +1,18 @@ | ||
def reduction_drop(self): | ||
def reduce_finish(self): | ||
|
||
'''Takes input from a Reducer in form of a tuple | ||
where the values the hyperparamater name and the | ||
value to drop. Returns self with a modified param_log.''' | ||
|
||
# get the column index | ||
to_remove_col = self.param_reference[self.out[1]] | ||
to_remove_col = self.param_reference[self._reduce_keys[1]] | ||
|
||
value_to_remove = self.out[0] | ||
value_to_remove = self._reduce_keys[0] | ||
|
||
# pick the index numbers for dropping available permutations | ||
indexs_to_drop = self.param_grid[self.param_grid[:, to_remove_col] == value_to_remove][:,-1] | ||
|
||
# drop the index numbers | ||
param_log = list(set(self.param_log).difference(set(indexs_to_drop))) | ||
self.param_log = list(set(self.param_log).difference(set(indexs_to_drop))) | ||
|
||
return param_log | ||
return self |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
import pandas as pd | ||
|
||
from ..metrics.names import metric_names | ||
|
||
|
||
def reduce_prepare(self): | ||
|
||
# load the data from the experiment log | ||
self.data = pd.read_csv(self.experiment_name + '.csv') | ||
self.names = metric_names() | ||
|
||
# apply the lookback window | ||
if self.reduction_window is not None: | ||
self.data = self.data.tail(self.reduction_window) | ||
|
||
self.param_columns = [col for col in self.data.columns if col not in self.names] | ||
self.param_table = self.data[self.param_columns] | ||
self.param_table.insert(0, self.reduction_metric, self.data[self.reduction_metric]) | ||
|
||
return self |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,29 +1,16 @@ | ||
from .ReductionTable import ReductionTable | ||
from .Reducers import Reducers | ||
from .reduce_drop import reduction_drop | ||
from .reduce_prepare import reduce_prepare | ||
from .reduce_finish import reduce_finish | ||
from .correlation import correlation | ||
|
||
|
||
def reduce_run(self): | ||
|
||
'''Takes in the Scan object, and returns a modified version | ||
of the self.param_log.''' | ||
self = reduce_prepare(self) | ||
|
||
self._filaname = self.experiment_name + '.csv' | ||
|
||
# create the table for reduction | ||
out = ReductionTable(self._filaname, | ||
self.reduction_metric, | ||
self.reduction_window, | ||
self.reduction_threshold) | ||
|
||
# create the reducer object | ||
out = Reducers(out) | ||
|
||
# apply the reduction | ||
if self.reduction_method == 'correlation': | ||
self.out = out.correlation() | ||
self = correlation(self) | ||
|
||
if self.out is None: | ||
return self.param_log | ||
if self._reduce_keys is None: | ||
return self | ||
else: | ||
return reduction_drop(self) | ||
return reduce_finish(self) |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters