Skip to content

Commit 186a2d3

Browse files
committed
Revert the #295 related changes; more debugging
1 parent 789d956 commit 186a2d3

File tree

6 files changed

+69
-76
lines changed

6 files changed

+69
-76
lines changed

arrayjit/lib/backend_impl.ml

Lines changed: 21 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -37,71 +37,52 @@ module No_device_buffer_and_copying () :
3737
type nonrec buffer_ptr = buffer_ptr [@@deriving sexp_of]
3838
end)
3939

40-
(* let used_memory = Atomic.make 0 *)
40+
let used_memory = Atomic.make 0
41+
let get_used_memory () = Atomic.get used_memory
4142

42-
let get_used_memory () =
43-
(* FIXME(295): alloc_zero_init_array is now using Ndarray. *)
44-
(* Atomic.get used_memory *)
45-
Atomic.get Ndarray.used_memory
46-
47-
let global_arena = Hash_set.create (module Ndarray)
48-
49-
(* FIXME(295): in rare cases, this causes crashes. *)
50-
(* {[
51-
let alloc_impl ~size_in_bytes =
43+
let%track7_l_sexp alloc_impl ~(size_in_bytes : int) : buffer_ptr =
5244
let%track7_l_sexp finalize (_ptr : buffer_ptr) : unit =
5345
ignore (Atomic.fetch_and_add used_memory ~-size_in_bytes : int)
5446
in
5547
let ptr = Ctypes.(to_voidp @@ allocate_n int8_t ~count:size_in_bytes) in
5648
let _ : int = Atomic.fetch_and_add used_memory size_in_bytes in
5749
Stdlib.Gc.finalise finalize ptr;
5850
ptr
59-
]} *)
6051

61-
let alloc_zero_init_array prec ~dims () =
62-
(* FIXME(295): in rare cases, this causes crashes. *)
63-
(* {[
52+
let%track7_l_sexp alloc_zero_init_array (prec : Ops.prec) ~(dims : int array) (() : unit) :
53+
buffer_ptr =
6454
let size_in_bytes =
6555
(if Array.length dims = 0 then 0 else Array.reduce_exn dims ~f:( * )) * Ops.prec_in_bytes prec
6656
in
6757
alloc_impl ~size_in_bytes
68-
]} *)
69-
(* Alternative: *)
70-
(* let%track7_l_sexp finalize (nd : Ndarray.t) = Hash_set.remove global_arena nd in *)
71-
let nd =
72-
Ndarray.create_array ~debug:[%string "array_%{Hash_set.length global_arena#Int}"] prec ~dims
73-
Ops.(Constant_fill { values = [| 0.0 |]; strict = false })
74-
in
75-
Hash_set.add global_arena nd;
76-
(* Stdlib.Gc.finalise finalize nd; *)
77-
Ndarray.get_voidptr_not_managed nd
7858

79-
let alloc_buffer ?old_buffer ~size_in_bytes () =
59+
let%track7_l_sexp alloc_buffer ?(old_buffer : buffer_ptr Backend_intf.buffer option)
60+
~(size_in_bytes : int) (() : unit) : buffer =
8061
match old_buffer with
8162
| Some ({ size_in_bytes = old_size; _ } as buffer) when size_in_bytes <= old_size -> buffer
82-
| _ ->
83-
(* FIXME(295): in rare cases, this causes crashes. *)
84-
(* { ptr = alloc_impl ~size_in_bytes; size_in_bytes } *)
85-
(* Alternative: *)
86-
(* FIXME: This is not helping. *)
87-
let ptr = alloc_zero_init_array Ops.byte ~dims:[| size_in_bytes |] () in
88-
{ ptr; size_in_bytes }
63+
| _ -> { ptr = alloc_impl ~size_in_bytes; size_in_bytes }
8964

9065
let free_buffer = None
9166

67+
type void_buffer_ptr = (Stdlib.Obj.t option, unit Ctypes_static.typ) Ctypes_ptr.Fat.t
68+
69+
let sexp_of_void_buffer_ptr (p : void_buffer_ptr) =
70+
Sexp.Atom (Ctypes_value_printing_stubs.string_of_pointer p)
71+
72+
let%track7_l_sexp memcpy ~(dst : void_buffer_ptr) ~(src : void_buffer_ptr) ~(size_in_bytes : int)
73+
: unit =
74+
if Ctypes_ptr.Fat.compare dst src <> 0 then
75+
Ctypes_memory_stubs.memcpy ~dst ~src ~size:size_in_bytes
76+
9277
let buffer_to_buffer ~dst:Ctypes_static.(CPointer dst) ~src:Ctypes_static.(CPointer src)
9378
~size_in_bytes =
94-
Ctypes_memory_stubs.memcpy ~dst ~src ~size:size_in_bytes
79+
memcpy ~dst ~src ~size_in_bytes
9580

