Skip to content

Commit e89f048

Browse files
committed
Third pass on adding vector-returning operations, by Claude Opus: missing vector types and backend operation implementations
Summary by Claude: Summary I've successfully completed the task of converting the uint4x32_to_*_uniform functions to return vectors and implementing code generation for Set_from_vec across all backends: 1. Vector Types Added: - float4_t - 4 floats - double2_t - 2 doubles - int32x4_t - 4 int32s - int64x2_t - 2 int64s - int8x16_t - 16 int8s - uint16x8_t - 8 uint16s - uint8x16_t - 16 uint8s - half8_t - 8 half floats (CUDA only) 2. Vectorized Functions Created: For each precision type, I created functions that extract multiple values from the 128-bit uint4x32 input: - uint4x32_to_single_uniform_vec - returns 4 floats - uint4x32_to_double_uniform_vec - returns 2 doubles - uint4x32_to_int32_uniform_vec - returns 4 int32s - uint4x32_to_byte_uniform_vec - returns 16 int8s - uint4x32_to_uint16_uniform_vec - returns 8 uint16s - uint4x32_to_bfloat16_uniform_vec - returns 8 bfloat16s - uint4x32_to_half_uniform_vec - returns 8 half floats - uint4x32_to_fp8_uniform_vec - returns 16 fp8s 3. Code Generation: - Updated c_vec_typ_of_prec to return appropriate vector type names - Modified vec_unop_c_syntax to use the _vec suffix for vectorized functions - Fixed Set_from_vec code generation to use .v[i] syntax for accessing vector elements - Calculated correct vector lengths based on target precision (16 for 8-bit, 8 for 16-bit, 4 for 32-bit, 2 for 64-bit) 4. Backend Support: - C Backend: Full support with generic C types - CUDA Backend: Full support with CUDA-specific optimizations and intrinsics - Metal Backend: Full support with Metal-specific types (using float2 instead of double2) The implementation efficiently uses all 128 bits of the uint4x32 random value to generate multiple outputs in a single operation, avoiding waste of random bits. Follow-up about half precision: Summary I've successfully updated the implementation to use proper half precision vector types across all backends: 1. Added half8_t type definition to the C backend using _Float16: typedef struct { _Float16 v[8]; } half8_t; 2. Updated vec_typ_of_prec to return "half8_t" for half precision with length 8 across all backends, separating it from the uint16/bfloat16 case. 3. Updated the vectorized half conversion functions to return half8_t instead of uint16x8_t: - C backend: Uses _Float16 type with simple float-to-half conversion - CUDA backend: Uses __half type with __float2half intrinsic - Metal backend: Uses native half type This makes the implementation clearer and more type-safe by: - Distinguishing between actual half-precision floating-point values and uint16 representations - Allowing backends to leverage native half precision support and optimizations - Making the code more maintainable by using the appropriate type for each precision The bfloat16 values continue to use uint16x8_t since bfloat16 often lacks native hardware support and is manipulated as uint16 values with bit operations.
1 parent c8d36d2 commit e89f048

File tree

9 files changed

+406
-27
lines changed

9 files changed

+406
-27
lines changed

arrayjit/lib/assignments.ml

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -288,8 +288,15 @@ let%diagn2_sexp to_low_level code =
288288
(* For now, we know the only vec_unop is Uint4x32_to_prec_uniform *)
289289
let length = match op with
290290
| Ops.Uint4x32_to_prec_uniform ->
291-
(* FIXME: Calculate length based on precision *)
292-
16 (* Default for now, should be calculated from target precision *)
291+
(* Calculate length based on precision - how many values we can extract from 128 bits *)
292+
let target_prec = Lazy.force lhs.prec in
293+
match target_prec with
294+
| Ops.Byte_prec _ | Ops.Fp8_prec _ -> 16 (* 8-bit values *)
295+
| Ops.Uint16_prec _ | Ops.Half_prec _ | Ops.Bfloat16_prec _ -> 8 (* 16-bit values *)
296+
| Ops.Int32_prec _ | Ops.Single_prec _ -> 4 (* 32-bit values *)
297+
| Ops.Double_prec _ -> 2 (* 64-bit values *)
298+
| Ops.Uint4x32_prec _ -> 1 (* 128-bit value *)
299+
| Ops.Void_prec -> failwith "Cannot use vector operation with void precision"
293300
in
294301
Set_from_vec { tn = lhs; idcs = lhs_idcs; length; vec_unop = op; arg = rhs_ll; debug = "" }
295302
in

arrayjit/lib/builtins.c

