Skip to content

Commit

Permalink
Fix Issue 5943 - Power expression optimisation: 2^^unsigned ==> 1<<un…
Browse files Browse the repository at this point in the history
…signed
  • Loading branch information
yebblies committed Jun 18, 2013
1 parent 984f155 commit 50a484a
Show file tree
Hide file tree
Showing 3 changed files with 168 additions and 75 deletions.
151 changes: 76 additions & 75 deletions src/expression.c
Expand Up @@ -12083,95 +12083,96 @@ Expression *PowExp::semantic(Scope *sc)
}
}

if ( (e1->type->isintegral() || e1->type->isfloating()) &&
(e2->type->isintegral() || e2->type->isfloating()))
if ( !(e1->type->isintegral() || e1->type->isfloating()) ||
!(e2->type->isintegral() || e2->type->isfloating()))
{
// For built-in numeric types, there are several cases.
// TODO: backend support, especially for e1 ^^ 2.
return incompatibleTypes();
}

bool wantSqrt = false;
// For built-in numeric types, there are several cases.
// TODO: backend support, especially for e1 ^^ 2.

// First, attempt to fold the expression.
e = optimize(WANTvalue);
if (e->op != TOKpow)
{
e = e->semantic(sc);
return e;
}
bool wantSqrt = false;

// Determine if we're raising to an integer power.
sinteger_t intpow = 0;
if (e2->op == TOKint64 && ((sinteger_t)e2->toInteger() == 2 || (sinteger_t)e2->toInteger() == 3))
intpow = e2->toInteger();
else if (e2->op == TOKfloat64 && (e2->toReal() == (sinteger_t)(e2->toReal())))
intpow = (sinteger_t)(e2->toReal());
// First, attempt to fold the expression.
e = optimize(WANTvalue);
if (e->op != TOKpow)
{
e = e->semantic(sc);
return e;
}

// Deal with x^^2, x^^3 immediately, since they are of practical importance.
if (intpow == 2 || intpow == 3)
{
// Replace x^^2 with (tmp = x, tmp*tmp)
// Replace x^^3 with (tmp = x, tmp*tmp*tmp)
Identifier *idtmp = Lexer::uniqueId("__powtmp");
VarDeclaration *tmp = new VarDeclaration(loc, e1->type->toBasetype(), idtmp, new ExpInitializer(Loc(), e1));
tmp->storage_class = STCctfe;
Expression *ve = new VarExp(loc, tmp);
Expression *ae = new DeclarationExp(loc, tmp);
/* Note that we're reusing ve. This should be ok.
*/
Expression *me = new MulExp(loc, ve, ve);
if (intpow == 3)
me = new MulExp(loc, me, ve);
e = new CommaExp(loc, ae, me);
e = e->semantic(sc);
return e;
}
// Determine if we're raising to an integer power.
sinteger_t intpow = 0;
if (e2->op == TOKint64 && ((sinteger_t)e2->toInteger() == 2 || (sinteger_t)e2->toInteger() == 3))
intpow = e2->toInteger();
else if (e2->op == TOKfloat64 && (e2->toReal() == (sinteger_t)(e2->toReal())))
intpow = (sinteger_t)(e2->toReal());

static int importMathChecked = 0;
static bool importMath = false;
if (!importMathChecked)
{
importMathChecked = 1;
for (size_t i = 0; i < Module::amodules.dim; i++)
{ Module *mi = Module::amodules[i];
//printf("\t[%d] %s\n", i, mi->toChars());
if (mi->ident == Id::math &&
mi->parent->ident == Id::std &&
!mi->parent->parent)
{
importMath = true;
goto L1;
}
}
error("must import std.math to use ^^ operator");
return new ErrorExp();
// Deal with x^^2, x^^3 immediately, since they are of practical importance.
if (intpow == 2 || intpow == 3)
{
// Replace x^^2 with (tmp = x, tmp*tmp)
// Replace x^^3 with (tmp = x, tmp*tmp*tmp)
Identifier *idtmp = Lexer::uniqueId("__powtmp");
VarDeclaration *tmp = new VarDeclaration(loc, e1->type->toBasetype(), idtmp, new ExpInitializer(Loc(), e1));
tmp->storage_class = STCctfe;
Expression *ve = new VarExp(loc, tmp);
Expression *ae = new DeclarationExp(loc, tmp);
/* Note that we're reusing ve. This should be ok.
*/
Expression *me = new MulExp(loc, ve, ve);
if (intpow == 3)
me = new MulExp(loc, me, ve);
e = new CommaExp(loc, ae, me);
e = e->semantic(sc);
return e;
}

L1: ;
}
else
{
if (!importMath)
static int importMathChecked = 0;
static bool importMath = false;
if (!importMathChecked)
{
importMathChecked = 1;
for (size_t i = 0; i < Module::amodules.dim; i++)
{ Module *mi = Module::amodules[i];
//printf("\t[%d] %s\n", i, mi->toChars());
if (mi->ident == Id::math &&
mi->parent->ident == Id::std &&
!mi->parent->parent)
{
error("must import std.math to use ^^ operator");
return new ErrorExp();
importMath = true;
goto L1;
}
}
error("must import std.math to use ^^ operator");
return new ErrorExp();

e = new IdentifierExp(loc, Id::empty);
e = new DotIdExp(loc, e, Id::std);
e = new DotIdExp(loc, e, Id::math);
if (e2->op == TOKfloat64 && e2->toReal() == 0.5)
{ // Replace e1 ^^ 0.5 with .std.math.sqrt(x)
e = new CallExp(loc, new DotIdExp(loc, e, Id::_sqrt), e1);
}
else
L1: ;
}
else
{
if (!importMath)
{
// Replace e1 ^^ e2 with .std.math.pow(e1, e2)
e = new CallExp(loc, new DotIdExp(loc, e, Id::_pow), e1, e2);
error("must import std.math to use ^^ operator");
return new ErrorExp();
}
e = e->semantic(sc);
return e;
}
return incompatibleTypes();

e = new IdentifierExp(loc, Id::empty);
e = new DotIdExp(loc, e, Id::std);
e = new DotIdExp(loc, e, Id::math);
if (e2->op == TOKfloat64 && e2->toReal() == 0.5)
{ // Replace e1 ^^ 0.5 with .std.math.sqrt(x)
e = new CallExp(loc, new DotIdExp(loc, e, Id::_sqrt), e1);
}
else
{
// Replace e1 ^^ e2 with .std.math.pow(e1, e2)
e = new CallExp(loc, new DotIdExp(loc, e, Id::_pow), e1, e2);
}
e = e->semantic(sc);
return e;
}