9681
let host_to_buffer src ~dst:Ctypes_static.(CPointer dst) =
97-
Ctypes_memory_stubs.memcpy ~dst
98-
~src:(Ndarray.get_fatptr_not_managed src)
99-
~size:(Ndarray.size_in_bytes src)
82+
memcpy ~dst ~src:(Ndarray.get_fatptr_not_managed src) ~size_in_bytes:(Ndarray.size_in_bytes src)
10083

10184
let buffer_to_host dst ~src:Ctypes_static.(CPointer src) =
102-
Ctypes_memory_stubs.memcpy
103-
~dst:(Ndarray.get_fatptr_not_managed dst)
104-
~src ~size:(Ndarray.size_in_bytes dst)
85+
memcpy ~dst:(Ndarray.get_fatptr_not_managed dst) ~src ~size_in_bytes:(Ndarray.size_in_bytes dst)
10586
end
10687

10788
module Device_types (Device_config : Device_config) = struct

arrayjit/lib/c_syntax.ml

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -321,20 +321,24 @@ struct
321321
| p_name, Merge_buffer ->
322322
if B.logs_to_stdout then
323323
fprintf ppf
324-
{|@[<7>printf(@[<h>"%s%%d: %s = %%p\n",@] log_id, (void*)merge_buffer);@]@ |}
324+
{|@[<7>printf(@[<h>"%s%%d: %s &[%d] = %%p\n",@] log_id, (void*)merge_buffer);@]@ |}
325325
!Utils.captured_log_prefix p_name
326+
(Tnode.num_elems @@ Option.value_exn merge_node)
326327
else
327328
fprintf ppf
328-
{|@[<7>fprintf(log_file,@ @[<h>"%s = %%p\n",@] (void*)merge_buffer);@]@ |} p_name
329+
{|@[<7>fprintf(log_file,@ @[<h>"%s &[%d] = %%p\n",@] (void*)merge_buffer);@]@ |}
330+
p_name
331+
(Tnode.num_elems @@ Option.value_exn merge_node)
329332
| _, Log_file_name -> ()
330333
| p_name, Param_ptr tn ->
331334
if B.logs_to_stdout then
332-
fprintf ppf {|@[<7>printf(@[<h>"%s%%d: %s = %%p\n",@] log_id, (void*)%s);@]@ |}
333-
!Utils.captured_log_prefix p_name
335+
fprintf ppf
336+
{|@[<7>printf(@[<h>"%s%%d: %s &[%d] = %%p\n",@] log_id, (void*)%s);@]@ |}
337+
!Utils.captured_log_prefix p_name (Tnode.num_elems tn)
334338
@@ get_ident tn
335339
else
336-
fprintf ppf {|@[<7>fprintf(log_file,@ @[<h>"%s = %%p\n",@] (void*)%s);@]@ |} p_name
337-
@@ get_ident tn
340+
fprintf ppf {|@[<7>fprintf(log_file,@ @[<h>"%s &[%d] = %%p\n",@] (void*)%s);@]@ |}
341+
p_name (Tnode.num_elems tn) (get_ident tn)
338342
| p_name, Static_idx s ->
339343
if B.logs_to_stdout then
340344
fprintf ppf {|@[<7>printf(@[<h>"%s%%d: %s = %%d\n",@] log_id, %s);@]@ |}

arrayjit/lib/cuda_backend.cudajit.ml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,8 @@ let device_to_device tn ~into_merge_buffer ~dst_ptr ~dst ~src_ptr ~src =
200200
let same_device = dev.ordinal = src.stream.device.ordinal in
201201
let size_in_bytes = Tn.size_in_bytes tn in
202202
let memcpy ~dst_ptr =
203-
if same_device then
203+
if same_device && Cu.Deviceptr.equal dst_ptr src_ptr then ()
204+
else if same_device then
204205
Cu.Stream.memcpy_D_to_D ~size_in_bytes ~dst:dst_ptr ~src:src_ptr dst.stream.runner
205206
else
206207
Cu.Stream.memcpy_peer ~size_in_bytes ~dst:dst_ptr ~dst_ctx:(ctx_of dst) ~src:src_ptr

