Skip to content

Commit

Permalink
Fix pow() optimization inconsistencies.
Browse files Browse the repository at this point in the history
  • Loading branch information
Mike Pall committed Jan 24, 2022
1 parent c18acfe commit 9512d5c
Show file tree
Hide file tree
Showing 15 changed files with 104 additions and 195 deletions.
7 changes: 2 additions & 5 deletions src/lj_asm.c
Expand Up @@ -1670,7 +1670,6 @@ static void asm_loop(ASMState *as)
#if !LJ_SOFTFP32
#if !LJ_TARGET_X86ORX64
#define asm_ldexp(as, ir) asm_callid(as, ir, IRCALL_ldexp)
#define asm_fppowi(as, ir) asm_callid(as, ir, IRCALL_lj_vm_powi)
#endif

static void asm_pow(ASMState *as, IRIns *ir)
Expand All @@ -1681,10 +1680,8 @@ static void asm_pow(ASMState *as, IRIns *ir)
IRCALL_lj_carith_powu64);
else
#endif
if (irt_isnum(IR(ir->op2)->t))
asm_callid(as, ir, IRCALL_pow);
else
asm_fppowi(as, ir);
asm_callid(as, ir, irt_isnum(IR(ir->op2)->t) ? IRCALL_lj_vm_pow :
IRCALL_lj_vm_powi);
}

static void asm_div(ASMState *as, IRIns *ir)
Expand Down
13 changes: 0 additions & 13 deletions src/lj_asm_x86.h
Expand Up @@ -2017,19 +2017,6 @@ static void asm_ldexp(ASMState *as, IRIns *ir)
asm_x87load(as, ir->op2);
}

static void asm_fppowi(ASMState *as, IRIns *ir)
{
/* The modified regs must match with the *.dasc implementation. */
RegSet drop = RSET_RANGE(RID_XMM0, RID_XMM1+1)|RID2RSET(RID_EAX);
if (ra_hasreg(ir->r))
rset_clear(drop, ir->r); /* Dest reg handled below. */
ra_evictset(as, drop);
ra_destreg(as, ir, RID_XMM0);
emit_call(as, lj_vm_powi_sse);
ra_left(as, RID_XMM0, ir->op1);
ra_left(as, RID_EAX, ir->op2);
}

