diff --git a/nn.c b/nn.c index 2476c06..d9de54a 100644 --- a/nn.c +++ b/nn.c @@ -894,16 +894,24 @@ void AnnTestError(struct Ann *net, float *input, float *desired, int setlen, flo } /* Train the net */ -float AnnTrain(struct Ann *net, float *input, float *desired, float maxerr, int maxepochs, int setlen, int algo) { +float AnnTrainWithAlgoFunc(struct Ann *net, float *input, float *desired, float maxerr, + int maxepochs, int setlen, AnnTrainAlgoFunc algo_func) { int i = 0; float e = maxerr+1; while (i++ < maxepochs && e >= maxerr) { - if (algo == NN_ALGO_BPROP) { - e = AnnResilientBPEpoch(net, input, desired, setlen); - } else if (algo == NN_ALGO_GD) { - e = AnnGDEpoch(net, input, desired, setlen); - } + e = (*algo_func)(net, input, desired, setlen); } return e; } + + +float AnnTrain(struct Ann *net, float *input, float *desired, float maxerr, int maxepochs, + int setlen, int algo) { + AnnTrainAlgoFunc algo_func; + if(algo == NN_ALGO_BPROP) algo_func = AnnResilientBPEpoch; + else if(algo == NN_ALGO_GD) algo_func = AnnGDEpoch; + else return -1; + + return AnnTrainWithAlgoFunc(net, input, desired, maxerr, maxepochs, setlen, algo_func); +} diff --git a/nn.h b/nn.h index 8106daf..db692ec 100644 --- a/nn.h +++ b/nn.h @@ -37,7 +37,6 @@ * an arbitrary number of layers, with arbitrary units for layer. * Only fully connected feed-forward networks are supported. */ struct AnnLayer { - int units; float *output; /* output[i], output of i-th unit */ float *error; /* error[i], output error of i-th unit*/ float *weight; /* weight[(i*units)+j] */ @@ -49,6 +48,7 @@ struct AnnLayer { /* (t-1 sgradient for resilient BP) */ float *delta; /* delta[(i*units)+j] cumulative update */ /* (per-weight delta for RPROP) */ + int units; /*moved to last position for alignment purposes*/ }; /* Feed forward network structure */ @@ -60,9 +60,12 @@ struct Ann { float rprop_maxupdate; float rprop_minupdate; float learn_rate; /* Used for GD training. */ + float _filler_; /*filler for alignment*/ struct AnnLayer *layer; }; +typedef float (*AnnTrainAlgoFunc)(struct Ann *net, float *input, float *desired, int setlen); + /* Raw interface to data structures */ #define OUTPUT(net,l,i) (net)->layer[l].output[i] #define ERROR(net,l,i) (net)->layer[l].error[i] @@ -131,6 +134,7 @@ float AnnBatchGDEpoch(struct Ann *net, float *input, float *desidered, int setle float AnnBatchGDMEpoch(struct Ann *net, float *input, float *desidered, int setlen); void AnnAdjustWeightsResilientBP(struct Ann *net); float AnnResilientBPEpoch(struct Ann *net, float *input, float *desidered, int setlen); +float AnnTrainWithAlgoFunc(struct Ann *net, float *input, float *desidered, float maxerr, int maxepochs, int setlen, AnnTrainAlgoFunc algo_func); float AnnTrain(struct Ann *net, float *input, float *desidered, float maxerr, int maxepochs, int setlen, int algo); void AnnTestError(struct Ann *net, float *input, float *desired, int setlen, float *avgerr, float *classerr); diff --git a/tests/Makefile b/tests/Makefile index 574c80f..0584996 100644 --- a/tests/Makefile +++ b/tests/Makefile @@ -1,13 +1,13 @@ all: nn-test-1 nn-test-2 nn-benchmark nn-test-1: nn-test-1.c ../nn.c ../nn.h - $(CC) nn-test-1.c ../nn.c -Wall -W -O2 -o nn-test-1 + $(CC) nn-test-1.c ../nn.c -Wall -W -O2 -o nn-test-1 -lm nn-test-2: nn-test-2.c ../nn.c ../nn.h - $(CC) nn-test-2.c ../nn.c -Wall -W -O2 -o nn-test-2 + $(CC) nn-test-2.c ../nn.c -Wall -W -O2 -o nn-test-2 -lm nn-benchmark: nn-benchmark.c ../nn.c ../nn.h - $(CC) -DUSE_SSE nn-benchmark.c ../nn.c -Wall -W -O3 -o nn-benchmark + $(CC) -DUSE_SSE nn-benchmark.c ../nn.c -Wall -W -O3 -o nn-benchmark -lm clean: rm -f nn-test-1 nn-test-2 nn-benchmark