Skip to content

Commit

Permalink
Optimize Montgomery multiplication for Ed448
Browse files Browse the repository at this point in the history
  • Loading branch information
Legrandin committed Jun 8, 2022
1 parent dfb16dd commit fe86785
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 2 deletions.
3 changes: 3 additions & 0 deletions src/ec_ws.c
Expand Up @@ -948,6 +948,7 @@ EXPORT_SYM int ec_ws_new_context(EcContext **pec_ctx,
}
break;
}
case ModulusEd448:
case ModulusGeneric:
break;
}
Expand Down Expand Up @@ -976,6 +977,7 @@ EXPORT_SYM void ec_free_context(EcContext *ec_ctx)
case ModulusP521:
free_g_p521(ec_ctx->prot_g);
break;
case ModulusEd448:
case ModulusGeneric:
break;
}
Expand Down Expand Up @@ -1344,6 +1346,7 @@ EXPORT_SYM int ec_ws_scalar(EcPoint *ecp, const uint8_t *k, size_t len, uint64_t
}
break;
}
case ModulusEd448:
case ModulusGeneric:
break;
}
Expand Down
129 changes: 128 additions & 1 deletion src/mont.c
Expand Up @@ -151,7 +151,8 @@ STATIC void rsquare(uint64_t *r2_mod_n, uint64_t *n, size_t nw)
* @param nw Number of words making up the 3 integers: out, a, and b.
* It also defines R as 2^(64*nw).
*
* Useful read: https://alicebob.cryptoland.net/understanding-the-montgomery-reduction-algorithm/
* Useful read:
* https://web.archive.org/web/20190917203334/https://alicebob.cryptoland.net/understanding-the-montgomery-reduction-algorithm/
*/
#if SCRATCHPAD_NR < 7
#error Scratchpad is too small
Expand Down Expand Up @@ -592,6 +593,123 @@ STATIC void mont_mult_p521(uint64_t *out, const uint64_t *a, const uint64_t *b,
add_mod(out, t, s, n, tmp1, tmp2, nw);
}

STATIC void mont_mult_ed448(uint64_t *out, const uint64_t *a, const uint64_t *b, const uint64_t *n, uint64_t m0, uint64_t *tmp, size_t nw)
{
size_t i;
uint64_t *t, *scratchpad, *t2;
unsigned cond;

assert(nw == 7);
assert(m0 == 1);

/*
* tmp is an array of SCRATCHPAD*nw words
* We carve out 3 values in it:
* - 3*nw words, the value a*b + m*n (we only use 2*nw+1 words)
* - 3*nw words, temporary area for computing the product
* - nw words, the reduced value with a final subtraction by n
*/
t = tmp;
scratchpad = tmp + 3*nw;
t2 = scratchpad + 3*nw;

if (a == b) {
square(t, scratchpad, a, nw);
} else {
product(t, scratchpad, a, b, nw);
}

t[2*nw] = 0; /** MSW **/

/** Clear lower words **/
for (i=0; i<7; i++) {
uint64_t k, k2_lo, k2_hi;
uint64_t carry, j;
uint64_t prod_lo, prod_hi;

k = t[i];
k2_lo = -k;
k2_hi = k - (k!=0);

/* n[0] = 2⁶⁴ - 1 */
prod_lo = k2_lo;
prod_hi = k2_hi;
t[i+0] += prod_lo;
prod_hi += t[i+0] < prod_lo;
carry = prod_hi;

/* n[1] = 2⁶⁴ - 1 */
prod_lo = k2_lo;
prod_hi = k2_hi;
prod_lo += carry;
prod_hi += prod_lo < carry;
t[i+1] += prod_lo;
prod_hi += t[i+1] < prod_lo;
carry = prod_hi;

/* n[2] = 2⁶⁴ - 1 */
prod_lo = k2_lo;
prod_hi = k2_hi;
prod_lo += carry;
prod_hi += prod_lo < carry;
t[i+2] += prod_lo;
prod_hi += t[i+2] < prod_lo;
carry = prod_hi;

/* n[3] = 2⁶⁴ - 2³² - 1 */
DP_MULT(n[3], k, prod_lo, prod_hi);
prod_lo += carry;
prod_hi += prod_lo < carry;
t[i+3] += prod_lo;
prod_hi += t[i+3] < prod_lo;
carry = prod_hi;

/* n[4] = 2⁶⁴ - 1 */
prod_lo = k2_lo;
prod_hi = k2_hi;
prod_lo += carry;
prod_hi += prod_lo < carry;
t[i+4] += prod_lo;
prod_hi += t[i+4] < prod_lo;
carry = prod_hi;

/* n[5] = 2⁶⁴ - 1 */
prod_lo = k2_lo;
prod_hi = k2_hi;
prod_lo += carry;
prod_hi += prod_lo < carry;
t[i+5] += prod_lo;
prod_hi += t[i+5] < prod_lo;
carry = prod_hi;

/* n[6] = 2⁶⁴ - 1 */
prod_lo = k2_lo;
prod_hi = k2_hi;
prod_lo += carry;
prod_hi += prod_lo < carry;
t[i+6] += prod_lo;
prod_hi += t[i+6] < prod_lo;
carry = prod_hi;

for (j=7; carry; j++) {
t[i+j] += carry;
carry = t[i+j] < carry;
}

assert(j <= (15-i));
}

assert(t[2*nw] <= 1); /** MSW **/

/** t[0..nw-1] == 0 **/

/** Divide by R and possibly subtract n **/
sub(t2, &t[nw], n, nw);
cond = (unsigned)(t[2*nw] | (uint64_t)ge(&t[nw], n, nw));
mod_select(out, t2, &t[nw], cond, (unsigned)nw);
}


