Skip to content

Commit 37815d9

Browse files
committed
Overhaul of ppx_minidebug setup: make it per-file opt-in at compile time; formatting
Signed-off-by: Lukasz Stafiniak <lukstafi@gmail.com>
1 parent 7093aac commit 37815d9

30 files changed

+208
-177
lines changed

arrayjit/lib/assignments.ml

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,10 @@ module Nd = Ndarray
77

88
let _get_local_debug_runtime = Utils.get_local_debug_runtime
99

10-
[%%global_debug_log_level 9]
11-
[%%global_debug_log_level_from_env_var "OCANNL_LOG_LEVEL"]
10+
[%%global_debug_log_level 0]
11+
12+
(* export OCANNL_LOG_LEVEL_ASSIGNMENTS=9 to enable debugging logs. *)
13+
[%%global_debug_log_level_from_env_var "OCANNL_LOG_LEVEL_ASSIGNMENTS"]
1214

1315
type init_data =
1416
| Reshape of Ndarray.t
@@ -247,14 +249,14 @@ let%track4_sexp to_low_level code =
247249
}
248250
in
249251
let for_loops = for_loop [] (Array.to_list projections.product_space) in
250-
(* Need initialization if:
251-
- initialize_neutral is true AND
252-
- (not surjective OR not injective)
252+
(* Need initialization if: initialize_neutral is true AND (not surjective OR not injective)
253+
253254
Not surjective: some positions never written (need init to avoid garbage)
255+
254256
Not injective: accumulation needed (need init for first += operation) *)
255-
let needs_init =
256-
initialize_neutral &&
257-
not (Indexing.is_surjective projections && Indexing.is_injective projections)
257+
let needs_init =
258+
initialize_neutral
259+
&& not (Indexing.is_surjective projections && Indexing.is_injective projections)
258260
in
259261
if needs_init then
260262
let dims = lazy projections.lhs_dims in
@@ -341,7 +343,8 @@ let%track4_sexp to_low_level code =
341343
| Fetch { array; fetch_op = Constant c; dims } ->
342344
Low_level.loop_over_dims (Lazy.force dims) ~body:(fun idcs -> set array idcs @@ Constant c)
343345
| Fetch { array; fetch_op = Constant_bits i; dims } ->
344-
Low_level.loop_over_dims (Lazy.force dims) ~body:(fun idcs -> set array idcs @@ Constant_bits i)
346+
Low_level.loop_over_dims (Lazy.force dims) ~body:(fun idcs ->
347+
set array idcs @@ Constant_bits i)
345348
| Fetch { array; fetch_op = Slice { batch_idx = { static_symbol = idx; _ }; sliced }; dims } ->
346349
(* TODO: doublecheck this always gets optimized away. *)
347350
Low_level.loop_over_dims (Lazy.force dims) ~body:(fun idcs ->

arrayjit/lib/backend_impl.ml

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,10 @@ module Lazy = Utils.Lazy
77

88
let _get_local_debug_runtime = Utils.get_local_debug_runtime
99

10-
[%%global_debug_log_level 9]
11-
[%%global_debug_log_level_from_env_var "OCANNL_LOG_LEVEL"]
10+
[%%global_debug_log_level 0]
11+
12+
(* export OCANNL_LOG_LEVEL_BACKEND_IMPL=9 to enable debugging logs. *)
13+
[%%global_debug_log_level_from_env_var "OCANNL_LOG_LEVEL_BACKEND_IMPL"]
1214

1315
open Backend_intf
1416

@@ -54,11 +56,11 @@ module No_device_buffer_and_copying () :
5456
(* FIXME(#344, #355): This is not efficient, but it won't be used for long. *)
5557
let ptr = alloc_impl ~size_in_bytes in
5658
(* Zero-initialize the allocated memory *)
57-
if size_in_bytes > 0 then (
58-
let arr = Ctypes.from_voidp Ctypes.uint8_t ptr in
59-
for i = 0 to size_in_bytes - 1 do
60-
Ctypes.(arr +@ i <-@ Unsigned.UInt8.zero)
61-
done);
59+
(if size_in_bytes > 0 then
60+
let arr = Ctypes.from_voidp Ctypes.uint8_t ptr in
61+
for i = 0 to size_in_bytes - 1 do
62+
Ctypes.(arr +@ i <-@ Unsigned.UInt8.zero)
63+
done);
6264
ptr
6365

6466
let%track7_sexp alloc_buffer ?(old_buffer : buffer_ptr Backend_intf.buffer option)

arrayjit/lib/backends.ml

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,10 @@ open Backend_impl
77

88
let _get_local_debug_runtime = Utils.get_local_debug_runtime
99

10-
[%%global_debug_log_level 9]
11-
[%%global_debug_log_level_from_env_var "OCANNL_LOG_LEVEL"]
10+
[%%global_debug_log_level 0]
11+
12+
(* export OCANNL_LOG_LEVEL_BACKENDS=9 to enable debugging logs. *)
13+
[%%global_debug_log_level_from_env_var "OCANNL_LOG_LEVEL_BACKENDS"]
1214

1315
let check_merge_buffer stream ~code_node =
1416
let name = function Some tn -> Tnode.debug_name tn | None -> "none" in
@@ -40,7 +42,7 @@ module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncin
4042
|> Option.iter ~f:(fun upd_e ->
4143
if not (equal_stream s d || Backend.is_done upd_e) then Backend.will_wait_for dst upd_e)
4244

43-
let%track2_sexp to_host (ctx : Backend.context) (tn : Tn.t) =
45+
let%track3_sexp to_host (ctx : Backend.context) (tn : Tn.t) =
4446
match (tn, Map.find ctx.ctx_arrays tn) with
4547
| { Tn.array = (lazy (Some hosted)); _ }, Some src ->
4648
if Tn.potentially_cross_stream tn then
@@ -87,7 +89,7 @@ module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncin
8789
(* Note: the previous event does not need to be done! *)
8890
s.updating_for_merge_buffer <- Some (tn, Some e)
8991

90-
let%track2_sexp from_host (ctx : Backend.context) tn =
92+
let%track3_sexp from_host (ctx : Backend.context) tn =
9193
match (tn, Map.find ctx.ctx_arrays tn) with
9294
| { Tn.array = (lazy (Some hosted)); _ }, Some dst ->
9395
wait_for_all ctx ctx.stream.reader_streams tn;
@@ -98,7 +100,7 @@ module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncin
98100
true
99101
| _ -> false
100102

101-
let%track2_sexp init_from_host (ctx : Backend.context) tn =
103+
let%track3_sexp init_from_host (ctx : Backend.context) tn =
102104
match (tn, Map.find ctx.ctx_arrays tn) with
103105
| { Tn.array = (lazy (Some hosted)); _ }, None ->
104106
let dims = Lazy.force tn.dims in
@@ -120,7 +122,7 @@ module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncin
120122
("init_from_host: tensor node is not hosted: " ^ Tn.debug_name tn ^ ", for stream "
121123
^ Backend.get_name ctx.stream)
122124

123-
let%track2_sexp device_to_device (tn : Tn.t) ~into_merge_buffer ~(dst : Backend.context)
125+
let%track3_sexp device_to_device (tn : Tn.t) ~into_merge_buffer ~(dst : Backend.context)
124126
~(src : Backend.context) =
125127
let ordinal_of ctx = ctx.stream.device.ordinal in
126128
let name_of ctx = Backend.(get_name ctx.stream) in
@@ -159,7 +161,7 @@ module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncin
159161
[%log "streaming into merge buffer", Tn.debug_name tn, "from", name_of src];
160162
true)
161163

162-
let%track2_sexp init_from_device (tn : Tn.t) ~(dst : Backend.context) ~(src : Backend.context) =
164+
let%track3_sexp init_from_device (tn : Tn.t) ~(dst : Backend.context) ~(src : Backend.context) =
163165
let ordinal_of ctx = ctx.stream.device.ordinal in
164166
let name_of ctx = Backend.(get_name ctx.stream) in
165167
match Map.find src.ctx_arrays tn with
@@ -491,11 +493,11 @@ module Raise_backend (Device : Lowered_backend) : Backend = struct
491493
[%log "Backends.alloc_if_needed: failed to add old node to context", (key : Tnode.t)];
492494
raise exn
493495
in
494-
let hash_find_exn ~message tbl =
496+
let hash_find_exn ~message:_msg tbl =
495497
try Hashtbl.find_exn tbl key
496498
with exn ->
497499
[%log
498-
"Backends.alloc_if_needed: failed to find node in hash table", message, (key : Tnode.t)];
500+
"Backends.alloc_if_needed: failed to find node in hash table", _msg, (key : Tnode.t)];
499501
raise exn
500502
in
501503
let device = stream.device in

arrayjit/lib/c_syntax.ml

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@ open Backend_intf
44

55
let _get_local_debug_runtime = Utils.get_local_debug_runtime
66

7-
[%%global_debug_log_level 9]
8-
[%%global_debug_log_level_from_env_var "OCANNL_LOG_LEVEL"]
7+
[%%global_debug_log_level 0]
8+
9+
(* export OCANNL_LOG_LEVEL_C_SYNTAX=9 to enable debugging logs. *)
10+
[%%global_debug_log_level_from_env_var "OCANNL_LOG_LEVEL_C_SYNTAX"]
911

1012
module Tn = Tnode
1113

@@ -1042,8 +1044,7 @@ module C_syntax (B : C_syntax_config) = struct
10421044
^^ hardline
10431045
^^ separate_map hardline
10441046
(fun ident ->
1045-
Option.value_exn ~here:[%here] B.local_heap_dealloc ~ident_doc:(string ident)
1046-
^^ semi)
1047+
Option.value_exn ~here:[%here] B.local_heap_dealloc ~ident_doc:(string ident) ^^ semi)
10471048
!heap_allocated
10481049
^^ hardline;
10491050

arrayjit/lib/cc_backend.ml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@ open Ir
44

55
let _get_local_debug_runtime = Utils.get_local_debug_runtime
66

7-
[%%global_debug_log_level 9]
8-
[%%global_debug_log_level_from_env_var "OCANNL_LOG_LEVEL"]
7+
[%%global_debug_log_level 0]
8+
9+
(* export OCANNL_LOG_LEVEL_CC_BACKEND=9 to enable debugging logs. *)
10+
[%%global_debug_log_level_from_env_var "OCANNL_LOG_LEVEL_CC_BACKEND"]
911

1012
include Backend_impl.No_device_buffer_and_copying ()
1113
open Backend_intf

arrayjit/lib/cuda_backend.ml

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,10 @@ open Backend_intf
77

88
let _get_local_debug_runtime = Utils.get_local_debug_runtime
99

10-
[%%global_debug_log_level 9]
11-
[%%global_debug_log_level_from_env_var "OCANNL_LOG_LEVEL"]
10+
[%%global_debug_log_level 0]
11+
12+
(* export OCANNL_LOG_LEVEL_CUDA_BACKEND=9 to enable debugging logs. *)
13+
[%%global_debug_log_level_from_env_var "OCANNL_LOG_LEVEL_CUDA_BACKEND"]
1214

1315
let () =
1416
Cu.cuda_call_hook :=
@@ -351,7 +353,9 @@ end) : Ir.Backend_impl.Lowered_backend = struct
351353
(string "hexp2(hlog2(" ^^ v1 ^^ string "),"
352354
^^ ifflat (space ^^ v2) (nest 2 (break 1 ^^ v2))
353355
^^ string ")")
354-
| ToPowOf, (Byte_prec _ | Uint16_prec _ | Int32_prec _ | Int64_prec _ | Fp8_prec _ | Uint4x32_prec _) ->
356+
| ( ToPowOf,
357+
(Byte_prec _ | Uint16_prec _ | Int32_prec _ | Int64_prec _ | Fp8_prec _ | Uint4x32_prec _)
358+
) ->
355359
invalid_arg "Cuda_backend.binop_syntax: ToPowOf not supported for integer precisions"
356360
| ToPowOf, Bfloat16_prec _ ->
357361
fun v1 v2 ->
@@ -666,7 +670,7 @@ end) : Ir.Backend_impl.Lowered_backend = struct
666670
let vec_unop_syntax prec op v =
667671
let open PPrint in
668672
match (op, prec) with
669-
| Ops.Uint4x32_to_prec_uniform, _ ->
673+
| Ops.Uint4x32_to_prec_uniform, _ ->
670674
group (string ("uint4x32_to_" ^ Ops.prec_string prec ^ "_uniform_vec(") ^^ v ^^ rparen)
671675

