diff --git a/tests/cpp/container_test.cc b/tests/cpp/container_test.cc index 005e15969a88..4428642b281d 100644 --- a/tests/cpp/container_test.cc +++ b/tests/cpp/container_test.cc @@ -17,11 +17,132 @@ * under the License. */ -#include -#include #include #include #include +#include +#include +#include +#include + +using namespace tvm; +using namespace tvm::runtime; + +class TestErrorSwitch { + public: + // Need this so that destructor of temporary objects don't interrupt our + // testing. + TestErrorSwitch(const TestErrorSwitch& other) + : should_fail(other.should_fail) { + const_cast(other).should_fail = false; + } + + TestErrorSwitch(bool fail_flag) : should_fail{fail_flag} {} + bool should_fail{false}; + + ~TestErrorSwitch() { + if (should_fail) { + exit(1); + } + } +}; + +class TestArrayObj : public Object, + public InplaceArrayBase { + public: + static constexpr const uint32_t _type_index = TypeIndex::kDynamic; + static constexpr const char* _type_key = "test.TestArrayObj"; + TVM_DECLARE_FINAL_OBJECT_INFO(TestArrayObj, Object); + uint32_t size; + + size_t GetSize() const { return size; } + + template + void Init(Iterator begin, Iterator end) { + size_t num_elems = std::distance(begin, end); + this->size = 0; + auto it = begin; + for (size_t i = 0; i < num_elems; ++i) { + InplaceArrayBase::EmplaceInit(i, *it++); + if (i == 1) { + throw std::bad_alloc(); + } + // Only increment size after the initialization succeeds + this->size++; + } + } + + template + void WrongInit(Iterator begin, Iterator end) { + size_t num_elems = std::distance(begin, end); + this->size = num_elems; + auto it = begin; + for (size_t i = 0; i < num_elems; ++i) { + InplaceArrayBase::EmplaceInit(i, *it++); + if (i == 1) { + throw std::bad_alloc(); + } + } + } + + friend class InplaceArrayBase; +}; + +TEST(ADT, Constructor) { + std::vector fields; + auto f1 = ADT::Tuple(fields); + auto f2 = ADT::Tuple(fields); + ADT v1{1, {f1, f2}}; + ASSERT_EQ(f1.tag(), 0); + ASSERT_EQ(f2.size(), 0); + ASSERT_EQ(v1.tag(), 1); + ASSERT_EQ(v1.size(), 2); + ASSERT_EQ(Downcast(v1[0]).tag(), 0); + ASSERT_EQ(Downcast(v1[1]).size(), 0); +} + +TEST(InplaceArrayBase, BadExceptionSafety) { + auto wrong_init = []() { + TestErrorSwitch f1{false}; + // WrongInit will set size to 3 so it will call destructor at index 1, which + // will exit with error status. + TestErrorSwitch f2{true}; + TestErrorSwitch f3{false}; + std::vector fields{f1, f2, f3}; + auto ptr = + make_inplace_array_object(fields.size()); + try { + ptr->WrongInit(fields.begin(), fields.end()); + } catch (...) { + } + // Call ~InplaceArrayBase + ptr.reset(); + // never reaches here. + exit(0); + }; + ASSERT_EXIT(wrong_init(), ::testing::ExitedWithCode(1), ""); +} + +TEST(InplaceArrayBase, ExceptionSafety) { + auto correct_init = []() { + TestErrorSwitch f1{false}; + // Init will fail at index 1, so destrucotr at index 1 should not be called + // since it's not initalized. + TestErrorSwitch f2{true}; + std::vector fields{f1, f2}; + auto ptr = + make_inplace_array_object(fields.size()); + try { + ptr->Init(fields.begin(), fields.end()); + } catch (...) { + } + // Call ~InplaceArrayBase + ptr.reset(); + // Skip the destructors of f1, f2, and fields + exit(0); + }; + ASSERT_EXIT(correct_init(), ::testing::ExitedWithCode(0), ""); +} TEST(Array, Expr) { using namespace tvm; @@ -99,11 +220,12 @@ TEST(Map, Iterator) { using namespace tvm; Expr a = 1, b = 2; Map map1{{a, b}}; - std::unordered_map map2(map1.begin(), map1.end()); + std::unordered_map map2(map1.begin(), + map1.end()); CHECK(map2[a].as()->value == 2); } -int main(int argc, char ** argv) { +int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); testing::FLAGS_gtest_death_test_style = "threadsafe"; return RUN_ALL_TESTS();