static int asm_swapops(ASMState *as, IRIns *ir)
{
IRIns *irl = IR(ir->op1);
Expand Down
2 changes: 1 addition & 1 deletion src/lj_dispatch.h
Expand Up @@ -44,7 +44,7 @@ extern double __divdf3(double a, double b);
#define GOTDEF(_) \
_(floor) _(ceil) _(trunc) _(log) _(log10) _(exp) _(sin) _(cos) _(tan) \
_(asin) _(acos) _(atan) _(sinh) _(cosh) _(tanh) _(frexp) _(modf) _(atan2) \
_(pow) _(fmod) _(ldexp) _(lj_vm_modi) \
_(lj_vm_pow) _(fmod) _(ldexp) _(lj_vm_modi) \
_(lj_dispatch_call) _(lj_dispatch_ins) _(lj_dispatch_stitch) \
_(lj_dispatch_profile) _(lj_err_throw) \
_(lj_ffh_coroutine_wrap_err) _(lj_func_closeuv) _(lj_func_newL_gc) \
Expand Down
2 changes: 1 addition & 1 deletion src/lj_ircall.h
Expand Up @@ -218,7 +218,7 @@ typedef struct CCallInfo {
_(ANY, log, 1, N, NUM, XA_FP) \
_(ANY, lj_vm_log2, 1, N, NUM, XA_FP) \
_(ANY, lj_vm_powi, 2, N, NUM, XA_FP) \
_(ANY, pow, 2, N, NUM, XA2_FP) \
_(ANY, lj_vm_pow, 2, N, NUM, XA2_FP) \
_(ANY, atan2, 2, N, NUM, XA2_FP) \
_(ANY, ldexp, 2, N, NUM, XA_FP) \
_(SOFTFP, lj_vm_tobit, 1, N, INT, XA_FP32) \
Expand Down
27 changes: 0 additions & 27 deletions src/lj_opt_fold.c
Expand Up @@ -1143,33 +1143,6 @@ LJFOLDF(simplify_numpow_xkint)
return ref;
}

LJFOLD(POW any KNUM)
LJFOLDF(simplify_numpow_xknum)
{
if (knumright == 0.5) /* x ^ 0.5 ==> sqrt(x) */
return emitir(IRTN(IR_FPMATH), fins->op1, IRFPM_SQRT);
return NEXTFOLD;
}

LJFOLD(POW KNUM any)
LJFOLDF(simplify_numpow_kx)
{
lua_Number n = knumleft;
if (n == 2.0 && irt_isint(fright->t)) { /* 2.0 ^ i ==> ldexp(1.0, i) */
#if LJ_TARGET_X86ORX64
/* Different IR_LDEXP calling convention on x86/x64 requires conversion. */
fins->o = IR_CONV;
fins->op1 = fins->op2;
fins->op2 = IRCONV_NUM_INT;
fins->op2 = (IRRef1)lj_opt_fold(J);
#endif
fins->op1 = (IRRef1)lj_ir_knum_one(J);
fins->o = IR_LDEXP;
return RETRYFOLD;
}
return NEXTFOLD;
}

/* -- Simplify conversions ------------------------------------------------ */

LJFOLD(CONV CONV IRCONV_NUM_INT) /* _NUM */
Expand Down
12 changes: 3 additions & 9 deletions src/lj_opt_narrow.c
Expand Up @@ -590,20 +590,14 @@ TRef lj_opt_narrow_pow(jit_State *J, TRef rb, TRef rc, TValue *vb, TValue *vc)
rb = conv_str_tonum(J, rb, vb);
rb = lj_ir_tonum(J, rb); /* Left arg is always treated as an FP number. */
rc = conv_str_tonum(J, rc, vc);
/* Narrowing must be unconditional to preserve (-x)^i semantics. */
if (tvisint(vc) || numisint(numV(vc))) {
int checkrange = 0;
/* pow() is faster for bigger exponents. But do this only for (+k)^i. */
if (tref_isk(rb) && (int32_t)ir_knum(IR(tref_ref(rb)))->u32.hi >= 0) {
int32_t k = numberVint(vc);
if (!(k >= -65536 && k <= 65536)) goto force_pow_num;
checkrange = 1;
}
int32_t k = numberVint(vc);
if (!(k >= -65536 && k <= 65536)) goto force_pow_num;
if (!tref_isinteger(rc)) {
/* Guarded conversion to integer! */
rc = emitir(IRTGI(IR_CONV), rc, IRCONV_INT_NUM|IRCONV_CHECK);
}
if (checkrange && !tref_isk(rc)) { /* Range guard: -65536 <= i <= 65536 */
if (!tref_isk(rc)) { /* Range guard: -65536 <= i <= 65536 */
TRef tmp = emitir(IRTI(IR_ADD), rc, lj_ir_kint(J, 65536));
emitir(IRTGI(IR_ULE), tmp, lj_ir_kint(J, 2*65536));
}
Expand Down
7 changes: 3 additions & 4 deletions src/lj_vm.h
Expand Up @@ -83,10 +83,6 @@ LJ_ASMF int32_t LJ_FASTCALL lj_vm_modi(int32_t, int32_t);
LJ_ASMF void lj_vm_floor_sse(void);
LJ_ASMF void lj_vm_ceil_sse(void);
LJ_ASMF void lj_vm_trunc_sse(void);
LJ_ASMF void lj_vm_powi_sse(void);
#define lj_vm_powi NULL
#else
LJ_ASMF double lj_vm_powi(double, int32_t);
#endif
#if LJ_TARGET_PPC || LJ_TARGET_ARM64
#define lj_vm_trunc trunc
Expand All @@ -102,6 +98,9 @@ LJ_ASMF int lj_vm_errno(void);
LJ_ASMF TValue *lj_vm_next(GCtab *t, uint32_t idx);
#endif

LJ_ASMF double lj_vm_powi(double, int32_t);
LJ_ASMF double lj_vm_pow(double, double);

/* Continuations for metamethods. */
LJ_ASMF void lj_cont_cat(void); /* Continue with concatenation. */
LJ_ASMF void lj_cont_ra(void); /* Store result in RA from instruction. */
Expand Down
82 changes: 45 additions & 37 deletions src/lj_vmmath.c
Expand Up @@ -30,11 +30,51 @@ LJ_FUNCA double lj_wrap_sinh(double x) { return sinh(x); }
LJ_FUNCA double lj_wrap_cosh(double x) { return cosh(x); }
LJ_FUNCA double lj_wrap_tanh(double x) { return tanh(x); }
LJ_FUNCA double lj_wrap_atan2(double x, double y) { return atan2(x, y); }
LJ_FUNCA double lj_wrap_pow(double x, double y) { return pow(x, y); }
LJ_FUNCA double lj_wrap_fmod(double x, double y) { return fmod(x, y); }
#endif

/* -- Helper functions for generated machine code ------------------------- */
/* -- Helper functions ---------------------------------------------------- */

/* Unsigned x^k. */
static double lj_vm_powui(double x, uint32_t k)
{
double y;
lj_assertX(k != 0, "pow with zero exponent");
for (; (k & 1) == 0; k >>= 1) x *= x;
y = x;
if ((k >>= 1) != 0) {
for (;;) {
x *= x;
if (k == 1) break;
if (k & 1) y *= x;
k >>= 1;
}
y *= x;
}
return y;
}

/* Signed x^k. */
double lj_vm_powi(double x, int32_t k)
{
if (k > 1)
return lj_vm_powui(x, (uint32_t)k);
else if (k == 1)
return x;
else if (k == 0)
return 1.0;
else
return 1.0 / lj_vm_powui(x, (uint32_t)-k);
}

double lj_vm_pow(double x, double y)
{
int32_t k = lj_num2int(y);
if ((k >= -65536 && k <= 65536) && y == (double)k)
return lj_vm_powi(x, k);
else
return pow(x, y);
}

double lj_vm_foldarith(double x, double y, int op)
{
Expand All @@ -44,7 +84,7 @@ double lj_vm_foldarith(double x, double y, int op)
case IR_MUL - IR_ADD: return x*y; break;
case IR_DIV - IR_ADD: return x/y; break;
case IR_MOD - IR_ADD: return x-lj_vm_floor(x/y)*y; break;
case IR_POW - IR_ADD: return pow(x, y); break;
case IR_POW - IR_ADD: return lj_vm_pow(x, y); break;
case IR_NEG - IR_ADD: return -x; break;
case IR_ABS - IR_ADD: return fabs(x); break;
#if LJ_HASJIT
Expand All @@ -56,6 +96,8 @@ double lj_vm_foldarith(double x, double y, int op)
}
}

/* -- Helper functions for generated machine code ------------------------- */

#if (LJ_HASJIT && !(LJ_TARGET_ARM || LJ_TARGET_ARM64 || LJ_TARGET_PPC)) || LJ_TARGET_MIPS
int32_t LJ_FASTCALL lj_vm_modi(int32_t a, int32_t b)
{
Expand All @@ -80,40 +122,6 @@ double lj_vm_log2(double a)
}
#endif

#if !LJ_TARGET_X86ORX64
/* Unsigned x^k. */
static double lj_vm_powui(double x, uint32_t k)
{
double y;
lj_assertX(k != 0, "pow with zero exponent");
for (; (k & 1) == 0; k >>= 1) x *= x;
y = x;
if ((k >>= 1) != 0) {
for (;;) {
x *= x;
if (k == 1) break;
if (k & 1) y *= x;
k >>= 1;
}
y *= x;
}
return y;
}

/* Signed x^k. */
double lj_vm_powi(double x, int32_t k)
{
if (k > 1)
return lj_vm_powui(x, (uint32_t)k);
else if (k == 1)
return x;
else if (k == 0)
return 1.0;
else
return 1.0 / lj_vm_powui(x, (uint32_t)-k);
}
#endif

/* Computes fpm(x) for extended math functions. */
double lj_vm_foldfpm(double x, int fpm)
{
Expand Down
13 changes: 8 additions & 5 deletions src/vm_arm.dasc
Expand Up @@ -1477,11 +1477,11 @@ static void build_subroutines(BuildCtx *ctx)
|.endif
|.endmacro
|
|.macro math_extern2, func
|.macro math_extern2, name, func
|.if HFABI
| .ffunc_dd math_ .. func
| .ffunc_dd math_ .. name
|.else
| .ffunc_nn math_ .. func
| .ffunc_nn math_ .. name
|.endif
| .IOS mov RA, BASE
| bl extern func
Expand All @@ -1492,6 +1492,9 @@ static void build_subroutines(BuildCtx *ctx)
| b ->fff_restv
|.endif
|.endmacro
|.macro math_extern2, func
| math_extern2 func, func
|.endmacro
|
|.if FPU
| .ffunc_d math_sqrt
Expand Down Expand Up @@ -1537,7 +1540,7 @@ static void build_subroutines(BuildCtx *ctx)
| math_extern sinh
| math_extern cosh
| math_extern tanh
| math_extern2 pow
| math_extern2 pow, lj_vm_pow
| math_extern2 atan2
| math_extern2 fmod
|
Expand Down Expand Up @@ -3203,7 +3206,7 @@ static void build_ins(BuildCtx *ctx, BCOp op, int defop)
break;
case BC_POW:
| // NYI: (partial) integer arithmetic.
| ins_arithfp extern, extern pow
| ins_arithfp extern, extern lj_vm_pow
break;

case BC_CAT:
Expand Down
11 changes: 7 additions & 4 deletions src/vm_arm64.dasc
Expand Up @@ -1387,11 +1387,14 @@ static void build_subroutines(BuildCtx *ctx)
| b ->fff_resn
|.endmacro
|
|.macro math_extern2, func
| .ffunc_nn math_ .. func
|.macro math_extern2, name, func
| .ffunc_nn math_ .. name
| bl extern func
| b ->fff_resn
|.endmacro
|.macro math_extern2, func
| math_extern2 func, func
|.endmacro
|
|.ffunc_n math_sqrt
| fsqrt d0, d0
Expand Down Expand Up @@ -1420,7 +1423,7 @@ static void build_subroutines(BuildCtx *ctx)
| math_extern sinh
| math_extern cosh
| math_extern tanh
| math_extern2 pow
| math_extern2 pow, lj_vm_pow
| math_extern2 atan2
| math_extern2 fmod
|
Expand Down Expand Up @@ -2674,7 +2677,7 @@ static void build_ins(BuildCtx *ctx, BCOp op, int defop)
| ins_arithload FARG1, FARG2
| ins_arithfallback ins_arithcheck_num
|.if "fpins" == "fpow"
| bl extern pow
| bl extern lj_vm_pow
|.else
| fpins FARG1, FARG1, FARG2
|.endif
Expand Down
11 changes: 7 additions & 4 deletions src/vm_mips.dasc
Expand Up @@ -1623,14 +1623,17 @@ static void build_subroutines(BuildCtx *ctx)
|. nop
|.endmacro
|
|.macro math_extern2, func
| .ffunc_nn math_ .. func
|.macro math_extern2, name, func
| .ffunc_nn math_ .. name
|. load_got func
| call_extern
|. nop
| b ->fff_resn
|. nop
|.endmacro
|.macro math_extern2, func
| math_extern2 func, func
|.endmacro
|
|// TODO: Return integer type if result is integer (own sf implementation).
|.macro math_round, func
Expand Down Expand Up @@ -1684,7 +1687,7 @@ static void build_subroutines(BuildCtx *ctx)
| math_extern sinh
| math_extern cosh
| math_extern tanh
| math_extern2 pow
| math_extern2 pow, lj_vm_pow
| math_extern2 atan2
| math_extern2 fmod
|
Expand Down Expand Up @@ -3689,7 +3692,7 @@ static void build_ins(BuildCtx *ctx, BCOp op, int defop)
| sltiu AT, SFARG1HI, LJ_TISNUM
| sltiu TMP0, SFARG2HI, LJ_TISNUM
| and AT, AT, TMP0
| load_got pow
| load_got lj_vm_pow
| beqz AT, ->vmeta_arith
|. addu RA, BASE, RA
|.if FPU
Expand Down

0 comments on commit 9512d5c

Please sign in to comment.