/
ceed-avx-tensor.c
298 lines (265 loc) · 12.7 KB
/
ceed-avx-tensor.c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
// Copyright (c) 2017-2024, Lawrence Livermore National Security, LLC and other CEED contributors.
// All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
//
// SPDX-License-Identifier: BSD-2-Clause
//
// This file is part of CEED: http://github.com/ceed
#include <ceed.h>
#include <ceed/backend.h>
#include <immintrin.h>
#include <stdbool.h>
#ifdef CEED_F64_H
#define rtype __m256d
#define loadu _mm256_loadu_pd
#define storeu _mm256_storeu_pd
#define set _mm256_set_pd
#define set1 _mm256_set1_pd
// c += a * b
#ifdef __FMA__
#define fmadd(c, a, b) (c) = _mm256_fmadd_pd((a), (b), (c))
#else
#define fmadd(c, a, b) (c) += _mm256_mul_pd((a), (b))
#endif
#else
#define rtype __m128
#define loadu _mm_loadu_ps
#define storeu _mm_storeu_ps
#define set _mm_set_ps
#define set1 _mm_set1_ps
// c += a * b
#ifdef __FMA__
#define fmadd(c, a, b) (c) = _mm_fmadd_ps((a), (b), (c))
#else
#define fmadd(c, a, b) (c) += _mm_mul_ps((a), (b))
#endif
#endif
//------------------------------------------------------------------------------
// Blocked Tensor Contract
//------------------------------------------------------------------------------
static inline int CeedTensorContract_Avx_Blocked(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J,
const CeedScalar *restrict t, CeedTransposeMode t_mode, const CeedInt add,
const CeedScalar *restrict u, CeedScalar *restrict v, const CeedInt JJ, const CeedInt CC) {
CeedInt t_stride_0 = B, t_stride_1 = 1;
if (t_mode == CEED_TRANSPOSE) {
t_stride_0 = 1;
t_stride_1 = J;
}
for (CeedInt a = 0; a < A; a++) {
// Blocks of 4 rows
for (CeedInt j = 0; j < (J / JJ) * JJ; j += JJ) {
for (CeedInt c = 0; c < (C / CC) * CC; c += CC) {
rtype vv[JJ][CC / 4]; // Output tile to be held in registers
for (CeedInt jj = 0; jj < JJ; jj++) {
for (CeedInt cc = 0; cc < CC / 4; cc++) vv[jj][cc] = loadu(&v[(a * J + j + jj) * C + c + cc * 4]);
}
for (CeedInt b = 0; b < B; b++) {
for (CeedInt jj = 0; jj < JJ; jj++) { // unroll
rtype tqv = set1(t[(j + jj) * t_stride_0 + b * t_stride_1]);
for (CeedInt cc = 0; cc < CC / 4; cc++) { // unroll
fmadd(vv[jj][cc], tqv, loadu(&u[(a * B + b) * C + c + cc * 4]));
}
}
}
for (CeedInt jj = 0; jj < JJ; jj++) {
for (CeedInt cc = 0; cc < CC / 4; cc++) storeu(&v[(a * J + j + jj) * C + c + cc * 4], vv[jj][cc]);
}
}
}
// Remainder of rows
const CeedInt j = (J / JJ) * JJ;
if (j < J) {
for (CeedInt c = 0; c < (C / CC) * CC; c += CC) {
rtype vv[JJ][CC / 4]; // Output tile to be held in registers
for (CeedInt jj = 0; jj < J - j; jj++) {
for (CeedInt cc = 0; cc < CC / 4; cc++) vv[jj][cc] = loadu(&v[(a * J + j + jj) * C + c + cc * 4]);
}
for (CeedInt b = 0; b < B; b++) {
for (CeedInt jj = 0; jj < J - j; jj++) { // doesn't unroll
rtype tqv = set1(t[(j + jj) * t_stride_0 + b * t_stride_1]);
for (CeedInt cc = 0; cc < CC / 4; cc++) { // unroll
fmadd(vv[jj][cc], tqv, loadu(&u[(a * B + b) * C + c + cc * 4]));
}
}
}
for (CeedInt jj = 0; jj < J - j; jj++) {
for (CeedInt cc = 0; cc < CC / 4; cc++) storeu(&v[(a * J + j + jj) * C + c + cc * 4], vv[jj][cc]);
}
}
}
}
return CEED_ERROR_SUCCESS;
}
//------------------------------------------------------------------------------
// Serial Tensor Contract Remainder
//------------------------------------------------------------------------------
static inline int CeedTensorContract_Avx_Remainder(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J,
const CeedScalar *restrict t, CeedTransposeMode t_mode, const CeedInt add,
const CeedScalar *restrict u, CeedScalar *restrict v, const CeedInt JJ, const CeedInt CC) {
CeedInt t_stride_0 = B, t_stride_1 = 1;
if (t_mode == CEED_TRANSPOSE) {
t_stride_0 = 1;
t_stride_1 = J;
}
const CeedInt J_break = J % JJ ? (J / JJ) * JJ : (J / JJ - 1) * JJ;
for (CeedInt a = 0; a < A; a++) {
// Blocks of 4 columns
for (CeedInt c = (C / CC) * CC; c < C; c += 4) {
// Blocks of 4 rows
for (CeedInt j = 0; j < J_break; j += JJ) {
rtype vv[JJ]; // Output tile to be held in registers
for (CeedInt jj = 0; jj < JJ; jj++) vv[jj] = loadu(&v[(a * J + j + jj) * C + c]);
for (CeedInt b = 0; b < B; b++) {
rtype tqu;
if (C - c == 1) tqu = set(0.0, 0.0, 0.0, u[(a * B + b) * C + c + 0]);
else if (C - c == 2) tqu = set(0.0, 0.0, u[(a * B + b) * C + c + 1], u[(a * B + b) * C + c + 0]);
else if (C - c == 3) tqu = set(0.0, u[(a * B + b) * C + c + 2], u[(a * B + b) * C + c + 1], u[(a * B + b) * C + c + 0]);
else tqu = loadu(&u[(a * B + b) * C + c]);
for (CeedInt jj = 0; jj < JJ; jj++) { // unroll
fmadd(vv[jj], tqu, set1(t[(j + jj) * t_stride_0 + b * t_stride_1]));
}
}
for (CeedInt jj = 0; jj < JJ; jj++) storeu(&v[(a * J + j + jj) * C + c], vv[jj]);
}
}
// Remainder of rows, all columns
for (CeedInt j = J_break; j < J; j++) {
for (CeedInt b = 0; b < B; b++) {
const CeedScalar tq = t[j * t_stride_0 + b * t_stride_1];
for (CeedInt c = (C / CC) * CC; c < C; c++) v[(a * J + j) * C + c] += tq * u[(a * B + b) * C + c];
}
}
}
return CEED_ERROR_SUCCESS;
}
//------------------------------------------------------------------------------
// Serial Tensor Contract C=1
//------------------------------------------------------------------------------
static inline int CeedTensorContract_Avx_Single(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t,
CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v,
const CeedInt AA, const CeedInt JJ) {
CeedInt t_stride_0 = B, t_stride_1 = 1;
if (t_mode == CEED_TRANSPOSE) {
t_stride_0 = 1;
t_stride_1 = J;
}
// Blocks of 4 rows
for (CeedInt a = 0; a < (A / AA) * AA; a += AA) {
for (CeedInt j = 0; j < (J / JJ) * JJ; j += JJ) {
rtype vv[AA][JJ / 4]; // Output tile to be held in registers
for (CeedInt aa = 0; aa < AA; aa++) {
for (CeedInt jj = 0; jj < JJ / 4; jj++) vv[aa][jj] = loadu(&v[(a + aa) * J + j + jj * 4]);
}
for (CeedInt b = 0; b < B; b++) {
for (CeedInt jj = 0; jj < JJ / 4; jj++) { // unroll
rtype tqv = set(t[(j + jj * 4 + 3) * t_stride_0 + b * t_stride_1], t[(j + jj * 4 + 2) * t_stride_0 + b * t_stride_1],
t[(j + jj * 4 + 1) * t_stride_0 + b * t_stride_1], t[(j + jj * 4 + 0) * t_stride_0 + b * t_stride_1]);
for (CeedInt aa = 0; aa < AA; aa++) { // unroll
fmadd(vv[aa][jj], tqv, set1(u[(a + aa) * B + b]));
}
}
}
for (CeedInt aa = 0; aa < AA; aa++) {
for (CeedInt jj = 0; jj < JJ / 4; jj++) storeu(&v[(a + aa) * J + j + jj * 4], vv[aa][jj]);
}
}
}
// Remainder of rows
const CeedInt a = (A / AA) * AA;
for (CeedInt j = 0; j < (J / JJ) * JJ; j += JJ) {
rtype vv[AA][JJ / 4]; // Output tile to be held in registers
for (CeedInt aa = 0; aa < A - a; aa++) {
for (CeedInt jj = 0; jj < JJ / 4; jj++) vv[aa][jj] = loadu(&v[(a + aa) * J + j + jj * 4]);
}
for (CeedInt b = 0; b < B; b++) {
for (CeedInt jj = 0; jj < JJ / 4; jj++) { // unroll
rtype tqv = set(t[(j + jj * 4 + 3) * t_stride_0 + b * t_stride_1], t[(j + jj * 4 + 2) * t_stride_0 + b * t_stride_1],
t[(j + jj * 4 + 1) * t_stride_0 + b * t_stride_1], t[(j + jj * 4 + 0) * t_stride_0 + b * t_stride_1]);
for (CeedInt aa = 0; aa < A - a; aa++) { // unroll
fmadd(vv[aa][jj], tqv, set1(u[(a + aa) * B + b]));
}
}
}
for (CeedInt aa = 0; aa < A - a; aa++) {
for (CeedInt jj = 0; jj < JJ / 4; jj++) storeu(&v[(a + aa) * J + j + jj * 4], vv[aa][jj]);
}
}
// Column remainder
const CeedInt A_break = A % AA ? (A / AA) * AA : (A / AA - 1) * AA;
// Blocks of 4 columns
for (CeedInt j = (J / JJ) * JJ; j < J; j += 4) {
// Blocks of 4 rows
for (CeedInt a = 0; a < A_break; a += AA) {
rtype vv[AA]; // Output tile to be held in registers
for (CeedInt aa = 0; aa < AA; aa++) vv[aa] = loadu(&v[(a + aa) * J + j]);
for (CeedInt b = 0; b < B; b++) {
rtype tqv;
if (J - j == 1) {
tqv = set(0.0, 0.0, 0.0, t[(j + 0) * t_stride_0 + b * t_stride_1]);
} else if (J - j == 2) {
tqv = set(0.0, 0.0, t[(j + 1) * t_stride_0 + b * t_stride_1], t[(j + 0) * t_stride_0 + b * t_stride_1]);
} else if (J - 3 == j) {
tqv =
set(0.0, t[(j + 2) * t_stride_0 + b * t_stride_1], t[(j + 1) * t_stride_0 + b * t_stride_1], t[(j + 0) * t_stride_0 + b * t_stride_1]);
} else {
tqv = set(t[(j + 3) * t_stride_0 + b * t_stride_1], t[(j + 2) * t_stride_0 + b * t_stride_1], t[(j + 1) * t_stride_0 + b * t_stride_1],
t[(j + 0) * t_stride_0 + b * t_stride_1]);
}
for (CeedInt aa = 0; aa < AA; aa++) { // unroll
fmadd(vv[aa], tqv, set1(u[(a + aa) * B + b]));
}
}
for (CeedInt aa = 0; aa < AA; aa++) storeu(&v[(a + aa) * J + j], vv[aa]);
}
}
// Remainder of rows, all columns
for (CeedInt b = 0; b < B; b++) {
for (CeedInt j = (J / JJ) * JJ; j < J; j++) {
const CeedScalar tq = t[j * t_stride_0 + b * t_stride_1];
for (CeedInt a = A_break; a < A; a++) v[a * J + j] += tq * u[a * B + b];
}
}
return CEED_ERROR_SUCCESS;
}
//------------------------------------------------------------------------------
// Tensor Contract - Common Sizes
//------------------------------------------------------------------------------
static int CeedTensorContract_Avx_Blocked_4_8(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t,
CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v) {
return CeedTensorContract_Avx_Blocked(contract, A, B, C, J, t, t_mode, add, u, v, 4, 8);
}
static int CeedTensorContract_Avx_Remainder_8_8(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t,
CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v) {
return CeedTensorContract_Avx_Remainder(contract, A, B, C, J, t, t_mode, add, u, v, 8, 8);
}
static int CeedTensorContract_Avx_Single_4_8(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t,
CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v) {
return CeedTensorContract_Avx_Single(contract, A, B, C, J, t, t_mode, add, u, v, 4, 8);
}
//------------------------------------------------------------------------------
// Tensor Contract Apply
//------------------------------------------------------------------------------
static int CeedTensorContractApply_Avx(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t,
CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v) {
const CeedInt blk_size = 8;
if (!add) {
for (CeedInt q = 0; q < A * J * C; q++) v[q] = (CeedScalar)0.0;
}
if (C == 1) {
// Serial C=1 Case
CeedTensorContract_Avx_Single_4_8(contract, A, B, C, J, t, t_mode, true, u, v);
} else {
// Blocks of 8 columns
if (C >= blk_size) CeedTensorContract_Avx_Blocked_4_8(contract, A, B, C, J, t, t_mode, true, u, v);
// Remainder of columns
if (C % blk_size) CeedTensorContract_Avx_Remainder_8_8(contract, A, B, C, J, t, t_mode, true, u, v);
}
return CEED_ERROR_SUCCESS;
}
//------------------------------------------------------------------------------
// Tensor Contract Create
//------------------------------------------------------------------------------
int CeedTensorContractCreate_Avx(CeedTensorContract contract) {
CeedCallBackend(CeedSetBackendFunction(CeedTensorContractReturnCeed(contract), "TensorContract", contract, "Apply", CeedTensorContractApply_Avx));
return CEED_ERROR_SUCCESS;
}
//------------------------------------------------------------------------------