Skip to content

Commit 582a71f

Browse files
committed
Proper support for half precision, don't use Ctypes.bigarray_start
1 parent 7661321 commit 582a71f

File tree

9 files changed

+128
-128
lines changed

9 files changed

+128
-128
lines changed

CHANGES.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
### Added
44

55
- The previously-mocked support for half precision.
6-
- Currently broken because of missing Ctypes coverage.
6+
- We work around the missing Ctypes coverage by not using `Ctypes.bigarray_start`.
77

88
### Changed
99

arrayjit/lib/backend_utils.ml

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ module C_syntax (B : sig
4444
val main_kernel_prefix : string
4545
val kernel_prep_line : string
4646
val extra_include_lines : string list
47+
val typ_of_prec : Ops.prec -> string
4748
end) =
4849
struct
4950
open Types
@@ -150,7 +151,7 @@ struct
150151
let loop_f = pp_float tn.prec in
151152
let loop_debug_f = debug_float tn.prec in
152153
let num_closing_braces = pp_top_locals ppf llv in
153-
let num_typ = Ops.cuda_typ_of_prec tn.prec in
154+
let num_typ = B.typ_of_prec tn.prec in
154155
if Utils.debug_log_from_routines () then (
155156
fprintf ppf "@[<2>{@ @[<2>%s new_set_v =@ %a;@]@ " num_typ loop_f llv;
156157
let v_code, v_idcs = loop_debug_f llv in
@@ -207,7 +208,7 @@ struct
207208
and pp_top_locals ppf (vcomp : Low_level.float_t) : int =
208209
match vcomp with
209210
| Local_scope { id = { scope_id = i; tn = { prec; _ } }; body; orig_indices = _ } ->
210-
let num_typ = Ops.cuda_typ_of_prec prec in
211+
let num_typ = B.typ_of_prec prec in
211212
(* Arrays are initialized to 0 by default. However, there is typically an explicit
212213
initialization for virtual nodes. *)
213214
fprintf ppf "@[<2>{@ %s v%d = 0;@ " num_typ i;
@@ -220,19 +221,19 @@ struct
220221
| Binop (_, v1, v2) -> pp_top_locals ppf v1 + pp_top_locals ppf v2
221222
| Unop (_, v) -> pp_top_locals ppf v
222223
and pp_float prec ppf value =
223-
let num_typ = Ops.cuda_typ_of_prec prec in
224+
let num_typ = B.typ_of_prec prec in
224225
let loop = pp_float prec in
225226
match value with
226227
| Local_scope { id; _ } ->
227228
(* Embedding of Local_scope is done by pp_top_locals. *)
228229
loop ppf @@ Get_local id
229230
| Get_local id ->
230-
let get_typ = Ops.cuda_typ_of_prec id.tn.prec in
231+
let get_typ = B.typ_of_prec id.tn.prec in
231232
if not @@ String.equal num_typ get_typ then fprintf ppf "(%s)" num_typ;
232233
fprintf ppf "v%d" id.scope_id
233234
| Get_global (Ops.Merge_buffer { source_node_id }, Some idcs) ->
234235
let tn = Option.value_exn ~here:[%here] @@ Tn.find ~id:source_node_id in
235-
fprintf ppf "@[<2>((%s*)merge_buffer)[%a@;<0 -2>]@]" (Ops.cuda_typ_of_prec prec)
236+
fprintf ppf "@[<2>((%s*)merge_buffer)[%a@;<0 -2>]@]" (B.typ_of_prec prec)
236237
pp_array_offset
237238
(idcs, Lazy.force tn.dims)
238239
| Get_global _ -> failwith "C_syntax: Get_global / FFI NOT IMPLEMENTED YET"
@@ -255,15 +256,15 @@ struct
255256
(* FIXME: don't recompute v *)
256257
fprintf ppf "@[<1>(%a > 0.0 ?@ %a : 0.0@;<0 -1>)@]" loop v loop v
257258
and debug_float prec (value : Low_level.float_t) : string * 'a list =
258-
let num_typ = Ops.cuda_typ_of_prec prec in
259+
let num_typ = B.typ_of_prec prec in
259260
let loop = debug_float prec in
260261
match value with
261262
| Local_scope { id; _ } ->
262263
(* Not printing the inlined definition: (1) code complexity; (2) don't overload the debug
263264
logs. *)
264265
loop @@ Get_local id
265266
| Get_local id ->
266-
let get_typ = Ops.cuda_typ_of_prec id.tn.prec in
267+
let get_typ = B.typ_of_prec id.tn.prec in
267268
let v =
268269
(if not @@ String.equal num_typ get_typ then "(" ^ num_typ ^ ")" else "")
269270
^ "v" ^ Int.to_string id.scope_id
@@ -315,7 +316,7 @@ struct
315316
if not @@ Utils.sexp_mem ~elem:backend_info tn.backend_info then
316317
tn.backend_info <- Utils.sexp_append ~elem:backend_info tn.backend_info;
317318
if B.is_in_context node && not (Hash_set.mem is_global tn) then
318-
(Ops.cuda_typ_of_prec tn.Tn.prec ^ " *" ^ get_ident tn, Param_ptr tn) :: params
319+
(B.typ_of_prec tn.Tn.prec ^ " *" ^ get_ident tn, Param_ptr tn) :: params
319320
else params)
320321
in
321322
let idx_params =
@@ -333,7 +334,7 @@ struct
333334
Option.(
334335
to_list
335336
@@ map merge_node ~f:(fun tn ->
336-
("const " ^ Ops.cuda_typ_of_prec tn.prec ^ " *merge_buffer", Merge_buffer)))
337+
("const " ^ B.typ_of_prec tn.prec ^ " *merge_buffer", Merge_buffer)))
337338
in
338339
let params = log_file @ merge_param @ idx_params @ params in
339340
let params =
@@ -382,7 +383,7 @@ struct
382383
Hashtbl.iteri traced_store ~f:(fun ~key:tn ~data:node ->
383384
if not (Tn.is_virtual_force tn 333 || B.is_in_context node || Hash_set.mem is_global tn)
384385
then
385-
fprintf ppf "%s %s[%d]%s;@ " (Ops.cuda_typ_of_prec tn.prec) (get_ident tn)
386+
fprintf ppf "%s %s[%d]%s;@ " (B.typ_of_prec tn.prec) (get_ident tn)
386387
(Tn.num_elems tn)
387388
(if node.zero_initialized then " = {0}" else "")
388389
else if (not (Tn.is_virtual_force tn 333)) && node.zero_initialized then pp_zero_out ppf tn);

