Skip to content

Commit

Permalink
Correct storage of TVMByteArray
Browse files Browse the repository at this point in the history
  • Loading branch information
Lunderberg committed Jun 18, 2024
1 parent 08d081b commit 85200c4
Showing 1 changed file with 53 additions and 14 deletions.
67 changes: 53 additions & 14 deletions include/tvm/runtime/packed_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -597,6 +597,14 @@ class TVMPODValue_ {
return AsObjectRef<tvm::runtime::String>().operator std::string();
}
}
operator TVMByteArray() const {
if (type_code_ == kTVMBytes) {
return *static_cast<TVMByteArray*>(value_.v_handle);
} else {
LOG(FATAL) << "Expected "
<< "TVMByteArray but got " << ArgTypeCode2Str(type_code_);
}
}

inline operator DLDataType() const;

Expand Down Expand Up @@ -692,6 +700,7 @@ class TVMArgValue : public TVMPODValue_ {
using TVMPODValue_::operator bool;
using TVMPODValue_::operator void*;
using TVMPODValue_::operator std::string;
using TVMPODValue_::operator TVMByteArray;
using TVMPODValue_::operator DLDataType;
using TVMPODValue_::operator DataType;
using TVMPODValue_::operator DLTensor*;
Expand Down Expand Up @@ -734,6 +743,7 @@ class TVMMovableArgValue_ : public TVMPODValue_ {
using TVMPODValue_::operator bool;
using TVMPODValue_::operator void*;
using TVMPODValue_::operator std::string;
using TVMPODValue_::operator TVMByteArray;
using TVMPODValue_::operator DLDataType;
using TVMPODValue_::operator DataType;
using TVMPODValue_::operator DLTensor*;
Expand Down Expand Up @@ -838,6 +848,7 @@ class TVMRetValue : public TVMPODValue_ {
using TVMPODValue_::operator bool;
using TVMPODValue_::operator void*;
using TVMPODValue_::operator std::string;
using TVMPODValue_::operator TVMByteArray;
using TVMPODValue_::operator DLDataType;
using TVMPODValue_::operator DataType;
using TVMPODValue_::operator DLTensor*;
Expand Down Expand Up @@ -903,11 +914,49 @@ class TVMRetValue : public TVMPODValue_ {
return *this;
}
TVMRetValue& operator=(std::string value) {
this->SwitchToString(kTVMStr, std::move(value));
this->Clear();

std::string* container = new std::string(std::move(value));
f_deleter_ = [](void* arg) { delete static_cast<std::string*>(arg); };
f_deleter_arg_ = container;

type_code_ = kTVMStr;
value_.v_str = container->c_str();

return *this;
}
TVMRetValue& operator=(TVMByteArray value) {
this->SwitchToString(kTVMBytes, std::string(value.data, value.size));
this->Clear();

/* \brief Container for owned data
*
* For consistency with TVMArgValue, kTVMBytes should store a
* `TVMByteArray*` in `value_.v_handle`. However, `TVMRetValue`
* must own its backing allocation, where `TVMByteArray` does not
* own the data to which it points.
*
* This struct provides both ownership over an allocation, and a
* `TVMByteArray` with a view into the owned allocation.
*/
struct OwnedArray {
OwnedArray(std::vector<char> arg)
: backing_vector(std::move(arg)), array{backing_vector.data(), backing_vector.size()} {}
OwnedArray(const OwnedArray&) = delete;

// The backing allocation
std::vector<char> backing_vector;

// The TVMByteArray, referencing the backing allocation
TVMByteArray array;
};

OwnedArray* container = new OwnedArray(std::vector<char>(value.data, value.data + value.size));
f_deleter_ = [](void* arg) { delete static_cast<OwnedArray*>(arg); };
f_deleter_arg_ = container;

type_code_ = kTVMBytes;
value_.v_handle = &container->array;

return *this;
}
TVMRetValue& operator=(NDArray other) {
Expand Down Expand Up @@ -996,11 +1045,11 @@ class TVMRetValue : public TVMPODValue_ {
void Assign(const T& other) {
switch (other.type_code()) {
case kTVMStr: {
SwitchToString(kTVMStr, other);
*this = other.operator std::string();
break;
}
case kTVMBytes: {
SwitchToString(kTVMBytes, other);
*this = other.operator TVMByteArray();
break;
}
case kTVMPackedFuncHandle: {
Expand Down Expand Up @@ -1039,16 +1088,6 @@ class TVMRetValue : public TVMPODValue_ {
type_code_ = type_code;
}
}
void SwitchToString(int type_code, std::string value) {
this->Clear();

std::string* container = new std::string(std::move(value));
f_deleter_ = [](void* arg) { delete static_cast<std::string*>(arg); };
f_deleter_arg_ = container;

type_code_ = type_code;
value_.v_str = container->c_str();
}
void SwitchToObject(int type_code, ObjectPtr<Object> other) {
if (other.data_ != nullptr) {
this->Clear();
Expand Down

0 comments on commit 85200c4

Please sign in to comment.