/************************************************************/
Expand Down
15 changes: 15 additions & 0 deletions src/optimize.c
Expand Up @@ -924,6 +924,21 @@ Expression *PowExp::optimize(int result, bool keepLvalue)
}
e = this;
}

if (e1->op == TOKint64 && e1->toInteger() > 0 &&
!((e1->toInteger() - 1) & e1->toInteger()) && // is power of two
e2->type->isintegral() && e2->type->isunsigned())
{
dinteger_t i = e1->toInteger();
dinteger_t mul = 1;
while ((i >>= 1) > 1)
mul++;
Expression *shift = new MulExp(loc, e2, new IntegerExp(loc, mul, e2->type));
shift->type = e2->type;
e = new ShlExp(loc, new IntegerExp(loc, 1, e1->type), shift);
e->type = type;
}

return e;
}

Expand Down
77 changes: 77 additions & 0 deletions test/runnable/test5943.d
@@ -0,0 +1,77 @@
// test that the import of std.math is not needed

__gshared uint x0 = 0;
__gshared uint x1 = 1;
__gshared uint x2 = 2;
__gshared uint x3 = 3;
__gshared uint x4 = 4;
__gshared uint x5 = 5;
__gshared uint x6 = 6;
__gshared uint x7 = 7;
__gshared uint x10 = 10;
__gshared uint x15 = 15;
__gshared uint x31 = 31;
__gshared uint x32 = 32;

void main()
{
assert(2 ^^ x0 == 1);
assert(2 ^^ x1 == 2);
assert(2 ^^ x31 == 0x80000000);
assert(4 ^^ x0 == 1);
assert(4 ^^ x1 == 4);
assert(4 ^^ x15 == 0x40000000);
assert(8 ^^ x0 == 1);
assert(8 ^^ x1 == 8);
assert(8 ^^ x10 == 0x40000000);
assert(16 ^^ x0 == 1);
assert(16 ^^ x1 == 16);
assert(16 ^^ x7 == 0x10000000);
assert(32 ^^ x0 == 1);
assert(32 ^^ x1 == 32);
assert(32 ^^ x6 == 0x40000000);
assert(64 ^^ x0 == 1);
assert(64 ^^ x1 == 64);
assert(64 ^^ x5 == 0x40000000);
assert(128 ^^ x0 == 1);
assert(128 ^^ x1 == 128);
assert(128 ^^ x4 == 0x10000000);
assert(256 ^^ x0 == 1);
assert(256 ^^ x1 == 256);
assert(256 ^^ x3 == 0x1000000);
assert(512 ^^ x0 == 1);
assert(512 ^^ x1 == 512);
assert(512 ^^ x3 == 0x8000000);
assert(1024 ^^ x0 == 1);
assert(1024 ^^ x1 == 1024);
assert(1024 ^^ x3 == 0x40000000);
assert(2048 ^^ x0 == 1);
assert(2048 ^^ x1 == 2048);
assert(2048 ^^ x2 == 0x400000);
assert(4096 ^^ x0 == 1);
assert(4096 ^^ x1 == 4096);
assert(4096 ^^ x2 == 0x1000000);
assert(8192 ^^ x0 == 1);
assert(8192 ^^ x1 == 8192);
assert(8192 ^^ x2 == 0x4000000);
assert(16384 ^^ x0 == 1);
assert(16384 ^^ x1 == 16384);
assert(16384 ^^ x2 == 0x10000000);
assert(32768 ^^ x0 == 1);
assert(32768 ^^ x1 == 32768);
assert(32768 ^^ x2 == 0x40000000);
assert(65536 ^^ x0 == 1);
assert(65536 ^^ x1 == 65536);
assert(131072 ^^ x0 == 1);
assert(131072 ^^ x1 == 131072);
assert(262144 ^^ x0 == 1);
assert(262144 ^^ x1 == 262144);
assert(524288 ^^ x0 == 1);
assert(524288 ^^ x1 == 524288);
assert(1048576 ^^ x0 == 1);
assert(1048576 ^^ x1 == 1048576);
assert(2097152 ^^ x0 == 1);
assert(2097152 ^^ x1 == 2097152);
assert(4194304 ^^ x0 == 1);
assert(4194304 ^^ x1 == 4194304);
}

0 comments on commit 50a484a

Please sign in to comment.