Skip to content

Commit c8d36d2

Browse files
committed
Second pass on adding vector-returning operations: cleanup and locating unfinished places
1 parent 85eaff9 commit c8d36d2

File tree

15 files changed

+134
-90
lines changed

15 files changed

+134
-90
lines changed

CLAUDE.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,12 @@ OCANNL (OCaml Compiles Algorithms for Neural Networks Learning) is a from-scratc
1414
The project uses Dune for building and testing:
1515

1616
```bash
17-
# Build all packages
17+
# Build all packages; this triggers running executables for cram-style tests
1818
dune build
1919

20+
# Only compile -- do not link nor run any executable
21+
dune build @check
22+
2023
# Build specific package
2124
dune build -p neural_nets_lib
2225
dune build -p arrayjit

arrayjit/lib/assignments.ml

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,8 @@ let%diagn2_sexp to_low_level code =
277277
| Indexing.Fixed_idx _ as idx -> idx
278278
| Indexing.Iterator s as idx -> Option.value ~default:idx (Map.find subst_map s)
279279
| Indexing.Affine { symbols; offset } ->
280+
(* FIXME: we need to substitute in the affine index, reuse code from
281+
loop_accum *)
280282
Indexing.Affine { symbols; offset }
281283
in
282284
let lhs_idcs = Array.map projections.project_lhs ~f:subst_index in
@@ -286,7 +288,7 @@ let%diagn2_sexp to_low_level code =
286288
(* For now, we know the only vec_unop is Uint4x32_to_prec_uniform *)
287289
let length = match op with
288290
| Ops.Uint4x32_to_prec_uniform ->
289-
(* TODO: Calculate length based on precision *)
291+
(* FIXME: Calculate length based on precision *)
290292
16 (* Default for now, should be calculated from target precision *)
291293
in
292294
Set_from_vec { tn = lhs; idcs = lhs_idcs; length; vec_unop = op; arg = rhs_ll; debug = "" }
@@ -484,17 +486,15 @@ let to_doc ?name ?static_indices () c =
484486
if Lazy.is_val projections then (Lazy.force projections).debug_info.spec
485487
else "<not-in-yet>"
486488
in
487-
string (ident lhs)
488-
^^ string " := "
489-
^^ string (Ops.vec_unop_cd_syntax op)
490-
^^ string "("
489+
string (ident lhs) ^^ space
490+
^^ string (Ops.assign_op_cd_syntax ~initialize_neutral:false Arg2) ^^ space
491+
^^ string (Ops.vec_unop_cd_syntax op) ^^ space
491492
^^ string (buffer_ident rhs)
492-
^^ string ")"
493493
^^ (if not (String.equal proj_spec ".") then string (" ~logic:\"" ^ proj_spec ^ "\"")
494494
else empty)
495495
^^ string ";" ^^ break 1
496496
| Fetch { array; fetch_op; dims = _ } ->
497-
string (ident array) ^^ string " := " ^^ doc_of_fetch_op fetch_op ^^ string ";" ^^ break 1
497+
string (ident array) ^^ string " =: " ^^ doc_of_fetch_op fetch_op ^^ string ";" ^^ break 1
498498
in
499499

500500
(* Create the header document *)

arrayjit/lib/builtins.c

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ extern uint4x32_t arrayjit_threefry4x32(uint4x32_t key, uint4x32_t counter) {
138138
}
139139

140140
/* Conversion functions from uint4x32 to various precisions uniformly */
141+
// FIXME: we need to return a vector of values, not just a single value
141142

142143
/* Convert to float in [0, 1) */
143144
extern float uint32_to_single_uniform(uint32_t x) {

arrayjit/lib/builtins.msl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ uint4x32_t arrayjit_threefry4x32(uint4x32_t key, uint4x32_t counter) {
131131
}
132132

133133
/* Conversion functions from uint4x32 to various precisions uniformly */
134+
// FIXME: we need to return a vector of values, not just a single value
134135

135136
/* Convert to float in [0, 1) */
136137
inline float uint32_to_single_uniform(uint32_t x) {

arrayjit/lib/builtins_small.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ typedef struct {
44
} uint4x32_t;
55

66
/* Conversion functions from uint4x32 to various precisions uniformly */
7+
// FIXME: we need to return a vector of values, not just a single value
78

89
/* Convert to float in [0, 1) using CUDA intrinsics */
910
__device__ __forceinline__ float uint32_to_single_uniform(uint32_t x) {

arrayjit/lib/c_syntax.ml

Lines changed: 28 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ module type C_syntax_config = sig
2828
val includes : string list
2929
val extra_declarations : string list
3030
val typ_of_prec : Ops.prec -> string
31+
val vec_typ_of_prec : length:int -> Ops.prec -> string
3132
val ident_blacklist : string list
3233

3334
val float_log_style : string
@@ -46,7 +47,7 @@ module type C_syntax_config = sig
4647

4748
val binop_syntax : Ops.prec -> Ops.binop -> PPrint.document -> PPrint.document -> PPrint.document
4849
val unop_syntax : Ops.prec -> Ops.unop -> PPrint.document -> PPrint.document
49-
val vec_unop_syntax : Ops.prec -> Ops.vec_unop -> string
50+
val vec_unop_syntax : Ops.prec -> Ops.vec_unop -> PPrint.document -> PPrint.document
5051
val convert_precision : from:Ops.prec -> to_:Ops.prec -> string * string
5152

5253
val kernel_log_param : (string * string) option
@@ -152,6 +153,7 @@ struct
152153
]
153154

154155
let typ_of_prec = Ops.c_typ_of_prec
156+
let vec_typ_of_prec = Ops.c_vec_typ_of_prec
155157
let float_log_style = if Input.full_printf_support then "%g" else "%de-3"
156158

157159
let styled_log_arg doc =
@@ -220,10 +222,7 @@ struct
220222
let p, _ = try Ops.unop_c_syntax prec op with Invalid_argument _ -> ("", "") in
221223
if String.is_suffix p ~suffix:"(" then functions := Set.add !functions (remove_paren p));
222224
List.iter
223-
Ops.
224-
[
225-
Uint4x32_to_prec_uniform;
226-
]
225+
Ops.[ Uint4x32_to_prec_uniform ]
227226
~f:(fun op ->
228227
let p, _ = try Ops.vec_unop_c_syntax prec op with Invalid_argument _ -> ("", "") in
229228
if String.is_suffix p ~suffix:"(" then functions := Set.add !functions (remove_paren p)));
@@ -418,9 +417,10 @@ struct
418417
let open PPrint in
419418
group (string op_prefix ^^ v ^^ string op_suffix)
420419

421-
let vec_unop_syntax prec op =
422-
let prefix, _ = Ops.vec_unop_c_syntax prec op in
423-
prefix
420+
let vec_unop_syntax prec op v =
421+
let op_prefix, op_suffix = Ops.vec_unop_c_syntax prec op in
422+
let open PPrint in
423+
group (string op_prefix ^^ v ^^ string op_suffix)
424424

425425
let convert_precision = Ops.c_convert_precision
426426
let kernel_log_param = Some ("const char*", "log_file_name")
@@ -588,25 +588,33 @@ module C_syntax (B : C_syntax_config) = struct
588588
let arg_prec = Ops.uint4x32 in
589589
let local_defs, arg_doc = pp_float arg_prec arg in
590590
(* Generate the function call *)
591-
let func_name = string (Ops.vec_unop_c_syntax prec vec_unop |> fst) in
591+
let result_doc = B.vec_unop_syntax prec vec_unop arg_doc in
592592
(* Generate assignments for each output element *)
593-
let assignments =
593+
let assignments =
594594
let open PPrint in
595595
let vec_var = string "vec_result" in
596-
let vec_typ = string (B.typ_of_prec prec ^ Int.to_string length) in
597-
let vec_decl = vec_typ ^^ space ^^ vec_var ^^ string " = " ^^ func_name ^^ arg_doc ^^ semi in
598-
let elem_assigns =
596+
let vec_typ = string (B.vec_typ_of_prec ~length prec) in
597+
let vec_decl = vec_typ ^^ space ^^ vec_var ^^ string " = " ^^ result_doc ^^ semi in
598+
let elem_assigns =
599599
List.init length ~f:(fun i ->
600-
let elem_idcs = Array.copy idcs in
601-
(match elem_idcs.(Array.length elem_idcs - 1) with
602-
| Fixed_idx idx -> elem_idcs.(Array.length elem_idcs - 1) <- Fixed_idx (idx + i)
603-
| _ -> failwith "Set_from_vec: last index must be Fixed_idx");
604-
let offset_doc = pp_array_offset (elem_idcs, dims) in
605-
ident_doc ^^ brackets offset_doc ^^ string " = " ^^ vec_var ^^ string ("." ^ Printf.sprintf "s%d" i) ^^ semi)
600+
let elem_idcs = Array.copy idcs in
601+
(match elem_idcs.(Array.length elem_idcs - 1) with
602+
| Fixed_idx idx -> elem_idcs.(Array.length elem_idcs - 1) <- Fixed_idx (idx + i)
603+
| _ ->
604+
(* FIXME: NOT IMPLEMENTED YET *)
605+
failwith "FIXME: Set_from_vec: NOT IMPLEMENTED YET general index");
606+
let offset_doc = pp_array_offset (elem_idcs, dims) in
607+
ident_doc ^^ brackets offset_doc ^^ string " = " ^^ vec_var
608+
^^ string ("." ^ Printf.sprintf "s%d" i)
609+
^^ semi)
606610
in
607611
vec_decl ^^ hardline ^^ separate hardline elem_assigns
608612
in
609-
if PPrint.is_empty local_defs then assignments else local_defs ^^ hardline ^^ assignments
613+
if Utils.debug_log_from_routines () then
614+
(* FIXME: NOT IMPLEMENTED YET *)
615+
failwith "FIXME: debug log for Set_from_vec"
616+
else if PPrint.is_empty local_defs then assignments
617+
else local_defs ^^ hardline ^^ assignments
610618
| Set_local ({ scope_id; tn = { prec; _ } }, value) ->
611619
let local_defs, value_doc = pp_float (Lazy.force prec) value in
612620
let assignment =

arrayjit/lib/cc_backend.ml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@ struct
171171

172172
(* Override to add our custom type and conversion support *)
173173
let typ_of_prec = typ_of_prec
174+
let vec_typ_of_prec = vec_typ_of_prec
174175
let extra_declarations = extra_declarations (* Our bfloat16/fp8 conversion functions *)
175176
let convert_precision = convert_precision
176177
end

arrayjit/lib/cuda_backend.ml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,11 @@ end) : Ir.Backend_impl.Lowered_backend = struct
322322
| Ops.Double_prec _ -> "double"
323323
| Ops.Void_prec -> "void"
324324

325+
let vec_typ_of_prec ~length prec =
326+
ignore length;
327+
(* FIXME: NOT IMPLEMENTED YET *)
328+
failwith "NOT IMPLEMENTED YET"
329+
325330
let binop_syntax prec v =
326331
(* TODO: consider using binop_syntax inherited from Pure_C_config and overriding only where
327332
different. *)

arrayjit/lib/low_level.ml

Lines changed: 39 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,14 @@ type t =
3636
| For_loop of { index : Indexing.symbol; from_ : int; to_ : int; body : t; trace_it : bool }
3737
| Zero_out of Tn.t
3838
| Set of { tn : Tn.t; idcs : Indexing.axis_index array; llsc : scalar_t; mutable debug : string }
39-
| Set_from_vec of { tn : Tn.t; idcs : Indexing.axis_index array; length : int; vec_unop : Ops.vec_unop; arg : scalar_t; mutable debug : string }
39+
| Set_from_vec of {
40+
tn : Tn.t;
41+
idcs : Indexing.axis_index array;
42+
length : int;
43+
vec_unop : Ops.vec_unop;
44+
arg : scalar_t;
45+
mutable debug : string;
46+
}
4047
| Set_local of scope_id * scalar_t
4148
[@@deriving sexp_of, equal]
4249

@@ -255,14 +262,15 @@ let visit_llc traced_store ~merge_node_id reverse_node_map ~max_visits llc =
255262
let traced : traced_array = get_node traced_store tn in
256263
(* Vector operations cannot be scalar constexpr *)
257264
traced.is_scalar_constexpr <- false;
258-
if first_visit then
259-
traced.is_complex <- traced.is_complex || is_complex_comp traced_store arg;
265+
if first_visit then traced.is_complex <- false;
260266
(* Mark all positions that will be written to *)
261267
for i = 0 to length - 1 do
262268
let pos_idcs = Array.copy idcs in
263269
(match pos_idcs.(Array.length pos_idcs - 1) with
264270
| Fixed_idx idx -> pos_idcs.(Array.length pos_idcs - 1) <- Fixed_idx (idx + i)
265-
| _ -> failwith "Set_from_vec: last index must be Fixed_idx");
271+
| _ ->
272+
(* FIXME: NOT IMPLEMENTED YET *)
273+
failwith "FIXME: Set_from_vec: NOT IMPLEMENTED YET general index");
266274
Hash_set.add traced.assignments (lookup env pos_idcs)
267275
done;
268276
Array.iter idcs ~f:(function
@@ -566,10 +574,11 @@ let inline_computation ~id computations_table traced static_indices call_args =
566574
| Set { tn; idcs; llsc; debug = _ } when Tn.equal tn traced.tn ->
567575
assert ([%equal: Indexing.axis_index array option] (Some idcs) def_args);
568576
Some (Set_local (id, loop_float env llsc))
569-
| Set_from_vec { tn; idcs; length = _; vec_unop = _; arg = _; debug = _ } when Tn.equal tn traced.tn ->
577+
| Set_from_vec { tn; idcs; length = _; vec_unop = _; arg = _; debug = _ }
578+
when Tn.equal tn traced.tn ->
570579
assert ([%equal: Indexing.axis_index array option] (Some idcs) def_args);
571580
(* For vector operations, we cannot inline them as scalar operations *)
572-
raise @@ Non_virtual 14
581+
raise @@ Non_virtual 140
573582
| Zero_out _ -> None
574583
| Set _ -> None
575584
| Set_from_vec _ -> None
@@ -649,7 +658,9 @@ let virtual_llc computations_table traced_store reverse_node_map static_indices
649658
| Set_from_vec { tn; idcs; length; vec_unop; arg; debug } ->
650659
let traced : traced_array = get_node traced_store tn in
651660
let next = if Tn.known_non_virtual traced.tn then process_for else Set.add process_for tn in
652-
let result = Set_from_vec { tn; idcs; length; vec_unop; arg = loop_float ~process_for:next arg; debug } in
661+
let result =
662+
Set_from_vec { tn; idcs; length; vec_unop; arg = loop_float ~process_for:next arg; debug }
663+
in
653664
if (not @@ Set.mem process_for tn) && (not @@ Tn.known_non_virtual traced.tn) then
654665
check_and_store_virtual computations_table traced static_indices result;
655666
result
@@ -735,7 +746,9 @@ let cleanup_virtual_llc reverse_node_map ~static_indices (llc : t) : t =
735746
else (
736747
assert (
737748
Array.for_all idcs ~f:(function Indexing.Iterator s -> Set.mem env_dom s | _ -> true));
738-
Some (Set_from_vec { tn; idcs; length; vec_unop; arg = loop_float ~balanced ~env_dom arg; debug }))
749+
Some
750+
(Set_from_vec
751+
{ tn; idcs; length; vec_unop; arg = loop_float ~balanced ~env_dom arg; debug }))
739752
| Set_local (id, llsc) ->
740753
assert (not @@ Tn.known_non_virtual id.tn);
741754
Tn.update_memory_mode id.tn Virtual 16;
@@ -867,12 +880,14 @@ let simplify_llc llc =
867880
| Binop (Arg1, llv1, _) -> loop_float llv1
868881
| Binop (Arg2, _, llv2) -> loop_float llv2
869882
| Binop (op, Constant c1, Constant c2) -> Constant (Ops.interpret_binop op c1 c2)
870-
| Binop (Add, llsc, Constant 0.) | Binop (Sub, llsc, Constant 0.) | Binop (Add, Constant 0., llsc)
871-
->
883+
| Binop (Add, llsc, Constant 0.)
884+
| Binop (Sub, llsc, Constant 0.)
885+
| Binop (Add, Constant 0., llsc) ->
872886
loop_float llsc
873887
| Binop (Sub, Constant 0., llsc) -> loop_float @@ Binop (Mul, Constant (-1.), llsc)
874-
| Binop (Mul, llsc, Constant 1.) | Binop (Div, llsc, Constant 1.) | Binop (Mul, Constant 1., llsc)
875-
->
888+
| Binop (Mul, llsc, Constant 1.)
889+
| Binop (Div, llsc, Constant 1.)
890+
| Binop (Mul, Constant 1., llsc) ->
876891
loop_float llsc
877892
| Binop (Mul, _, Constant 0.) | Binop (Div, Constant 0., _) | Binop (Mul, Constant 0., _) ->
878893
Constant 0.
@@ -1130,14 +1145,17 @@ let to_doc_cstyle ?name ?static_indices () llc =
11301145
p.debug <- Buffer.contents b);
11311146
result
11321147
| Set_from_vec p ->
1148+
let prec = Lazy.force p.tn.prec in
1149+
let prefix, postfix = Ops.vec_unop_c_syntax prec p.vec_unop in
1150+
(* TODO: this assumes argument is generated from the high-level code, which means it is
1151+
either Get or Local_scope -- they don't need precision. *)
1152+
let vec_result = string prefix ^^ doc_of_float Ops.Void_prec p.arg ^^ string postfix in
1153+
let length_doc = string ("<" ^ Int.to_string p.length ^ ">") in
11331154
let result =
11341155
group
11351156
(doc_ident p.tn
11361157
^^ brackets (pp_indices p.idcs)
1137-
^^ string " := "
1138-
^^ string (Ops.vec_unop_cd_syntax p.vec_unop)
1139-
^^ string "(" ^^ doc_of_float (Ops.uint4x32) p.arg ^^ string ", "
1140-
^^ int p.length ^^ string ");")
1158+
^^ length_doc ^^ string " := " ^^ vec_result ^^ string ";")
11411159
in
11421160
if not (String.is_empty p.debug) then (
11431161
let b = Buffer.create 100 in
@@ -1215,22 +1233,23 @@ let to_doc ?name ?static_indices () llc =
12151233
p.debug <- Buffer.contents b;
12161234
result
12171235
| Set_from_vec p ->
1236+
let length_doc = string ("<" ^ Int.to_string p.length ^ ">") in
12181237
let result =
12191238
group
12201239
(doc_ident p.tn
12211240
^^ brackets (pp_indices p.idcs)
1222-
^^ string " := "
1241+
^^ length_doc ^^ string " := "
12231242
^^ string (Ops.vec_unop_cd_syntax p.vec_unop)
1224-
^^ string "(" ^^ doc_of_float p.arg ^^ string ", "
1225-
^^ int p.length ^^ string ");")
1243+
^^ string "(" ^^ doc_of_float p.arg ^^ string ", " ^^ length_doc ^^ string ");")
12261244
in
12271245
let b = Buffer.create 100 in
12281246
PPrint.ToBuffer.pretty 0.7 100 b result;
12291247
p.debug <- Buffer.contents b;
12301248
result
12311249
| Comment message -> string ("/* " ^ message ^ " */")
12321250
| Staged_compilation callback -> callback ()
1233-
| Set_local (id, llsc) -> group (doc_local id ^^ string " := " ^^ doc_of_float llsc ^^ string ";")
1251+
| Set_local (id, llsc) ->
1252+
group (doc_local id ^^ string " := " ^^ doc_of_float llsc ^^ string ";")
12341253
and doc_of_float value =
12351254
match value with
12361255
| Local_scope { id; body; _ } ->

arrayjit/lib/metal_backend.ml

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -451,6 +451,11 @@ end) : Ir.Backend_impl.Lowered_backend = struct
451451
| Ops.Double_prec _ -> "double"
452452
| Ops.Void_prec -> "void"
453453

454+
let vec_typ_of_prec ~length prec =
455+
ignore (length, prec);
456+
(* FIXME: NOT IMPLEMENTED YET *)
457+
failwith "NOT IMPLEMENTED YET"
458+
454459
let metal_prec_suffix_float = function
455460
| Ops.Byte_prec _ -> ""
456461
| Ops.Uint16_prec _ -> ""
@@ -557,9 +562,7 @@ end) : Ir.Backend_impl.Lowered_backend = struct
557562
| Not, _ -> fun v -> string "!" ^^ v
558563
(* Logical not *)
559564

560-
let vec_unop_syntax prec op =
561-
match op with
562-
| Ops.Uint4x32_to_prec_uniform -> "uint4x32_to_" ^ Ops.prec_string prec ^ "_uniform("
565+
(* Keep vec_unop_syntax same as in pure C syntax. *)
563566

564567
let convert_precision ~from ~to_ =
565568
if Ops.equal_prec from to_ then ("", "") else ("(" ^ typ_of_prec to_ ^ ")(", ")")

0 commit comments

Comments
 (0)