/* ---- PUBLIC FUNCTIONS ---- */

void mont_context_free(MontContext *ctx)
Expand Down Expand Up @@ -824,6 +942,9 @@ int mont_mult(uint64_t* out, const uint64_t* a, const uint64_t *b, uint64_t *tmp
case ModulusP521:
mont_mult_p521(out, a, b, ctx->modulus, ctx->m0, tmp, ctx->words);
break;
case ModulusEd448:
mont_mult_ed448(out, a, b, ctx->modulus, ctx->m0, tmp, ctx->words);
break;
case ModulusGeneric:
mont_mult_generic(out, a, b, ctx->modulus, ctx->m0, tmp, ctx->words);
break;
Expand Down Expand Up @@ -1005,6 +1126,7 @@ int mont_context_init(MontContext **out, const uint8_t *modulus, size_t mod_len)
const uint8_t p256_mod[32] = "\xff\xff\xff\xff\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff";
const uint8_t p384_mod[48] = "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xfe\xff\xff\xff\xff\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\xff\xff";
const uint8_t p521_mod[66] = "\x01\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff";
const uint8_t ed448_mod[56] = "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xfe\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff";
uint64_t *scratchpad = NULL;
MontContext *ctx;
int res;
Expand Down Expand Up @@ -1051,6 +1173,11 @@ int mont_context_init(MontContext **out, const uint8_t *modulus, size_t mod_len)
ctx->modulus_type = ModulusP521;
}
break;
case sizeof(ed448_mod):
if (0 == cmp_modulus(modulus, mod_len, ed448_mod, sizeof(ed448_mod))) {
ctx->modulus_type = ModulusEd448;
}
break;
}

ctx->words = ((unsigned)mod_len + 7) / 8;
Expand Down
2 changes: 1 addition & 1 deletion src/mont.h
Expand Up @@ -8,7 +8,7 @@
*/
#define SCRATCHPAD_NR 7

typedef enum _ModulusType { ModulusGeneric, ModulusP256, ModulusP384, ModulusP521 } ModulusType;
typedef enum _ModulusType { ModulusGeneric, ModulusP256, ModulusP384, ModulusP521, ModulusEd448 } ModulusType;

typedef struct mont_context {
ModulusType modulus_type;
Expand Down

0 comments on commit fe86785

Please sign in to comment.