Skip to content

Commit

Permalink
Merge pull request #15 from mingodad/simd-generalization
Browse files Browse the repository at this point in the history
Generalize the simd code to manage sse, avx, neon
  • Loading branch information
antirez committed Jun 21, 2018
2 parents 812cbc4 + 2a8b209 commit ec681c0
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 63 deletions.
11 changes: 9 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,26 @@ endif

all:
@echo ""
@echo "Make avx -- Faster if you have a modern CPU."
@echo "Make neon -- Faster if you have a modern ARM CPU."
@echo "Make sse -- Faster if you have a modern CPU."
@echo "Make avx -- Even faster if you have a modern CPU."
@echo "Make generic -- Works everywhere."
@echo ""
@echo "The avx code uses AVX2, it requires Haswell (Q2 2013) or better."
@echo ""

generic: neuralredis.so
neon:
make neuralredis.so CFLAGS=-DUSE_NEON

sse:
make neuralredis.so CFLAGS=-DUSE_SSE SSE="-msse3"

avx:
make neuralredis.so CFLAGS=-DUSE_AVX AVX="-mavx2 -mfma"

.c.xo:
$(CC) -I. $(CFLAGS) $(SHOBJ_CFLAGS) $(AVX) -fPIC -c $< -o $@
$(CC) -I. $(CFLAGS) $(SHOBJ_CFLAGS) $(AVX) $(SSE) -fPIC -c $< -o $@

nn.c: nn.h

Expand Down
175 changes: 114 additions & 61 deletions nn.c
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,88 @@
#include <time.h>
#include <string.h>

#ifdef USE_AVX
#if defined(USE_AVX512)
#define USING_SIMD
#include <immintrin.h>

typedef __m512 simdf_t;
#define SIMDF_SIZE 16

#define simdf_zero() _mm512_setzero_ps()
#define simdf_set1f(x) _mm512_set1_ps(x)
#define simdf_loadu(x) _mm512_loadu_ps(x)
#define simdf_mul(a,b) _mm512_mul_ps(a,b)
#define simdf_add(a,b) _mm512_add_ps(a,b)
#define simdf_storeu(a,b) _mm512_storeu_ps(a,b)

//let the compiler optmize this
#define simdf_sum(x) (x[0] + x[1] + x[2] + x[3] + x[4] + x[5] + x[6] + x[7] + \
x[8] + x[9] + x[10] + x[11] + x[12] + x[13] + x[14] + x[15])

#define simdf_show(x) printf("%d : %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f\n", \
__LINE__, x[0], x[1], x[2], x[3], x[4], x[5], x[6], x[7], \
x[8], x[9], x[10], x[11], x[12], x[13], x[14], x[15]);
#endif

#if defined(USE_AVX)
#define USING_SIMD
#include <immintrin.h>

typedef __m256 simdf_t;
#define SIMDF_SIZE 8

#define simdf_zero() _mm256_setzero_ps()
#define simdf_set1f(x) _mm256_set1_ps(x)
#define simdf_loadu(x) _mm256_loadu_ps(x)
#define simdf_mul(a,b) _mm256_mul_ps(a,b)
#define simdf_add(a,b) _mm256_add_ps(a,b)
#define simdf_storeu(a,b) _mm256_storeu_ps(a,b)

//let the compiler optmize this
#define simdf_sum(x) (x[0] + x[1] + x[2] + x[3] + x[4] + x[5] + x[6] + x[7])

#define simdf_show(x) printf("%d : %f, %f, %f, %f, %f, %f, %f, %f\n", \
__LINE__, x[0], x[1], x[2], x[3], x[4], x[5], x[6], x[7]);
#endif

#if defined(USE_SSE)
#define USING_SIMD
#include <xmmintrin.h>
#include <pmmintrin.h>
#include <immintrin.h>
typedef __m128 simdf_t;
#define SIMDF_SIZE 4

#define simdf_zero() _mm_setzero_ps()
#define simdf_set1f(x) _mm_set1_ps(x)
#define simdf_loadu(x) _mm_loadu_ps(x)
#define simdf_mul(a,b) _mm_mul_ps(a,b)
#define simdf_add(a,b) _mm_add_ps(a,b)
#define simdf_storeu(a,b) _mm_storeu_ps(a,b)

//let the compiler optmize this
#define simdf_sum(x) (x[0] + x[1] + x[2] + x[3])

#define simdf_show(x) printf("%d : %f, %f, %f, %f\n", __LINE__, x[0], x[1], x[2], x[3]);
#endif

#if defined(USE_NEON)
#define USING_SIMD
#include <arm_neon.h>

typedef float32x4_t simdf_t;
#define SIMDF_SIZE 4

#define simdf_zero() vdupq_n_f32(0.0f)
#define simdf_set1f(x) vdupq_n_f32(x);
#define simdf_loadu(x) vld1q_f32(x)
#define simdf_mul(a,b) vmulq_f32(a,b)
#define simdf_add(a,b) vaddq_f32(a,b)
#define simdf_storeu(a,b) vst1q_f32((float32_t*)a,b)

//let the compiler optmize this
#define simdf_sum(x) (x[0] + x[1] + x[2] + x[3])

#define simdf_show(x) printf("%d : %f, %f, %f, %f\n", __LINE__, x[0], x[1], x[2], x[3]);
#endif

