Skip to content

Commit

Permalink
Added distributions::InvCDF for weighted sampling of features to make…
Browse files Browse the repository at this point in the history
… the split. Unit tests added as well.
  • Loading branch information
timo.erkkila@gmail.com committed Oct 22, 2012
1 parent 91953f3 commit 6ff5b96
Show file tree
Hide file tree
Showing 7 changed files with 199 additions and 7 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
COMPILER = g++
COMPILER = g++44
CFLAGS = -O2 -std=c++0x -Wall -Wextra -pedantic -Isrc/
TFLAGS = -pthread
SOURCEFILES = src/progress.cpp src/statistics.cpp src/math.cpp src/stochasticforest.cpp src/rootnode.cpp src/node.cpp src/treedata.cpp src/datadefs.cpp src/utils.cpp src/distributions.cpp
Expand Down
31 changes: 31 additions & 0 deletions src/distributions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,34 @@ num_t distributions::RandInt::uniform() {

}

distributions::InvCDF::InvCDF(const vector<num_t>& weights) {

num_t sum = 0.0;

for ( size_t i = 0; i < weights.size(); ++i ) {
assert( weights[i] >= 0.0 );
sum += weights[i];
}

num_t cumProb = 0.0;

assert( sum > 0.0 );

for ( size_t i = 0; i < weights.size(); ++i ) {
if ( weights[i] > 0.0 ) {
cumProb += weights[i] / sum;
icdf_[cumProb] = i;
}
}

}

distributions::InvCDF::~InvCDF() {

}

size_t distributions::InvCDF::at(const num_t prob) {
assert(prob >= 0.0 && prob < 1.0);

return( icdf_.upper_bound(prob)->second );
}
11 changes: 9 additions & 2 deletions src/distributions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,15 @@ namespace distributions {
tr1::uniform_int<size_t> rand_;

};

typedef map<datadefs::num_t,size_t> icdf_t;

class InvCDF {
public:
InvCDF(const vector<datadefs::num_t>& weights);
~InvCDF();
size_t at(const datadefs::num_t prob);
private:
map<datadefs::num_t,size_t> icdf_;
};

//icdf_t

