Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ struct PassThrough
__host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; }

template <typename LowIdx, typename UpIdx>
__host__ __device__ static void CalculateLowerIndex(LowIdx& idx_low, const UpIdx& idx_up)
__host__ __device__ static constexpr void CalculateLowerIndex(LowIdx& idx_low,
const UpIdx& idx_up)
{
static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1,
"wrong! inconsistent # of dimension");
Expand Down Expand Up @@ -1708,7 +1709,8 @@ struct Vectorize
__host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; }

template <typename LowIdx, typename UpIdx>
__host__ __device__ void CalculateLowerIndex(LowIdx& idx_low, const UpIdx& idx_up) const
__host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low,
const UpIdx& idx_up) const
{
static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1,
"wrong! inconsistent # of dimension");
Expand Down
265 changes: 265 additions & 0 deletions composable_kernel/include/tensor_description/static_tensor.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,265 @@
#ifndef CK_STATIC_TENSOR_HPP
#define CK_STATIC_TENSOR_HPP

#include "ignore.hpp"

namespace ck {

// StaticTensor for Scalar
template <AddressSpaceEnum_t AddressSpace,
typename T,
typename TensorDesc,
bool InvalidElementUseNumericalZeroValue,
typename enable_if<TensorDesc::IsKnownAtCompileTime(), bool>::type = false>
struct StaticTensor
{
static constexpr auto desc_ = TensorDesc{};
static constexpr index_t ndim_ = TensorDesc::GetNumOfDimension();
static constexpr index_t element_space_size_ = desc_.GetElementSpaceSize();

__host__ __device__ constexpr StaticTensor() : invalid_element_value_{0} {}

__host__ __device__ constexpr StaticTensor(T invalid_element_value)
: invalid_element_value_{invalid_element_value}
{
}

// read access
template <typename Idx,
typename enable_if<is_known_at_compile_time<Idx>::value && Idx::Size() == ndim_,
bool>::type = false>
__host__ __device__ constexpr const T& operator[](Idx) const
{
constexpr auto coord = make_tensor_coordinate(desc_, to_multi_index(Idx{}));

constexpr index_t offset = coord.GetOffset();

constexpr bool is_valid = coordinate_has_valid_offset(desc_, coord);

if constexpr(is_valid)
{
return data_[Number<offset>{}];
}
else
{
if constexpr(InvalidElementUseNumericalZeroValue)
{
return T{0};
}
else
{
return invalid_element_value_;
}
}
}

// write access
template <typename Idx,
typename enable_if<is_known_at_compile_time<Idx>::value && Idx::Size() == ndim_,
bool>::type = false>
__host__ __device__ constexpr T& operator()(Idx)
{
constexpr auto coord = make_tensor_coordinate(desc_, to_multi_index(Idx{}));

constexpr index_t offset = coord.GetOffset();

constexpr bool is_valid = coordinate_has_valid_offset(desc_, coord);

if constexpr(is_valid)
{
return data_(Number<offset>{});
}
else
{
return ignore;
}
}

StaticBuffer<AddressSpace, T, element_space_size_, true> data_;
T invalid_element_value_ = T{0};
};

// StaticTensor for vector
template <AddressSpaceEnum_t AddressSpace,
typename S,
index_t ScalarPerVector,
typename TensorDesc,
bool InvalidElementUseNumericalZeroValue,
typename enable_if<TensorDesc::IsKnownAtCompileTime(), bool>::type = false>
struct StaticTensorTupleOfVectorBuffer
{
static constexpr auto desc_ = TensorDesc{};
static constexpr index_t ndim_ = TensorDesc::GetNumOfDimension();
static constexpr index_t element_space_size_ = desc_.GetElementSpaceSize();

static constexpr index_t num_of_vector_ =
math::integer_divide_ceil(element_space_size_, ScalarPerVector);

using V = vector_type<S, ScalarPerVector>;

__host__ __device__ constexpr StaticTensorTupleOfVectorBuffer() : invalid_element_value_{0} {}

__host__ __device__ constexpr StaticTensorTupleOfVectorBuffer(S invalid_element_value)
: invalid_element_value_{invalid_element_value}
{
}

// Get S
// Idx is for S, not V
template <typename Idx,
typename enable_if<is_known_at_compile_time<Idx>::value && Idx::Size() == ndim_,
bool>::type = false>
__host__ __device__ constexpr const S& operator[](Idx) const
{
constexpr auto coord = make_tensor_coordinate(desc_, to_multi_index(Idx{}));

constexpr index_t offset = coord.GetOffset();

constexpr bool is_valid = coordinate_has_valid_offset(desc_, coord);

if constexpr(is_valid)
{
return data_[Number<offset>{}];
}
else
{
if constexpr(InvalidElementUseNumericalZeroValue)
{
return S{0};
}
else
{
return invalid_element_value_;
}
}
}

// Set S
// Idx is for S, not V
template <typename Idx,
typename enable_if<is_known_at_compile_time<Idx>::value && Idx::Size() == ndim_,
bool>::type = false>
__host__ __device__ constexpr S& operator()(Idx)
{
constexpr auto coord = make_tensor_coordinate(desc_, to_multi_index(Idx{}));

constexpr index_t offset = coord.GetOffset();

constexpr bool is_valid = coordinate_has_valid_offset(desc_, coord);

if constexpr(is_valid)
{
return data_(Number<offset>{});
}
else
{
return ignore;
}
}

// Get X
// Idx is for S, not X. Idx should be aligned with X
template <typename X,
typename Idx,
typename enable_if<has_same_scalar_type<S, X>::value &&
is_known_at_compile_time<Idx>::value && Idx::Size() == ndim_,
bool>::type = false>
__host__ __device__ constexpr X GetAsType(Idx) const
{
constexpr auto coord = make_tensor_coordinate(desc_, to_multi_index(Idx{}));

constexpr index_t offset = coord.GetOffset();

constexpr bool is_valid = coordinate_has_valid_offset(desc_, coord);

if constexpr(is_valid)
{
return data_.template GetAsType<X>(Number<offset>{});
}
else
{
if constexpr(InvalidElementUseNumericalZeroValue)
{
// TODO: is this right way to initialize a vector?
return X{0};
}
else
{
// TODO: is this right way to initialize a vector?
return X{invalid_element_value_};
}
}
}

// Set X
// Idx is for S, not X. Idx should be aligned with X
template <typename X,
typename Idx,
typename enable_if<has_same_scalar_type<S, X>::value &&
is_known_at_compile_time<Idx>::value && Idx::Size() == ndim_,
bool>::type = false>
__host__ __device__ constexpr void SetAsType(Idx, X x)
{
constexpr auto coord = make_tensor_coordinate(desc_, to_multi_index(Idx{}));

constexpr index_t offset = coord.GetOffset();

constexpr bool is_valid = coordinate_has_valid_offset(desc_, coord);

if constexpr(is_valid)
{
data_.template SetAsType<X>(Number<offset>{}, x);
}
}

// Get read access to V. No is_valid check
// Idx is for S, not V. Idx should be aligned with V
template <typename Idx>
__host__ __device__ constexpr const V& GetVectorTypeReference(Idx) const
{
constexpr auto coord = make_tensor_coordinate(desc_, to_multi_index(Idx{}));

constexpr index_t offset = coord.GetOffset();

return data_.GetVectorTypeReference(Number<offset>{});
}

// Get read access to V. No is_valid check
// Idx is for S, not V. Idx should be aligned with V
template <typename Idx>
__host__ __device__ constexpr V& GetVectorTypeReference(Idx)
{
constexpr auto coord = make_tensor_coordinate(desc_, to_multi_index(Idx{}));

constexpr index_t offset = coord.GetOffset();

return data_.GetVectorTypeReference(Number<offset>{});
}

StaticBufferTupleOfVector<AddressSpace, S, num_of_vector_, ScalarPerVector, true> data_;
S invalid_element_value_ = S{0};
};

template <AddressSpaceEnum_t AddressSpace,
typename T,
typename TensorDesc,
typename enable_if<TensorDesc::IsKnownAtCompileTime(), bool>::type = false>
__host__ __device__ constexpr auto make_static_tensor(TensorDesc)
{
return StaticTensor<AddressSpace, T, TensorDesc, true>{};
}

template <
AddressSpaceEnum_t AddressSpace,
typename T,
typename TensorDesc,
typename X,
typename enable_if<TensorDesc::IsKnownAtCompileTime(), bool>::type = false,
typename enable_if<is_same<remove_cvref_t<T>, remove_cvref_t<X>>::value, bool>::type = false>
__host__ __device__ constexpr auto make_static_tensor(TensorDesc, X invalid_element_value)
{
return StaticTensor<AddressSpace, T, TensorDesc, true>{invalid_element_value};
}

} // namespace ck
#endif
14 changes: 14 additions & 0 deletions composable_kernel/include/tensor_description/tensor_adaptor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,20 @@ struct TensorAdaptor

