Skip to content

Commit

Permalink
Add IndexType (#139)
Browse files Browse the repository at this point in the history
* Add IndexType

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add more BoxArray features

* Add checking of index values

* Bug fix for flip

* Add tests/test_indextype.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Clean up in test_indextype.py

* Remove unneeded const

* More clean up

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
dpgrote and pre-commit-ci[bot] committed Jul 5, 2023
1 parent 1195802 commit 0354504
Show file tree
Hide file tree
Showing 4 changed files with 204 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/Base/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ foreach(D IN LISTS AMReX_SPACEDIM)
DistributionMapping.cpp
FArrayBox.cpp
Geometry.cpp
IndexType.cpp
IntVect.cpp
RealVect.cpp
MultiFab.cpp
Expand Down
123 changes: 123 additions & 0 deletions src/Base/IndexType.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
/* Copyright 2021-2022 The AMReX Community
*
* Authors: David Grote
* License: BSD-3-Clause-LBNL
*/
#include <AMReX_Config.H>
#include <AMReX_Dim3.H>
#include <AMReX_IntVect.H>
#include <AMReX_IndexType.H>

#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/numpy.h>

#include <array>
#include <sstream>
#include <string>

namespace py = pybind11;
using namespace amrex;

namespace {
int check_index(const int i)
{
const int ii = (i >= 0) ? i : AMREX_SPACEDIM + i;
if ((ii < 0) || (ii >= AMREX_SPACEDIM))
throw py::index_error( "IndexType index " + std::to_string(i) + " out of bounds");
return ii;
}
}

void init_IndexType(py::module &m) {
py::class_< IndexType > index_type(m, "IndexType");
index_type.def("__repr__",
[](py::object& obj) {
py::str py_name = obj.attr("__class__").attr("__name__");
const std::string name = py_name;
const auto iv = obj.cast<IndexType>();
std::stringstream s;
s << iv;
return "<amrex." + name + " " + s.str() + ">";
}
)
.def("__str",
[](const IndexType& iv) {
std::stringstream s;
s << iv;
return s.str();
})

.def(py::init<>())
.def(py::init<IndexType>())
#if (AMREX_SPACEDIM > 1)
.def(py::init<AMREX_D_DECL(IndexType::CellIndex, IndexType::CellIndex, IndexType::CellIndex)>())
#endif

.def("__getitem__",
[](const IndexType& v, const int i) {
const int ii = check_index(i);
return v[ii];
})

.def("__len__", [](IndexType const &) { return AMREX_SPACEDIM; })
.def("__eq__",
py::overload_cast<const IndexType&>(&IndexType::operator==, py::const_))
.def("__ne__",
py::overload_cast<const IndexType&>(&IndexType::operator!=, py::const_))
.def("__lt__", &IndexType::operator<)

.def("set", [](IndexType& v, int i) {
const int ii = check_index(i);
v.set(ii);
})
.def("unset", [](IndexType& v, int i) {
const int ii = check_index(i);
v.unset(ii);
})
.def("test", [](const IndexType& v, int i) {
const int ii = check_index(i);
return v.test(ii);
})
.def("setall", &IndexType::setall)
.def("clear", &IndexType::clear)
.def("any", &IndexType::any)
.def("ok", &IndexType::ok)
.def("flip", [](IndexType& v, int i) {
const int ii = check_index(i);
v.flip(ii);
})

.def("cell_centered", py::overload_cast<>(&IndexType::cellCentered, py::const_))
.def("cell_centered", [](const IndexType& v, int i) {
const int ii = check_index(i);
return v.cellCentered(ii);
})
.def("node_centered", py::overload_cast<>(&IndexType::nodeCentered, py::const_))
.def("node_centered", [](const IndexType& v, int i) {
const int ii = check_index(i);
return v.nodeCentered(ii);
})

.def("set_type", [](IndexType& v, int i, IndexType::CellIndex t) {
const int ii = check_index(i);
v.setType(ii, t);
})
.def("ix_type", py::overload_cast<>(&IndexType::ixType, py::const_))
.def("ix_type", [](const IndexType& v, int i) {
const int ii = check_index(i);
return v.ixType(ii);
})
.def("to_IntVect", &IndexType::toIntVect)

.def_static("cell_type", &IndexType::TheCellType)
.def_static("node_type", &IndexType::TheNodeType)

;

py::enum_<IndexType::CellIndex>(index_type, "CellIndex")
.value("CELL", IndexType::CellIndex::CELL)
.value("NODE", IndexType::CellIndex::NODE)
.export_values();

}
3 changes: 3 additions & 0 deletions src/pyAMReX.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ void init_Dim3(py::module&);
void init_DistributionMapping(py::module&);
void init_FArrayBox(py::module&);
void init_Geometry(py::module&);
void init_IndexType(py::module &);
void init_IntVect(py::module &);
void init_RealVect(py::module &);
void init_AmrMesh(py::module &);
Expand Down Expand Up @@ -70,6 +71,7 @@ PYBIND11_MODULE(amrex_3d_pybind, m) {
Dim3
FArrayBox
IntVect
IndexType
RealVect
MultiFab
ParallelDescriptor
Expand All @@ -88,6 +90,7 @@ PYBIND11_MODULE(amrex_3d_pybind, m) {
init_Arena(m);
init_Dim3(m);
init_IntVect(m);
init_IndexType(m);
init_RealVect(m);
init_Periodicity(m);
init_Array4(m);
Expand Down
77 changes: 77 additions & 0 deletions tests/test_indextype.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# -*- coding: utf-8 -*-

import pytest

import amrex.space3d as amr


@pytest.mark.skipif(amr.Config.spacedim != 1, reason="Requires AMREX_SPACEDIM = 1")
def test_indextype_1d():
obj = amr.IndexType(amr.IndexType.CellIndex.NODE)
assert obj.node_centered()
assert not obj.cell_centered()
with pytest.raises(IndexError):
obj[-2]
with pytest.raises(IndexError):
obj[1]


@pytest.mark.skipif(amr.Config.spacedim != 2, reason="Requires AMREX_SPACEDIM = 2")
def test_indextype_2d():
obj = amr.IndexType(amr.IndexType.CellIndex.NODE, amr.IndexType.CellIndex.CELL)
assert obj.node_centered(0)
assert obj.cell_centered(1)
assert obj.node_centered(-2)
assert obj.cell_centered(-1)

with pytest.raises(IndexError):
obj[-3]
with pytest.raises(IndexError):
obj[2]


@pytest.mark.skipif(amr.Config.spacedim != 3, reason="Requires AMREX_SPACEDIM = 3")
def test_indextype_3d():
obj = amr.IndexType(
amr.IndexType.CellIndex.NODE,
amr.IndexType.CellIndex.CELL,
amr.IndexType.CellIndex.NODE,
)

# Check indexing
assert obj.node_centered(0)
assert obj.cell_centered(1)
assert obj.node_centered(2)
assert obj.node_centered(-3)
assert obj.cell_centered(-2)
assert obj.node_centered(-1)
with pytest.raises(IndexError):
obj[-4]
with pytest.raises(IndexError):
obj[3]

# Check methods
obj.set(1)
assert obj.node_centered()
obj.unset(1)
assert not obj.node_centered()


def test_indextype_static():
cell = amr.IndexType.cell_type()
for i in range(amr.Config.spacedim):
assert not cell.test(i)

node = amr.IndexType.node_type()
for i in range(amr.Config.spacedim):
assert node[i]

assert cell == amr.IndexType.cell_type()
assert node == amr.IndexType.node_type()
assert cell < node


def test_indextype_conversions():
node = amr.IndexType.node_type()
assert node.ix_type() == amr.IntVect(1)
assert node.to_IntVect() == amr.IntVect(1)

0 comments on commit 0354504

Please sign in to comment.