arrayjit/lib/cc_backend.ml

Lines changed: 33 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,6 @@ let ctx_arrays context = context.arrays
2626

2727
type buffer_ptr = ctx_array [@@deriving sexp_of]
2828

29-
(** Alternative approach:
30-
31-
{[
32-
type buffer_ptr = unit Ctypes_static.ptr
33-
34-
let sexp_of_buffer_ptr ptr = Sexp.Atom (Ops.ptr_to_string ptr Ops.Void_prec)
35-
let buffer_ptr ctx_array = Ndarray.get_voidptr ctx_array
36-
]} *)
37-
3829
let buffer_ptr ctx_array = ctx_array
3930

4031
let alloc_buffer ?old_buffer ~size_in_bytes () =
@@ -112,18 +103,38 @@ let c_compile_and_load ~f_name =
112103
while rc = 0 && (not @@ (Stdlib.Sys.file_exists libname && Stdlib.Sys.file_exists log_fname)) do
113104
Unix.sleepf 0.001
114105
done;
115-
(if rc <> 0 then
116-
let errors =
117-
"Cc_backend.c_compile_and_load: compilation failed with errors:\n"
118-
^ Stdio.In_channel.read_all log_fname
119-
in
120-
Stdio.prerr_endline errors;
121-
invalid_arg errors);
106+
if rc <> 0 then (
107+
let errors =
108+
"Cc_backend.c_compile_and_load: compilation failed with errors:\n"
109+
^ Stdio.In_channel.read_all log_fname
110+
in
111+
Stdio.prerr_endline errors;
112+
invalid_arg errors);
122113
(* Note: RTLD_DEEPBIND not available on MacOS. *)
123114
let result = { lib = Dl.dlopen ~filename:libname ~flags:[ RTLD_NOW ]; libname } in
124115
Stdlib.Gc.finalise (fun lib -> Dl.dlclose ~handle:lib.lib) result;
125116
result
126117

