Skip to content

Commit 5919417

Browse files
committed
Embed non-linked builtins in OCaml strings for easier availability at runtime
Note: Metal backend still broken
1 parent c0edd14 commit 5919417

File tree

6 files changed

+15
-20
lines changed

6 files changed

+15
-20
lines changed

arrayjit/lib/builtins_large.cu renamed to arrayjit/lib/builtins_cuda_large.ml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
let source = {|
12
#include <cuda_runtime.h>
23
#include <stdint.h>
34

@@ -107,4 +108,6 @@ __device__ uint4x32_t arrayjit_threefry4x32_impl(uint4x32_t key, uint4x32_t coun
107108
return result;
108109
}
109110

110-
__device__ uint4x32_t ( *arrayjit_threefry4x32)(uint4x32_t key, uint4x32_t counter) = arrayjit_threefry4x32_impl;
111+
__device__ uint4x32_t ( *arrayjit_threefry4x32)(uint4x32_t key, uint4x32_t counter) = arrayjit_threefry4x32_impl;
112+
113+
|}

arrayjit/lib/builtins_small.cu renamed to arrayjit/lib/builtins_cuda_small.ml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
1+
let source = {|
22
typedef struct {
33
uint32_t v[4];
44
} uint4x32_t;
@@ -181,4 +181,5 @@ __device__ uint8x16_t uint4x32_to_u8_uniform_vec(uint4x32_t x) {
181181
result.v[i*4 + 3] = (uint8_t)((x.v[i] >> 24) & 0xFF);
182182
}
183183
return result;
184-
}
184+
}
185+
|}

arrayjit/lib/builtins.metal renamed to arrayjit/lib/builtins_metal.ml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
let source = {|
12
#include <metal_stdlib>
23
using namespace metal;
34

@@ -367,4 +368,5 @@ uint4 half_to_uint4x32(uint16_t x) {
367368

368369
uint4 fp8_to_uint4x32(uint8_t x) {
369370
return uint4(uint32_t(x), 0, 0, 0);
370-
}
371+
}
372+
|}

arrayjit/lib/cuda_backend.ml

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -709,11 +709,7 @@ end) : Ir.Backend_impl.Lowered_backend = struct
709709
if Utils.debug_log_from_routines () then
710710
Buffer.add_string b "__device__ int printf (const char * format, ... );\n";
711711
Buffer.add_string b "\n\n";
712-
let builtins_path =
713-
Stdlib.Filename.concat (Stdlib.Filename.dirname Stdlib.__FILE__) "builtins_small.cu"
714-
in
715-
let builtins_content = Stdio.In_channel.read_all builtins_path in
716-
Buffer.add_string b builtins_content;
712+
Buffer.add_string b Builtins_cuda_small.source;
717713
(* Needs to be after the small builtins, because uses uint4x32_t. *)
718714
Buffer.add_string b builtins_large_header;
719715
Buffer.add_string b "\n\n"

arrayjit/lib/dune

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@
6262
(name cuda_backend)
6363
(public_name arrayjit.cuda_backend)
6464
(optional)
65-
(modules cuda_backend)
65+
(modules cuda_backend builtins_cuda_large builtins_cuda_small)
6666
(libraries base cudajit.cuda cudajit.nvrtc utils ir)
6767
(preprocess
6868
(pps
@@ -78,7 +78,7 @@
7878
(name metal_backend)
7979
(public_name arrayjit.metal_backend)
8080
(optional)
81-
(modules metal_backend)
81+
(modules metal_backend builtins_metal)
8282
(libraries base metal utils ir)
8383
(preprocess
8484
(pps

arrayjit/lib/metal_backend.ml

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -639,21 +639,14 @@ end) : Ir.Backend_impl.Lowered_backend = struct
639639
Stdio.prerr_endline error_msg;
640640
failwith error_msg
641641

642-
let prepend_builtins b =
643-
let builtins_path =
644-
Stdlib.Filename.concat (Stdlib.Filename.dirname Stdlib.__FILE__) "builtins.metal"
645-
in
646-
let builtins_content = Stdio.In_channel.read_all builtins_path in
647-
Buffer.add_string b builtins_content;
648-
Buffer.add_string b "\n\n"
649-
650642
let compile ~name bindings lowered =
651643
let module Syntax = C_syntax.C_syntax (C_syntax_config (struct
652644
let procs = [| lowered |]
653645
end)) in
654646
let idx_params = Indexing.bound_symbols bindings in
655647
let b = Buffer.create 4096 in
656-
prepend_builtins b;
648+
Buffer.add_string b Builtins_metal.source;
649+
Buffer.add_string b "\n";
657650
let declarations_doc = Syntax.print_declarations () in
658651
(* Add Metal address space qualifiers *)
659652
let params, proc_doc = Syntax.compile_proc ~name idx_params lowered in

0 commit comments

Comments
 (0)