forked from AnswerDotAI/gpu.cpp
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathhalf.h
268 lines (228 loc) · 10.1 KB
/
half.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
#ifndef HALF_H
#define HALF_H
#include <cfloat>
#include <climits>
#include <cmath>
#include <cstdint>
#include <cstdio>
struct half;
half halfFromFloat(float f);
float halfToFloat(half h);
/**
* Experimental implementation of half-precision 16-bit floating point numbers.
*/
struct half {
uint16_t data;
// Default constructor
half() : data(0) {}
// Constructor from float
half(float f) { *this = halfFromFloat(f); }
// Constructor from uint16_t
explicit half(uint16_t value) : data(value) {}
operator float() const { return halfToFloat(*this); }
// Conversion operator to uint16_t
operator uint16_t() const { return data; }
// Overload assignment operator from uint16_t
half &operator=(uint16_t value) {
data = value;
return *this;
}
// Overload assignment operator from another half
half &operator=(const half &other) {
data = other.data;
return *this;
}
// Overload assignment operator from float
half &operator=(float value) {
data = halfFromFloat(value);
return *this;
}
};
/**
* @brief Converts a 32-bit float to a 16-bit half-precision float.
*
* Based on Mike Acton's half.c implementation.
*/
half halfFromFloat(float f) {
union {
float f;
uint32_t u;
} floatUnion = {f};
uint32_t float32 = floatUnion.u;
// Constants for bit masks, shifts, and biases
const uint16_t ONE = 0x0001;
const uint32_t FLOAT_SIGN_MASK = 0x80000000;
const uint32_t FLOAT_EXP_MASK = 0x7f800000;
const uint32_t FLOAT_MANTISSA_MASK = 0x007fffff;
const uint32_t FLOAT_HIDDEN_BIT = 0x00800000;
const uint32_t FLOAT_ROUND_BIT = 0x00001000;
const uint16_t FLOAT_EXP_BIAS = 0x007f;
const uint16_t HALF_EXP_BIAS = 0x000f;
const uint16_t FLOAT_SIGN_POS = 0x001f;
const uint16_t HALF_SIGN_POS = 0x000f;
const uint16_t FLOAT_EXP_POS = 0x0017;
const uint16_t HALF_EXP_POS = 0x000a;
const uint16_t HALF_EXP_MASK = 0x7c00;
const uint16_t FLOAT_EXP_FLAGGED_VALUE = 0x00ff;
const uint16_t HALF_EXP_MASK_VALUE = HALF_EXP_MASK >> HALF_EXP_POS;
const uint16_t HALF_EXP_MAX_VALUE = HALF_EXP_MASK_VALUE - ONE;
const uint16_t FLOAT_HALF_SIGN_POS_OFFSET = FLOAT_SIGN_POS - HALF_SIGN_POS;
const uint16_t FLOAT_HALF_BIAS_OFFSET = FLOAT_EXP_BIAS - HALF_EXP_BIAS;
const uint16_t FLOAT_HALF_MANTISSA_POS_OFFSET = FLOAT_EXP_POS - HALF_EXP_POS;
const uint16_t HALF_NAN_MIN = HALF_EXP_MASK | ONE;
// Extracting the sign, exponent, and mantissa from the 32-bit float
const uint32_t floatSignMasked = float32 & FLOAT_SIGN_MASK;
const uint32_t floatExpMasked = float32 & FLOAT_EXP_MASK;
const uint16_t halfSign =
static_cast<uint16_t>(floatSignMasked >> FLOAT_HALF_SIGN_POS_OFFSET);
const uint16_t floatExp =
static_cast<uint16_t>(floatExpMasked >> FLOAT_EXP_POS);
const uint32_t floatMantissa = float32 & FLOAT_MANTISSA_MASK;
// Check for NaN
if ((floatExpMasked == FLOAT_EXP_MASK) && (floatMantissa != 0)) {
half result;
result.data =
HALF_EXP_MASK | (floatMantissa >> FLOAT_HALF_MANTISSA_POS_OFFSET);
return result;
}
// Adjusting the exponent and rounding the mantissa
const uint16_t floatExpHalfBias = floatExp - FLOAT_HALF_BIAS_OFFSET;
const uint32_t floatMantissaRoundMask = floatMantissa & FLOAT_ROUND_BIT;
const uint32_t floatMantissaRoundOffset = floatMantissaRoundMask << ONE;
const uint32_t floatMantissaRounded =
floatMantissa + floatMantissaRoundOffset;
// Handling denormalized numbers
const uint32_t floatMantissaDenormShiftAmount = ONE - floatExpHalfBias;
const uint32_t floatMantissaWithHidden =
floatMantissaRounded | FLOAT_HIDDEN_BIT;
const uint32_t floatMantissaDenorm =
floatMantissaWithHidden >> floatMantissaDenormShiftAmount;
const uint16_t halfMantissaDenorm = static_cast<uint16_t>(
floatMantissaDenorm >> FLOAT_HALF_MANTISSA_POS_OFFSET);
const uint16_t halfDenorm = halfSign | halfMantissaDenorm;
// Handling special cases: infinity and NaN
const uint16_t halfInf = halfSign | HALF_EXP_MASK;
const uint16_t mantissaNan =
static_cast<uint16_t>(floatMantissa >> FLOAT_HALF_MANTISSA_POS_OFFSET);
const uint16_t halfNan = halfSign | HALF_EXP_MASK | mantissaNan;
const uint16_t halfNanNotInf = halfSign | HALF_NAN_MIN;
// Handling overflow
const uint16_t halfExpNormOverflowOffset = floatExpHalfBias + ONE;
const uint16_t halfExpNormOverflow = halfExpNormOverflowOffset
<< HALF_EXP_POS;
const uint16_t halfNormOverflow = halfSign | halfExpNormOverflow;
// Handling normalized numbers
const uint16_t halfExpNorm = floatExpHalfBias << HALF_EXP_POS;
const uint16_t halfMantissaNorm = static_cast<uint16_t>(
floatMantissaRounded >> FLOAT_HALF_MANTISSA_POS_OFFSET);
const uint16_t halfNorm = halfSign | halfExpNorm | halfMantissaNorm;
// Checks and conditions
const uint16_t halfIsDenorm = FLOAT_HALF_BIAS_OFFSET >= floatExp;
const uint16_t floatHalfExpBiasedFlag =
FLOAT_EXP_FLAGGED_VALUE - FLOAT_HALF_BIAS_OFFSET;
const uint16_t floatExpIsFlagged = floatExpHalfBias == floatHalfExpBiasedFlag;
const uint16_t isFloatMantissaZero = floatMantissa == 0;
const uint16_t isHalfNanZero = mantissaNan == 0;
const uint16_t floatIsInf = floatExpIsFlagged && isFloatMantissaZero;
const uint16_t floatIsNanUnderflow = floatExpIsFlagged && isHalfNanZero;
const uint16_t floatIsNan = floatExpIsFlagged;
const uint16_t expIsOverflow = floatExpHalfBias > HALF_EXP_MAX_VALUE;
const uint32_t floatMantissaRoundedOverflow =
floatMantissaRounded & FLOAT_HIDDEN_BIT;
const uint32_t mantissaNormIsOverflow = floatMantissaRoundedOverflow != 0;
const uint16_t halfIsInf = expIsOverflow || floatIsInf;
// Selecting final result based on conditions
const uint16_t checkOverflowResult =
mantissaNormIsOverflow ? halfNormOverflow : halfNorm;
const uint16_t checkNanResult = floatIsNan ? halfNan : checkOverflowResult;
const uint16_t checkNanUnderflowResult =
floatIsNanUnderflow ? halfNanNotInf : checkNanResult;
const uint16_t checkInfResult = halfIsInf ? halfInf : checkNanUnderflowResult;
const uint16_t checkDenormResult = halfIsDenorm ? halfDenorm : checkInfResult;
// Final result after all checks
half result;
result.data = checkDenormResult;
return result;
}
/**
* @brief Converts a 16-bit half-precision float to a 32-bit float.
*
* Based on Mike Acton's half.c implementation.
*/
float halfToFloat(half h) {
// Constants for bit masks, shifts, and biases
const uint16_t ONE = 0x0001;
const uint16_t TWO = 0x0002;
const uint32_t FLOAT_EXP_MASK = 0x7f800000;
const uint32_t FLOAT_MANTISSA_MASK = 0x007fffff;
const uint16_t FLOAT_EXP_BIAS = 0x007f;
const uint16_t HALF_EXP_BIAS = 0x000f;
const uint16_t HALF_SIGN_MASK = 0x8000;
const uint16_t HALF_EXP_MASK = 0x7c00;
const uint16_t HALF_MANTISSA_MASK = 0x03ff;
const uint16_t HALF_EXP_POS = 0x000a;
const uint16_t FLOAT_EXP_POS = 0x0017;
const uint16_t FLOAT_SIGN_POS = 0x001f;
const uint16_t HALF_SIGN_POS = 0x000f;
const uint16_t HALF_FLOAT_DENORM_SA_OFFSET = 0x000a;
const uint32_t HALF_FLOAT_BIAS_OFFSET = HALF_EXP_BIAS - FLOAT_EXP_BIAS;
const uint16_t HALF_FLOAT_SIGN_POS_OFFSET = FLOAT_SIGN_POS - HALF_SIGN_POS;
const uint16_t HALF_FLOAT_MANTISSA_POS_OFFSET = FLOAT_EXP_POS - HALF_EXP_POS;
// Extracting the sign, exponent, and mantissa from the 16-bit float
const uint32_t halfSignMasked = h.data & HALF_SIGN_MASK;
const uint32_t halfExpMasked = h.data & HALF_EXP_MASK;
const uint16_t halfMantissa = h.data & HALF_MANTISSA_MASK;
// Shifting the sign bit to the correct position for the 32-bit float
const uint32_t floatSign = halfSignMasked << HALF_FLOAT_SIGN_POS_OFFSET;
// Adjusting the exponent
const uint16_t halfExpHalfBias = halfExpMasked >> HALF_EXP_POS;
const uint32_t floatExp = halfExpHalfBias - HALF_FLOAT_BIAS_OFFSET;
// Shifting the mantissa to the correct position for the 32-bit float
const uint32_t floatMantissa = halfMantissa << HALF_FLOAT_MANTISSA_POS_OFFSET;
// Checking conditions for zero, denormalized, infinity, and NaN
const uint32_t isExpNonZero = halfExpMasked != 0;
const uint32_t isMantissaNonZero = halfMantissa != 0;
const uint32_t isZero = !(isExpNonZero || isMantissaNonZero);
const uint32_t isDenorm = !isZero && !isExpNonZero;
const uint32_t isExpFlagged = halfExpMasked == HALF_EXP_MASK;
const uint32_t isInf = isExpFlagged && !isMantissaNonZero;
const uint32_t isNan = isExpFlagged && isMantissaNonZero;
// Handling denormalized numbers
const uint16_t halfMantissaLeadingZeros = __builtin_clz(halfMantissa) - 16;
const uint16_t halfDenormShiftAmount =
halfMantissaLeadingZeros + HALF_FLOAT_DENORM_SA_OFFSET;
const uint32_t halfFloatDenormMantissaShiftAmount =
halfDenormShiftAmount - TWO;
const uint32_t halfFloatDenormMantissa =
halfMantissa << halfFloatDenormMantissaShiftAmount;
const uint32_t floatDenormMantissa =
halfFloatDenormMantissa & FLOAT_MANTISSA_MASK;
const uint32_t halfFloatDenormShiftAmount = ONE - halfDenormShiftAmount;
const uint32_t floatDenormExp = halfFloatDenormShiftAmount + FLOAT_EXP_BIAS;
const uint32_t floatDenormExpPacked = floatDenormExp << FLOAT_EXP_POS;
const uint32_t floatDenorm =
floatSign | floatDenormExpPacked | floatDenormMantissa;
// Handling special cases: infinity and NaN
const uint32_t floatInf = floatSign | FLOAT_EXP_MASK;
const uint32_t floatNan = floatSign | FLOAT_EXP_MASK | floatMantissa;
// Handling zero
const uint32_t floatZero = floatSign;
// Handling normalized numbers
const uint32_t floatExpPacked = floatExp << FLOAT_EXP_POS;
const uint32_t packed = floatSign | floatExpPacked | floatMantissa;
// Selecting final result based on conditions
const uint32_t checkZeroResult = isZero ? floatZero : packed;
const uint32_t checkDenormResult = isDenorm ? floatDenorm : checkZeroResult;
const uint32_t checkInfResult = isInf ? floatInf : checkDenormResult;
const uint32_t checkNanResult = isNan ? floatNan : checkInfResult;
// Final result after all checks
const uint32_t result = checkNanResult;
// Reinterpret the uint32_t result as a float using a union
union {
uint32_t u;
float f;
} floatUnion;
floatUnion.u = result;
return floatUnion.f;
}
#endif // HALF_H