118+
module C_syntax_config (Input : sig
119+
val for_lowereds : Low_level.optimized array
120+
val opt_ctx_arrays : (Tn.t, buffer_ptr, Tn.comparator_witness) Base.Map.t option
121+
end) =
122+
struct
123+
let for_lowereds = Input.for_lowereds
124+
125+
type nonrec ctx_array = ctx_array
126+
127+
let opt_ctx_arrays = Input.opt_ctx_arrays
128+
let hardcoded_context_ptr = Some Ndarray.c_ptr_to_string
129+
let is_in_context = is_in_context
130+
let host_ptrs_for_readonly = true
131+
let logs_to_stdout = false
132+
let main_kernel_prefix = ""
133+
let kernel_prep_line = ""
134+
let extra_include_lines = []
135+
let typ_of_prec = Ops.c_typ_of_prec
136+
end
137+
127138
let%diagn_sexp compile ~(name : string) ~opt_ctx_arrays bindings (lowered : Low_level.optimized) =
128139
let opt_ctx_arrays =
129140
Option.map opt_ctx_arrays ~f:(fun ctx_arrays ->
@@ -140,20 +151,10 @@ let%diagn_sexp compile ~(name : string) ~opt_ctx_arrays bindings (lowered : Low_
140151
else ctx_arrays
141152
| Some _ -> ctx_arrays))
142153
in
143-
let module Syntax = Backend_utils.C_syntax (struct
154+
let module Syntax = Backend_utils.C_syntax (C_syntax_config (struct
144155
let for_lowereds = [| lowered |]
145-
146-
type nonrec ctx_array = ctx_array
147-
148156
let opt_ctx_arrays = opt_ctx_arrays
149-
let hardcoded_context_ptr = Some Ndarray.c_ptr_to_string
150-
let is_in_context = is_in_context
151-
let host_ptrs_for_readonly = true
152-
let logs_to_stdout = false
153-
let main_kernel_prefix = ""
154-
let kernel_prep_line = ""
155-
let extra_include_lines = []
156-
end) in
157+
end)) in
157158
(* FIXME: do we really want all of them, or only the used ones? *)
158159
let idx_params = Indexing.bound_symbols bindings in
159160
let pp_file = Utils.pp_file ~base_name:name ~extension:".c" in
@@ -183,20 +184,10 @@ let%diagn_sexp compile_batch ~names ~opt_ctx_arrays bindings
183184
else ctx_arrays
184185
| Some _ -> ctx_arrays)))
185186
in
186-
let module Syntax = Backend_utils.C_syntax (struct
187+
let module Syntax = Backend_utils.C_syntax (C_syntax_config (struct
187188
let for_lowereds = for_lowereds
188-
189-
type nonrec ctx_array = ctx_array
190-
191189
let opt_ctx_arrays = opt_ctx_arrays
192-
let hardcoded_context_ptr = Some Ndarray.c_ptr_to_string
193-
let is_in_context = is_in_context
194-
let host_ptrs_for_readonly = true
195-
let logs_to_stdout = false
196-
let main_kernel_prefix = ""
197-
let kernel_prep_line = ""
198-
let extra_include_lines = []
199-
end) in
190+
end)) in
200191
(* FIXME: do we really want all of them, or only the used ones? *)
201192
let idx_params = Indexing.bound_symbols bindings in
202193
let global_ctx_arrays =
@@ -270,13 +261,11 @@ let%diagn_sexp link_compiled ~merge_buffer (prior_context : context) (code : pro
270261
| bs, Log_file_name :: ps ->
271262
Param_1 (ref (Some log_file_name), link bs ps Ctypes.(string @-> cs))
272263
| bs, Merge_buffer :: ps ->
273-
let get_ptr (buffer, _) = Ndarray.get_voidptr buffer in
264+
let get_ptr (buffer, _) = Ndarray.get_voidptr_not_managed buffer in
274265
Param_2f (get_ptr, merge_buffer, link bs ps Ctypes.(ptr void @-> cs))
275266
| bs, Param_ptr tn :: ps ->
276267
let nd = match Map.find arrays tn with Some nd -> nd | None -> assert false in
277-
(* let f ba = Ctypes.bigarray_start Ctypes_static.Genarray ba in let c_ptr =
278-
Ndarray.(map { f } nd) in *)
279-
let c_ptr = Ndarray.get_voidptr nd in
268+
let c_ptr = Ndarray.get_voidptr_not_managed nd in
280269
Param_2 (ref (Some c_ptr), link bs ps Ctypes.(ptr void @-> cs))
281270
in
282271
(* Reverse the input order because [Indexing.apply] will reverse it again. Important:

arrayjit/lib/cuda_backend.cudajit.ml

Lines changed: 26 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -307,26 +307,34 @@ let is_in_context node =
307307
Tnode.default_to_most_local node.Low_level.tn 33;
308308
match node.tn.memory_mode with Some ((Virtual | Local), _) -> false | _ -> true
309309

310-
let compile ~name bindings ({ Low_level.traced_store; _ } as lowered) =
311-
(* TODO: The following link seems to claim it's better to expand into loops than use memset.
312-
https://stackoverflow.com/questions/23712558/how-do-i-best-initialize-a-local-memory-array-to-0 *)
313-
let module Syntax = Backend_utils.C_syntax (struct
314-
let for_lowereds = [| lowered |]
310+
module C_syntax_config (Input : sig
311+
val for_lowereds : Low_level.optimized array
312+
end) =
313+
struct
314+
let for_lowereds = Input.for_lowereds
315315

316-
type nonrec ctx_array = buffer_ptr
316+
type nonrec ctx_array = buffer_ptr
317317

318-
let opt_ctx_arrays = None
319-
let hardcoded_context_ptr = None
320-
let is_in_context = is_in_context
321-
let host_ptrs_for_readonly = true
322-
let logs_to_stdout = true
323-
let main_kernel_prefix = "extern \"C\" __global__"
318+
let opt_ctx_arrays = None
319+
let hardcoded_context_ptr = None
320+
let is_in_context = is_in_context
321+
let host_ptrs_for_readonly = true
322+
let logs_to_stdout = true
323+
let main_kernel_prefix = "extern \"C\" __global__"
324324

325-
let kernel_prep_line =
326-
"/* FIXME: single-threaded for now. */if (threadIdx.x != 0 || blockIdx.x != 0) { return; }"
325+
let kernel_prep_line =
326+
"/* FIXME: single-threaded for now. */if (threadIdx.x != 0 || blockIdx.x != 0) { return; }"
327327

328-
let extra_include_lines = [ "#include <cuda_fp16.h>" ]
329-
end) in
328+
let extra_include_lines = [ "#include <cuda_fp16.h>" ]
329+
let typ_of_prec = Ops.cuda_typ_of_prec
330+
end
331+
332+
let compile ~name bindings ({ Low_level.traced_store; _ } as lowered) =
333+
(* TODO: The following link seems to claim it's better to expand into loops than use memset.
334+
https://stackoverflow.com/questions/23712558/how-do-i-best-initialize-a-local-memory-array-to-0 *)
335+
let module Syntax = Backend_utils.C_syntax (C_syntax_config (struct
336+
let for_lowereds = [| lowered |]
337+
end)) in
330338
let idx_params = Indexing.bound_symbols bindings in
331339
let b = Buffer.create 4096 in
332340
let ppf = Stdlib.Format.formatter_of_buffer b in
@@ -339,23 +347,9 @@ let compile ~name bindings ({ Low_level.traced_store; _ } as lowered) =
339347

340348
let compile_batch ~names bindings lowereds =
341349
let for_lowereds = Array.filter_map ~f:Fn.id lowereds in
342-
let module Syntax = Backend_utils.C_syntax (struct
350+
let module Syntax = Backend_utils.C_syntax (C_syntax_config (struct
343351
let for_lowereds = for_lowereds
344-
345-
type nonrec ctx_array = buffer_ptr
346-
347-
let opt_ctx_arrays = None
348-
let hardcoded_context_ptr = None
349-
let is_in_context = is_in_context
350-
let host_ptrs_for_readonly = true
351-
let logs_to_stdout = true
352-
let main_kernel_prefix = "extern \"C\" __global__"
353-
354-
let kernel_prep_line =
355-
"/* FIXME: single-threaded for now. */if (threadIdx.x != 0 || blockIdx.x != 0) { return; }"
356-
357-
let extra_include_lines = [ "#include <cuda_fp16.h>" ]
358-
end) in
352+
end)) in
359353
let idx_params = Indexing.bound_symbols bindings in
360354
let b = Buffer.create 4096 in
361355
let ppf = Stdlib.Format.formatter_of_buffer b in

arrayjit/lib/gcc_backend.gccjit.ml

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ type buffer_ptr = ctx_array [@@deriving sexp_of]
3535
type buffer_ptr = unit Ctypes_static.ptr
3636
3737
let sexp_of_buffer_ptr ptr = Sexp.Atom (Ops.ptr_to_string ptr Ops.Void_prec)
38-
let buffer_ptr ctx_array = Ndarray.get_voidptr ctx_array
38+
let buffer_ptr ctx_array = Ndarray.get_voidptr_not_managed ctx_array
3939
]} *)
4040

4141
let buffer_ptr ctx_array = ctx_array
@@ -171,6 +171,7 @@ let zero_out ctx block node =
171171
]
172172

