Skip to content

Commit 57d8c49

Browse files
committed
The previously-mocked support for half precision (but missing Ctypes)
Currently broken because of missing Ctypes coverage.
1 parent 1f4a416 commit 57d8c49

File tree

15 files changed

+72
-37
lines changed

15 files changed

+72
-37
lines changed

CHANGES.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
## [0.4.1] -- current
22

3+
### Added
4+
5+
- The previously-mocked support for half precision.
6+
- Currently broken because of missing Ctypes coverage.
7+
38
### Changed
49

510
- Removed the `pipes_cc, pipes_gccjit` backends (`Pipes_multicore_backend`) -- I had fixed `Pipes_multicore_backend` by using the `poll` library instead of `Unix.select`, but it turns out to be very very slow.

arrayjit.opam

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,12 @@ homepage: "https://github.com/lukstafi/ocannl"
1313
doc: "https://github.com/lukstafi/ocannl/blob/master/README.md"
1414
bug-reports: "https://github.com/lukstafi/ocannl/issues"
1515
depends: [
16-
"ocaml" {>= "5.1.0"}
16+
"ocaml" {>= "5.2.0"}
1717
"dune" {>= "3.11"}
1818
"base"
1919
"core"
20-
"ctypes" {>= "0.20"}
21-
"ctypes-foreign" {>= "0.20"}
20+
"ctypes" {>= "0.23"}
21+
"ctypes-foreign" {>= "0.23"}
2222
"printbox"
2323
"printbox-text"
2424
"ocannl_npy"

arrayjit/lib/assignments.ml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,7 @@ let fprint_hum ?name ?static_indices () ppf c =
301301
let tn = Option.value_exn ~here:[%here] @@ Tn.find ~id:source_node_id in
302302
fprintf ppf "%s.merge" (ident tn)
303303
| Imported (Ops.External_unsafe { ptr; prec; dims = _ }) ->
304-
fprintf ppf "%s" @@ Ops.ptr_to_string ptr prec
304+
fprintf ppf "%s" @@ Ops.ptr_to_string_hum ptr prec
305305
| Slice { batch_idx; sliced } ->
306306
fprintf ppf "%s @@| %s" (ident sliced) (Indexing.symbol_ident batch_idx.static_symbol)
307307
| Embed_symbol { static_symbol; static_range = _ } ->

arrayjit/lib/backend_utils.ml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ module C_syntax (B : sig
4343
val logs_to_stdout : bool
4444
val main_kernel_prefix : string
4545
val kernel_prep_line : string
46+
val extra_include_lines : string list
4647
end) =
4748
struct
4849
open Types
@@ -89,7 +90,8 @@ struct
8990
let open Stdlib.Format in
9091
let is_global = Hash_set.create (module Tn) in
9192
fprintf ppf
92-
{|@[<v 0>#include <stdio.h>@,#include <stdlib.h>@,#include <string.h>@,/* Global declarations. */@,|};
93+
{|@[<v 0>#include <stdio.h>@,#include <stdlib.h>@,#include <string.h>%a@,/* Global declarations. */@,|}
94+
(pp_print_list pp_print_string) B.extra_include_lines;
9395
Array.iter B.for_lowereds ~f:(fun l ->
9496
Hashtbl.iter l.Low_level.traced_store ~f:(fun (node : Low_level.traced_array) ->
9597
if not @@ Hash_set.mem is_global node.tn then
@@ -114,7 +116,7 @@ struct
114116
| _ -> ()));
115117
fprintf ppf "@,@]";
116118
is_global
117-
119+
118120
let compile_main ~traced_store ppf llc : unit =
119121
let open Stdlib.Format in
120122
let visited = Hash_set.create (module Tn) in

arrayjit/lib/cc_backend.ml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ let%diagn_sexp compile ~(name : string) ~opt_ctx_arrays bindings (lowered : Low_
147147
let logs_to_stdout = false
148148
let main_kernel_prefix = ""
149149
let kernel_prep_line = ""
150+
let extra_include_lines = []
150151
end) in
151152
(* FIXME: do we really want all of them, or only the used ones? *)
152153
let idx_params = Indexing.bound_symbols bindings in
@@ -189,6 +190,7 @@ let%diagn_sexp compile_batch ~names ~opt_ctx_arrays bindings
189190
let logs_to_stdout = false
190191
let main_kernel_prefix = ""
191192
let kernel_prep_line = ""
193+
let extra_include_lines = []
192194
end) in
193195
(* FIXME: do we really want all of them, or only the used ones? *)
194196
let idx_params = Indexing.bound_symbols bindings in

arrayjit/lib/cuda_backend.cudajit.ml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,8 @@ let compile ~name bindings ({ Low_level.traced_store; _ } as lowered) =
324324

325325
let kernel_prep_line =
326326
"/* FIXME: single-threaded for now. */if (threadIdx.x != 0 || blockIdx.x != 0) { return; }"
327+
328+
let extra_include_lines = [ "#include <cuda_fp16.h>" ]
327329
end) in
328330
let idx_params = Indexing.bound_symbols bindings in
329331
let b = Buffer.create 4096 in
@@ -351,6 +353,8 @@ let compile_batch ~names bindings lowereds =
351353

352354
let kernel_prep_line =
353355
"/* FIXME: single-threaded for now. */if (threadIdx.x != 0 || blockIdx.x != 0) { return; }"
356+
357+
let extra_include_lines = [ "#include <cuda_fp16.h>" ]
354358
end) in
355359
let idx_params = Indexing.bound_symbols bindings in
356360
let b = Buffer.create 4096 in
@@ -430,8 +434,7 @@ let link_proc ~prior_context ~name ~(params : (string * param_source) list) ~glo
430434
[%log_block
431435
context.label;
432436
Utils.log_trace_tree _output]);
433-
(* if Utils.debug_log_from_routines () then Cu.ctx_set_limit CU_LIMIT_PRINTF_FIFO_SIZE
434-
4096; *)
437+
(* if Utils.debug_log_from_routines () then Cu.ctx_set_limit CU_LIMIT_PRINTF_FIFO_SIZE 4096; *)
435438
Cu.launch_kernel func ~grid_dim_x:1 ~block_dim_x:1 ~shared_mem_bytes:0 context.device.stream
436439
args;
437440
[%log "kernel launched"]

arrayjit/lib/low_level.ml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -872,9 +872,9 @@ let fprint_hum ?name ?static_indices () ppf llc =
872872
| Get_global (Ops.C_function s, None) -> fprintf ppf "%s()" s
873873
| Get_global (Ops.C_function s, Some idcs) -> fprintf ppf "%s(%a)" s pp_indices idcs
874874
| Get_global (Ops.External_unsafe { ptr; prec; dims = _ }, None) ->
875-
fprintf ppf "%s" @@ Ops.ptr_to_string ptr prec
875+
fprintf ppf "%s" @@ Ops.ptr_to_string_hum ptr prec
876876
| Get_global (Ops.External_unsafe { ptr; prec; dims = _ }, Some idcs) ->
877-
fprintf ppf "%s[%a]" (Ops.ptr_to_string ptr prec) pp_indices idcs
877+
fprintf ppf "%s[%a]" (Ops.ptr_to_string_hum ptr prec) pp_indices idcs
878878
| Get_global (Ops.Merge_buffer { source_node_id }, None) ->
879879
let tn = Option.value_exn ~here:[%here] @@ Tnode.find ~id:source_node_id in
880880
fprintf ppf "%a.merge" pp_ident tn

arrayjit/lib/ndarray.ml

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ let precision_to_bigarray_kind (type ocaml elt_t) (prec : (ocaml, elt_t) Ops.pre
4949
(ocaml, elt_t) Bigarray.kind =
5050
match prec with
5151
| Byte -> Bigarray.Char
52-
| Half -> Bigarray.Float32
52+
| Half -> Bigarray.Float16
5353
| Single -> Bigarray.Float32
5454
| Double -> Bigarray.Float64
5555

@@ -368,7 +368,12 @@ let retrieve_flat_values arr =
368368

369369
let c_ptr_to_string nd =
370370
let prec = get_prec nd in
371-
let f arr = Ops.ptr_to_string (Ctypes.bigarray_start Ctypes_static.Genarray arr) prec in
371+
let f arr = Ops.c_ptr_to_string (Ctypes.bigarray_start Ctypes_static.Genarray arr) prec in
372+
map { f } nd
373+
374+
let ptr_to_string_hum nd =
375+
let prec = get_prec nd in
376+
let f arr = Ops.ptr_to_string_hum (Ctypes.bigarray_start Ctypes_static.Genarray arr) prec in
372377
map { f } nd
373378

374379
(** {2 *** Creating ***} *)
@@ -379,8 +384,8 @@ let create_array ~debug:_debug prec ~dims init_op =
379384
[%debug2_sexp
380385
[%log_block
381386
"create_array";
382-
[%log _debug, c_ptr_to_string result]]];
383-
let%debug2_sexp debug_finalizer _result = [%log "Deleting", _debug, c_ptr_to_string _result] in
387+
[%log _debug, ptr_to_string_hum result]]];
388+
let%debug2_sexp debug_finalizer _result = [%log "Deleting", _debug, ptr_to_string_hum _result] in
384389
if Utils.settings.log_level > 1 then Stdlib.Gc.finalise debug_finalizer result;
385390
result
386391

