Skip to content

Commit

Permalink
updated examples
Browse files Browse the repository at this point in the history
  • Loading branch information
Krxsy committed Jun 2, 2017
1 parent bc41f6e commit cf10c17
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 6 deletions.
4 changes: 2 additions & 2 deletions examples/fanova_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
f = fANOVA(X = features, Y = responses, config_space=cs)

# marginal of particular parameter:
dims = list([1])
dims = (1, )
res = f.quantify_importance(dims)
print(res)

Expand All @@ -39,6 +39,6 @@
# first create an instance of the visualizer with fanova object and configspace
vis = visualizer.Visualizer(f, cs)
# creating the plot of pairwise marginal:
vis.plot_pairwise_marginal(list([0,2]), resolution=20)
vis.plot_pairwise_marginal((0,2), resolution=20)
# creating all plots in the directory
vis.create_all_plots(plot_dir)
12 changes: 8 additions & 4 deletions examples/onlineLDA_example.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import numpy as np
from smac.configspace import ConfigurationSpace
from ConfigSpace.hyperparameters import UniformFloatHyperparameter
import csv
from fanova import fANOVA
import fanova.visualizer

Expand All @@ -17,17 +16,22 @@
f = open(param_file, 'rb')
cs = ConfigurationSpace()
for row in f:
cs.add_hyperparameter(UniformFloatHyperparameter("%s" %row[0:4], np.float(row[6:9]), np.float(row[10:13]),np.float(row[18:21])))
cs.add_hyperparameter(UniformFloatHyperparameter("%s" %row[0:4].decode('utf-8'), np.float(row[6:9]), np.float(row[10:13]),np.float(row[18:21])))
param = cs.get_hyperparameters()

# create an instance of fanova with data for the random forest and the configSpace
f = fANOVA(X = X, Y = Y)
f = fANOVA(X = X, Y = Y, config_space = cs)

# marginal for first parameter
p_list = (0, )
p_list = (1, )
res = f.quantify_importance(p_list)
print(res)

print(cs)
p2_list = ('Col1', 'Col2')
res2 = f.quantify_importance(p2_list)
print(res2)

# getting the most important pairwise marginals sorted by importance
best_margs = f.get_most_important_pairwise_marginals(n=3)
print(best_margs)
Expand Down

0 comments on commit cf10c17

Please sign in to comment.