Skip to content

Commit

Permalink
Merge pull request #16 from mingodad/alignment
Browse files Browse the repository at this point in the history
Small changes for alignment and function split
  • Loading branch information
antirez committed Jun 22, 2018
2 parents ec681c0 + 9ee4d19 commit fef3d1b
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 10 deletions.
20 changes: 14 additions & 6 deletions nn.c
Expand Up @@ -894,16 +894,24 @@ void AnnTestError(struct Ann *net, float *input, float *desired, int setlen, flo
}

/* Train the net */
float AnnTrain(struct Ann *net, float *input, float *desired, float maxerr, int maxepochs, int setlen, int algo) {
float AnnTrainWithAlgoFunc(struct Ann *net, float *input, float *desired, float maxerr,
int maxepochs, int setlen, AnnTrainAlgoFunc algo_func) {
int i = 0;
float e = maxerr+1;

while (i++ < maxepochs && e >= maxerr) {
if (algo == NN_ALGO_BPROP) {
e = AnnResilientBPEpoch(net, input, desired, setlen);
} else if (algo == NN_ALGO_GD) {
e = AnnGDEpoch(net, input, desired, setlen);
}
e = (*algo_func)(net, input, desired, setlen);
}
return e;
}


float AnnTrain(struct Ann *net, float *input, float *desired, float maxerr, int maxepochs,
int setlen, int algo) {
AnnTrainAlgoFunc algo_func;
if(algo == NN_ALGO_BPROP) algo_func = AnnResilientBPEpoch;
else if(algo == NN_ALGO_GD) algo_func = AnnGDEpoch;
else return -1;

return AnnTrainWithAlgoFunc(net, input, desired, maxerr, maxepochs, setlen, algo_func);
}
6 changes: 5 additions & 1 deletion nn.h
Expand Up @@ -37,7 +37,6 @@
* an arbitrary number of layers, with arbitrary units for layer.
* Only fully connected feed-forward networks are supported. */
struct AnnLayer {
int units;
float *output; /* output[i], output of i-th unit */
float *error; /* error[i], output error of i-th unit*/
float *weight; /* weight[(i*units)+j] */
Expand All @@ -49,6 +48,7 @@ struct AnnLayer {
/* (t-1 sgradient for resilient BP) */
float *delta; /* delta[(i*units)+j] cumulative update */
/* (per-weight delta for RPROP) */
int units; /*moved to last position for alignment purposes*/
};

/* Feed forward network structure */
Expand All @@ -60,9 +60,12 @@ struct Ann {
float rprop_maxupdate;
float rprop_minupdate;
float learn_rate; /* Used for GD training. */
float _filler_; /*filler for alignment*/
struct AnnLayer *layer;
};

typedef float (*AnnTrainAlgoFunc)(struct Ann *net, float *input, float *desired, int setlen);

/* Raw interface to data structures */
#define OUTPUT(net,l,i) (net)->layer[l].output[i]
#define ERROR(net,l,i) (net)->layer[l].error[i]
Expand Down Expand Up @@ -131,6 +134,7 @@ float AnnBatchGDEpoch(struct Ann *net, float *input, float *desidered, int setle
float AnnBatchGDMEpoch(struct Ann *net, float *input, float *desidered, int setlen);
void AnnAdjustWeightsResilientBP(struct Ann *net);
float AnnResilientBPEpoch(struct Ann *net, float *input, float *desidered, int setlen);
float AnnTrainWithAlgoFunc(struct Ann *net, float *input, float *desidered, float maxerr, int maxepochs, int setlen, AnnTrainAlgoFunc algo_func);
float AnnTrain(struct Ann *net, float *input, float *desidered, float maxerr, int maxepochs, int setlen, int algo);
void AnnTestError(struct Ann *net, float *input, float *desired, int setlen, float *avgerr, float *classerr);

Expand Down
6 changes: 3 additions & 3 deletions tests/Makefile
@@ -1,13 +1,13 @@
all: nn-test-1 nn-test-2 nn-benchmark

nn-test-1: nn-test-1.c ../nn.c ../nn.h
$(CC) nn-test-1.c ../nn.c -Wall -W -O2 -o nn-test-1
$(CC) nn-test-1.c ../nn.c -Wall -W -O2 -o nn-test-1 -lm

nn-test-2: nn-test-2.c ../nn.c ../nn.h
$(CC) nn-test-2.c ../nn.c -Wall -W -O2 -o nn-test-2
$(CC) nn-test-2.c ../nn.c -Wall -W -O2 -o nn-test-2 -lm

nn-benchmark: nn-benchmark.c ../nn.c ../nn.h
$(CC) -DUSE_SSE nn-benchmark.c ../nn.c -Wall -W -O3 -o nn-benchmark
$(CC) -DUSE_SSE nn-benchmark.c ../nn.c -Wall -W -O3 -o nn-benchmark -lm

clean:
rm -f nn-test-1 nn-test-2 nn-benchmark

0 comments on commit fef3d1b

Please sign in to comment.