This repository has been archived by the owner on Nov 17, 2023. It is now read-only.
/
half.h
374 lines (336 loc) · 14 KB
/
half.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file half.h
* \brief definition of half (float16) type.
*
* \author Junyuan Xie
*/
#ifndef MSHADOW_HALF_H_
#define MSHADOW_HALF_H_
#include "./base.h"
#if MSHADOW_USE_F16C
#include <x86intrin.h>
#endif // MSHADOW_USE_F16C
// This flag dictates rounding for the float2half() routine only (used generally on Windows),
// not the f16c lib or cuda v7.5 (or later) behavior which is fixed at round-to-nearest-even.
#ifndef MSHADOW_HALF_ROUND_TO_NEAREST
#define MSHADOW_HALF_ROUND_TO_NEAREST 1
#endif
#if (MSHADOW_USE_CUDA && CUDA_VERSION >= 7050)
#define MSHADOW_CUDA_HALF 1
#include <cuda_fp16.h>
#if defined(__CUDA_ARCH__)
/*! \brief __half2float_warp */
MSHADOW_XINLINE float __half2float_warp(const volatile __half& h) { /* NOLINT(*) */
__half val;
#if CUDA_VERSION >= 9000
val = const_cast<__half&>(h);
#else
val.x = h.x;
#endif
return __half2float(val);
}
#endif
#else
#define MSHADOW_CUDA_HALF 0
#endif
/*! \brief namespace for mshadow */
namespace mshadow {
/* \brief name space for host/device portable half-precision floats */
namespace half {
#define MSHADOW_HALF_OPERATOR(RTYPE, OP) \
MSHADOW_XINLINE RTYPE operator OP (half_t a, half_t b) { \
return RTYPE(float(a) OP float(b)); /* NOLINT(*) */ \
} \
template<typename T> \
MSHADOW_XINLINE RTYPE operator OP (half_t a, T b) { \
return RTYPE(float(a) OP float(b)); /* NOLINT(*) */ \
} \
template<typename T> \
MSHADOW_XINLINE RTYPE operator OP (T a, half_t b) { \
return RTYPE(float(a) OP float(b)); /* NOLINT(*) */ \
}
#define MSHADOW_HALF_ASSIGNOP(AOP, OP) \
template<typename T> \
MSHADOW_XINLINE half_t operator AOP (const T& a) { \
return *this = half_t(float(*this) OP float(a)); /* NOLINT(*)*/ \
} \
template<typename T> \
MSHADOW_XINLINE half_t operator AOP (const volatile T& a) volatile { \
return *this = half_t(float(*this) OP float(a)); /* NOLINT(*)*/ \
}
#if (MSHADOW_CUDA_HALF && defined(__CUDA_ARCH__))
#define MSHADOW_HALF_CONVERSIONOP(T) \
MSHADOW_XINLINE operator T() const { \
return T(__half2float(cuhalf_)); /* NOLINT(*)*/ \
} \
MSHADOW_XINLINE operator T() const volatile { \
return T(__half2float_warp(cuhalf_)); /* NOLINT(*)*/ \
}
#elif(MSHADOW_USE_F16C)
#define MSHADOW_HALF_CONVERSIONOP(T) \
MSHADOW_XINLINE operator T() const { \
return T(_cvtsh_ss(half_)); /* NOLINT(*)*/ \
} \
MSHADOW_XINLINE operator T() const volatile { \
return T(_cvtsh_ss(half_)); /* NOLINT(*)*/ \
}
#else
#define MSHADOW_HALF_CONVERSIONOP(T) \
MSHADOW_XINLINE operator T() const { \
return T(half2float(half_)); /* NOLINT(*)*/ \
} \
MSHADOW_XINLINE operator T() const volatile { \
return T(half2float(half_)); /* NOLINT(*)*/ \
}
#endif // (MSHADOW_CUDA_HALF && defined(__CUDA_ARCH__))
class MSHADOW_ALIGNED(2) half_t {
public:
union {
uint16_t half_;
#if MSHADOW_CUDA_HALF
__half cuhalf_;
#endif // MSHADOW_CUDA_HALF
};
static MSHADOW_XINLINE half_t Binary(uint16_t value) {
half_t res;
res.half_ = value;
return res;
}
MSHADOW_XINLINE half_t() {}
#if MSHADOW_CUDA_HALF
MSHADOW_XINLINE explicit half_t(const __half& value) {
cuhalf_ = value;
}
#endif // MSHADOW_CUDA_HALF
MSHADOW_XINLINE half_t(const float& value) { constructor(value); }
MSHADOW_XINLINE explicit half_t(const double& value) { constructor(value); }
MSHADOW_XINLINE explicit half_t(const int8_t& value) { constructor(value); }
MSHADOW_XINLINE explicit half_t(const uint8_t& value) { constructor(value); }
MSHADOW_XINLINE explicit half_t(const int32_t& value) { constructor(value); }
MSHADOW_XINLINE explicit half_t(const uint32_t& value) { constructor(value); }
MSHADOW_XINLINE explicit half_t(const int64_t& value) { constructor(value); }
MSHADOW_XINLINE explicit half_t(const uint64_t& value) { constructor(value); }
MSHADOW_HALF_CONVERSIONOP(float)
MSHADOW_HALF_ASSIGNOP(+=, +)
MSHADOW_HALF_ASSIGNOP(-=, -)
MSHADOW_HALF_ASSIGNOP(*=, *)
MSHADOW_HALF_ASSIGNOP(/=, /)
MSHADOW_XINLINE half_t operator+() {
return *this;
}
MSHADOW_XINLINE half_t operator-() {
return half_t(-float(*this)); // NOLINT(*)
}
MSHADOW_XINLINE half_t operator=(const half_t& a) {
half_ = a.half_;
return a;
}
template<typename T>
MSHADOW_XINLINE half_t operator=(const T& a) {
return *this = half_t(a); /* NOLINT(*)*/
}
MSHADOW_XINLINE half_t operator=(const half_t& a) volatile {
half_ = a.half_;
return a;
}
template<typename T>
MSHADOW_XINLINE half_t operator=(const T& a) volatile {
return *this = half_t(a); /* NOLINT(*)*/
}
private:
union Bits {
float f;
int32_t si;
uint32_t ui;
};
static int const fp16FractionBits = 10;
static int const fp32FractionBits = 23;
static int32_t const fp32FractionMask = ~(~0u << fp32FractionBits); // == 0x7fffff
static int32_t const fp32HiddenBit = 1 << fp32FractionBits; // == 0x800000
static int const shift = fp32FractionBits - fp16FractionBits; // == 13
static int const shiftSign = 16;
static int32_t const expAdjust = 127 - 15; // exp32-127 = exp16-15, so exp16 = exp32 - (127-15)
static int32_t const infN = 0x7F800000; // flt32 infinity
static int32_t const maxN = 0x477FFFFF; // max flt32 that's a flt16 normal after >> by shift
static int32_t const minN = 0x38800000; // min flt16 normal as a flt32
static int32_t const maxZ = 0x33000000; // max fp32 number that's still rounded to zero in fp16
static int32_t const signN = 0x80000000; // flt32 sign bit
static int32_t const infC = infN >> shift;
static int32_t const nanN = (infC + 1) << shift; // minimum flt16 nan as a flt32
static int32_t const maxC = maxN >> shift;
static int32_t const minC = minN >> shift;
static int32_t const signC = signN >> shiftSign; // flt16 sign bit
static int32_t const mulN = 0x52000000; // (1 << 23) / minN
static int32_t const mulC = 0x33800000; // minN / (1 << (23 - shift))
static int32_t const subC = 0x003FF; // max flt32 subnormal down shifted
static int32_t const norC = 0x00400; // min flt32 normal down shifted
static int32_t const maxD = infC - maxC - 1;
static int32_t const minD = minC - subC - 1;
MSHADOW_XINLINE uint16_t float2half(const float& value) const {
Bits v;
v.f = value;
uint32_t sign = v.si & signN; // grab sign bit
v.si ^= sign; // clear sign bit from v
sign >>= shiftSign; // logical shift sign to fp16 position
if (v.si <= maxZ) {
// Handle eventual zeros here to ensure vshift will not exceed 32 below.
v.ui = 0;
} else if (v.si < minN) {
// Handle denorms
uint32_t exp32 = v.ui >> fp32FractionBits;
int32_t exp16 = exp32 - expAdjust;
// If exp16 == 0 (just into the denorm range), then significant should be shifted right 1.
// Smaller (so negative) exp16 values should result in greater right shifts.
uint32_t vshift = 1 - exp16;
uint32_t significand = fp32HiddenBit | (v.ui & fp32FractionMask);
v.ui = significand >> vshift;
// The only time it's *not* OK to add 0x1000 (i.e. half the flt16 fraction lsb) is
// when the lsb of the flt16 fraction == 0 (so not rounding up to even) and the additional
// bits to the right of the lsb are 1000... (including flt32 significand bits
// that may be lost during the above vshift). The first term below will always
// be true for vshift >=12 (since even the 'hidden bit' has been shifted to the
// right of the '1' bit in 0x1000). And when vshift <= 11, both terms combine to make
// the proper test of the flt32 significand bits, including those lost during the vshift.
#if MSHADOW_HALF_ROUND_TO_NEAREST == 1
// Rounding may increase the exponent to 1, but that's OK.
v.ui += (v.ui & 0x3fff) != 0x1000 || (significand & 0x7ff) ? 0x1000 : 0;
#endif
} else if (v.si <= maxN) {
// Handle norms
#if MSHADOW_HALF_ROUND_TO_NEAREST == 1
// Rounding may increase the exponent, possibly creating an inf, but that's OK.
v.ui += (v.ui & 0x3fff) != 0x1000 ? 0x1000 : 0;
#endif
v.ui -= expAdjust << fp32FractionBits;
} else if (v.si <= infN) {
v.si = infN;
} else if (v.si < nanN) {
v.si = nanN;
}
v.ui >>= shift;
return sign | (v.ui & 0x7fff);
}
// Same as above routine, except for addition of volatile keyword
MSHADOW_XINLINE uint16_t float2half(const volatile float& value) const volatile { // NOLINT (*)
Bits v;
v.f = value;
uint32_t sign = v.si & signN; // grab sign bit
v.si ^= sign; // clear sign bit from v
sign >>= shiftSign; // logical shift sign to fp16 position
if (v.si <= maxZ) {
// Handle eventual zeros here to ensure vshift will not exceed 32 below.
v.ui = 0;
} else if (v.si < minN) {
// Handle denorms
uint32_t exp32 = v.ui >> fp32FractionBits;
int32_t exp16 = exp32 - expAdjust;
// If exp16 == 0 (just into the denorm range), then significant should be shifted right 1.
// Smaller (so negative) exp16 values should result in greater right shifts.
uint32_t vshift = 1 - exp16;
uint32_t significand = fp32HiddenBit | (v.ui & fp32FractionMask);
v.ui = significand >> vshift;
#if MSHADOW_HALF_ROUND_TO_NEAREST == 1
// Rounding may increase the exponent to 1, but that's OK.
v.ui += (v.ui & 0x3fff) != 0x1000 || (significand & 0x7ff) ? 0x1000 : 0;
#endif
} else if (v.si <= maxN) {
// Handle norms
#if MSHADOW_HALF_ROUND_TO_NEAREST == 1
// Rounding may increase the exponent, possibly creating an inf, but that's OK.
v.ui += (v.ui & 0x3fff) != 0x1000 ? 0x1000 : 0;
#endif
v.ui -= expAdjust << fp32FractionBits;
} else if (v.si <= infN) {
v.si = infN;
} else if (v.si < nanN) {
v.si = nanN;
}
v.ui >>= shift;
return sign | (v.ui & 0x7fff);
}
MSHADOW_XINLINE float half2float(const uint16_t& value) const {
Bits v;
v.ui = value;
int32_t sign = v.si & signC;
v.si ^= sign;
sign <<= shiftSign;
v.si ^= ((v.si + minD) ^ v.si) & -(v.si > subC);
v.si ^= ((v.si + maxD) ^ v.si) & -(v.si > maxC);
Bits s;
s.si = mulC;
s.f *= v.si;
int32_t mask = -(norC > v.si);
v.si <<= shift;
v.si ^= (s.si ^ v.si) & mask;
v.si |= sign;
return v.f;
}
MSHADOW_XINLINE float half2float(const volatile uint16_t& value) const volatile { // NOLINT(*)
Bits v;
v.ui = value;
int32_t sign = v.si & signC;
v.si ^= sign;
sign <<= shiftSign;
v.si ^= ((v.si + minD) ^ v.si) & -(v.si > subC);
v.si ^= ((v.si + maxD) ^ v.si) & -(v.si > maxC);
Bits s;
s.si = mulC;
s.f *= v.si;
int32_t mask = -(norC > v.si);
v.si <<= shift;
v.si ^= (s.si ^ v.si) & mask;
v.si |= sign;
return v.f;
}
template<typename T>
MSHADOW_XINLINE void constructor(const T& value) {
#if (MSHADOW_CUDA_HALF && defined(__CUDA_ARCH__))
cuhalf_ = __float2half(float(value)); // NOLINT(*)
#elif(MSHADOW_USE_F16C)
half_ = _cvtss_sh(static_cast<float>(value), 0);
#else /* !MSHADOW_CUDA_HALF && !MSHADOW_USE_F16C */
half_ = float2half(float(value)); // NOLINT(*)
#endif /* !MSHADOW_CUDA_HALF && !MSHADOW_USE_F16C */
}
};
/*! \brief overloaded + operator for half_t */
MSHADOW_HALF_OPERATOR(half_t, +)
/*! \brief overloaded - operator for half_t */
MSHADOW_HALF_OPERATOR(half_t, -)
/*! \brief overloaded * operator for half_t */
MSHADOW_HALF_OPERATOR(half_t, *)
/*! \brief overloaded / operator for half_t */
MSHADOW_HALF_OPERATOR(half_t, /)
/*! \brief overloaded > operator for half_t */
MSHADOW_HALF_OPERATOR(bool, >)
/*! \brief overloaded < operator for half_t */
MSHADOW_HALF_OPERATOR(bool, <)
/*! \brief overloaded >= operator for half_t */
MSHADOW_HALF_OPERATOR(bool, >=)
/*! \brief overloaded <= operator for half_t */
MSHADOW_HALF_OPERATOR(bool, <=)
#define MSHADOW_HALF_MIN mshadow::half::half_t::Binary(0xFBFF);
#define MSHADOW_HALF_MAX mshadow::half::half_t::Binary(0x7BFF);
#define MSHADOW_HALF_SIGN_BIT 0x8000
#define MSHADOW_HALF_EXPONENT_BITS 0x7c00
} // namespace half
} // namespace mshadow
#endif // MSHADOW_HALF_H_