#include "nn.h"
Expand Down Expand Up @@ -272,31 +350,6 @@ struct Ann *AnnCreateNet2(int iunits, int ounits) {
return AnnCreateNet(2, units);
}

/* Simulate the net one time. */
#ifdef USE_AVX
/* Provided to stack overflow by user Marat Dukhan. */
float avx_horizontal_sum(__m256 x) {
// hiQuad = ( x7, x6, x5, x4 )
const __m128 hiQuad = _mm256_extractf128_ps(x, 1);
// loQuad = ( x3, x2, x1, x0 )
const __m128 loQuad = _mm256_castps256_ps128(x);
// sumQuad = ( x3 + x7, x2 + x6, x1 + x5, x0 + x4 )
const __m128 sumQuad = _mm_add_ps(loQuad, hiQuad);
// loDual = ( -, -, x1 + x5, x0 + x4 )
const __m128 loDual = sumQuad;
// hiDual = ( -, -, x3 + x7, x2 + x6 )
const __m128 hiDual = _mm_movehl_ps(sumQuad, sumQuad);
// sumDual = ( -, -, x1 + x3 + x5 + x7, x0 + x2 + x4 + x6 )
const __m128 sumDual = _mm_add_ps(loDual, hiDual);
// lo = ( -, -, -, x0 + x2 + x4 + x6 )
const __m128 lo = sumDual;
// hi = ( -, -, -, x1 + x3 + x5 + x7 )
const __m128 hi = _mm_shuffle_ps(sumDual, sumDual, 0x1);
// sum = ( -, -, -, x0 + x1 + x2 + x3 + x4 + x5 + x6 + x7 )
const __m128 sum = _mm_add_ss(lo, hi);
return _mm_cvtss_f32(sum);
}
#endif

void AnnSimulate(struct Ann *net) {
int i, j, k;
Expand All @@ -312,24 +365,24 @@ void AnnSimulate(struct Ann *net) {

k = 0;

#ifdef USE_AVX
int psteps = units/8;
#ifdef USING_SIMD
int psteps = units/SIMDF_SIZE;
simdf_t sumA = simdf_zero();
for (int x = 0; x < psteps; x++) {
__m256 weights = _mm256_loadu_ps(w);
__m256 outputs = _mm256_loadu_ps(o);
__m256 prod = _mm256_mul_ps(weights,outputs);
A += avx_horizontal_sum(prod);
w += 8;
o += 8;
simdf_t weights = simdf_loadu(w);
simdf_t outputs = simdf_loadu(o);
simdf_t prod = simdf_mul(weights,outputs);
sumA = simdf_add(sumA, prod);
w += SIMDF_SIZE;
o += SIMDF_SIZE;
}
k += 8*psteps;
A += simdf_sum(sumA);
k += SIMDF_SIZE*psteps;
#endif

/* Handle final piece shorter than 16 bytes. */
/* Handle final piece shorter than SIMDF_SIZE . */
for (; k < units; k++) {
float W = *w++;
float O = *o++;
A += W*O;
A += (*w++) * (*o++);
}
OUTPUT(net, i-1, j) = sigmoid(A);
}
Expand Down Expand Up @@ -586,39 +639,39 @@ void AnnCalculateGradients(struct Ann *net, float *desired) {

/* 1. Calculate the gradient */
k = 0;
#ifdef USE_AVX
__m256 es = _mm256_set1_ps(error_signal);

int psteps = prevunits/8;
#ifdef USING_SIMD
simdf_t es = simdf_set1f(error_signal);

int psteps = prevunits/SIMDF_SIZE;
for (int x = 0; x < psteps; x++) {
__m256 outputs = _mm256_loadu_ps(o);
__m256 gradients = _mm256_mul_ps(es,outputs);
_mm256_storeu_ps(g,gradients);
o += 8;
g += 8;
simdf_t outputs = simdf_loadu(o);
//simdf_t gradients = simdf_mul(es,outputs);
simdf_storeu(g,simdf_mul(es,outputs));
o += SIMDF_SIZE;
g += SIMDF_SIZE;
}
k += 8*psteps;
k += SIMDF_SIZE*psteps;
#endif
/* Handle final piece shorter than SIMDF_SIZE . */
for (; k < prevunits; k++) *g++ = error_signal*(*o++);

/* 2. And back-propagate the error to the previous layer */
k = 0;
#ifdef USE_AVX
psteps = prevunits/8;
#ifdef USING_SIMD
for (int x = 0; x < psteps; x++) {
__m256 weights = _mm256_loadu_ps(w);
__m256 errors = _mm256_loadu_ps(e);
__m256 prod = _mm256_fmadd_ps(es,weights,errors);
_mm256_storeu_ps(e,prod);
e += 8;
w += 8;
simdf_t weights = simdf_loadu(w);
simdf_t errors = simdf_loadu(e);
//simdf_t prod = simdf_mul(es, weights);
simdf_storeu(e, simdf_add( simdf_mul(es, weights), errors));
e += SIMDF_SIZE;
w += SIMDF_SIZE;
}
k += 8*psteps;
k += SIMDF_SIZE*psteps;
#endif
/* Handle final piece shorter than SIMDF_SIZE . */
for (; k < prevunits; k++) {
*e += error_signal * (*w);
e++;
w++;
(*e++) += error_signal * (*w++);
}
}
}
Expand Down

0 comments on commit ec681c0

Please sign in to comment.