Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ enable_cppcheck(
)

add_subdirectory(host)
add_subdirectory(device_operation)
add_subdirectory(example)
add_subdirectory(profiler)
add_subdirectory(test)
195 changes: 58 additions & 137 deletions composable_kernel/include/tensor_operation/element_wise_operation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,178 +7,99 @@ namespace element_wise {

struct PassThrough
{
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
y = x;
}
__host__ __device__ void operator()(float& y, const float& x) const { y = x; }

// TODO remove this
template <typename T>
__host__ __device__ constexpr T operator()(T v) const
{
return v;
}
__host__ __device__ void operator()(half_t& y, const half_t& x) const { y = x; }
};

struct AddRelu
{
template <typename T>
__host__ __device__ constexpr void operator()(T& y, const T& x0, const T& x1) const
__host__ __device__ constexpr void operator()(float& y, const float& x0, const float& x1) const
{
T a = x0 + x1;
y = a > 0 ? a : 0;
const float a = x0 + x1;
y = a > 0 ? a : 0;
}

// TODO remove this
template <typename T1>
__host__ constexpr float operator()(float v0, T1 v1) const
__host__ __device__ constexpr void
operator()(half_t& y, const half_t& x0, const half_t& x1) const
{
float b = v0 + v1;
float c = b > 0 ? b : 0;

return c;
const half_t a = x0 + x1;
y = a > 0 ? a : 0;
}
};

// TODO remove this
template <typename T1>
__device__ constexpr float operator()(float v0, T1 v1) const
struct AddHardswish
{
__host__ __device__ constexpr void operator()(float& y, const float& x0, const float& x1) const
{
#if 0
float a = v1 + v0;
float b = max(a, float(0));

return b;
#else
float b = v1 + v0;
float c = b > 0 ? b : 0;
float a = x0 + x1;
float b = a + float{3};
float c = (b > 0) * (b > float{6} ? float{6} : b) * a * float{0.166667};
y = c;
}

return c;
#endif
__host__ __device__ constexpr void
operator()(half_t& y, const half_t& x0, const half_t& x1) const
{
float a = x0 + x1;
float b = a + float{3};
float c = (b > 0) * (b > float{6} ? float{6} : b) * a * float{0.166667};
y = c;
}
};

struct AddReluAdd
{
template <typename T>
__host__ __device__ constexpr void operator()(T& y, const T& x0, const T& x1, const T& x2) const
__host__ __device__ constexpr void
operator()(half_t& y, const half_t& x0, const half_t& x1, const half_t& x2) const
{
T a = x0 + x1;
T b = a > 0 ? a : 0;
y = b + x2;
half_t a = x0 + x1;
half_t b = a > 0 ? a : 0;
y = b + x2;
}

// TODO remove this
template <typename T1, typename T2>
__host__ constexpr float operator()(float v0, T1 v1, T2 v2) const
__host__ __device__ constexpr void
operator()(float& y, const float& x0, const float& x1, const float& x2) const
{
float b = v0 + v1;
float c = b > 0 ? b : 0;
float d = c + v2;

return d;
float a = x0 + x1;
float b = a > 0 ? a : 0;
float c = b + x2;
y = c;
}

// TODO remove this
template <typename T1, typename T2>
__device__ constexpr float operator()(float v0, T1 v1, T2 v2) const
__host__ __device__ constexpr void
operator()(half_t& y, const float& x0, const half_t& x1, const half_t& x2) const
{
#if 0
float a = v1 + v0;
float b = max(a, float(0));
float c = b + v2;

return c;
#else
float b = v1 + v2;
float c = (v0 > -v1) ? b + v0 : v2;

return c;
#endif
float a = x0 + x1;
float b = a > 0 ? a : 0;
float c = b + x2;
y = c;
}
};

} // namespace element_wise
} // namespace tensor_operation
} // namespace ck