672676
let ternop_syntax prec v =
@@ -698,7 +702,7 @@ end) : Ir.Backend_impl.Lowered_backend = struct
698702
| Single_prec _, Half_prec _ -> ("__float2half(", ")")
699703
| Byte_prec _, Half_prec _ -> ("__ushort2half_rn((unsigned short int)", ")")
700704
| Double_prec _, Uint4x32_prec _ -> ("{(unsigned int)(", "), 0, 0, 0}")
701-
| Single_prec _, Uint4x32_prec _ -> ("{(unsigned int)(", "), 0, 0, 0}")
705+
| Single_prec _, Uint4x32_prec _ -> ("{(unsigned int)(", "), 0, 0, 0}")
702706
| Int32_prec _, Uint4x32_prec _ -> ("{(unsigned int)(", "), 0, 0, 0}")
703707
| Int64_prec _, Uint4x32_prec _ -> ("int64_to_uint4x32(", ")")
704708
| Uint4x32_prec _, _ -> ("", ".v[0]")

arrayjit/lib/indexing.ml

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
open Base
22

3+
[%%global_debug_log_level 0]
4+
5+
(* export OCANNL_LOG_LEVEL_INDEXING=9 to enable debugging logs. *)
6+
[%%global_debug_log_level_from_env_var "OCANNL_LOG_LEVEL_INDEXING"]
7+
38
type symbol = Symbol of int [@@deriving compare, equal, sexp, hash, variants]
49

