Skip to content

Commit

Permalink
Merge pull request #506 from cothan/master
Browse files Browse the repository at this point in the history
Fix type confusion in Falcon-512/1024 in aarch64 implementation
  • Loading branch information
thomwiggers committed Sep 12, 2023
2 parents c086189 + d704640 commit 8e220a8
Show file tree
Hide file tree
Showing 2 changed files with 140 additions and 134 deletions.
137 changes: 70 additions & 67 deletions crypto_sign/falcon-1024/aarch64/poly_int.c
Original file line number Diff line number Diff line change
Expand Up @@ -198,12 +198,13 @@ uint16_t PQCLEAN_FALCON1024_AARCH64_poly_compare_with_zero(int16_t f[FALCON_N])
* If coefficient is larger than Q, it is subtracted with Q
*/
void PQCLEAN_FALCON1024_AARCH64_poly_convert_to_unsigned(int16_t f[FALCON_N]) {
// Total SIMD registers: 26 = 8 + 16 + 2
// Total SIMD registers: 26 = 8 + 16 + 1 + 1
uint16x8x4_t b0, b1; // 8
int16x8x4_t a0, a1, c0, c1; // 16
uint16x8_t neon_q, neon_2q; // 2
int16x8_t neon_q; // 1
uint16x8_t neon_2q; // 1

neon_q = vdupq_n_u16(FALCON_Q);
neon_q = vdupq_n_s16(FALCON_Q);
neon_2q = vdupq_n_u16(FALCON_Q << 1);

for (int i = 0; i < FALCON_N; i += 64) {
Expand All @@ -222,15 +223,15 @@ void PQCLEAN_FALCON1024_AARCH64_poly_convert_to_unsigned(int16_t f[FALCON_N]) {
b1.val[2] = vcltzq_s16(a1.val[2]);
b1.val[3] = vcltzq_s16(a1.val[3]);

c0.val[0] = vandq_s16(b0.val[0], neon_2q);
c0.val[1] = vandq_s16(b0.val[1], neon_2q);
c0.val[2] = vandq_s16(b0.val[2], neon_2q);
c0.val[3] = vandq_s16(b0.val[3], neon_2q);
c0.val[0] = vreinterpretq_s16_u16(vandq_u16(b0.val[0], neon_2q));
c0.val[1] = vreinterpretq_s16_u16(vandq_u16(b0.val[1], neon_2q));
c0.val[2] = vreinterpretq_s16_u16(vandq_u16(b0.val[2], neon_2q));
c0.val[3] = vreinterpretq_s16_u16(vandq_u16(b0.val[3], neon_2q));

c1.val[0] = vandq_s16(b1.val[0], neon_2q);
c1.val[1] = vandq_s16(b1.val[1], neon_2q);
c1.val[2] = vandq_s16(b1.val[2], neon_2q);
c1.val[3] = vandq_s16(b1.val[3], neon_2q);
c1.val[0] = vreinterpretq_s16_u16(vandq_u16(b1.val[0], neon_2q));
c1.val[1] = vreinterpretq_s16_u16(vandq_u16(b1.val[1], neon_2q));
c1.val[2] = vreinterpretq_s16_u16(vandq_u16(b1.val[2], neon_2q));
c1.val[3] = vreinterpretq_s16_u16(vandq_u16(b1.val[3], neon_2q));

vadd_x4(a0, a0, c0);
vadd_x4(a1, a1, c1);
Expand All @@ -248,15 +249,15 @@ void PQCLEAN_FALCON1024_AARCH64_poly_convert_to_unsigned(int16_t f[FALCON_N]) {

// Conditional subtraction with FALCON_Q

c0.val[0] = vandq_s16(b0.val[0], neon_q);
c0.val[1] = vandq_s16(b0.val[1], neon_q);
c0.val[2] = vandq_s16(b0.val[2], neon_q);
c0.val[3] = vandq_s16(b0.val[3], neon_q);
c0.val[0] = vandq_s16(vreinterpretq_s16_u16(b0.val[0]), neon_q);
c0.val[1] = vandq_s16(vreinterpretq_s16_u16(b0.val[1]), neon_q);
c0.val[2] = vandq_s16(vreinterpretq_s16_u16(b0.val[2]), neon_q);
c0.val[3] = vandq_s16(vreinterpretq_s16_u16(b0.val[3]), neon_q);

c1.val[0] = vandq_s16(b1.val[0], neon_q);
c1.val[1] = vandq_s16(b1.val[1], neon_q);
c1.val[2] = vandq_s16(b1.val[2], neon_q);
c1.val[3] = vandq_s16(b1.val[3], neon_q);
c1.val[0] = vandq_s16(vreinterpretq_s16_u16(b1.val[0]), neon_q);
c1.val[1] = vandq_s16(vreinterpretq_s16_u16(b1.val[1]), neon_q);
c1.val[2] = vandq_s16(vreinterpretq_s16_u16(b1.val[2]), neon_q);
c1.val[3] = vandq_s16(vreinterpretq_s16_u16(b1.val[3]), neon_q);

vsub_x4(a0, a0, c0);
vsub_x4(a1, a1, c1);
Expand All @@ -271,17 +272,19 @@ void PQCLEAN_FALCON1024_AARCH64_poly_convert_to_unsigned(int16_t f[FALCON_N]) {
*/
int PQCLEAN_FALCON1024_AARCH64_poly_int16_to_int8(int8_t G[FALCON_N], const int16_t t[FALCON_N]) {
// Total SIMD registers: 32
int16x8x4_t a, f; // 8
uint16x8x4_t c0, c1, d0, d1; // 16
uint16x8x2_t e; // 2
int8x16x4_t g; // 4
int16x8_t neon_127, neon__127, neon_q_2, neon__q_2, neon_q; // 5
int16x8x4_t a, f; // 8
int16x8x4_t d0, d1; // 8
uint16x8x4_t c0, c1, x0, x1; // 16
uint16x8x2_t e; // 2
int8x16x4_t g; // 4
int16x8_t neon_127, neon__127, neon_q_2, neon__q_2; // 4
uint16x8_t neon_q; // 1
neon_127 = vdupq_n_s16(127);
neon__127 = vdupq_n_s16(-127);
neon_q = vdupq_n_s16(FALCON_Q);
neon_q_2 = vdupq_n_s16(FALCON_Q >> 1);
neon__q_2 = vdupq_n_s16(-(FALCON_Q >> 1));

neon_q = vdupq_n_u16(FALCON_Q);
e.val[1] = vdupq_n_u16(0);

for (int i = 0; i < FALCON_N; i += 64) {
Expand All @@ -301,40 +304,40 @@ int PQCLEAN_FALCON1024_AARCH64_poly_int16_to_int8(int8_t G[FALCON_N], const int1
c1.val[3] = vcgeq_s16(f.val[3], neon_q_2);

// Perform subtraction with Q
c0.val[0] = vandq_s16(c0.val[0], neon_q);
c0.val[1] = vandq_s16(c0.val[1], neon_q);
c0.val[2] = vandq_s16(c0.val[2], neon_q);
c0.val[3] = vandq_s16(c0.val[3], neon_q);
d0.val[0] = vreinterpretq_s16_u16(vandq_u16(c0.val[0], neon_q));
d0.val[1] = vreinterpretq_s16_u16(vandq_u16(c0.val[1], neon_q));
d0.val[2] = vreinterpretq_s16_u16(vandq_u16(c0.val[2], neon_q));
d0.val[3] = vreinterpretq_s16_u16(vandq_u16(c0.val[3], neon_q));

c1.val[0] = vandq_s16(c1.val[0], neon_q);
c1.val[1] = vandq_s16(c1.val[1], neon_q);
c1.val[2] = vandq_s16(c1.val[2], neon_q);
c1.val[3] = vandq_s16(c1.val[3], neon_q);
d1.val[0] = vreinterpretq_s16_u16(vandq_u16(c1.val[0], neon_q));
d1.val[1] = vreinterpretq_s16_u16(vandq_u16(c1.val[1], neon_q));
d1.val[2] = vreinterpretq_s16_u16(vandq_u16(c1.val[2], neon_q));
d1.val[3] = vreinterpretq_s16_u16(vandq_u16(c1.val[3], neon_q));

vsub_x4(a, a, c0);
vsub_x4(f, f, c1);
vsub_x4(a, a, d0);
vsub_x4(f, f, d1);

// -Q/2 > a ? 1: 0
d0.val[0] = vcgtq_s16(neon__q_2, a.val[0]);
d0.val[1] = vcgtq_s16(neon__q_2, a.val[1]);
d0.val[2] = vcgtq_s16(neon__q_2, a.val[2]);
d0.val[3] = vcgtq_s16(neon__q_2, a.val[3]);
c0.val[0] = vcgtq_s16(neon__q_2, a.val[0]);
c0.val[1] = vcgtq_s16(neon__q_2, a.val[1]);
c0.val[2] = vcgtq_s16(neon__q_2, a.val[2]);
c0.val[3] = vcgtq_s16(neon__q_2, a.val[3]);

d1.val[0] = vcgtq_s16(neon__q_2, f.val[0]);
d1.val[1] = vcgtq_s16(neon__q_2, f.val[1]);
d1.val[2] = vcgtq_s16(neon__q_2, f.val[2]);
d1.val[3] = vcgtq_s16(neon__q_2, f.val[3]);
c1.val[0] = vcgtq_s16(neon__q_2, f.val[0]);
c1.val[1] = vcgtq_s16(neon__q_2, f.val[1]);
c1.val[2] = vcgtq_s16(neon__q_2, f.val[2]);
c1.val[3] = vcgtq_s16(neon__q_2, f.val[3]);

// Perform addition with Q
d0.val[0] = vandq_s16(d0.val[0], neon_q);
d0.val[1] = vandq_s16(d0.val[1], neon_q);
d0.val[2] = vandq_s16(d0.val[2], neon_q);
d0.val[3] = vandq_s16(d0.val[3], neon_q);
d0.val[0] = vreinterpretq_s16_u16(vandq_u16(c0.val[0], neon_q));
d0.val[1] = vreinterpretq_s16_u16(vandq_u16(c0.val[1], neon_q));
d0.val[2] = vreinterpretq_s16_u16(vandq_u16(c0.val[2], neon_q));
d0.val[3] = vreinterpretq_s16_u16(vandq_u16(c0.val[3], neon_q));

d1.val[0] = vandq_s16(d1.val[0], neon_q);
d1.val[1] = vandq_s16(d1.val[1], neon_q);
d1.val[2] = vandq_s16(d1.val[2], neon_q);
d1.val[3] = vandq_s16(d1.val[3], neon_q);
d1.val[0] = vreinterpretq_s16_u16(vandq_u16(c1.val[0], neon_q));
d1.val[1] = vreinterpretq_s16_u16(vandq_u16(c1.val[1], neon_q));
d1.val[2] = vreinterpretq_s16_u16(vandq_u16(c1.val[2], neon_q));
d1.val[3] = vreinterpretq_s16_u16(vandq_u16(c1.val[3], neon_q));

vadd_x4(a, a, d0);
vadd_x4(f, f, d1);
Expand All @@ -358,30 +361,30 @@ int PQCLEAN_FALCON1024_AARCH64_poly_int16_to_int8(int8_t G[FALCON_N], const int1
c1.val[3] = vcgtq_s16(a.val[3], neon_127);

// -127 > f ? 1 : 0
d0.val[0] = vcgtq_s16(neon__127, f.val[0]);
d0.val[1] = vcgtq_s16(neon__127, f.val[1]);
d0.val[2] = vcgtq_s16(neon__127, f.val[2]);
d0.val[3] = vcgtq_s16(neon__127, f.val[3]);
x0.val[0] = vcgtq_s16(neon__127, f.val[0]);
x0.val[1] = vcgtq_s16(neon__127, f.val[1]);
x0.val[2] = vcgtq_s16(neon__127, f.val[2]);
x0.val[3] = vcgtq_s16(neon__127, f.val[3]);
// f > 127 ? 1 : 0
d1.val[0] = vcgtq_s16(f.val[0], neon_127);
d1.val[1] = vcgtq_s16(f.val[1], neon_127);
d1.val[2] = vcgtq_s16(f.val[2], neon_127);
d1.val[3] = vcgtq_s16(f.val[3], neon_127);
x1.val[0] = vcgtq_s16(f.val[0], neon_127);
x1.val[1] = vcgtq_s16(f.val[1], neon_127);
x1.val[2] = vcgtq_s16(f.val[2], neon_127);
x1.val[3] = vcgtq_s16(f.val[3], neon_127);

c0.val[0] = vorrq_u16(c0.val[0], c1.val[0]);
c0.val[1] = vorrq_u16(c0.val[1], c1.val[1]);
c0.val[2] = vorrq_u16(c0.val[2], c1.val[2]);
c0.val[3] = vorrq_u16(c0.val[3], c1.val[3]);

d0.val[0] = vorrq_u16(d0.val[0], d1.val[0]);
d0.val[1] = vorrq_u16(d0.val[1], d1.val[1]);
d0.val[2] = vorrq_u16(d0.val[2], d1.val[2]);
d0.val[3] = vorrq_u16(d0.val[3], d1.val[3]);
x0.val[0] = vorrq_u16(x0.val[0], x1.val[0]);
x0.val[1] = vorrq_u16(x0.val[1], x1.val[1]);
x0.val[2] = vorrq_u16(x0.val[2], x1.val[2]);
x0.val[3] = vorrq_u16(x0.val[3], x1.val[3]);

c0.val[0] = vorrq_u16(c0.val[0], d0.val[0]);
c0.val[2] = vorrq_u16(c0.val[2], d0.val[2]);
c0.val[1] = vorrq_u16(c0.val[1], d0.val[1]);
c0.val[3] = vorrq_u16(c0.val[3], d0.val[3]);
c0.val[0] = vorrq_u16(c0.val[0], x0.val[0]);
c0.val[1] = vorrq_u16(c0.val[1], x0.val[1]);
c0.val[2] = vorrq_u16(c0.val[2], x0.val[2]);
c0.val[3] = vorrq_u16(c0.val[3], x0.val[3]);

c0.val[0] = vorrq_u16(c0.val[0], c0.val[2]);
c0.val[1] = vorrq_u16(c0.val[1], c0.val[3]);
Expand Down

0 comments on commit 8e220a8

Please sign in to comment.