-
Notifications
You must be signed in to change notification settings - Fork 99
/
wrapper_basic_gemm.cpp
215 lines (196 loc) · 10.2 KB
/
wrapper_basic_gemm.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <numeric>
#include <cstdlib>
#include <iostream>
#include <initializer_list>
#include <vector>
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/library/utility/fill.hpp"
#include "ck/wrapper/layout.hpp"
#include "ck/wrapper/tensor.hpp"
#include "ck/wrapper/operations/copy.hpp"
#include "ck/wrapper/operations/gemm.hpp"
#include "ck/wrapper/utils/kernel_utils.hpp"
struct SimpleDeviceMem
{
SimpleDeviceMem() = delete;
SimpleDeviceMem(std::size_t mem_size) : p_mem_{}
{
(void)hipMalloc(static_cast<void**>(&p_mem_), mem_size);
}
void* GetDeviceBuffer() { return p_mem_; }
~SimpleDeviceMem() { (void)hipFree(p_mem_); }
void* p_mem_;
};
template <typename DataType,
typename GemmTraits,
ck::index_t scalar_per_vector,
typename BlockShape,
typename ThreadLayout>
__global__ void __CK_WRAPPER_LAUNCH_BOUNDS__ DeviceGemm(const void* p_a,
const void* p_b,
void* p_c,
const ck::index_t M,
const ck::index_t N,
const ck::index_t K,
const BlockShape tile_shape,
const ThreadLayout thread_layout)
{
constexpr auto MPerBlock = ck::wrapper::size<0>(tile_shape);
constexpr auto NPerBlock = ck::wrapper::size<1>(tile_shape);
constexpr auto KPerBlock = ck::wrapper::size<2>(tile_shape);
// Specify layouts for global memory.
const auto a_global_layout =
ck::wrapper::make_layout(ck::make_tuple(M, K), ck::make_tuple(K, 1));
const auto b_global_layout =
ck::wrapper::make_layout(ck::make_tuple(N, K), ck::make_tuple(K, 1));
const auto c_global_layout =
ck::wrapper::make_layout(ck::make_tuple(M, N), ck::make_tuple(N, 1));
// Specify layouts for tiles.
constexpr auto a_tile_layout = ck::wrapper::make_layout(
ck::make_tuple(MPerBlock, KPerBlock), ck::make_tuple(KPerBlock, ck::Number<1>{}));
constexpr auto b_tile_layout = ck::wrapper::make_layout(
ck::make_tuple(NPerBlock, KPerBlock), ck::make_tuple(KPerBlock, ck::Number<1>{}));
constexpr auto c_tile_layout = ck::wrapper::make_layout(
ck::make_tuple(MPerBlock, NPerBlock), ck::make_tuple(NPerBlock, ck::Number<1>{}));
// Apply padding for global memory.
auto a_global_layout_padded = ck::wrapper::pad(a_global_layout, shape(a_tile_layout));
auto b_global_layout_padded = ck::wrapper::pad(b_global_layout, shape(b_tile_layout));
auto c_global_layout_padded = ck::wrapper::pad(c_global_layout, shape(c_tile_layout));
// Make tensors for global memory.
auto a_global_tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Global>(
static_cast<const DataType*>(p_a), a_global_layout_padded);
auto b_global_tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Global>(
static_cast<const DataType*>(p_b), b_global_layout_padded);
auto c_global_tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Global>(
static_cast<DataType*>(p_c), c_global_layout_padded);
// Allocate lds memory.
__shared__ DataType lds_a[ck::wrapper::size(a_tile_layout)];
__shared__ DataType lds_b[ck::wrapper::size(b_tile_layout)];
// Make tensors for lds memory.
auto a_lds_tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Lds>(
static_cast<DataType*>(lds_a), a_tile_layout);
auto b_lds_tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Lds>(
static_cast<DataType*>(lds_b), b_tile_layout);
// Specify block index as tuple.
const auto block_idxs = ck::make_tuple(static_cast<ck::index_t>(blockIdx.x),
static_cast<ck::index_t>(blockIdx.y),
ck::wrapper::slice());
// Specify access parameters for copy.
using DimAccessOrder = ck::Tuple<ck::Number<0>, ck::Number<1>>;
constexpr ck::index_t vector_dim = 1;
// Create tile and partition for C. Use specific function for blockwise_gemm to assign the
// appropriate partitions.
auto c_global_local_tile = ck::wrapper::make_local_tile(
c_global_tensor,
tile_shape,
block_idxs,
make_tuple(ck::Number<1>{}, ck::Number<1>{}, ck::wrapper::slice(KPerBlock)));
auto c_global_local_partition =
ck::wrapper::make_blockwise_gemm_xdl_c_local_partition<DataType,
decltype(a_tile_layout),
decltype(b_tile_layout),
ck::wrapper::size(thread_layout),
GemmTraits>(c_global_local_tile);
// Create C vgpr to accumulate results.
auto c_vgpr_reg = ck::wrapper::make_blockwise_gemm_xdl_c_vgpr<DataType,
decltype(a_tile_layout),
decltype(b_tile_layout),
ck::wrapper::size(thread_layout),
GemmTraits>();
// Clear C vgpr.
ck::wrapper::clear(c_vgpr_reg);
// Iterate over K with KPerBlock step.
const ck::index_t num_loop = ck::math::integer_divide_ceil(K, KPerBlock);
ck::index_t i = 0;
do
{
// Get KPerBlock slice.
const auto k_slice = ck::wrapper::slice(i * KPerBlock, (i + 1) * KPerBlock);
auto a_global_tensor_k_slice = a_global_tensor(ck::wrapper::slice(), k_slice);
auto b_global_tensor_k_slice = b_global_tensor(ck::wrapper::slice(), k_slice);
// Create local tiles for A and B.
auto a_global_local_tile = ck::wrapper::make_local_tile(
a_global_tensor_k_slice,
tile_shape,
block_idxs,
make_tuple(ck::Number<1>{}, ck::wrapper::slice(N), ck::Number<1>{}));
auto b_global_local_tile = ck::wrapper::make_local_tile(
b_global_tensor_k_slice,
tile_shape,
block_idxs,
make_tuple(ck::wrapper::slice(M), ck::Number<1>{}, ck::Number<1>{}));
// Copy from global to lds.
ck::wrapper::blockwise_copy<DimAccessOrder, vector_dim, scalar_per_vector>(
a_global_local_tile, a_lds_tensor, thread_layout);
ck::wrapper::blockwise_copy<DimAccessOrder, vector_dim, scalar_per_vector>(
b_global_local_tile, b_lds_tensor, thread_layout);
// Synchronize lds.
ck::block_sync_lds();
// Execute blockwise gemm.
ck::wrapper::blockwise_gemm_xdl<DataType, ck::wrapper::size(thread_layout), GemmTraits>(
a_lds_tensor, b_lds_tensor, c_vgpr_reg);
++i;
} while(i < num_loop);
// Copy vgpr results to C global memory.
ck::wrapper::copy(c_vgpr_reg, c_global_local_partition);
}
template <typename DataType,
typename GemmTraits,
ck::index_t scalar_per_vector,
typename BlockShape,
typename ThreadLayout>
void PerformGemm(const ck::index_t M,
const ck::index_t N,
const ck::index_t K,
const BlockShape& tile_shape,
const ThreadLayout& thread_layout)
{
// Global memory buffers
SimpleDeviceMem a_mem(M * K * sizeof(DataType));
SimpleDeviceMem b_mem(K * N * sizeof(DataType));
SimpleDeviceMem c_mem(M * N * sizeof(DataType));
const ck::index_t grid_size_x =
ck::math::integer_divide_ceil(M, ck::wrapper::size<0>(tile_shape));
const ck::index_t grid_size_y =
ck::math::integer_divide_ceil(N, ck::wrapper::size<1>(tile_shape));
const auto kernel =
DeviceGemm<DataType, GemmTraits, scalar_per_vector, BlockShape, ThreadLayout>;
const float avg_time = launch_and_time_kernel(StreamConfig{nullptr, true},
kernel,
dim3(grid_size_x, grid_size_y, 1),
dim3(ck::wrapper::size(thread_layout)),
0,
a_mem.GetDeviceBuffer(),
b_mem.GetDeviceBuffer(),
c_mem.GetDeviceBuffer(),
M,
N,
K,
tile_shape,
thread_layout);
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype =
sizeof(DataType) * M * K + sizeof(DataType) * K * N + sizeof(DataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / avg_time;
float gb_per_sec = num_btype / 1.E6 / avg_time;
std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, "
<< gb_per_sec << " GB/s, " << std::endl;
}
int main(int argc, char* argv[])
{
using DataType = ck::half_t;
const auto thread_layout =
ck::wrapper::make_layout(ck::make_tuple(ck::Number<64>{}, ck::Number<4>{}),
ck::make_tuple(ck::Number<4>{}, ck::Number<1>{}));
const auto tile_shape = ck::make_tuple(ck::Number<256>{}, ck::Number<128>{}, ck::Number<32>{});
PerformGemm<DataType, ck::wrapper::BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_8K1, 8>(
3840, 4096, 4096, tile_shape, thread_layout);
return 0;
}