Expand Down
54 changes: 53 additions & 1 deletion test/distributions_test.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ class DistributionsTest : public CppUnit::TestFixture {
CPPUNIT_TEST_SUITE( DistributionsTest );
CPPUNIT_TEST( test_randint );
CPPUNIT_TEST( test_uniform );
CPPUNIT_TEST( test_invcdf1 );
CPPUNIT_TEST( test_invcdf2 );
CPPUNIT_TEST( test_invcdf3 );
CPPUNIT_TEST_SUITE_END();

public:
Expand All @@ -17,6 +20,9 @@ class DistributionsTest : public CppUnit::TestFixture {

void test_randint();
void test_uniform();
void test_invcdf1();
void test_invcdf2();
void test_invcdf3();

private:

Expand Down Expand Up @@ -84,7 +90,7 @@ void DistributionsTest::test_uniform() {

for ( size_t i = 0; i < 100000; ++i ) {
datadefs::num_t r = randInt_.uniform();
CPPUNIT_ASSERT( 0.0 <= r && r <= 1.0 );
CPPUNIT_ASSERT( 0.0 <= r && r < 1.0 );

//cout << " " << r;

Expand All @@ -101,7 +107,53 @@ void DistributionsTest::test_uniform() {

}

void DistributionsTest::test_invcdf1() {

distributions::InvCDF icdf({0.2,0.2,0.2,0.2,0.2});

CPPUNIT_ASSERT( icdf.at(0.00) == 0 );
CPPUNIT_ASSERT( icdf.at(0.05) == 0 );
CPPUNIT_ASSERT( icdf.at(0.15) == 0 );
CPPUNIT_ASSERT( icdf.at(0.25) == 1 );
CPPUNIT_ASSERT( icdf.at(0.35) == 1 );
CPPUNIT_ASSERT( icdf.at(0.45) == 2 );
CPPUNIT_ASSERT( icdf.at(0.55) == 2 );
CPPUNIT_ASSERT( icdf.at(0.65) == 3 );
CPPUNIT_ASSERT( icdf.at(0.75) == 3 );
CPPUNIT_ASSERT( icdf.at(0.85) == 4 );
CPPUNIT_ASSERT( icdf.at(0.95) == 4 );
CPPUNIT_ASSERT( icdf.at(0.999999999999) == 4 );

}

void DistributionsTest::test_invcdf2() {

distributions::InvCDF icdf({1,2,3,0,4});

CPPUNIT_ASSERT( icdf.at(0.00) == 0 );
CPPUNIT_ASSERT( icdf.at(0.05) == 0 );
CPPUNIT_ASSERT( icdf.at(0.15) == 1 );
CPPUNIT_ASSERT( icdf.at(0.25) == 1 );
CPPUNIT_ASSERT( icdf.at(0.35) == 2 );
CPPUNIT_ASSERT( icdf.at(0.45) == 2 );
CPPUNIT_ASSERT( icdf.at(0.55) == 2 );
CPPUNIT_ASSERT( icdf.at(0.65) == 4 );
CPPUNIT_ASSERT( icdf.at(0.75) == 4 );
CPPUNIT_ASSERT( icdf.at(0.85) == 4 );
CPPUNIT_ASSERT( icdf.at(0.95) == 4 );
CPPUNIT_ASSERT( icdf.at(0.99999999999) == 4 );

}

void DistributionsTest::test_invcdf3() {

distributions::InvCDF icdf({0,0,1,0,0});

for ( size_t i = 0; i < 100; ++i ) {
CPPUNIT_ASSERT( icdf.at(static_cast<datadefs::num_t>(0.01*i)) == 2 );
}

}


// Registers the fixture into the test 'registry'
Expand Down
4 changes: 2 additions & 2 deletions test/stochasticforest_test.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ void StochasticForestTest::test_treeDataPercolation() {
Treedata::Treedata testData("testdata.tsv",&parameters_);

// Next load a test predictor
StochasticForest SF(&parameters_);
StochasticForest SF(parameters_);

vector<num_t> prediction,confidence;

Expand Down Expand Up @@ -109,7 +109,7 @@ void StochasticForestTest::test_CART() {

parameters_.nMaxLeaves = 2;

StochasticForest CART(&treeData,&parameters_);
StochasticForest CART(&treeData,parameters_);

CPPUNIT_ASSERT( CART.rootNodes_[0]->splitter_.name == "N:input" );

Expand Down
2 changes: 1 addition & 1 deletion test/treedata_test.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -767,7 +767,7 @@ void treeDataTest::test_fullSplitterSweep() {
//size_t targetIdx = treeData->getFeatureIdx("N:output");
size_t minSamples = 3;

vector<string> sweepFiles = {"test_fullSplitterSweep.txt","test_fullSplitterSweep_C:class.txt"};
vector<string> sweepFiles = {"test_fullSplitterSweep.txt","test_fullSplitterSweep_class.txt"};

for ( size_t f = 0; f < sweepFiles.size(); ++f ) {

Expand Down
102 changes: 102 additions & 0 deletions test_fullSplitterSweep_class.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
C:class N:output 8.277247e-02 132 146
C:class N:input 1.080478e-02 273 4
C:class N:noise_1 1.186730e-02 145 131
C:class N:noise_2 9.047567e-03 149 131
C:class N:noise_3 5.632269e-03 37 240
C:class N:noise_4 2.012800e-02 101 176
C:class N:noise_5 4.715754e-03 144 132
C:class C:noise_6 1.915083e-03 93 178
C:class N:noise_7 8.512985e-03 263 19
C:class N:noise_8 4.338452e-03 17 260
C:class N:noise_9 7.568886e-03 157 128
C:class N:noise_10 5.742623e-03 212 67
C:class N:noise_11 8.089438e-03 267 7
C:class N:noise_12 6.164748e-03 162 107
C:class N:noise_13 5.167658e-03 144 132
C:class N:noise_14 7.158856e-03 265 13
C:class N:noise_15 8.629035e-03 31 245
C:class N:noise_16 1.294779e-02 270 6
C:class N:noise_17 3.829856e-03 43 233
C:class C:noise_18 1.939052e-02 89 183
C:class N:noise_19 7.300695e-03 13 262
C:class N:noise_20 1.300216e-02 88 188
C:class N:noise_21 7.683229e-03 227 43
C:class N:noise_22 1.820842e-02 8 268
C:class C:noise_23 4.497780e-03 77 195
C:class N:noise_24 1.553144e-02 31 240
C:class N:noise_25 7.985333e-03 28 244
C:class N:noise_26 1.021783e-02 213 64
C:class N:noise_27 5.208401e-03 32 239
C:class N:noise_28 8.072889e-03 250 28
C:class N:noise_29 7.635241e-03 10 265
C:class N:noise_30 1.439265e-02 225 50
C:class N:noise_31 9.134501e-03 269 10
C:class N:noise_32 1.537890e-02 185 89
C:class N:noise_33 7.579281e-03 3 278
C:class N:noise_34 1.615210e-02 41 231
C:class N:noise_35 9.442723e-03 52 227
C:class N:noise_36 9.529369e-03 205 70
C:class N:noise_37 8.874968e-03 182 90
C:class N:noise_38 6.306706e-03 83 191
C:class N:noise_39 6.690488e-03 13 257
C:class N:noise_40 6.732478e-03 156 123
C:class N:noise_41 6.800112e-03 258 16
C:class N:noise_42 1.330721e-02 5 271
C:class N:noise_43 6.459754e-03 251 22
C:class N:noise_44 1.463259e-02 86 190
C:class N:noise_45 1.074883e-02 57 221
C:class N:noise_46 4.948310e-03 249 27
C:class N:noise_47 1.185683e-02 74 201
C:class N:noise_48 5.228776e-03 203 68
C:class N:noise_49 8.015075e-03 270 3
C:class N:noise_50 3.827547e-03 245 26
C:class N:noise_51 5.623361e-03 8 260
C:class N:noise_52 1.132672e-02 264 11
C:class N:noise_53 1.474747e-02 198 72
C:class N:noise_54 1.464262e-02 184 93
C:class N:noise_55 1.061787e-02 57 218
C:class N:noise_56 9.102942e-03 275 7
C:class N:noise_57 8.660064e-03 4 270
C:class N:noise_58 1.468407e-02 7 273
C:class N:noise_59 4.682519e-03 254 14
C:class N:noise_60 8.742796e-03 87 189
C:class N:noise_61 6.415251e-03 276 3
C:class C:noise_62 3.877685e-03 91 182
C:class N:noise_63 9.286570e-03 252 26
C:class N:noise_64 7.245304e-03 6 272
C:class N:noise_65 8.578843e-03 274 4
C:class C:noise_66 3.314331e-03 100 172
C:class N:noise_67 8.746950e-03 268 4
C:class N:noise_68 8.593788e-03 63 215
C:class C:noise_69 4.556471e-03 83 198
C:class N:noise_70 1.385411e-02 269 5
C:class N:noise_71 7.371832e-03 45 234
C:class N:noise_72 5.581014e-03 266 7
C:class N:noise_73 5.526464e-03 222 50
C:class C:noise_74 1.235055e-03 86 186
C:class C:noise_75 2.357321e-03 81 190
C:class N:noise_76 5.492676e-03 231 42
C:class N:noise_77 1.658936e-02 207 66
C:class N:noise_78 9.358209e-03 38 239
C:class N:noise_79 5.969206e-03 185 95
C:class N:noise_80 1.487118e-02 17 260
C:class N:noise_81 5.329469e-03 268 5
C:class N:noise_82 7.679314e-03 246 26
C:class N:noise_83 9.127313e-03 7 274
C:class N:noise_84 4.509829e-03 137 143
C:class N:noise_85 4.530825e-03 110 169
C:class C:noise_86 1.859930e-03 78 199
C:class N:noise_87 1.187371e-02 30 244
C:class N:noise_88 5.670918e-03 200 80
C:class N:noise_89 7.086722e-03 21 256
C:class N:noise_90 1.036187e-02 207 59
C:class N:noise_91 1.046930e-02 4 274
C:class N:noise_92 1.057821e-02 239 38
C:class N:noise_93 7.872152e-03 45 233
C:class N:noise_94 8.085322e-03 23 250
C:class C:noise_95 1.761905e-03 100 180
C:class N:noise_96 9.914971e-03 144 131
C:class N:noise_97 1.001201e-02 17 261
C:class N:noise_98 7.998670e-03 274 3
C:class N:noise_99 6.701586e-03 221 60
C:class N:noise_100 1.181143e-02 191 82

0 comments on commit 6ff5b96

Please sign in to comment.