__host__ __device__ constexpr auto GetElementSize() const { return element_size_; }

#if 0 // debug
template <index_t I>
__host__ __device__ constexpr index_t GetTopDimensionLength(Number<I> idim) const
{
// TODO: not implemented
}

template <index_t I>
__host__ __device__ constexpr index_t GetBottomDimensionLength(Number<I> idim) const
{
// TODO: not implemented
}
#endif

template <typename TopIdx>
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,10 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);

StaticBufferV2<AddressSpaceEnum_t::Vgpr, vector_type<FloatAcc, 16>, MRepeat * NRepeat, true>
StaticBufferOfVectorTypeV2<AddressSpaceEnum_t::Vgpr,
vector_type<FloatAcc, 16>,
MRepeat * NRepeat,
true>
c_thread_buf_;

__host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "cluster_descriptor.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "threadwise_tensor_slice_transfer_v3r2.hpp"

namespace ck {

Expand Down Expand Up @@ -146,22 +146,22 @@ struct BlockwiseTensorSliceTransfer_v4
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});

using ThreadwiseTransfer =
ThreadwiseTensorSliceTransfer_v3<ThreadSliceLengths,
DstInMemOp,
SrcData,
DstData,
SrcDesc,
DstDesc,
SrcDimAccessOrder,
DstDimAccessOrder,
SrcVectorDim,
DstVectorDim,
SrcScalarPerVector,
DstScalarPerVector,
SrcScalarStrideInVector,
DstScalarStrideInVector,
ThreadTransferSrcResetCoordinateAfterRun,
ThreadTransferDstResetCoordinateAfterRun>;
ThreadwiseTensorSliceTransfer_v3r2<ThreadSliceLengths,
DstInMemOp,
SrcData,
DstData,
SrcDesc,
DstDesc,
SrcDimAccessOrder,
DstDimAccessOrder,
SrcVectorDim,
DstVectorDim,
SrcScalarPerVector,
DstScalarPerVector,
SrcScalarStrideInVector,
DstScalarStrideInVector,
ThreadTransferSrcResetCoordinateAfterRun,
ThreadTransferDstResetCoordinateAfterRun>;

ThreadwiseTransfer threadwise_transfer_;
};
Expand Down
Loading