Skip to content

Commit

Permalink
Bug fix and inplace operators
Browse files Browse the repository at this point in the history
  • Loading branch information
Pencilcaseman committed Jul 14, 2023
1 parent cb184f8 commit b017099
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 6 deletions.
100 changes: 100 additions & 0 deletions librapid/include/librapid/array/arrayContainer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,36 @@ namespace librapid {
/// \param value The value to write to the array's storage
LIBRAPID_ALWAYS_INLINE void write(size_t index, const Scalar &value);

template<typename T>
LIBRAPID_ALWAYS_INLINE ArrayContainer &operator+=(const T &other);

template<typename T>
LIBRAPID_ALWAYS_INLINE ArrayContainer &operator-=(const T &other);

template<typename T>
LIBRAPID_ALWAYS_INLINE ArrayContainer &operator*=(const T &other);

template<typename T>
LIBRAPID_ALWAYS_INLINE ArrayContainer &operator/=(const T &other);

template<typename T>
LIBRAPID_ALWAYS_INLINE ArrayContainer &operator%=(const T &other);

template<typename T>
LIBRAPID_ALWAYS_INLINE ArrayContainer &operator&=(const T &other);

template<typename T>
LIBRAPID_ALWAYS_INLINE ArrayContainer &operator|=(const T &other);

template<typename T>
LIBRAPID_ALWAYS_INLINE ArrayContainer &operator^=(const T &other);

template<typename T>
LIBRAPID_ALWAYS_INLINE ArrayContainer &operator<<=(const T &other);

template<typename T>
LIBRAPID_ALWAYS_INLINE ArrayContainer &operator>>=(const T &other);

/// \brief Return an iterator to the beginning of the array container
/// \return Iterator
LIBRAPID_INLINE Iterator begin() const noexcept;
Expand Down Expand Up @@ -635,6 +665,76 @@ namespace librapid {
m_storage[index] = value;
}

template<typename ShapeType_, typename StorageType_>
template<typename T>
auto ArrayContainer<ShapeType_, StorageType_>::operator+=(const T &value)
-> ArrayContainer & {
*this = *this + value;
}

template<typename ShapeType_, typename StorageType_>
template<typename T>
auto ArrayContainer<ShapeType_, StorageType_>::operator-=(const T &value)
-> ArrayContainer & {
*this = *this - value;
}

template<typename ShapeType_, typename StorageType_>
template<typename T>
auto ArrayContainer<ShapeType_, StorageType_>::operator*=(const T &value)
-> ArrayContainer & {
*this = *this * value;
}

template<typename ShapeType_, typename StorageType_>
template<typename T>
auto ArrayContainer<ShapeType_, StorageType_>::operator/=(const T &value)
-> ArrayContainer & {
*this = *this / value;
}

template<typename ShapeType_, typename StorageType_>
template<typename T>
auto ArrayContainer<ShapeType_, StorageType_>::operator%=(const T &value)
-> ArrayContainer & {
*this = *this % value;
}

template<typename ShapeType_, typename StorageType_>
template<typename T>
auto ArrayContainer<ShapeType_, StorageType_>::operator&=(const T &value)
-> ArrayContainer & {
*this = *this & value;
}

template<typename ShapeType_, typename StorageType_>
template<typename T>
auto ArrayContainer<ShapeType_, StorageType_>::operator|=(const T &value)
-> ArrayContainer & {
*this = *this | value;
}

template<typename ShapeType_, typename StorageType_>
template<typename T>
auto ArrayContainer<ShapeType_, StorageType_>::operator^=(const T &value)
-> ArrayContainer & {
*this = *this ^ value;
}

template<typename ShapeType_, typename StorageType_>
template<typename T>
auto ArrayContainer<ShapeType_, StorageType_>::operator<<=(const T &value)
-> ArrayContainer & {
*this = *this << value;
}

template<typename ShapeType_, typename StorageType_>
template<typename T>
auto ArrayContainer<ShapeType_, StorageType_>::operator>>=(const T &value)
-> ArrayContainer & {
*this = *this >> value;
}

template<typename ShapeType_, typename StorageType_>
auto ArrayContainer<ShapeType_, StorageType_>::begin() const noexcept -> Iterator {
return Iterator(ArrayView(*this), 0);
Expand Down
8 changes: 2 additions & 6 deletions librapid/include/librapid/array/linalg/arrayMultiply.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ namespace librapid {

return MatmulClass::GEMV;
} else if (shapeA.ndim() == 2 && shapeB.ndim() == 1) {
LIBRAPID_ASSERT(shapeA[int(m_transA)] == shapeB[0],
LIBRAPID_ASSERT(shapeA[int(!m_transA)] == shapeB[0],
"Rows of OP(A) must match elements of B. Expected: {} -- Got: {}",
shapeA[int(m_transA)],
shapeB[0]);
Expand Down Expand Up @@ -326,11 +326,7 @@ namespace librapid {
return {1};
}
case MatmulClass::GEMV: {
if (shapeA.ndim() == 1) {
return {shapeA[0]};
} else {
return {shapeA[int(!m_transA)]};
}
return {shapeA[int(m_transA)]};
}
case MatmulClass::GEMM: {
return {shapeA[int(m_transA)], shapeB[int(!m_transB)]};
Expand Down

0 comments on commit b017099

Please sign in to comment.