Skip to content

Commit

Permalink
Merge branch 'master' of https://github.com/automl/fanova.git
Browse files Browse the repository at this point in the history
  • Loading branch information
aaronkl committed Jun 17, 2015
2 parents 861d3d4 + ca7b249 commit cd933ee
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 8 deletions.
15 changes: 13 additions & 2 deletions fanova/src/net/aeatk/fanova/model/FunctionalANOVARunner.java
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,17 @@ public static void decomposeVariance(RandomForest existingForest, List<Algorithm
double[][] allObservations = new double[numDim][];
double[][] allIntervalSizes = new double[numDim][];

//=== Loop over trees to get number of trees with positive variance.
double[] treeTotalVariance = new double[forest.Trees.length];
int numTreesWithPositiveVariance = 0;
for(int numTree=0; numTree<forest.Trees.length; numTree++){
double thisTreeTotalVariance = RandomForestPreprocessor.computeTotalVarianceOfRegressionTree(forest.Trees[numTree], configSpace);
treeTotalVariance[numTree] = thisTreeTotalVariance;
if (thisTreeTotalVariance > 0.0){
numTreesWithPositiveVariance++;
}
}

//=== Loop over trees.
for(int numTree=0; numTree<forest.Trees.length; numTree++){
HashSet<Integer> allVariableIndices = new HashSet<Integer>();
Expand All @@ -75,7 +86,7 @@ public static void decomposeVariance(RandomForest existingForest, List<Algorithm
}

//=== Get the tree's total variance (only works for marginal trees, i.e., in the absence of instance features).
double thisTreeTotalVariance = RandomForestPreprocessor.computeTotalVarianceOfRegressionTree(forest.Trees[numTree], configSpace);
double thisTreeTotalVariance = treeTotalVariance[numTree];
if (thisTreeTotalVariance == 0.0){
s = "Tree " + numTree + " has no variance -> skipping.";
log.info(s);
Expand Down Expand Up @@ -271,7 +282,7 @@ public static void decomposeVariance(RandomForest existingForest, List<Algorithm
log.debug("ThisTreeVarianceContributions of index" + indexSet.toString() + " for Tree" + numTree + " : " + thisTreeVarianceContributions.get(indexSet));
log.debug("ThisTreeTotalVariance for Tree" + numTree + " : " + thisTreeTotalVariance);
tmpThisTreeFractionExplained += thisFractionExplained;
totalFractionsExplained.put(indexSet, previousFractionExplained + 1.0/forest.Trees.length * thisFractionExplained);
totalFractionsExplained.put(indexSet, previousFractionExplained + 1.0/numTreesWithPositiveVariance * thisFractionExplained);
log.debug("TotalFractionExplained for Tree" + numTree + " : " + totalFractionsExplained.get(indexSet));

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ public class FunctionalANOVAVarianceDecompose {
private Vector <HashMap<Integer,Double>> singleVarianceContributions;
private HashMap<HashSet<Integer>,Double> thisTreeVarianceContributions = new HashMap<HashSet<Integer>,Double>();
private HashMap<HashSet<Integer>,Double> totalFractionsExplained = new HashMap<HashSet<Integer>,Double>();

private int numTreesWithPositiveVariance;

public FunctionalANOVAVarianceDecompose(RandomForest existingForest, List<AlgorithmRunResult> testRuns,
ParameterConfigurationSpace configSpace, Random rand,
boolean compareToDef, double quantileToCompareTo, boolean logModel) throws IOException, InterruptedException
Expand All @@ -42,11 +43,13 @@ public FunctionalANOVAVarianceDecompose(RandomForest existingForest, List<Algori

//=== Initialize variables to be incrementally updated.
String s;



int numDim = configSpace.getCategoricalSize().length;
allObservations = new double[forest.Trees.length][numDim][];
allIntervalSizes = new double[forest.Trees.length][numDim][];
thisTreeTotalVariance = new double[forest.Trees.length];
this.numTreesWithPositiveVariance = 0;
//=== Loop over trees.
for(int numTree=0; numTree<forest.Trees.length; numTree++){
HashSet<Integer> allVariableIndices = new HashSet<Integer>();
Expand All @@ -61,6 +64,9 @@ public FunctionalANOVAVarianceDecompose(RandomForest existingForest, List<Algori
log.info(s);
continue;
}
else{
this.numTreesWithPositiveVariance++;
}
s = "Tree " + numTree + ": Total variance of predictor: " + thisTreeTotalVariance[numTree];
log.info(s);

Expand Down Expand Up @@ -130,6 +136,10 @@ public double getMarginal(int dim)

for(int numTree=0; numTree<forest.Trees.length; numTree++){

if (thisTreeTotalVariance[numTree] == 0.0){
continue;
}

int[] indicesOfObservations = new int[1];
//=== Compute marginal predictions for each instantiation of this categorical parameter.
indicesOfObservations[0] = dim;
Expand Down Expand Up @@ -170,7 +180,7 @@ public double getMarginal(int dim)
previousFractionExplained = totalFractionsExplained.get(set);
}
double thisFractionExplained = thisTreeVarianceContributions.get(set)/thisTreeTotalVariance[numTree]*100;
totalFractionsExplained.put(set, previousFractionExplained + 1.0/forest.Trees.length * thisFractionExplained);
totalFractionsExplained.put(set, previousFractionExplained + 1.0/this.numTreesWithPositiveVariance* thisFractionExplained);
}
return totalFractionsExplained.get(set);
}
Expand All @@ -193,6 +203,10 @@ public double getPairwiseMarginal(int dim1, int dim2)

for(int numTree=0; numTree<forest.Trees.length; numTree++){

if (thisTreeTotalVariance[numTree] == 0.0){
continue;
}

int[] indicesOfObservations = new int[2];
indicesOfObservations[0] = dim1;
ArrayList<Double> as = new ArrayList<Double>();
Expand Down Expand Up @@ -234,7 +248,7 @@ public double getPairwiseMarginal(int dim1, int dim2)
previousFractionExplained = totalFractionsExplained.get(set);
}
double thisFractionExplained = thisTreeVarianceContributions.get(set)/thisTreeTotalVariance[numTree]*100;
totalFractionsExplained.put(set, previousFractionExplained + 1.0/forest.Trees.length * thisFractionExplained);
totalFractionsExplained.put(set, previousFractionExplained + 1.0/this.numTreesWithPositiveVariance * thisFractionExplained);

}
return totalFractionsExplained.get(set);
Expand Down
Binary file modified pyfanova/fanova/fanova.jar
Binary file not shown.
12 changes: 10 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def check_java_exists():
error_msg = """
Java not found!
pysmac needs java in order to work. You can download java from:
Fanova needs java in order to work. You can download java from:
http://java.com/getjava
"""
raise RuntimeError(error_msg)
Expand All @@ -46,7 +46,15 @@ def check_java_exists():
name = "pyfanova",
version = "1.0",
packages = find_packages(),
install_requires = ['numpy', 'docutils>=0.3', 'setuptools', 'matplotlib>=1.4.2'],
install_requires = [
'numpy',
'docutils>=0.3',
'setuptools',
'matplotlib>=1.4.2',
'ParameterConfigSpace'],

dependency_links=['https://github.com/automl/ParameterConfigSpace/archive/master.zip'],

author = "Tobias Domhan, Aaron Klein (python wrapper). Frank Hutter (FANOVA)",
author_email = "kleinaa@cs.uni-freiburg.de",
description = "Functional ANOVA: an implementation of the ICML 2014 paper 'An Efficient Approach for Assessing Hyperparameter Importance' by Frank Hutter, Holger Hoos and Kevin Leyton-Brown.",
Expand Down

0 comments on commit cd933ee

Please sign in to comment.