namespace ck {
namespace tensor_operation {
namespace element_wise {

struct AddLeakyReluAdd
struct AddHardswishAdd
{
template <typename T1, typename T2>
__host__ constexpr float operator()(float v0, T1 v1, T2 v2) const
__host__ __device__ constexpr void
operator()(float& y, const float& x0, const float& x1, const float& x2) const
{
float a = v0 + v1;
float b = 0.1 * a;
float c = b > 0 ? b : 0;
float d = c + v2;

return d;
float a = x0 + x1;
float b = a + float{3};
float c = (b > 0) * (b > float{6} ? float{6} : b) * a * float{0.166667};
float d = c + x2;
y = d;
}

template <typename T1, typename T2>
__device__ constexpr float operator()(float v0, T1 v1, T2 v2) const
__host__ __device__ constexpr void
operator()(half_t& y, const half_t& x0, const half_t& x1, const half_t& x2) const
{
#if 0
// this use not too many registers, but use fp64 mul
float a = v0 + v1;
float b = 0.1 * a;
float c = b > 0 ? b : 0;
float d = c + v2;

return d;
#elif 0
// this spill register
float a = v0 + v1;
float b = float(0.1) * a;
float c = b > 0 ? b : 0;
float d = c + v2;

return d;
#elif 0
// this use lots of registers (but no spill)
constexpr float alpha = 0.1;
constexpr float alpha_inv = 1.0 / alpha;

float a = v2 * alpha_inv;
float b = v1 + v0;
float c = b > 0 ? b : 0;
float d = alpha * (a + c);

return d;
#elif 1
// this use lots of registers (but no spill), 89 Tflops
constexpr float alpha = 0.1;
constexpr float alpha_inv = 1.0 / alpha;

float a = v2 * alpha_inv;
float b = v1 + v0;
float c = max(b, float(0));
float d = alpha * (a + c);

return d;
#elif 1
// this spill registers, 89 Tflops
float a = v0 + v1;
float alpha = 0.1;

float b;
asm volatile("\n \
v_mul_f32_e32 %0, %1, %2 \n \
"
: "=v"(b)
: "s"(alpha), "v"(a));

float c = b > 0 ? b : 0;
float d = c + v2;

return d;
#endif
float a = x0 + x1;
float b = a + float{3};
float c = (b > 0) * (b > float{6} ? float{6} : b) * a * float{0.166667};
float d = c + x2;
y = d;
}
};

} // namespace element_wise
} // namespace tensor_operation
} // namespace ck
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,9 +199,13 @@ struct ThreadwiseTensorSliceTransfer_v1r3
constexpr index_t src_offset = src_desc.CalculateOffset(
src_slice_origin_idx + dst_data_idx + i * dst_scalar_step_in_vector);

// apply element-wise operation and type convert
dst_vector.template AsType<DstData>()(i) =
type_convert<DstData>(dst_element_op_(src_buf[Number<src_offset>{}]));
SrcData dst_v;

// apply element-wise operation
dst_element_op_(dst_v, src_buf[Number<src_offset>{}]);

// apply type convert
dst_vector.template AsType<DstData>()(i) = type_convert<DstData>(dst_v);
});

const bool is_dst_valid =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,9 @@ struct ThreadwiseTensorSliceTransfer_v1r4
dst_vector.template AsType<DstData>()(Number<0>{}) = type_convert<DstData>(dst_v);
#else
// apply element-wise operation in DstData type
const DstData dst_v = dst_element_op_(src_v, dst0_v, dst1_v);
DstData dst_v;

dst_element_op_(dst_v, src_v, dst0_v, dst1_v);

dst_vector.template AsType<DstData>()(Number<0>{}) = dst_v;
#endif
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,11 @@ struct ThreadwiseTensorSliceTransfer_v3r1

// apply SrcElementwiseOperation on src_vector_container
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
src_vector_container.template AsType<SrcData>()(i) =
src_element_op_(src_vector_container.template AsType<SrcData>()[i]);
SrcData src_v;

src_element_op_(src_v, src_vector_container.template AsType<SrcData>()[i]);

src_vector_container.template AsType<SrcData>()(i) = src_v;
});

// copy data from src_vector_container into src_thread_scratch_
Expand Down Expand Up @@ -452,10 +455,13 @@ struct ThreadwiseTensorSliceTransfer_v3r1
auto dst_vector_container = dst_vector_type{
dst_thread_scratch_.template GetAsType<dst_vector_t>(dst_data_idx_seq)};

// apply DstElementwiseOperation on dst_vector_container
static_for<0, DstScalarPerVector, 1>{}([&](auto i) {
dst_vector_container.template AsType<DstData>()(i) =
dst_element_op_(dst_vector_container.template AsType<DstData>()[i]);
DstData dst_v;

// apply DstElementwiseOperation
dst_element_op_(dst_v, dst_vector_container.template AsType<DstData>()[i]);

dst_vector_container.template AsType<DstData>()(i) = dst_v;
});

// copy data from dst_vector_container to dst_buf
Expand Down
Loading