510
let unique_id = ref 1
@@ -218,7 +223,7 @@ let is_surjective proj =
218223

219224
let is_injective proj =
220225
let product_iterator_set = Set.of_array (module Symbol) proj.product_iterators in
221-
226+
222227
(* Check each LHS index for injectivity *)
223228
let lhs_symbols, is_injective_mapping =
224229
Array.fold proj.project_lhs ~init:([], true) ~f:(fun (syms, still_injective) idx ->
@@ -229,19 +234,17 @@ let is_injective proj =
229234
| Fixed_idx _ -> (syms, true)
230235
| Affine { symbols; _ } ->
231236
(* Filter for symbols that are product iterators *)
232-
let product_symbols =
233-
List.filter symbols ~f:(fun (_coeff, s) ->
234-
Set.mem product_iterator_set s)
237+
let product_symbols =
238+
List.filter symbols ~f:(fun (_coeff, s) -> Set.mem product_iterator_set s)
235239
in
236240
(* If more than one product iterator in this Affine index, not injective *)
237-
if List.length product_symbols > 1 then
238-
(syms, false)
241+
if List.length product_symbols > 1 then (syms, false)
239242
else
240243
(* (coefficients don't matter for injectivity) *)
241244
(List.map product_symbols ~f:snd @ syms, true)
242245
| Sub_axis -> (syms, true))
243246
in
244-
247+
245248
if not is_injective_mapping then false
246249
else
247250
let lhs_symbol_set = Set.of_list (module Symbol) lhs_symbols in

arrayjit/lib/low_level.ml

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,10 @@ module Tn = Tnode
55

66
let _get_local_debug_runtime = Utils.get_local_debug_runtime
77

8-
[%%global_debug_log_level 9]
9-
[%%global_debug_log_level_from_env_var "OCANNL_LOG_LEVEL"]
8+
[%%global_debug_log_level 0]
9+
10+
(* export OCANNL_LOG_LEVEL_LOW_LEVEL=9 to enable debugging logs. *)
11+
[%%global_debug_log_level_from_env_var "OCANNL_LOG_LEVEL_LOW_LEVEL"]
1012

1113
module Scope_id = struct
1214
type t = { tn : Tn.t; scope_id : int } [@@deriving sexp_of, equal, hash, compare]
@@ -221,8 +223,9 @@ let is_complex_comp traced_store llsc =
221223
let is_scalar_dims tn = Array.for_all ~f:(( = ) 1) @@ Lazy.force tn.Tn.dims
222224

223225
let visit_llc traced_store ~merge_node_id reverse_node_map ~max_visits llc =
226+
(* FIXME(#351): avoid excessive inlining while CSE is not implemented *)
224227
let inline_complex_computations =
225-
Utils.get_global_flag ~default:true ~arg_name:"inline_complex_computations"
228+
Utils.get_global_flag ~default:false ~arg_name:"inline_complex_computations"
226229
in
227230
let is_too_many = function Visits i -> i > max_visits | Recurrent -> true in
228231
(* FIXME: migrate hashtable to use offsets instead of indices *)

arrayjit/lib/lowered_backend_missing.ml

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,6 @@
11
open Base
22
open Ir
33

4-
let _get_local_debug_runtime = Utils.get_local_debug_runtime
5-
6-
[%%global_debug_log_level 9]
7-
[%%global_debug_log_level_from_env_var "OCANNL_LOG_LEVEL"]
8-
94
module Missing (Config : sig
105
val name : string
116
end) =

arrayjit/lib/metal_backend.ml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,10 @@ module Impl = Backend_impl (* Alias for Backend_impl *)
88

99
let _get_local_debug_runtime = Utils.get_local_debug_runtime
1010

11-
[%%global_debug_log_level 9]
12-
[%%global_debug_log_level_from_env_var "OCANNL_LOG_LEVEL"]
11+
[%%global_debug_log_level 0]
12+
13+
(* export OCANNL_LOG_LEVEL_METAL_BACKEND=9 to enable debugging logs. *)
14+
[%%global_debug_log_level_from_env_var "OCANNL_LOG_LEVEL_METAL_BACKEND"]
1315

1416
type ullong = Unsigned.ULLong.t
1517

0 commit comments

Comments
 (0)