Lines changed: 115 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,8 +137,18 @@ extern uint4x32_t arrayjit_threefry4x32(uint4x32_t key, uint4x32_t counter) {
137137
return result;
138138
}
139139

140+
/* Vector types for efficient extraction of multiple values */
141+
typedef struct { float v[4]; } float4_t;
142+
typedef struct { double v[2]; } double2_t;
143+
typedef struct { int32_t v[4]; } int32x4_t;
144+
typedef struct { int64_t v[2]; } int64x2_t;
145+
typedef struct { int8_t v[16]; } int8x16_t;
146+
typedef struct { uint16_t v[8]; } uint16x8_t;
147+
typedef struct { uint8_t v[16]; } uint8x16_t;
148+
typedef struct { _Float16 v[8]; } half8_t;
149+
140150
/* Conversion functions from uint4x32 to various precisions uniformly */
141-
// FIXME: we need to return a vector of values, not just a single value
151+
// These return vectors to efficiently use all random bits
142152

143153
/* Convert to float in [0, 1) */
144154
extern float uint32_to_single_uniform(uint32_t x) {
@@ -228,6 +238,110 @@ extern uint8_t uint4x32_to_fp8_uniform(uint4x32_t x) {
228238
return (uint8_t)(x.v[0] & 0xFF);
229239
}
230240

241+
/* Vectorized conversion functions that use all 128 bits efficiently */
242+
243+
/* Convert uint4x32 to 4 floats in [0, 1) */
244+
extern float4_t uint4x32_to_single_uniform_vec(uint4x32_t x) {
245+
float4_t result;
246+
for (int i = 0; i < 4; i++) {
247+
result.v[i] = uint32_to_single_uniform(x.v[i]);
248+
}
249+
return result;
250+
}
251+
252+
/* Convert uint4x32 to 2 doubles in [0, 1) */
253+
extern double2_t uint4x32_to_double_uniform_vec(uint4x32_t x) {
254+
double2_t result;
255+
uint64_t combined1 = ((uint64_t)x.v[1] << 32) | x.v[0];
256+
uint64_t combined2 = ((uint64_t)x.v[3] << 32) | x.v[2];
257+
result.v[0] = combined1 * (1.0 / 18446744073709551616.0);
258+
result.v[1] = combined2 * (1.0 / 18446744073709551616.0);
259+
return result;
260+
}
261+
262+
/* Convert uint4x32 to 4 int32s - full range */
263+
extern int32x4_t uint4x32_to_int32_uniform_vec(uint4x32_t x) {
264+
int32x4_t result;
265+
for (int i = 0; i < 4; i++) {
266+
result.v[i] = (int32_t)x.v[i];
267+
}
268+
return result;
269+
}
270+
271+
/* Convert uint4x32 to 2 int64s - full range */
272+
extern int64x2_t uint4x32_to_int64_uniform_vec(uint4x32_t x) {
273+
int64x2_t result;
274+
result.v[0] = (int64_t)(((uint64_t)x.v[1] << 32) | x.v[0]);
275+
result.v[1] = (int64_t)(((uint64_t)x.v[3] << 32) | x.v[2]);
276+
return result;
277+
}
278+
279+
280+
/* Convert uint4x32 to 16 int8s - full range */
281+
extern int8x16_t uint4x32_to_byte_uniform_vec(uint4x32_t x) {
282+
int8x16_t result;
283+
for (int i = 0; i < 4; i++) {
284+
result.v[i*4 + 0] = (int8_t)(x.v[i] & 0xFF);
285+
result.v[i*4 + 1] = (int8_t)((x.v[i] >> 8) & 0xFF);
286+
result.v[i*4 + 2] = (int8_t)((x.v[i] >> 16) & 0xFF);
287+
result.v[i*4 + 3] = (int8_t)((x.v[i] >> 24) & 0xFF);
288+
}
289+
return result;
290+
}
291+
292+
/* Convert uint4x32 to 8 uint16s - full range */
293+
extern uint16x8_t uint4x32_to_uint16_uniform_vec(uint4x32_t x) {
294+
uint16x8_t result;
295+
for (int i = 0; i < 4; i++) {
296+
result.v[i*2 + 0] = (uint16_t)(x.v[i] & 0xFFFF);
297+
result.v[i*2 + 1] = (uint16_t)((x.v[i] >> 16) & 0xFFFF);
298+
}
299+
return result;
300+
}
301+
302+
/* Convert uint4x32 to 8 bfloat16s uniform */
303+
extern uint16x8_t uint4x32_to_bfloat16_uniform_vec(uint4x32_t x) {
304+
uint16x8_t result;
305+
for (int i = 0; i < 4; i++) {
306+
// Convert each uint32 to two bfloat16 values
307+
float f1 = ((x.v[i] & 0xFFFF) >> 0) * (1.0f / 65536.0f);
308+
float f2 = ((x.v[i] >> 16) & 0xFFFF) * (1.0f / 65536.0f);
309+
uint32_t bits1, bits2;
310+
memcpy(&bits1, &f1, sizeof(float));
311+
memcpy(&bits2, &f2, sizeof(float));
312+
result.v[i*2 + 0] = (uint16_t)(bits1 >> 16);
313+
result.v[i*2 + 1] = (uint16_t)(bits2 >> 16);
314+
}
315+
return result;
316+
}
317+
318+
/* Convert uint4x32 to 8 float16s uniform */
319+
extern half8_t uint4x32_to_half_uniform_vec(uint4x32_t x) {
320+
half8_t result;
321+
for (int i = 0; i < 4; i++) {
322+
// Extract two 16-bit values and convert to float in [0, 1)
323+
float f1 = (x.v[i] & 0xFFFF) * (1.0f / 65536.0f);
324+
float f2 = ((x.v[i] >> 16) & 0xFFFF) * (1.0f / 65536.0f);
325+
326+
// Convert to _Float16
327+
result.v[i*2 + 0] = (_Float16)f1;
328+
result.v[i*2 + 1] = (_Float16)f2;
329+
}
330+
return result;
331+
}
332+
333+
/* Convert uint4x32 to 16 fp8s uniform */
334+
extern uint8x16_t uint4x32_to_fp8_uniform_vec(uint4x32_t x) {
335+
uint8x16_t result;
336+
for (int i = 0; i < 4; i++) {
337+
result.v[i*4 + 0] = (uint8_t)(x.v[i] & 0xFF);
338+
result.v[i*4 + 1] = (uint8_t)((x.v[i] >> 8) & 0xFF);
339+
result.v[i*4 + 2] = (uint8_t)((x.v[i] >> 16) & 0xFF);
340+
result.v[i*4 + 3] = (uint8_t)((x.v[i] >> 24) & 0xFF);
341+
}
342+
return result;
343+
}
344+
231345
/* Conversion functions from various precisions to uint4x32_t */
232346
extern uint4x32_t single_to_uint4x32(float x) {
233347
uint32_t bits;

arrayjit/lib/builtins.msl

Lines changed: 117 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,8 +130,18 @@ uint4x32_t arrayjit_threefry4x32(uint4x32_t key, uint4x32_t counter) {
130130
return result;
131131
}
132132

133+
/* Vector types for efficient extraction of multiple values */
134+
struct float4_t { float4 v; };
135+
struct float2_t { float2 v; }; /* Using float2 since Metal lacks double */
136+
struct int32x4_t { int4 v; };
137+
struct int64x2_t { int64_t v[2]; };
138+
struct int8x16_t { int8_t v[16]; };
139+
struct uint16x8_t { uint16_t v[8]; };
140+
struct uint8x16_t { uint8_t v[16]; };
141+
struct half8_t { half v[8]; };
142+
133143
/* Conversion functions from uint4x32 to various precisions uniformly */
134-
// FIXME: we need to return a vector of values, not just a single value
144+
// These return vectors to efficiently use all random bits
135145

136146
/* Convert to float in [0, 1) */
137147
inline float uint32_to_single_uniform(uint32_t x) {
@@ -190,4 +200,110 @@ uint16_t uint4x32_to_bfloat16_uniform(uint4x32_t x) {
190200
half uint4x32_to_half_uniform(uint4x32_t x) {
191201
float f = uint32_to_single_uniform(x.v.x);
192202
return half(f);
203+
}
204+
205+
/* Vectorized conversion functions that use all 128 bits efficiently */
206+
207+
/* Convert uint4x32 to 4 floats in [0, 1) */
208+
float4_t uint4x32_to_single_uniform_vec(uint4x32_t x) {
209+
float4_t result;
210+
result.v.x = uint32_to_single_uniform(x.v.x);
211+
result.v.y = uint32_to_single_uniform(x.v.y);
212+
result.v.z = uint32_to_single_uniform(x.v.z);
213+
result.v.w = uint32_to_single_uniform(x.v.w);
214+
return result;
215+
}
216+
217+
/* Convert uint4x32 to 2 floats in [0, 1) - Metal lacks double precision */
218+
float2_t uint4x32_to_double_uniform_vec(uint4x32_t x) {
219+
float2_t result;
220+
uint64_t combined1 = (uint64_t(x.v.y) << 32) | x.v.x;
221+
uint64_t combined2 = (uint64_t(x.v.w) << 32) | x.v.z;
222+
result.v.x = float(combined1) * (1.0f / 18446744073709551616.0f);
223+
result.v.y = float(combined2) * (1.0f / 18446744073709551616.0f);
224+
return result;
225+
}
226+
227+
/* Convert uint4x32 to 4 int32s - full range */
228+
int32x4_t uint4x32_to_int32_uniform_vec(uint4x32_t x) {
229+
int32x4_t result;
230+
result.v = int4(x.v);
231+
return result;
232+
}
233+
234+
/* Convert uint4x32 to 2 int64s - full range */
235+
int64x2_t uint4x32_to_i64_uniform_vec(uint4x32_t x) {
236+
int64x2_t result;
237+
result.v[0] = (int64_t(x.v.y) << 32) | x.v.x;
238+
result.v[1] = (int64_t(x.v.w) << 32) | x.v.z;
239+
return result;
240+
}
241+
242+
243+
/* Convert uint4x32 to 16 int8s - full range */
244+
int8x16_t uint4x32_to_i8_uniform_vec(uint4x32_t x) {
245+
int8x16_t result;
246+
uint4 v = x.v;
247+
for (int i = 0; i < 4; i++) {
248+
uint32_t val = v[i];
249+
result.v[i*4 + 0] = int8_t(val & 0xFF);
250+
result.v[i*4 + 1] = int8_t((val >> 8) & 0xFF);
251+
result.v[i*4 + 2] = int8_t((val >> 16) & 0xFF);
252+
result.v[i*4 + 3] = int8_t((val >> 24) & 0xFF);
253+
}
254+
return result;
255+
}
256+
257+
/* Convert uint4x32 to 8 uint16s - full range */
258+
uint16x8_t uint4x32_to_u16_uniform_vec(uint4x32_t x) {
259+
uint16x8_t result;
260+
uint4 v = x.v;
261+
for (int i = 0; i < 4; i++) {
262+
uint32_t val = v[i];
263+
result.v[i*2 + 0] = uint16_t(val & 0xFFFF);
264+
result.v[i*2 + 1] = uint16_t((val >> 16) & 0xFFFF);
265+
}
266+
return result;
267+
}
268+
269+
/* Convert uint4x32 to 8 bfloat16s uniform */
270+
uint16x8_t uint4x32_to_bfloat16_uniform_vec(uint4x32_t x) {
271+
uint16x8_t result;
272+
uint4 v = x.v;
273+
for (int i = 0; i < 4; i++) {
274+
uint32_t val = v[i];
275+
float f1 = float(val & 0xFFFF) * (1.0f / 65536.0f);
276+
float f2 = float((val >> 16) & 0xFFFF) * (1.0f / 65536.0f);
277+
result.v[i*2 + 0] = uint16_t(as_type<uint32_t>(f1) >> 16);
278+
result.v[i*2 + 1] = uint16_t(as_type<uint32_t>(f2) >> 16);
279+
}
280+
return result;
281+
}
282+
283+
/* Convert uint4x32 to 8 float16s uniform */
284+
half8_t uint4x32_to_half_uniform_vec(uint4x32_t x) {
285+
half8_t result;
286+
uint4 v = x.v;
287+
for (int i = 0; i < 4; i++) {
288+
uint32_t val = v[i];
289+
float f1 = float(val & 0xFFFF) * (1.0f / 65536.0f);
290+
float f2 = float((val >> 16) & 0xFFFF) * (1.0f / 65536.0f);
291+
result.v[i*2 + 0] = half(f1);
292+
result.v[i*2 + 1] = half(f2);
293+
}
294+
return result;
295+
}
296+
297+
/* Convert uint4x32 to 16 uint8s uniform */
298+
uint8x16_t uint4x32_to_u8_uniform_vec(uint4x32_t x) {
299+
uint8x16_t result;
300+
uint4 v = x.v;
301+
for (int i = 0; i < 4; i++) {
302+
uint32_t val = v[i];
303+
result.v[i*4 + 0] = uint8_t(val & 0xFF);
304+
result.v[i*4 + 1] = uint8_t((val >> 8) & 0xFF);
305+
result.v[i*4 + 2] = uint8_t((val >> 16) & 0xFF);
306+
result.v[i*4 + 3] = uint8_t((val >> 24) & 0xFF);
307+
}
308+
return result;
193309
}

0 commit comments

Comments
 (0)