arrayjit/lib/ndarray.ml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -389,8 +389,9 @@ let hash_t nd = Nativeint.hash @@ to_native nd
389389

390390
let used_memory = Atomic.make 0
391391

392-
let create_array ~debug:_debug prec ~dims init_op =
393-
let size_in_bytes =
392+
let%track7_l_sexp create_array ~debug:(_debug : string) (prec : Ops.prec) ~(dims : int array)
393+
init_op =
394+
let size_in_bytes : int =
394395
(if Array.length dims = 0 then 0 else Array.reduce_exn dims ~f:( * )) * Ops.prec_in_bytes prec
395396
in
396397
let%track7_l_sexp finalizer (_result : t) =

bin/dune

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@
2020
(executable
2121
(name hello_world_op)
2222
(modules hello_world_op)
23-
(libraries ocannl base stdio)
23+
(libraries ocannl base stdio ppx_minidebug.runtime)
2424
(preprocess
25-
(pps ppx_minidebug ppx_ocannl))
25+
(pps ppx_minidebug ppx_ocannl ppx_sexp_conv))
2626
(modes exe))
2727

2828
(executable

bin/hello_world_op.ml

Lines changed: 31 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,11 @@ let _get_local_debug_runtime = Arrayjit.Utils._get_local_debug_runtime
1515
[%%global_debug_log_level 9]
1616
[%%global_debug_log_level_from_env_var "OCANNL_LOG_LEVEL"]
1717

18-
let setup () =
18+
let setup (() : unit) : unit =
1919
Arrayjit.Utils.settings.output_debug_files_in_build_directory <- true;
2020
Arrayjit.Utils.settings.debug_log_from_routines <- true
2121

22-
let%track2_sexp _Pointwise_multiplication_dims_1 () =
22+
let%track2_sexp _Pointwise_multiplication_dims_1 (() : unit) : unit =
2323
Tensor.unsafe_reinitialize ();
2424
let module Backend = (val Arrayjit.Backends.fresh_backend ()) in
2525
let backend =
@@ -33,11 +33,11 @@ let%track2_sexp _Pointwise_multiplication_dims_1 () =
3333
let ctx = Backend.make_context stream in
3434
Rand.init 0;
3535
(* "Hey" is inferred to be a scalar. *)
36-
let%op y = 2 *. "hey" 7.0 in
37-
Train.forward_and_forget backend ctx y;
38-
Tensor.print ~with_code:false ~with_grad:false `Default @@ y
36+
let%op ya = 2 *. "hey" 7.0 in
37+
Train.forward_and_forget backend ctx ya;
38+
Tensor.print ~with_code:false ~with_grad:false `Default @@ ya
3939

40-
let%track2_sexp _Matrix_multiplication_dims_1x1 () =
40+
let%track2_sexp _Matrix_multiplication_dims_1x1 (() : unit) : unit =
4141
Tensor.unsafe_reinitialize ();
4242
let module Backend = (val Arrayjit.Backends.fresh_backend ()) in
4343
let backend =
@@ -51,13 +51,13 @@ let%track2_sexp _Matrix_multiplication_dims_1x1 () =
5151
let ctx = Backend.make_context stream in
5252
Rand.init 0;
5353
(* Hey is inferred to be a matrix because of matrix multiplication [*]. *)
54-
let%op y = ("hey" 7.0 * 'q' 2.0) + 'p' 1.0 in
55-
Train.forward_and_forget backend ctx y;
54+
let%op yb = ("hey" 7.0 * 'q' 2.0) + 'p' 1.0 in
55+
Train.forward_and_forget backend ctx yb;
5656
(* Punning for ["hey"] above introduced the [hey] identifier. *)
5757
Tensor.print ~with_code:false ~with_grad:false `Default @@ hey;
58-
Tensor.print ~with_code:false ~with_grad:false `Default @@ y
58+
Tensor.print ~with_code:false ~with_grad:false `Default @@ yb
5959

60-
let%track2_sexp _Print_constant_tensor () =
60+
let%track2_sexp _Print_constant_tensor (() : unit) : unit =
6161
Tensor.unsafe_reinitialize ();
6262
let module Backend = (val Arrayjit.Backends.fresh_backend ()) in
6363
let backend =
@@ -153,7 +153,7 @@ let%track2_sexp _Print_constant_tensor () =
153153
Tensor.print ~force:true ~with_code:false ~with_grad:false `Inline @@ heyhoo4;
154154
Tensor.print ~force:true ~with_code:false ~with_grad:false `Default @@ heyhoo4
155155

156-
let%track2_sexp _Matrix_multiplication_dims_2x3 () =
156+
let%track2_sexp _Matrix_multiplication_dims_2x3 (() : unit) : unit =
157157
Tensor.unsafe_reinitialize ();
158158
let module Backend = (val Arrayjit.Backends.fresh_backend ()) in
159159
let backend =
@@ -167,12 +167,12 @@ let%track2_sexp _Matrix_multiplication_dims_2x3 () =
167167
let ctx = Backend.make_context stream in
168168
Rand.init 0;
169169
(* Hey is inferred to be a matrix. *)
170-
let%op y = ("hey" 7.0 * [ 2; 3 ]) + [ 4; 5; 6 ] in
171-
Train.forward_and_forget backend ctx y;
170+
let%op yc = ("hey" 7.0 * [ 2; 3 ]) + [ 4; 5; 6 ] in
171+
Train.forward_and_forget backend ctx yc;
172172
Tensor.print ~with_code:false ~with_grad:false `Default @@ hey;
173-
Tensor.print ~with_code:false ~with_grad:false `Default @@ y
173+
Tensor.print ~with_code:false ~with_grad:false `Default @@ yc
174174

175-
let%track2_sexp _Big_matrix () =
175+
let%track2_sexp _Big_matrix (() : unit) : unit =
176176
Tensor.unsafe_reinitialize ();
177177
let module Backend = (val Arrayjit.Backends.fresh_backend ()) in
178178
let backend =
@@ -188,14 +188,14 @@ let%track2_sexp _Big_matrix () =
188188
(* Hey is inferred to be a matrix. *)
189189
let hey = Tensor.param ~values:[| 0.5 |] "hey" in
190190
let zero_to_twenty = TDSL.range 20 in
191-
let%op yb = (hey * zero_to_twenty) + zero_to_twenty in
192-
Train.forward_and_forget backend ctx yb;
191+
let%op yd = (hey * zero_to_twenty) + zero_to_twenty in
192+
Train.forward_and_forget backend ctx yd;
193193
Tensor.print ~with_code:false ~with_grad:false `Inline zero_to_twenty;
194194
Tensor.print ~with_code:false ~with_grad:false `Default zero_to_twenty;
195195
Tensor.print ~with_code:false ~with_grad:false `Default hey;
196-
Tensor.print ~with_code:false ~with_grad:false `Default yb
196+
Tensor.print ~with_code:false ~with_grad:false `Default yd
197197

198-
let%track2_sexp _Very_big_tensor () =
198+
let%track2_sexp _Very_big_tensor (() : unit) : unit =
199199
Tensor.unsafe_reinitialize ();
200200
let module Backend = (val Arrayjit.Backends.fresh_backend ()) in
201201
let backend =
@@ -209,23 +209,29 @@ let%track2_sexp _Very_big_tensor () =
209209
let ctx = Backend.make_context stream in
210210
Rand.init 0;
211211
let hey =
212-
TDSL.range_of_shape ~batch_dims:[ 6 ] ~input_dims:[ 7; 8; 9 ] ~output_dims:[ 10; 11 ] ()
212+
TDSL.range_of_shape ~batch_dims:[ 6 ] ~input_dims:[ 7; 8 ] ~output_dims:[ 9 ] ()
213213
in
214-
let%op hoo = (hey * (1 + 1)) - 10 in
215-
Train.forward_and_forget backend ctx hoo;
214+
let%op ye = (hey * (1 + 1)) - 10 in
215+
Train.forward_and_forget backend ctx ye;
216216
Tensor.print ~with_code:false ~with_grad:false `Default hey;
217-
Tensor.print ~with_code:false ~with_grad:false `Default hoo
217+
Tensor.print ~with_code:false ~with_grad:false `Default ye
218218

219-
let _suspended () =
219+
let _suspended (() : unit) : unit =
220220
setup ();
221221
_Matrix_multiplication_dims_2x3 ();
222222
_Big_matrix ()
223223

224-
let () =
224+
let _suspended (() : unit) : unit =
225225
setup ();
226226
_Pointwise_multiplication_dims_1 ();
227227
_Matrix_multiplication_dims_1x1 ();
228228
_Print_constant_tensor ();
229229
_Matrix_multiplication_dims_2x3 ();
230230
_Big_matrix ();
231231
_Very_big_tensor ()
232+
233+
let (() : unit) : unit =
234+
setup ();
235+
_Matrix_multiplication_dims_2x3 ();
236+
_Big_matrix ();
237+
_Very_big_tensor ()

0 commit comments

Comments
 (0)