arrayjit/lib/ops.ml

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,10 @@ module Lazy = Utils.Lazy
66
(** {2 *** Precision ***} *)
77

88
type uint8_elt = Bigarray.int8_unsigned_elt
9-
10-
(* FIXME: Upcoming in OCaml 5.2.0. See:
11-
https://github.com/ocaml/ocaml/pull/10775/commits/ba6a2c378056c8669fb1bb99bf07b12d69bd4a12 *)
12-
type float16_elt = Bigarray.float32_elt
9+
type float16_elt = Bigarray.float16_elt
1310
type float32_elt = Bigarray.float32_elt
1411
type float64_elt = Bigarray.float64_elt
1512

16-
let float16 : (float, float16_elt) Bigarray.kind = Bigarray.float32
17-
1813
type ('ocaml, 'impl) precision =
1914
| Byte : (char, uint8_elt) precision
2015
| Half : (float, float16_elt) precision
@@ -97,15 +92,28 @@ let map_prec ?default { f } = function
9792
| Void_prec ->
9893
Option.value_or_thunk default ~default:(fun () -> invalid_arg "map_prec: Void_prec")
9994
| Byte_prec Byte -> f Byte
100-
| Half_prec (Half | Single) -> f Half
101-
| Single_prec (Single | Half) -> f Single
95+
| Half_prec Half -> f Half
96+
| Single_prec Single -> f Single
10297
| Double_prec Double -> f Double
10398
| _ -> .
10499

105100
let cuda_typ_of_prec = function
106101
| Byte_prec _ -> "unsigned char"
107-
(* TODO: or should it be uint8, or uint8_t? *)
108-
| Half_prec _ -> (* FIXME: *) "float"
102+
| Half_prec _ -> "__half"
103+
| Single_prec _ -> "float"
104+
| Double_prec _ -> "double"
105+
| Void_prec -> "void"
106+
107+
let c_typ_of_prec = function
108+
| Byte_prec _ -> "unsigned char"
109+
| Half_prec _ -> "_Float16"
110+
| Single_prec _ -> "float"
111+
| Double_prec _ -> "double"
112+
| Void_prec -> "void"
113+
114+
let hum_typ_of_prec = function
115+
| Byte_prec _ -> "byte"
116+
| Half_prec _ -> "half"
109117
| Single_prec _ -> "float"
110118
| Double_prec _ -> "double"
111119
| Void_prec -> "void"
@@ -220,10 +228,18 @@ let sexp_of_voidptr p = Sexp.Atom Ctypes.(string_of (ptr void) p)
220228
let compare_voidptr = Ctypes.ptr_compare
221229
let equal_voidptr : voidptr -> voidptr -> bool = phys_equal
222230

223-
let ptr_to_string (type elem) (ptr : elem Ctypes.ptr) prec =
231+
let cuda_ptr_to_string (type elem) (ptr : elem Ctypes.ptr) prec =
224232
"(" ^ cuda_typ_of_prec prec ^ "*)"
225233
^ Nativeint.Hex.to_string (Ctypes.raw_address_of_ptr @@ Ctypes.to_voidp ptr)
226234

235+
let c_ptr_to_string (type elem) (ptr : elem Ctypes.ptr) prec =
236+
"(" ^ c_typ_of_prec prec ^ "*)"
237+
^ Nativeint.Hex.to_string (Ctypes.raw_address_of_ptr @@ Ctypes.to_voidp ptr)
238+
239+
let ptr_to_string_hum (type elem) (ptr : elem Ctypes.ptr) prec =
240+
"(" ^ hum_typ_of_prec prec ^ "*)"
241+
^ Nativeint.Hex.to_string (Ctypes.raw_address_of_ptr @@ Ctypes.to_voidp ptr)
242+
227243
type global_identifier =
228244
| C_function of string (** Calls a no-argument or indices-arguments C function. *)
229245
| External_unsafe of {

arrayjit/lib/tnode.ml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,7 @@ let header tn =
295295
| (lazy None) -> "<not-hosted>"
296296
| (lazy (Some nd)) ->
297297
let size = Int.to_string_hum @@ Nd.size_in_bytes nd in
298-
if Utils.settings.log_level > 0 then size ^ " @ " ^ Nd.c_ptr_to_string nd else size
298+
if Utils.settings.log_level > 0 then size ^ " @ " ^ Nd.ptr_to_string_hum nd else size
299299
else "<not-in-yet>"
300300
in
301301
let repeating_nograd_idents = Hashtbl.create ~size:1 (module String) in

0 commit comments

Comments
 (0)