Skip to content

Commit

Permalink
Feature: Smart STK Field Types (#1209)
Browse files Browse the repository at this point in the history
* Add prototype for smartptr

* Attempt some new things

* Working concept on device

* clean some things up

* Start working on the actual class implementation

* Make unit test fixture

* Generalize code a little

* Trying a new design

* Redesign again: templates for space and access type

* Add more tests

* Add tests to ensure sync count is correct

* Rename file

* Add some fixes for host

* Style

* Change scope to access and create creator obj

* Add host specialization for bucket loops

* Add host READ access overloads

* Format

* Refactor for three different MEMSPACE's

* Update some comments

* Style

* Passing device tests

* Style

* Rename to SmartField and add explicit instantiation

* Code comments and things

* Tweaks

* Prep FieldManager for interface

* Add some partial template specializations

* FieldManager interface

* Test FieldManager iFace

* Style

* Style

* Start using, and improve interface for legacy case

* More conversions in unit-tests

* Style

* Add additional accessor functions

* Fix unit tests

* Split classes

* Test on device and format

* Add some convenience type names
  • Loading branch information
psakievich committed Sep 27, 2023
1 parent caf36a5 commit 73c36ff
Show file tree
Hide file tree
Showing 12 changed files with 807 additions and 94 deletions.
4 changes: 2 additions & 2 deletions include/ElemDataRequestsGPU.h
Expand Up @@ -258,7 +258,7 @@ auto&
ElemDataRequestsGPU::get_coord_ptr(const T& fieldMgr, const U& iter) const
{
if constexpr (std::is_same_v<T, nalu::FieldManager>)
return fieldMgr.get_ngp_field_ptr(iter.second->name());
return fieldMgr.template get_ngp_field_ptr<double>(iter.second->name());
else
return fieldMgr.template get_field<double>(
iter.second->mesh_meta_data_ordinal());
Expand Down Expand Up @@ -297,7 +297,7 @@ ElemDataRequestsGPU::get_field_ptr(
const T& fieldMgr, const FieldInfo& finfo) const
{
if constexpr (std::is_same_v<T, nalu::FieldManager>)
return fieldMgr.get_ngp_field_ptr(finfo.field->name());
return fieldMgr.template get_ngp_field_ptr<double>(finfo.field->name());
else
return fieldMgr.template get_field<double>(
finfo.field->mesh_meta_data_ordinal());
Expand Down
42 changes: 30 additions & 12 deletions include/FieldManager.h
Expand Up @@ -11,19 +11,17 @@
#define FIELDMANAGER_H_

#include "FieldRegistry.h"
#include "SmartField.h"
#include <stk_mesh/base/FieldState.hpp>
#include "stk_mesh/base/GetNgpField.hpp"
#include <string>
#include <type_traits>

namespace stk {
namespace mesh {
namespace stk::mesh {
class MetaData;
}
} // namespace stk
} // namespace stk::mesh

namespace sierra {
namespace nalu {
namespace sierra::nalu {

class FieldManager
{
Expand Down Expand Up @@ -132,23 +130,43 @@ class FieldManager

/// Given the named field that has already been registered on the CPU
/// return the GPU version of the same field.
stk::mesh::NgpField<double>& get_ngp_field_ptr(std::string name) const
template <typename T>
stk::mesh::NgpField<T>& get_ngp_field_ptr(
std::string name,
stk::mesh::FieldState state = stk::mesh::FieldState::StateNone) const
{
FieldDefTypes fieldDef =
FieldRegistry::query(numDimensions_, numStates_, name);
const stk::mesh::FieldBase& stkField = std::visit(
[&](auto def) -> stk::mesh::FieldBase& {
return meta_
.get_field<typename decltype(def)::FieldType>(def.rank, name)
->field_of_state(stk::mesh::FieldState::StateNone);
->field_of_state(state);
},
fieldDef);
stk::mesh::NgpField<double>& tmp =
stk::mesh::get_updated_ngp_field<double>(stkField);
stk::mesh::NgpField<T>& tmp = stk::mesh::get_updated_ngp_field<T>(stkField);
return tmp;
}

template <typename T, typename ACCESS>
SmartField<stk::mesh::NgpField<T>, tags::DEVICE, ACCESS>
get_device_smart_field(
std::string name,
stk::mesh::FieldState state = stk::mesh::FieldState::StateNone) const
{
return MakeSmartField<tags::DEVICE, ACCESS>().template operator()<T>(
get_ngp_field_ptr<T>(name, state));
}

template <typename T, typename ACCESS>
SmartField<T, tags::LEGACY, ACCESS> get_legacy_smart_field(
std::string name,
stk::mesh::FieldState state = stk::mesh::FieldState::StateNone) const
{
return MakeSmartField<tags::LEGACY, ACCESS>().template operator()<T>(
get_field_ptr<T>(name, state));
}
};
} // namespace nalu
} // namespace sierra
} // namespace sierra::nalu

#endif /* FIELDMANAGER_H_ */

0 comments on commit 73c36ff

Please sign in to comment.