Permalink
Browse files

Merge branch '53-compilation-of-domain-wise-cf' into 'master'

Resolve "compilation of domain-wise CF"

Closes #53

See merge request !242
  • Loading branch information...
JSchoeberl committed Sep 22, 2017
2 parents 42c3ac6 + 3f61815 commit 696819d96f4ab475e48feeb9652b1e125c58e5c4
Showing with 100 additions and 49 deletions.
  1. +23 −18 fem/code_generation.hpp
  2. +45 −19 fem/coefficient.cpp
  3. +1 −1 fem/intrule.hpp
  4. +4 −4 fem/python_fem.cpp
  5. +7 −7 fem/symbolicintegrator.cpp
  6. +20 −0 tests/pytest/test_coefficient.py
@@ -17,6 +17,23 @@
namespace ngfem
{
template <typename T>
string ToLiteral(const T & val)
{
stringstream ss;
#if (defined __cpp_hex_float) && (__cpp_hex_float <= __cplusplus)
ss << std::hexfloat;
ss << val;
ss << " /* (" << std::setprecision(16) << std::scientific;
ss << val << ") */";
#else
ss << std::setprecision(16);
ss << val;
#endif
return ss.str();
}
struct Code
{
string top;
@@ -63,7 +80,7 @@ namespace ngfem
void operator /=(CodeExpr other) { code = "(" + S()+Op('/')+other.S() + ')'; }
operator string () { return code; }
CodeExpr operator ()(int i) { return CodeExpr( S() + '(' + ToString(i) + ')' ); }
CodeExpr operator ()(int i) { return CodeExpr( S() + '(' + ToLiteral(i) + ')' ); }
CodeExpr Func(string s) { return CodeExpr( s + "(" + S() + ")" ); }
CodeExpr Call(string s, string args="") { return CodeExpr( S()+'.'+ s + "(" + args + ")"); }
string Assign (CodeExpr other, bool declare = true)
@@ -83,40 +100,28 @@ namespace ngfem
template<typename TVal>
string Declare(string type, TVal value )
{
return type + " " + code + "("+ToString(value)+");\n";
return type + " " + code + "("+ToLiteral(value)+");\n";
}
};
inline CodeExpr Var(double val)
{
stringstream ss;
ss << ToString(val);
ss << " /* (" << std::setprecision(16);
ss << val << ") */";
return ss.str();
return ToLiteral(val);
}
inline CodeExpr Var(Complex val)
{
stringstream ss;
ss << "Complex(";
ss << ToString(val.real());
ss << ",";
ss << ToString(val.imag());
ss << ")";
ss << " /* (" << std::setprecision(16);
ss << val.real() << ", " << val.imag() << ") */";
return ss.str();
return ToLiteral(val);
}
inline CodeExpr Var(string name, int i, int j=0, int k=0)
{
return CodeExpr(name + '_' + ToString(i) + '_' + ToString(j) + '_' + ToString(k));
return CodeExpr(name + '_' + ToLiteral(i) + '_' + ToLiteral(j) + '_' + ToLiteral(k));
}
inline CodeExpr Var(int i, int j=0, int k=0)
{
return CodeExpr("var_" + ToString(i) + '_' + ToString(j) + '_' + ToString(k));
return CodeExpr("var_" + ToLiteral(i) + '_' + ToLiteral(j) + '_' + ToLiteral(k));
}
template<typename TFunc>
@@ -301,15 +301,15 @@ namespace ngfem
void DomainConstantCoefficientFunction :: GenerateCode(Code &code, FlatArray<int> inputs, int index) const
{
code.header += "double tmp_" + ToString(index) + "["+ToString(val.Size())+"] = {";
code.header += "double tmp_" + ToLiteral(index) + "["+ToLiteral(val.Size())+"] = {";
for (auto i : Range(val))
{
code.header += ToString(val[i]);
code.header += ToLiteral(val[i]);
if(i<val.Size()-1)
code.header += ", ";
}
code.header += "};\n";
code.header += Var(index).Assign("tmp_"+ToString(index) + "[mir.GetTransformation().GetElementIndex()]");
code.header += Var(index).Assign("tmp_"+ToLiteral(index) + "[mir.GetTransformation().GetElementIndex()]");
}
@@ -2082,8 +2082,8 @@ class MultMatMatCoefficientFunction : public CoefficientFunction
throw Exception("Mult of non-matrices called");
if (dims_c1[1] != dims_c2[0])
throw Exception(string("Matrix dimensions don't fit: m1 is ") +
ToString(dims_c1[0]) + " x " + ToString(dims_c1[1]) +
", m2 is " + ToString(dims_c2[0]) + " x " + ToString(dims_c2[1]) );
ToLiteral(dims_c1[0]) + " x " + ToLiteral(dims_c1[1]) +
", m2 is " + ToLiteral(dims_c2[0]) + " x " + ToLiteral(dims_c2[1]) );
// dims = { dims_c1[0], dims_c2[1] };
SetDimensions( Array<int> ({ dims_c1[0], dims_c2[1] }));
inner_dim = dims_c1[1];
@@ -2486,7 +2486,7 @@ class MultMatVecCoefficientFunction : public CoefficientFunction
throw Exception("Not a mat-vec multiplication");
if (dims_c1[1] != dims_c2[0])
throw Exception(string ("Matrix dimensions don't fit: mat is ") +
ToString(dims_c1[0]) + " x " + ToString(dims_c1[1]) + ", vec is " + ToString(dims_c2[0]));
ToLiteral(dims_c1[0]) + " x " + ToLiteral(dims_c1[1]) + ", vec is " + ToLiteral(dims_c2[0]));
// dims = Array<int> ({ dims_c1[0] });
SetDimensions (Array<int> ({ dims_c1[0] }));
inner_dim = dims_c1[1];
@@ -3595,7 +3595,29 @@ class DomainWiseCoefficientFunction : public T_CoefficientFunction<DomainWiseCoe
virtual void GenerateCode(Code &code, FlatArray<int> inputs, int index) const
{
code.body += "// DomainWiseCoefficientFunction: not implemented\n;";
code.body += "// DomainWiseCoefficientFunction:\n";
string type = "decltype(0.0";
for(int in : inputs)
type += "+decltype("+Var(in).S()+")()";
type += ")";
TraverseDimensions( Dimensions(), [&](int ind, int i, int j) {
code.body += Var(index,i,j).Declare(type);
});
code.body += "switch(domain_index) {\n";
for(int domain : Range(inputs))
{
code.body += "case " + ToLiteral(domain) + ": \n";
TraverseDimensions( Dimensions(), [&](int ind, int i, int j) {
code.body += " "+Var(index, i, j).Assign(Var(inputs[domain], i, j), false);
});
code.body += " break;\n";
}
code.body += "default: \n";
TraverseDimensions( Dimensions(), [&](int ind, int i, int j) {
code.body += " "+Var(index, i, j).Assign(string("0.0"), false);
});
code.body += " break;\n";
code.body += "}\n";
}
virtual void TraverseTree (const function<void(CoefficientFunction&)> & func)
@@ -4539,7 +4561,7 @@ class VectorialCoefficientFunction : public T_CoefficientFunction<VectorialCoeff
case 0: dirname = "x"; break;
case 1: dirname = "y"; break;
case 2: dirname = "z"; break;
default: dirname = ToString(dir);
default: dirname = ToLiteral(dir);
}
return string("coordinate ")+dirname;
}
@@ -4572,8 +4594,8 @@ class VectorialCoefficientFunction : public T_CoefficientFunction<VectorialCoeff
virtual void GenerateCode(Code &code, FlatArray<int> inputs, int index) const {
auto v = Var(index);
// code.body += v.Assign(CodeExpr(string("mir.GetPoints()(i,")+ToString(dir)+")"));
code.body += v.Assign(CodeExpr(string("points(i,")+ToString(dir)+")"));
// code.body += v.Assign(CodeExpr(string("mir.GetPoints()(i,")+ToLiteral(dir)+")"));
code.body += v.Assign(CodeExpr(string("points(i,")+ToLiteral(dir)+")"));
}
template <typename T>
@@ -4702,8 +4724,8 @@ shared_ptr<CoefficientFunction> MakeCoordinateCoefficientFunction (int comp)
TraverseDimensions( cf->Dimensions(), [&](int ind, int i, int j) {
code.body += Var(steps.Size(),i,j).Declare(res_type);
code.body += Var(steps.Size(),i,j).Assign(Var(steps.Size()-1,i,j),false);
string sget = "(i," + ToString(ii) + ") =";
if(simd) sget = "(" + ToString(ii) + ",i) =";
string sget = "(i," + ToLiteral(ii) + ") =";
if(simd) sget = "(" + ToLiteral(ii) + ",i) =";
for (auto ideriv : Range(deriv+1))
{
@@ -4746,6 +4768,7 @@ shared_ptr<CoefficientFunction> MakeCoordinateCoefficientFunction (int comp)
s << " ) {" << endl;
s << code.header << endl;
s << "auto points = mir.GetPoints();" << endl;
s << "auto domain_index = mir.GetTransformation().GetElementIndex();" << endl;
s << "for ( auto i : Range(mir)) {" << endl;
s << "auto & ip = mir[i];" << endl;
s << code.body << endl;
@@ -4761,8 +4784,8 @@ shared_ptr<CoefficientFunction> MakeCoordinateCoefficientFunction (int comp)
pointer_code += "}\n";
codes.push_back(pointer_code);
}
std::thread thread{ [this, codes, maxderiv] () {
try {
auto compile_func = [this, codes, maxderiv] () {
library.Compile( codes );
compiled_function_simd = library.GetFunction<lib_function_simd>("CompiledEvaluateSIMD");
compiled_function = library.GetFunction<lib_function>("CompiledEvaluate");
@@ -4777,14 +4800,17 @@ shared_ptr<CoefficientFunction> MakeCoordinateCoefficientFunction (int comp)
compiled_function_dderiv = library.GetFunction<lib_function_dderiv>("CompiledEvaluateDDeriv");
}
cout << IM(7) << "Compilation done" << endl;
};
if(wait)
compile_func();
else
{
try {
std::thread( compile_func ).detach();
} catch (const std::exception &e) {
cerr << IM(3) << "Compilation of CoefficientFunction failed: " << e.what() << endl;
}
}};
if(wait)
thread.join();
else
thread.detach();
}
}
}
@@ -282,7 +282,7 @@ namespace ngfem
FlatVector<Complex> GetPointComplex() const;
FlatMatrix<Complex> GetJacobianComplex() const;
// dimension of range
int Dim() const;
NGS_DLL_HEADER int Dim() const;
VorB VB() const;
bool IsComplex() const { return is_complex; }
void SetOwnsTrafo (bool aowns_trafo = true) { owns_trafo = aowns_trafo; }
@@ -391,9 +391,9 @@ struct GenericPow {
virtual void GenerateCode(Code &code, FlatArray<int> inputs, int index) const {
string miptype;
if(code.is_simd)
miptype = "SIMD<DimMappedIntegrationPoint<"+ToString(D)+">>*";
miptype = "SIMD<DimMappedIntegrationPoint<"+ToLiteral(D)+">>*";
else
miptype = "DimMappedIntegrationPoint<"+ToString(D)+">*";
miptype = "DimMappedIntegrationPoint<"+ToLiteral(D)+">*";
auto nv_expr = CodeExpr("static_cast<const "+miptype+">(&ip)->GetNV()");
auto nv = Var("tmp", index);
code.body += nv.Assign(nv_expr);
@@ -453,9 +453,9 @@ struct GenericPow {
virtual void GenerateCode(Code &code, FlatArray<int> inputs, int index) const {
string miptype;
if(code.is_simd)
miptype = "SIMD<DimMappedIntegrationPoint<"+ToString(D)+">>*";
miptype = "SIMD<DimMappedIntegrationPoint<"+ToLiteral(D)+">>*";
else
miptype = "DimMappedIntegrationPoint<"+ToString(D)+">*";
miptype = "DimMappedIntegrationPoint<"+ToLiteral(D)+">*";
auto tv_expr = CodeExpr("static_cast<const "+miptype+">(&ip)->GetTV()");
auto tv = Var("tmp", index);
code.body += tv.Assign(tv_expr);
@@ -94,11 +94,11 @@ namespace ngfem
header += Var("comp", index,i,j).Declare("{scal_type}", 0.0);
if(!testfunction && code.deriv==2)
{
header += "if(( ({ud}->trialfunction == {this}) && ({ud}->trial_comp=="+ToString(ind)+"))\n"+
" || (({ud}->testfunction == {this}) && ({ud}->test_comp=="+ToString(ind)+")))\n";
header += "if(( ({ud}->trialfunction == {this}) && ({ud}->trial_comp=="+ToLiteral(ind)+"))\n"+
" || (({ud}->testfunction == {this}) && ({ud}->test_comp=="+ToLiteral(ind)+")))\n";
}
else
header += "if({ud}->{comp_string}=="+ToString(ind)+" && {ud}->{func_string} == {this})\n";
header += "if({ud}->{comp_string}=="+ToLiteral(ind)+" && {ud}->{func_string} == {this})\n";
header += Var("comp", index,i,j).S() + string("{get_component}") + " = 1.0;\n";
});
string body = "";
@@ -116,9 +116,9 @@ namespace ngfem
var += ".Value()";
string values = "{values}";
if(code.is_simd)
values += "(" + ToString(ind) + ",i)";
values += "(" + ToLiteral(ind) + ",i)";
else
values += "(i," + ToString(ind) + ")";
values += "(i," + ToLiteral(ind) + ")";
body += var + " = " + values + ";\n";
});
@@ -144,11 +144,11 @@ namespace ngfem
string func_string = testfunction ? "testfunction" : "trialfunction";
string comp_string = testfunction ? "test_comp" : "trial_comp";
std::map<string,string> variables;
variables["ud"] = "tmp_"+ToString(index)+"_0";
variables["ud"] = "tmp_"+ToLiteral(index)+"_0";
variables["this"] = "reinterpret_cast<ProxyFunction*>("+code.AddPointer(this)+")";
variables["func_string"] = testfunction ? "testfunction" : "trialfunction";
variables["comp_string"] = testfunction ? "test_comp" : "trial_comp";
variables["testfunction"] = ToString(testfunction);
variables["testfunction"] = ToLiteral(testfunction);
variables["flatmatrix"] = code.is_simd ? "FlatMatrix<SIMD<double>>" : "FlatMatrix<double>";
@@ -2,6 +2,9 @@
from netgen.geom2d import unit_square
from netgen.csg import unit_cube
from ngsolve import *
from netgen.geom2d import SplineGeometry
def test_ParameterCF():
p = Parameter(23)
@@ -26,6 +29,23 @@ def test_mesh_size_cf():
for val in v:
assert abs(val) < 1e-14
def test_domainwise_cf():
geo = SplineGeometry()
geo.AddCircle ( (0, 0), r=1, leftdomain=1, rightdomain=2)
geo.AddCircle ( (0, 0), r=1.4, leftdomain=2, rightdomain=0)
mesh = Mesh(geo.GenerateMesh(maxh=0.1))
c_vec = CoefficientFunction([(x,y),(1,3)])
c = c_vec[0]*c_vec[1]
c_false = c.Compile(False);
error_false = Integrate((c-c_false)*(c-c_false), mesh)
assert abs(error_false) < 1e-14
c_true = c.Compile(True, wait=True);
error_true = Integrate((c-c_true)*(c-c_true), mesh)
assert abs(error_true) < 1e-14
if __name__ == "__main__":
test_ParameterCF()
test_mesh_size_cf()
test_domainwise_cf()

0 comments on commit 696819d

Please sign in to comment.