173173
let get_c_ptr ctx num_typ ba =
174+
(* FIXME(#284): half precision support breaks here. *)
174175
Gccjit.(RValue.ptr ctx (Type.pointer num_typ) @@ Ctypes.bigarray_start Ctypes_static.Genarray ba)
175176

176177
let prepare_node ~debug_log_zero_out ~get_ident ctx nodes traced_store ctx_nodes initializations
@@ -840,12 +841,10 @@ let%diagn_sexp link_compiled ~merge_buffer (prior_context : context) (code : pro
840841
Param_1 (ref (Some log_file_name), link bs ps Ctypes.(string @-> cs))
841842
| bs, Param_ptr tn :: ps ->
842843
let nd = match Map.find arrays tn with Some nd -> nd | None -> assert false in
843-
(* let f ba = Ctypes.bigarray_start Ctypes_static.Genarray ba in let c_ptr =
844-
Ndarray.(map { f } nd) in *)
845-
let c_ptr = Ndarray.get_voidptr nd in
844+
let c_ptr = Ndarray.get_voidptr_not_managed nd in
846845
Param_2 (ref (Some c_ptr), link bs ps Ctypes.(ptr void @-> cs))
847846
| bs, Merge_buffer :: ps ->
848-
let get_ptr (buffer, _) = Ndarray.get_voidptr buffer in
847+
let get_ptr (buffer, _) = Ndarray.get_voidptr_not_managed buffer in
849848
Param_2f (get_ptr, merge_buffer, link bs ps Ctypes.(ptr void @-> cs))
850849
in
851850
(* Folding by [link] above reverses the input order. Important: [code.bindings] are traversed

arrayjit/lib/ndarray.ml

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,10 @@ module A = Bigarray.Genarray
1414

1515
type ('ocaml, 'elt_t) bigarray = ('ocaml, 'elt_t, Bigarray.c_layout) A.t
1616

17-
let big_ptr_to_string arr =
18-
"@"
19-
^ Nativeint.Hex.to_string
20-
(Ctypes.raw_address_of_ptr @@ Ctypes.to_voidp
21-
@@ Ctypes.bigarray_start Ctypes_static.Genarray arr)
17+
let bigarray_start_not_managed (arr : ('ocaml, 'elt_t) bigarray) =
18+
Ctypes_bigarray.unsafe_address arr
19+
20+
let big_ptr_to_string arr = "@" ^ Nativeint.Hex.to_string (bigarray_start_not_managed arr)
2221

2322
let sexp_of_bigarray (arr : ('a, 'b) bigarray) =
2423
let dims = A.dims arr in
@@ -179,14 +178,15 @@ let map2 { f2 } x1 x2 =
179178

180179
let dims = map { f = A.dims }
181180

182-
let get_voidptr =
181+
let get_voidptr_not_managed nd : unit Ctypes.ptr =
183182
let f arr =
184-
let open Ctypes in
185-
coerce
186-
(ptr @@ typ_of_bigarray_kind @@ Bigarray.Genarray.kind arr)
187-
(ptr void) (bigarray_start genarray arr)
183+
Ctypes_static.CPointer
184+
(Ctypes_memory.make_unmanaged ~reftyp:Ctypes_static.void @@ bigarray_start_not_managed arr)
185+
(* This doesn't work because Ctypes.bigarray_start doesn't support half precision: *)
186+
(* let open Ctypes in coerce (ptr @@ typ_of_bigarray_kind @@ Bigarray.Genarray.kind arr) (ptr
187+
void) (bigarray_start genarray arr) *)
188188
in
189-
map { f }
189+
map { f } nd
190190

191191
let set_from_float arr idx v =
192192
match arr with
@@ -368,12 +368,12 @@ let retrieve_flat_values arr =
368368

369369
let c_ptr_to_string nd =
370370
let prec = get_prec nd in
371-
let f arr = Ops.c_ptr_to_string (Ctypes.bigarray_start Ctypes_static.Genarray arr) prec in
371+
let f arr = Ops.c_rawptr_to_string (bigarray_start_not_managed arr) prec in
372372
map { f } nd
373373

374374
let ptr_to_string_hum nd =
375375
let prec = get_prec nd in
376-
let f arr = Ops.ptr_to_string_hum (Ctypes.bigarray_start Ctypes_static.Genarray arr) prec in
376+
let f arr = Ops.rawptr_to_string_hum (bigarray_start_not_managed arr) prec in
377377
map { f } nd
378378

379379
(** {2 *** Creating ***} *)

0 commit comments

Comments
 (0)