Permalink
Browse files

make binop constant folding work on floats as well, plus tests

  • Loading branch information...
1 parent dee0f8d commit 919ceb41703b75bf36cff5e04cc4776f98b02ccc @aras-p committed Jan 31, 2013
@@ -18,55 +18,61 @@ TIntermConstant* FoldBinaryConstantExpression(TOperator op, TIntermConstant* nod
if (nodeA->getType() != nodeB->getType())
return NULL;
- // for now, only support integers; we really only need constant folding for array sizes
- if (nodeA->getBasicType() != EbtInt)
+ // for now, only support integers and floats
+ if (nodeA->getBasicType() != EbtInt && nodeA->getBasicType() != EbtFloat)
return NULL;
TIntermConstant* newNode = new TIntermConstant(nodeA->getType());
+
+#define DO_FOLD_OP(oper) \
+ if (nodeA->getBasicType() == EbtInt) \
+ for (unsigned i = 0; i < newNode->getCount(); ++i) \
+ newNode->setValue(i, nodeA->getValue(i).asInt oper nodeB->getValue(i).asInt); \
+ else \
+ for (unsigned i = 0; i < newNode->getCount(); ++i) \
+ newNode->setValue(i, nodeA->getValue(i).asFloat oper nodeB->getValue(i).asFloat)
+
+#define DO_FOLD_OP_INT(oper) \
+ if (nodeA->getBasicType() == EbtInt) \
+ for (unsigned i = 0; i < newNode->getCount(); ++i) \
+ newNode->setValue(i, nodeA->getValue(i).asInt oper nodeB->getValue(i).asInt); \
+ else { \
+ delete newNode; \
+ return NULL; \
+ }
+
+#define DO_FOLD_OP_ZERO(oper) \
+ if (nodeA->getBasicType() == EbtInt) \
+ for (unsigned i = 0; i < newNode->getCount(); ++i) \
+ newNode->setValue(i, nodeB->getValue(i).asInt ? nodeA->getValue(i).asInt oper nodeB->getValue(i).asInt : 0); \
+ else \
+ for (unsigned i = 0; i < newNode->getCount(); ++i) \
+ newNode->setValue(i, nodeB->getValue(i).asInt ? nodeA->getValue(i).asFloat oper nodeB->getValue(i).asFloat : 0)
switch (op)
{
- case EOpAdd:
- for (unsigned i = 0; i < newNode->getCount(); ++i)
- newNode->setValue(i, nodeA->getValue(i).asInt + nodeB->getValue(i).asInt);
- break;
- case EOpSub:
- for (unsigned i = 0; i < newNode->getCount(); ++i)
- newNode->setValue(i, nodeA->getValue(i).asInt - nodeB->getValue(i).asInt);
- break;
- case EOpMul:
- for (unsigned i = 0; i < newNode->getCount(); ++i)
- newNode->setValue(i, nodeA->getValue(i).asInt * nodeB->getValue(i).asInt);
- break;
- case EOpDiv:
- for (unsigned i = 0; i < newNode->getCount(); ++i)
- newNode->setValue(i, nodeB->getValue(i).asInt ? nodeA->getValue(i).asInt / nodeB->getValue(i).asInt : 0);
- break;
+ case EOpAdd: DO_FOLD_OP(+); break;
+ case EOpSub: DO_FOLD_OP(-); break;
+ case EOpMul: DO_FOLD_OP(*); break;
+ case EOpDiv: DO_FOLD_OP_ZERO(/); break;
case EOpMod:
- for (unsigned i = 0; i < newNode->getCount(); ++i)
- newNode->setValue(i, nodeB->getValue(i).asInt ? nodeA->getValue(i).asInt % nodeB->getValue(i).asInt : 0);
- break;
- case EOpRightShift:
- for (unsigned i = 0; i < newNode->getCount(); ++i)
- newNode->setValue(i, nodeA->getValue(i).asInt >> nodeB->getValue(i).asInt);
- break;
- case EOpLeftShift:
- for (unsigned i = 0; i < newNode->getCount(); ++i)
- newNode->setValue(i, nodeA->getValue(i).asInt << nodeB->getValue(i).asInt);
- break;
- case EOpAnd:
- for (unsigned i = 0; i < newNode->getCount(); ++i)
- newNode->setValue(i, nodeA->getValue(i).asInt & nodeB->getValue(i).asInt);
- break;
- case EOpInclusiveOr:
- for (unsigned i = 0; i < newNode->getCount(); ++i)
- newNode->setValue(i, nodeA->getValue(i).asInt | nodeB->getValue(i).asInt);
- break;
- case EOpExclusiveOr:
- for (unsigned i = 0; i < newNode->getCount(); ++i)
- newNode->setValue(i, nodeA->getValue(i).asInt ^ nodeB->getValue(i).asInt);
+ if (nodeA->getBasicType() == EbtInt)
+ {
+ for (unsigned i = 0; i < newNode->getCount(); ++i)
+ newNode->setValue(i, nodeB->getValue(i).asInt ? nodeA->getValue(i).asInt % nodeB->getValue(i).asInt : 0);
+ }
+ else
+ {
+ delete newNode;
+ return NULL;
+ }
break;
+ case EOpRightShift: DO_FOLD_OP_INT(>>); break;
+ case EOpLeftShift: DO_FOLD_OP_INT(<<); break;
+ case EOpAnd: DO_FOLD_OP_INT(&); break;
+ case EOpInclusiveOr: DO_FOLD_OP_INT(|); break;
+ case EOpExclusiveOr: DO_FOLD_OP_INT(^); break;
default:
delete newNode;
return NULL;
@@ -0,0 +1,15 @@
+static const int kInt1 = 3;
+const int kInt2 = 7;
+static int kInt3 = 11;
+
+float size1 = 512.0;
+float size2 = 1.0 / 512.0;
+float size3 = kInt1;
+//float size4 = kInt2 / 3.0;
+//float size5 = 5 + kInt3;
+//float size6 = kInt2 * 512.0;
+
+float4 main() : POSITION
+{
+ return float4(kInt1, kInt2, kInt3, size2);
+}
@@ -0,0 +1,23 @@
+#version 120
+const int kInt1 = 3;
+const int kInt2 = 7;
+#line 3
+int kInt3 = 11;
+uniform float size1 = 512.0;
+uniform float size2 = 0.00195313;
+#line 7
+uniform float size3 = 3.0;
+#line 12
+vec4 xlat_main( );
+#line 12
+vec4 xlat_main( ) {
+ return vec4( 3.0, 7.0, float(kInt3), size2);
+}
+void main() {
+ vec4 xl_retval;
+ xl_retval = xlat_main( );
+ gl_Position = vec4(xl_retval);
+}
+
+// uniforms:
+// size2:<none> type 9 arrsize 0
@@ -0,0 +1,23 @@
+#version 120
+const int kInt1 = 3;
+const int kInt2 = 7;
+#line 3
+int kInt3 = 11;
+uniform float size1 = 512.0;
+uniform float size2 = 0.00195313;
+#line 7
+uniform float size3 = 3.0;
+#line 12
+vec4 xlat_main( );
+#line 12
+vec4 xlat_main( ) {
+ return vec4( 3.0, 7.0, float(kInt3), size2);
+}
+void main() {
+ vec4 xl_retval;
+ xl_retval = xlat_main( );
+ gl_Position = vec4(xl_retval);
+}
+
+// uniforms:
+// size2:<none> type 9 arrsize 0
@@ -0,0 +1,15 @@
+static const int kInt1 = 3;
+const int kInt2 = 7;
+static int kInt3 = 11;
+
+float size1 = 512.0;
+float size2 = 1.0 / 512.0;
+float size3 = kInt1;
+//float size4 = kInt2 / 3.0;
+//float size5 = 5 + kInt3;
+//float size6 = kInt2 * 512.0;
+
+float4 main() : POSITION
+{
+ return float4(kInt1, kInt2, kInt3, size2);
+}
@@ -0,0 +1,22 @@
+const int kInt1 = 3;
+const int kInt2 = 7;
+#line 3
+int kInt3;
+uniform float size1;
+uniform float size2;
+#line 7
+uniform float size3;
+#line 12
+vec4 xlat_main( );
+#line 12
+vec4 xlat_main( ) {
+ return vec4( 3.0, 7.0, float(kInt3), size2);
+}
+void main() {
+ vec4 xl_retval;
+ xl_retval = xlat_main( );
+ gl_Position = vec4(xl_retval);
+}
+
+// uniforms:
+// size2:<none> type 9 arrsize 0
@@ -0,0 +1,22 @@
+const highp int kInt1 = 3;
+const highp int kInt2 = 7;
+#line 3
+highp int kInt3;
+uniform highp float size1;
+uniform highp float size2;
+#line 7
+uniform highp float size3;
+#line 12
+highp vec4 xlat_main( );
+#line 12
+highp vec4 xlat_main( ) {
+ return vec4( 3.0, 7.0, float(kInt3), size2);
+}
+void main() {
+ highp vec4 xl_retval;
+ xl_retval = xlat_main( );
+ gl_Position = vec4(xl_retval);
+}
+
+// uniforms:
+// size2:<none> type 9 arrsize 0

0 comments on commit 919ceb4

Please sign in to comment.