Permalink
Browse files

Merge branch '68-code-generation-for-complex-coefficient-functions' i…

…nto 'master'


Resolve "Code generation for complex coefficient functions"

Closes #68

See merge request !247
  • Loading branch information...
mhochsteger committed Oct 6, 2017
2 parents 3b06dc0 + 8ff1fb1 commit dee0c0b44952fe6b54b778b04b2140c716f0e378
@@ -1129,164 +1129,6 @@ namespace ngcomp
return fes->DefinedOn(trafo.VB(), trafo.GetElementIndex());
}
void GridFunctionCoefficientFunction :: GenerateCode(Code &code, FlatArray<int> inputs, int index) const
{
string mycode_simd = R"CODE_(
STACK_ARRAY(SIMD<double>, {hmem}, mir.Size()*{dim});
AFlatMatrix<double> {values}({dim}, mir.IR().GetNIP(), &{hmem}[0] /* .Data() */);
{
auto gfcf = reinterpret_cast<GridFunctionCoefficientFunction*>({gfcf_ptr});
ProxyUserData * ud = (ProxyUserData*)mir.GetTransformation().userdata;
if (ud && ud->HasMemory(gfcf) && ud->Computed(gfcf))
{
{values} = AFlatMatrix<double> ({dim}, mir.IR().GetNIP(), &ud->GetAMemory(gfcf)(0,0));
}
else
{
const GridFunction & gf = *reinterpret_cast<GridFunction*>({gf_ptr});
const ElementTransformation &trafo = mir.GetTransformation();
auto elnr = trafo.GetElementNr();
// const FESpace &fes = *gf.GetFESpace();
const FESpace & fes = *reinterpret_cast<FESpace*>({fes_ptr});
auto vb = trafo.VB();
ElementId ei(vb, elnr);
// DifferentialOperator * diffop = (DifferentialOperator*){diffop_ptr};
// DifferentialOperator * trace_diffop = (DifferentialOperator*){trace_diffop_ptr};
DifferentialOperator * diffop[3] = { (DifferentialOperator*){diffop_vol_ptr}, (DifferentialOperator*){diffop_bnd_ptr}, (DifferentialOperator*){diffop_bbnd_ptr}};
// BilinearFormIntegrator * bfi = (BilinearFormIntegrator*){bfi_ptr};
if (!trafo.BelongsToMesh((void*)(fes.GetMeshAccess().get()))) {
throw Exception ("SIMD - evaluation not available for different meshes");
} else if(!fes.DefinedOn(vb,trafo.GetElementIndex())){
{values} = 0.0;
} else {
const FiniteElement & fel = fes.GetFE (ei, gridfunction_local_heap);
int dim = fes.GetDimension();
auto &dnums = gridfunction_dnums;
fes.GetDofNrs (ei, dnums);
gridfunction_elu.SetSize(dnums.Size()*dim);
FlatVector<> elu(dnums.Size()*dim, &gridfunction_elu[0]);
gf.GetElementVector ({comp}, dnums, elu);
fes.TransformVec (ei, elu, TRANSFORM_SOL);
/*
if (diffop && vb==VOL)
diffop->Apply (fel, mir, elu, {values});
else if (trace_diffop && vb==BND)
trace_diffop->Apply (fel, mir, elu, {values});
else if (bfi)
throw Exception ("GridFunctionCoefficientFunction: SIMD evaluate not possible 1");
// bfi->CalcFlux (fel, mir, elu, values, true, gridfunction_local_heap);
else if (fes.GetEvaluator(vb))
fes.GetEvaluator(vb) -> Apply (fel, mir, elu, {values}); // , gridfunction_local_heap);
else if (fes.GetIntegrator(vb))
throw Exception ("GridFunctionCoefficientFunction: SIMD evaluate not possible 2");
// fes.GetIntegrator(vb) ->CalcFlux (fel, mir, elu, values, false, gridfunction_local_heap);
*/
if (diffop[vb])
diffop[vb]->Apply (fel, mir, elu, {values});
else
throw Exception ("GridFunctionCoefficientFunction: SIMD: don't know how I shall evaluate");
if (ud)
{
if (ud->HasMemory(gfcf))
{
ud->GetAMemory(gfcf) = BareSliceMatrix<SIMD<double>> ({values});
ud->SetComputed(gfcf);
}
}
}
}
}
)CODE_";
string mycode = R"CODE_(
// Matrix<> {values}{mir.Size(), {dim}};
STACK_ARRAY(double, {hmem}, mir.Size()*{dim});
FlatMatrix<double> {values}(mir.Size(), {dim}, &{hmem}[0]);
{
const GridFunction & gf = *reinterpret_cast<GridFunction*>({gf_ptr});
const ElementTransformation &trafo = mir.GetTransformation();
auto elnr = trafo.GetElementNr();
const FESpace &fes = *gf.GetFESpace();
auto vb = trafo.VB();
ElementId ei(vb, elnr);
// DifferentialOperator * diffop = (DifferentialOperator*){diffop_ptr};
// DifferentialOperator * trace_diffop = (DifferentialOperator*){trace_diffop_ptr};
// DifferentialOperator * trace_diffop[3] = (DifferentialOperator*){diffop_vol,diffop_bnd,diffop_bbnd};
DifferentialOperator * diffop[3] = { (DifferentialOperator*){diffop_vol_ptr}, (DifferentialOperator*){diffop_bnd_ptr}, (DifferentialOperator*){diffop_bbnd_ptr}};
// BilinearFormIntegrator * bfi = (BilinearFormIntegrator*){bfi_ptr};
if (!trafo.BelongsToMesh((void*)(fes.GetMeshAccess().get()))) {
gf.Evaluate(mir, {values});
//for (auto i : Range(mir.Size()))
// gf.Evaluate(mir[i], {values}.Row(i));
} else if(!fes.DefinedOn(vb,trafo.GetElementIndex())){
{values} = 0.0;
} else {
const FiniteElement & fel = fes.GetFE (ei, gridfunction_local_heap);
int dim = fes.GetDimension();
auto &dnums = gridfunction_dnums;
fes.GetDofNrs (ei, dnums);
gridfunction_elu.SetSize(dnums.Size()*dim);
FlatVector<> elu(dnums.Size()*dim, &gridfunction_elu[0]);
gf.GetElementVector ({comp}, dnums, elu);
fes.TransformVec (ei, elu, TRANSFORM_SOL);
/*
if (diffop && vb==VOL)
diffop->Apply (fel, mir, elu, {values}, gridfunction_local_heap);
else if (trace_diffop && vb==BND)
trace_diffop->Apply (fel, mir, elu, {values}, gridfunction_local_heap);
else if (bfi)
bfi->CalcFlux (fel, mir, elu, {values}, true, gridfunction_local_heap);
else if (fes.GetEvaluator(vb))
fes.GetEvaluator(vb) -> Apply (fel, mir, elu, {values}, gridfunction_local_heap);
else if (fes.GetIntegrator(vb))
fes.GetIntegrator(vb) ->CalcFlux (fel, mir, elu, {values}, false, gridfunction_local_heap);
*/
if (diffop[vb])
diffop[vb]->Apply (fel, mir, elu, {values}, gridfunction_local_heap);
else
throw Exception ("don't know how I shall evaluate");
}
}
)CODE_";
std::map<string,string> variables;
auto values = Var("values", index);
variables["values"] = values.S();
variables["gf_ptr"] = code.AddPointer(gf);
variables["gfcf_ptr"] = code.AddPointer(this); // ToString(this);
variables["fes_ptr"] = code.AddPointer(fes.get()); // ToString(fes.get());
// variables["diffop_ptr"] = ToString(diffop[VOL].get());
// variables["trace_diffop_ptr"] = ToString(diffop[BND].get());
// variables["bfi_ptr"] = ToString(trace_diffop.get());
variables["diffop_vol_ptr"] = code.AddPointer(diffop[VOL].get()); // ToString(diffop[VOL].get());
variables["diffop_bnd_ptr"] = code.AddPointer(diffop[BND].get()); // ToString(diffop[BND].get());
variables["diffop_bbnd_ptr"] = code.AddPointer(diffop[BBND].get()); // ToString(diffop[BBND].get());
variables["comp"] = ToString(comp);
variables["dim"] = ToString(Dimension());
variables["hmem"] = Var("hmem", index).S();
if(code.is_simd)
{
code.header += Code::Map(mycode_simd, variables);
TraverseDimensions( Dimensions(), [&](int ind, int i, int j) {
code.body += Var(index,i,j).Assign("SIMD<double>("+values.S()+".Get("+ToString(ind)+",i))");
});
}
else
{
code.header += Code::Map(mycode, variables);
TraverseDimensions( Dimensions(), [&](int ind, int i, int j) {
code.body += Var(index,i,j).Assign(values.S()+"(i,"+ToString(ind)+")");
});
}
}
bool GridFunctionCoefficientFunction::IsComplex() const
{
return gf->GetFESpace()->IsComplex();
@@ -46,7 +46,6 @@ namespace ngcomp
virtual bool DefinedOn (const ElementTransformation & trafo);
void SelectComponent (int acomp) { comp = acomp; }
const GridFunction & GetGridFunction() const { return *gf; }
virtual void GenerateCode(Code &code, FlatArray<int> inputs, int index) const;
virtual double Evaluate (const BaseMappedIntegrationPoint & ip) const;
virtual Complex EvaluateComplex (const BaseMappedIntegrationPoint & ip) const;
@@ -1,9 +1,16 @@
#include "fem.hpp"
#include <algorithm>
namespace ngfem
{
atomic<unsigned> Code::id_counter{0};
void Code::AddLinkFlag(string flag)
{
if(std::find(std::begin(link_flags), std::end(link_flags), flag) == std::end(link_flags))
link_flags.push_back(flag);
}
string Code::AddPointer(const void *p)
{
string name = "compiled_code_pointer" + ToString(id_counter++);
@@ -29,7 +36,7 @@ namespace ngfem
#endif // WIN32
}
void Library::Compile(const std::vector<string> &codes )
void Library::Compile(const std::vector<string> &codes, const std::vector<string> &link_flags )
{
static ngstd::Timer tcompile("CompiledCF::Compile");
static ngstd::Timer tlink("CompiledCF::Link");
@@ -61,6 +68,8 @@ namespace ngfem
string slink = "cmd /C \"ngsld.bat /OUT:" + prefix+".dll " + object_files + "\"";
#else
string slink = "ngsld -shared " + object_files + " -o " + prefix + ".so -lngstd -lngbla -lngfem";
for (auto flag : link_flags)
slink += " "+flag;
#endif
int err = system(slink.c_str());
if (err) throw Exception ("problem calling linker");
@@ -49,11 +49,14 @@ namespace ngfem
string body;
bool is_simd;
int deriv;
std::vector<string> link_flags;
string pointer;
string AddPointer(const void *p );
void AddLinkFlag(string flag);
static atomic<unsigned> id_counter;
static string Map( string code, std::map<string,string> variables ) {
for ( auto mapping : variables ) {
@@ -119,7 +122,7 @@ namespace ngfem
inline CodeExpr Var(Complex val)
{
return ToLiteral(val);
return "Complex"+ToLiteral(val);
}
inline CodeExpr Var(string name, int i, int j=0, int k=0)
@@ -190,7 +193,7 @@ namespace ngfem
public:
Library() : lib(nullptr) {}
// Compile a given string and load the library
void Compile( const std::vector<string> &codes );
void Compile(const std::vector<string> &codes, const std::vector<string> &libraries );
void Load( string alib_name );
Oops, something went wrong.

0 comments on commit dee0c0b

Please sign in to comment.