Skip to content

Commit

Permalink
Add cpp unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
wweic committed Nov 28, 2019
1 parent 29e7c6a commit 20261ce
Showing 1 changed file with 126 additions and 4 deletions.
130 changes: 126 additions & 4 deletions tests/cpp/container_test.cc
Expand Up @@ -17,11 +17,132 @@
* under the License.
*/

#include <vector>
#include <unordered_map>
#include <dmlc/logging.h>
#include <gtest/gtest.h>
#include <tvm/packed_func_ext.h>
#include <tvm/runtime/container.h>
#include <new>
#include <unordered_map>
#include <vector>

using namespace tvm;
using namespace tvm::runtime;

class TestErrorSwitch {
public:
// Need this so that destructor of temporary objects don't interrupt our
// testing.
TestErrorSwitch(const TestErrorSwitch& other)
: should_fail(other.should_fail) {
const_cast<TestErrorSwitch&>(other).should_fail = false;
}

TestErrorSwitch(bool fail_flag) : should_fail{fail_flag} {}
bool should_fail{false};

~TestErrorSwitch() {
if (should_fail) {
exit(1);
}
}
};

class TestArrayObj : public Object,
public InplaceArrayBase<TestArrayObj, TestErrorSwitch> {
public:
static constexpr const uint32_t _type_index = TypeIndex::kDynamic;
static constexpr const char* _type_key = "test.TestArrayObj";
TVM_DECLARE_FINAL_OBJECT_INFO(TestArrayObj, Object);
uint32_t size;

size_t GetSize() const { return size; }

template <typename Iterator>
void Init(Iterator begin, Iterator end) {
size_t num_elems = std::distance(begin, end);
this->size = 0;
auto it = begin;
for (size_t i = 0; i < num_elems; ++i) {
InplaceArrayBase::EmplaceInit(i, *it++);
if (i == 1) {
throw std::bad_alloc();
}
// Only increment size after the initialization succeeds
this->size++;
}
}

template <typename Iterator>
void WrongInit(Iterator begin, Iterator end) {
size_t num_elems = std::distance(begin, end);
this->size = num_elems;
auto it = begin;
for (size_t i = 0; i < num_elems; ++i) {
InplaceArrayBase::EmplaceInit(i, *it++);
if (i == 1) {
throw std::bad_alloc();
}
}
}

friend class InplaceArrayBase;
};

TEST(ADT, Constructor) {
std::vector<ObjectRef> fields;
auto f1 = ADT::Tuple(fields);
auto f2 = ADT::Tuple(fields);
ADT v1{1, {f1, f2}};
ASSERT_EQ(f1.tag(), 0);
ASSERT_EQ(f2.size(), 0);
ASSERT_EQ(v1.tag(), 1);
ASSERT_EQ(v1.size(), 2);
ASSERT_EQ(Downcast<ADT>(v1[0]).tag(), 0);
ASSERT_EQ(Downcast<ADT>(v1[1]).size(), 0);
}

TEST(InplaceArrayBase, BadExceptionSafety) {
auto wrong_init = []() {
TestErrorSwitch f1{false};
// WrongInit will set size to 3 so it will call destructor at index 1, which
// will exit with error status.
TestErrorSwitch f2{true};
TestErrorSwitch f3{false};
std::vector<TestErrorSwitch> fields{f1, f2, f3};
auto ptr =
make_inplace_array_object<TestArrayObj, TestErrorSwitch>(fields.size());
try {
ptr->WrongInit(fields.begin(), fields.end());
} catch (...) {
}
// Call ~InplaceArrayBase
ptr.reset();
// never reaches here.
exit(0);
};
ASSERT_EXIT(wrong_init(), ::testing::ExitedWithCode(1), "");
}

TEST(InplaceArrayBase, ExceptionSafety) {
auto correct_init = []() {
TestErrorSwitch f1{false};
// Init will fail at index 1, so destrucotr at index 1 should not be called
// since it's not initalized.
TestErrorSwitch f2{true};
std::vector<TestErrorSwitch> fields{f1, f2};
auto ptr =
make_inplace_array_object<TestArrayObj, TestErrorSwitch>(fields.size());
try {
ptr->Init(fields.begin(), fields.end());
} catch (...) {
}
// Call ~InplaceArrayBase
ptr.reset();
// Skip the destructors of f1, f2, and fields
exit(0);
};
ASSERT_EXIT(correct_init(), ::testing::ExitedWithCode(0), "");
}

TEST(Array, Expr) {
using namespace tvm;
Expand Down Expand Up @@ -99,11 +220,12 @@ TEST(Map, Iterator) {
using namespace tvm;
Expr a = 1, b = 2;
Map<Expr, Expr> map1{{a, b}};
std::unordered_map<Expr, Expr, NodeHash, NodeEqual> map2(map1.begin(), map1.end());
std::unordered_map<Expr, Expr, NodeHash, NodeEqual> map2(map1.begin(),
map1.end());
CHECK(map2[a].as<IntImm>()->value == 2);
}

int main(int argc, char ** argv) {
int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe";
return RUN_ALL_TESTS();
Expand Down

0 comments on commit 20261ce

Please sign in to comment.