Skip to content

Commit

Permalink
MSPtr counter takes on ownership of the pointer record creation and l…
Browse files Browse the repository at this point in the history
…ifetime.
  • Loading branch information
mdavis36 committed Apr 9, 2024
1 parent c18723a commit 75749a2
Show file tree
Hide file tree
Showing 3 changed files with 129 additions and 175 deletions.
156 changes: 6 additions & 150 deletions src/chai/ManagedSharedPtr.hpp
Expand Up @@ -12,117 +12,6 @@ namespace chai {



class msp_counted_base {
public:
msp_counted_base() noexcept : m_use_count(1) {}

virtual ~msp_counted_base() noexcept {}

virtual void m_dispose() noexcept = 0;
virtual void m_destroy() noexcept { delete this; }

void m_add_ref_copy() noexcept { ++m_use_count; }

void m_release() noexcept {
if(--m_use_count == 0) {
m_dispose();
m_destroy();
}
}

long m_get_use_count() const noexcept { return m_use_count; }
private:
msp_counted_base(msp_counted_base const&) = delete;
msp_counted_base& operator=(msp_counted_base const&) = delete;

long m_use_count = 0;
};

template<typename Ptr, typename Record>
class msp_counted_ptr final : public msp_counted_base {
public:
msp_counted_ptr(Record p) noexcept : m_record(p) {}
//virtual void m_dispose() noexcept { delete (m_record.get_pointer(chai::CPU)); }// TODO : Other Exec spaces...
virtual void m_dispose() noexcept { delete (Ptr)m_record->m_pointers[chai::CPU]; }// TODO : Other Exec spaces...
virtual void m_destroy() noexcept { delete this; }
msp_counted_ptr(msp_counted_ptr const&) = delete;
msp_counted_ptr& operator=(msp_counted_ptr const&) = delete;
private:
Record m_record;
};

template<typename Ptr, typename Record, typename Deleter>
class msp_counted_deleter final : public msp_counted_base {

class impl {
public:
impl(Record p, Deleter d) : m_record(p), m_deleter(std::move(d)) {}
Deleter& m_del() noexcept { return m_deleter; }
Record m_record;
Deleter m_deleter;
};

public:
msp_counted_deleter(Record p, Deleter d) noexcept : m_impl(p, std::move(d)) {}
virtual void m_dispose() noexcept {
printf("Delete GPU Memory Here...\n");
m_impl.m_del()((Ptr)m_impl.m_record->m_pointers[chai::CPU]);
}
virtual void m_destroy() noexcept { this->~msp_counted_deleter(); }
msp_counted_deleter(msp_counted_deleter const&) = delete;
msp_counted_deleter& operator=(msp_counted_deleter const&) = delete;
private:
impl m_impl;
};


class msp_shared_count {
public:
constexpr msp_shared_count() noexcept : m_pi(0) {}

template<typename Ptr, typename Record>
explicit msp_shared_count(Ptr, Record p)
: m_pi( new msp_counted_ptr<Ptr, Record>(p) ) {}

template<typename Ptr, typename Record, typename Deleter>
explicit msp_shared_count(Ptr, Record p, Deleter d)
: m_pi( new msp_counted_deleter<Ptr, Record, Deleter>(p, d) ) {}

~msp_shared_count() noexcept
{ if (m_pi) m_pi->m_release(); }

msp_shared_count(msp_shared_count const& rhs) noexcept : m_pi(rhs.m_pi)
{ if (m_pi) m_pi->m_add_ref_copy(); }

msp_shared_count& operator=(msp_shared_count const& rhs) noexcept {
msp_counted_base* temp = rhs.m_pi;
if (temp != m_pi)
{
if (temp) temp->m_add_ref_copy();
if (m_pi) m_pi->m_release();
m_pi = temp;
}
return *this;
}

void m_swap(msp_shared_count& rhs) noexcept {
msp_counted_base* temp = rhs.m_pi;
rhs.m_pi = m_pi;
m_pi = temp;
}

long m_get_use_count() const noexcept
{ return m_pi ? m_pi->m_get_use_count() : 0; }

friend inline bool
operator==(msp_shared_count const& a, msp_shared_count const& b) noexcept
{ return a.m_pi == b.m_pi; }

msp_counted_base* m_pi;

};





Expand Down Expand Up @@ -166,39 +55,13 @@ class ManagedSharedPtr {
/*
* Constructors
*/
constexpr ManagedSharedPtr() noexcept : m_ref_count() {}
constexpr ManagedSharedPtr() noexcept : m_record_count() {}

//// *Default* Ctor with convertible type Yp -> Tp
//template<typename Yp, typename = SafeConv<Yp>>
//explicit ManagedSharedPtr(Yp* host_p) :
// m_pointer_record(new msp_pointer_record<Tp>(host_p)),
// m_ref_count(host_p, m_pointer_record),
// m_active_pointer(static_cast<Yp*>(m_pointer_record->m_pointers[chai::CPU]))
// //m_resource_manager(SharedPtrManager::getInstance())
//{}

//template<typename Yp, typename = SafeConv<Yp>>
//explicit ManagedSharedPtr(Yp* host_p, Yp* device_p) :
// m_pointer_record(new msp_pointer_record<Yp>(host_p, device_p)),
// m_ref_count(host_p, m_pointer_record),
// m_active_pointer(static_cast<Yp*>(m_pointer_record->m_pointers[chai::CPU]))
// //m_resource_manager(SharedPtrManager::getInstance())
//{}

//template<typename Yp, typename Deleter, typename = SafeConv<Yp>>
//ManagedSharedPtr(Yp* host_p, Deleter d) :
// m_pointer_record(new msp_pointer_record<Yp>(host_p)),
// m_ref_count(host_p, m_pointer_record, std::move(d)),
// m_active_pointer(static_cast<Yp*>(m_pointer_record->m_pointers[chai::CPU]))
// //m_resource_manager(SharedPtrManager::getInstance())
//{}

template<typename Yp, typename Deleter, typename = SafeConv<Yp>>
ManagedSharedPtr(Yp* host_p, Yp* device_p, Deleter d) :
m_pointer_record(new msp_pointer_record<Yp>(host_p, device_p)),
m_ref_count(host_p, m_pointer_record, std::move(d)),
m_active_pointer(static_cast<Yp*>(m_pointer_record->m_pointers[chai::CPU]))
//m_resource_manager(SharedPtrManager::getInstance())
m_record_count(host_p, device_p, std::move(d)),
m_active_pointer(m_record_count.getPointer<Yp>(chai::CPU))
{}

/*
Expand All @@ -208,11 +71,9 @@ class ManagedSharedPtr {

template<typename Yp, typename = Compatible<Yp>>
ManagedSharedPtr(ManagedSharedPtr<Yp> const& rhs) noexcept :
m_ref_count(rhs.m_ref_count),
m_record_count(rhs.m_record_count),
m_active_pointer(rhs.m_active_pointer)
{
// TODO : Is this safe??
m_pointer_record = reinterpret_cast<msp_pointer_record<Tp>*>(rhs.m_pointer_record);
}


Expand All @@ -230,7 +91,7 @@ class ManagedSharedPtr {


public:
long use_count() const noexcept { return m_ref_count.m_get_use_count(); }
long use_count() const noexcept { return m_record_count.m_get_use_count(); }

/*
* Private Members
Expand All @@ -239,13 +100,8 @@ class ManagedSharedPtr {
template<typename Tp1>
friend class ManagedSharedPtr;

//template<typename Yp, typename... Args>
//friend ManagedSharedPtr<Yp> make_managed(Args... args);

mutable msp_pointer_record<Tp>* m_pointer_record = nullptr;
msp_shared_count m_ref_count;
msp_record_count m_record_count;
mutable element_type* m_active_pointer = nullptr;

//mutable SharedPtrManager* m_resource_manager = nullptr;
};

Expand Down
144 changes: 121 additions & 23 deletions src/chai/SharedPointerRecord.hpp
Expand Up @@ -19,7 +19,7 @@ namespace chai
/*!
* \brief Struct holding details about each pointer.
*/
template<typename Tp>
//template<typename Tp>
struct msp_pointer_record {

// Using NUM_EXECUTION_SPACES for the time being, this will help with logical
Expand All @@ -35,18 +35,6 @@ struct msp_pointer_record {

int m_allocators[NUM_EXECUTION_SPACES];

//template<typename Yp>
//msp_pointer_record(Yp* host_p = nullptr, Yp* device_p = nullptr) : m_last_space(NONE) {
// for (int space = 0; space < NUM_EXECUTION_SPACES; ++space ) {
// m_pointers[space] = nullptr;
// m_touched[space] = false;
// m_owned[space] = true;
// m_allocators[space] = 0;
// }
// m_pointers[CPU] = host_p;
// m_pointers[GPU] = device_p;
//}


msp_pointer_record(void* host_p = nullptr, void* device_p = nullptr) : m_last_space(NONE) {
for (int space = 0; space < NUM_EXECUTION_SPACES; ++space ) {
Expand All @@ -59,23 +47,133 @@ struct msp_pointer_record {
m_pointers[GPU] = device_p;
}

//Tp* get_pointer(ExecutionSpace space) noexcept { return m_pointers[space]; }
//template<typename Yp>
//msp_pointer_record(msp_pointer_record<Yp> const& rhs) :
// m_pointers(rhs.m_pointers),
// m_touched(rhs.m_touched),
// m_owned(rhs.m_owned),
// m_last_space(rhs.m_last_space),
// m_allocators(rhs.m_allocators)
//{}
};


class msp_counted_base {
public:
msp_counted_base() noexcept : m_use_count(1) {}

virtual ~msp_counted_base() noexcept {}

virtual void m_dispose() noexcept = 0;
virtual void m_destroy() noexcept { delete this; }

void m_add_ref_copy() noexcept { ++m_use_count; }

void m_release() noexcept {
if(--m_use_count == 0) {
m_dispose();
m_destroy();
}
}

long m_get_use_count() const noexcept { return m_use_count; }

virtual msp_pointer_record& getPointerRecord() noexcept = 0;

private:
msp_counted_base(msp_counted_base const&) = delete;
msp_counted_base& operator=(msp_counted_base const&) = delete;

long m_use_count = 0;
};

template<typename Ptr>
class msp_counted_ptr final : public msp_counted_base {
public:
msp_counted_ptr(Ptr h_p, Ptr d_p) noexcept : m_record(h_p, d_p) {}
virtual void m_dispose() noexcept { delete (Ptr)m_record.m_pointers[chai::CPU]; }// TODO : Other Exec spaces...
virtual void m_destroy() noexcept { delete this; }
msp_counted_ptr(msp_counted_ptr const&) = delete;
msp_counted_ptr& operator=(msp_counted_ptr const&) = delete;

msp_pointer_record& getPointerRecord() noexcept { return m_record; }
private:
msp_pointer_record m_record;
};

template<typename Ptr, typename Deleter>
class msp_counted_deleter final : public msp_counted_base {

class impl {
public:
impl(Ptr h_p, Ptr d_p, Deleter d) : m_record(h_p, d_p), m_deleter(std::move(d)) {}
Deleter& m_del() noexcept { return m_deleter; }
msp_pointer_record m_record;
Deleter m_deleter;
};

public:
msp_counted_deleter(Ptr h_p, Ptr d_p, Deleter d) noexcept : m_impl(h_p, d_p, std::move(d)) {}
virtual void m_dispose() noexcept {
printf("Delete GPU Memory Here...\n");
m_impl.m_del()((Ptr)m_impl.m_record.m_pointers[chai::CPU]);
}
virtual void m_destroy() noexcept { this->~msp_counted_deleter(); }
msp_counted_deleter(msp_counted_deleter const&) = delete;
msp_counted_deleter& operator=(msp_counted_deleter const&) = delete;

msp_pointer_record& getPointerRecord() noexcept { return m_impl.m_record; }
private:
impl m_impl;
};

} // end of namespace chai

class msp_record_count {
public:
constexpr msp_record_count() noexcept : m_pi(0) {}

template<typename Ptr>
explicit msp_record_count(Ptr h_p, Ptr d_p)
: m_pi( new msp_counted_ptr<Ptr>(h_p, d_p) ) {}

template<typename Ptr, typename Deleter>
explicit msp_record_count(Ptr h_p, Ptr d_p, Deleter d)
: m_pi( new msp_counted_deleter<Ptr, Deleter>(h_p, d_p, d) ) {}

~msp_record_count() noexcept
{ if (m_pi) m_pi->m_release(); }

msp_record_count(msp_record_count const& rhs) noexcept : m_pi(rhs.m_pi)
{ if (m_pi) m_pi->m_add_ref_copy(); }

msp_record_count& operator=(msp_record_count const& rhs) noexcept {
msp_counted_base* temp = rhs.m_pi;
if (temp != m_pi)
{
if (temp) temp->m_add_ref_copy();
if (m_pi) m_pi->m_release();
m_pi = temp;
}
return *this;
}

void m_swap(msp_record_count& rhs) noexcept {
msp_counted_base* temp = rhs.m_pi;
rhs.m_pi = m_pi;
m_pi = temp;
}

long m_get_use_count() const noexcept
{ return m_pi ? m_pi->m_get_use_count() : 0; }

friend inline bool
operator==(msp_record_count const& a, msp_record_count const& b) noexcept
{ return a.m_pi == b.m_pi; }

msp_pointer_record& getPointerRecord() noexcept { return m_pi->getPointerRecord(); }

template<typename Ptr>
Ptr* getPointer(chai::ExecutionSpace space) noexcept { return static_cast<Ptr*>(getPointerRecord().m_pointers[space]); }

msp_counted_base* m_pi;

};





} // end of namespace chai
#endif // CHAI_SharedPointerRecord_HPP
4 changes: 2 additions & 2 deletions tests/integration/managed_ptr_tests.cpp
Expand Up @@ -172,12 +172,12 @@ TEST(managed_ptr, shared_ptr)
//chai::ManagedSharedPtr<TestBase> sptr(new TestDerived());

//chai::ManagedSharedPtr<TestDerived> sptr = chai::make_shared<TestDerived>();
chai::ManagedSharedPtr<TestDerived> sptr = chai::make_shared_deleter<TestDerived>(
chai::ManagedSharedPtr<TestBase> sptr = chai::make_shared_deleter<TestDerived>(
[](TestDerived* p){ printf("Custom Deleter Call\n"); p->~TestDerived(); });

std::cout << "use_count : " << sptr.use_count() << std::endl;

auto sptr2 = sptr;
chai::ManagedSharedPtr<TestBase> sptr2 = sptr;
sptr2->doSomething();
std::cout << "use_count : " << sptr.use_count() << std::endl;

Expand Down

0 comments on commit 75749a2

Please sign in to comment.