Skip to content

Commit 2611a1b

Browse files
committed
Broken: Constant_fill via unrolling, fix Tensor.params field typing, final round of refactoring / plumbing (mostly by Claude Sonnet)
It's broken because most of the new functionality is not implemented yet. Also, in all places, neither parameter optimization nor input optimization done properly!
1 parent 874fa31 commit 2611a1b

21 files changed

+249
-163
lines changed

arrayjit/lib/c_syntax.ml

Lines changed: 74 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,7 @@ struct
267267
match op with
268268
| Ops.Satur01_gate -> (
269269
match prec with
270-
| Ops.Byte_prec _ | Ops.Uint16_prec _ | Ops.Int32_prec _ ->
270+
| Ops.Byte_prec _ | Ops.Uint16_prec _ | Ops.Int32_prec _ | Ops.Uint4x32_prec _ ->
271271
let open PPrint in
272272
group
273273
(parens
@@ -592,16 +592,44 @@ module C_syntax (B : C_syntax_config) = struct
592592
let prefix, postfix = B.convert_precision ~from:scope_prec ~to_:prec in
593593
let expr = string prefix ^^ string ("v" ^ Int.to_string id.scope_id) ^^ string postfix in
594594
(empty, expr)
595-
| Access (Ops.Merge_buffer { source_node_id }, Some idcs) ->
596-
let tn = Option.value_exn ~here:[%here] @@ Tn.find ~id:source_node_id in
595+
| Access (Low_level.Merge_buffer { source }, Some idcs) ->
596+
let tn = source in
597597
let from_prec = Lazy.force tn.prec in
598598
let prefix, postfix = B.convert_precision ~from:from_prec ~to_:prec in
599599
let offset_doc = pp_array_offset (idcs, Lazy.force tn.dims) in
600600
let expr =
601601
string prefix ^^ string "merge_buffer" ^^ brackets offset_doc ^^ string postfix
602602
in
603603
(empty, expr)
604-
| Access _ -> failwith "C_syntax: Access / FFI NOT IMPLEMENTED YET"
604+
| Access (Low_level.C_function f_name, None) ->
605+
let expr = string (f_name ^ "()") in
606+
(empty, expr)
607+
| Access (Low_level.External_unsafe { ptr; prec = source_prec; dims }, Some idcs) ->
608+
let dims_val = Lazy.force dims in
609+
let prefix, postfix = B.convert_precision ~from:source_prec ~to_:prec in
610+
let offset_doc = pp_array_offset (idcs, dims_val) in
611+
let ptr_str = Ops.c_rawptr_to_string (Ctypes.raw_address_of_ptr @@ Ctypes.to_voidp ptr) source_prec in
612+
let expr =
613+
string prefix ^^ string ("(*(" ^ ptr_str ^ " + ") ^^ offset_doc ^^ string "))" ^^ string postfix
614+
in
615+
(empty, expr)
616+
| Access (Low_level.File_mapped (file, source_prec), Some idcs) ->
617+
let prefix, postfix = B.convert_precision ~from:source_prec ~to_:prec in
618+
let expr =
619+
string prefix ^^ string ("file_mapped_data_" ^ file ^ "[") ^^ pp_array_offset (idcs, [||]) ^^ string "]" ^^ string postfix
620+
in
621+
(empty, expr)
622+
| Access (Low_level.Uint4x32_to_prec_uniform { source; prec = source_prec }, Some idcs) ->
623+
let tn = source in
624+
let prefix, postfix = B.convert_precision ~from:source_prec ~to_:prec in
625+
let offset_doc = pp_array_offset (idcs, Lazy.force tn.dims) in
626+
let source_ident = string (get_ident tn) in
627+
let expr =
628+
string prefix ^^ string ("uint4x32_to_" ^ Ops.prec_string source_prec ^ "_uniform(")
629+
^^ source_ident ^^ brackets offset_doc ^^ string ")" ^^ string postfix
630+
in
631+
(empty, expr)
632+
| Access _ -> failwith "C_syntax: Access cases with wrong indices / FFI NOT IMPLEMENTED YET"
605633
| Get (tn, idcs) ->
606634
let ident_doc = string (get_ident tn) in
607635
let from_prec = Lazy.force tn.prec in
@@ -665,8 +693,8 @@ module C_syntax (B : C_syntax_config) = struct
665693
let prefix, postfix = B.convert_precision ~from:scope_prec ~to_:prec in
666694
let v_doc = string prefix ^^ string ("v" ^ Int.to_string id.scope_id) ^^ string postfix in
667695
(v_doc ^^ braces (string ("=" ^ B.float_log_style)), [ `Value v_doc ])
668-
| Access (Ops.Merge_buffer { source_node_id }, Some idcs) ->
669-
let tn = Option.value_exn ~here:[%here] @@ Tn.find ~id:source_node_id in
696+
| Access (Low_level.Merge_buffer { source }, Some idcs) ->
697+
let tn = source in
670698
let from_prec = Lazy.force tn.prec in
671699
let dims = Lazy.force tn.dims in
672700
let prefix, postfix = B.convert_precision ~from:from_prec ~to_:prec in
@@ -681,7 +709,46 @@ module C_syntax (B : C_syntax_config) = struct
681709
^^ braces (string ("=" ^ B.float_log_style))
682710
in
683711
(expr_doc, [ `Accessor (idcs, dims); `Value access_doc ])
684-
| Access _ -> failwith "C_syntax: Access / FFI NOT IMPLEMENTED YET"
712+
| Access (Low_level.C_function f_name, None) ->
713+
let expr_doc = string (f_name ^ "()") in
714+
(expr_doc, [])
715+
| Access (Low_level.External_unsafe { ptr; prec = source_prec; dims }, Some idcs) ->
716+
let dims_val = Lazy.force dims in
717+
let prefix, postfix = B.convert_precision ~from:source_prec ~to_:prec in
718+
let offset_doc = pp_array_offset (idcs, dims_val) in
719+
let ptr_str = Ops.c_rawptr_to_string (Ctypes.raw_address_of_ptr @@ Ctypes.to_voidp ptr) source_prec in
720+
let access_doc =
721+
string prefix ^^ string ("(*(" ^ ptr_str ^ " + ") ^^ offset_doc ^^ string "))" ^^ string postfix
722+
in
723+
let expr_doc =
724+
string prefix ^^ string ("external[%u]{=" ^ B.float_log_style ^ "}") ^^ string postfix
725+
in
726+
(expr_doc, [ `Accessor (idcs, dims_val); `Value access_doc ])
727+
| Access (Low_level.File_mapped (file, source_prec), Some idcs) ->
728+
let prefix, postfix = B.convert_precision ~from:source_prec ~to_:prec in
729+
let access_doc =
730+
string prefix ^^ string ("file_mapped_data_" ^ file ^ "[") ^^ pp_array_offset (idcs, [||]) ^^ string "]" ^^ string postfix
731+
in
732+
let expr_doc =
733+
string prefix ^^ string ("file_mapped_" ^ file ^ "[%u]{=" ^ B.float_log_style ^ "}") ^^ string postfix
734+
in
735+
(expr_doc, [ `Accessor (idcs, [||]); `Value access_doc ])
736+
| Access (Low_level.Uint4x32_to_prec_uniform { source; prec = source_prec }, Some idcs) ->
737+
let tn = source in
738+
let prefix, postfix = B.convert_precision ~from:source_prec ~to_:prec in
739+
let dims = Lazy.force tn.dims in
740+
let offset_doc = pp_array_offset (idcs, dims) in
741+
let source_ident = string (get_ident tn) in
742+
let access_doc =
743+
string prefix ^^ string ("uint4x32_to_" ^ Ops.prec_string source_prec ^ "_uniform(")
744+
^^ source_ident ^^ brackets offset_doc ^^ string ")" ^^ string postfix
745+
in
746+
let expr_doc =
747+
string prefix ^^ string ("uint4x32_to_" ^ Ops.prec_string source_prec ^ "_uniform(")
748+
^^ source_ident ^^ brackets (string "%u") ^^ string "){=" ^^ string B.float_log_style ^^ string "}" ^^ string postfix
749+
in
750+
(expr_doc, [ `Accessor (idcs, dims); `Value access_doc ])
751+
| Access _ -> failwith "C_syntax: Access cases with wrong indices / FFI NOT IMPLEMENTED YET"
685752
| Get (tn, idcs) ->
686753
let ident_doc = string (get_ident tn) in
687754
let from_prec = Lazy.force tn.prec in

arrayjit/lib/metal_backend.ml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,7 @@ end) : Ir.Backend_impl.Lowered_backend = struct
443443
| Ops.Byte_prec _ -> "uchar"
444444
| Ops.Uint16_prec _ -> "ushort"
445445
| Ops.Int32_prec _ -> "int"
446+
| Ops.Uint4x32_prec _ -> "uint4" (* Metal's uint4 type - 128-bit *)
446447
| Ops.Half_prec _ -> "half"
447448
| Ops.Bfloat16_prec _ -> "bfloat" (* Metal supports bfloat16 natively *)
448449
| Ops.Fp8_prec _ -> invalid_arg "Metal backend does not support FP8 precision"
@@ -454,6 +455,7 @@ end) : Ir.Backend_impl.Lowered_backend = struct
454455
| Ops.Byte_prec _ -> ""
455456
| Ops.Uint16_prec _ -> ""
456457
| Ops.Int32_prec _ -> ""
458+
| Ops.Uint4x32_prec _ -> "" (* No specific suffix for uint4 *)
457459
| Ops.Half_prec _ -> "h"
458460
| Ops.Bfloat16_prec _ -> "bf" (* TODO: Verify actual Metal suffix for bfloat16 *)
459461
| Ops.Fp8_prec _ -> invalid_arg "Metal backend does not support FP8 precision"
@@ -523,6 +525,7 @@ end) : Ir.Backend_impl.Lowered_backend = struct
523525
^^ space ^^ string "?" ^^ space ^^ v2 ^^ space ^^ string ":" ^^ space
524526
^^ string ("0.0" ^ s)))
525527
| ToPowOf, _ -> func "pow"
528+
| Threefry4x32, _ -> func "threefry4x32" (* Metal implementation of Threefry4x32 *)
526529
| Arg1, _ | Arg2, _ -> invalid_arg "Metal C_syntax_config: Arg1/Arg2 not operators"
527530

528531
let unop_syntax prec op =

arrayjit/test/test_numerical_types.ml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ let test_bfloat16_conversions () =
1616
(* Test round-trip through ndarray *)
1717
let arr =
1818
Ndarray.create_array ~debug:"test" Ops.bfloat16 ~dims:[| 3; 2 |] ~padding:None
19-
(Assignments.Constant_fill [| 1.0; 2.0; 3.14; -1.5; 0.125; 1000.0 |])
2019
in
2120

2221
Stdio.printf "\nBFloat16 array values:\n";
@@ -37,7 +36,6 @@ let test_fp8_conversions () =
3736
(* Test round-trip through ndarray *)
3837
let arr =
3938
Ndarray.create_array ~debug:"test" Ops.fp8 ~dims:[| 2; 2 |] ~padding:None
40-
(Ops.Constant_fill { values = [| 1.0; 0.5; 2.0; -1.0 |]; strict = true })
4139
in
4240

4341
Stdio.printf "\nFP8 array values:\n";
@@ -56,7 +54,6 @@ let test_padding () =
5654
let arr =
5755
Ndarray.create_array ~debug:"padded_test" Ops.single ~dims:padded_dims
5856
~padding:(Some (padding_config, padding_value))
59-
(Ops.Constant_fill { values = [| 1.0; 2.0; 3.0; 4.0; 5.0; 6.0 |]; strict = true })
6057
in
6158

6259
Stdio.printf "Padded array (dims 4x6, unpadded region 2x3):\n";

bin/compilation_speed.ml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ let benchmark_overhead backend () =
3636
Train.to_routine (module Backend) ctx ~name:"init_assign_x" IDX.empty mock_update_x
3737
in
3838
let f_routine =
39-
Train.to_routine (module Backend) init_assign_x.context IDX.empty update_f.fwd_bprop
39+
Train.to_routine (module Backend) init_assign_x.context IDX.empty update_f
4040
in
4141
Tensor.print_tree ~with_grad:true ~with_backend_info:true ~depth:9 f;
4242

bin/micrograd_basic.ml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ let%diagn_sexp _suspended () =
2525
(* List.iter ~f:(function Some diff -> Train.set_hosted diff.grad | None -> ()) [ a.diff; b.diff
2626
]; *)
2727
let update = Train.grad_update d in
28-
let routine = Train.to_routine (module Backend) ctx IDX.empty update.fwd_bprop in
28+
let routine = Train.to_routine (module Backend) ctx IDX.empty update in
2929
Train.run routine;
3030
Tensor.print_tree ~with_grad:true ~depth:9 d;
3131
Stdio.print_endline "\n";
@@ -52,7 +52,7 @@ let%diagn_sexp () : unit =
5252
List.iter ~f:(function Some diff -> Train.set_hosted diff.grad | None -> ()) [ a.diff; b.diff ];
5353
(* Train.every_non_literal_on_host g; *)
5454
let update = Train.grad_update g in
55-
let routine = Train.to_routine (module Backend) ctx IDX.empty update.fwd_bprop in
55+
let routine = Train.to_routine (module Backend) ctx IDX.empty update in
5656
Utils.capture_stdout_logs @@ fun () ->
5757
Train.run routine;
5858
(* Tensor.print_tree ~with_grad:true ~depth:9 g; *)

bin/micrograd_demo.ml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,13 +77,13 @@ let experiment seed ~no_batch_shape_inference ~use_builtin_weight_decay () =
7777
let update = Train.grad_update scalar_loss in
7878
let%op learning_rate = 0.1 *. (!..steps - !@step_n) /. !..steps in
7979
Train.set_hosted learning_rate.value;
80-
let sgd = Train.sgd_update ~learning_rate ~weight_decay update in
80+
let sgd = Train.sgd_update ~learning_rate ~weight_decay scalar_loss in
8181

8282
let module Backend = (val Backends.fresh_backend ()) in
8383
let stream = Backend.(new_stream @@ get_device ~ordinal:0) in
8484
let ctx = Backend.make_context stream in
8585
let routine =
86-
Train.to_routine (module Backend) ctx bindings (Asgns.sequence [ update.fwd_bprop; sgd ])
86+
Train.to_routine (module Backend) ctx bindings (Asgns.sequence [ update; sgd ])
8787
in
8888
(* Stdio.print_endline "\n******** scalar_loss **********"; Tensor.print_tree ~with_id:true
8989
~with_grad:false ~depth:9 scalar_loss; Stdio.print_endline "\n******** learning_rate

bin/micrograd_demo_logging.ml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ let () =
3131
let%op g = g + (10. /. f) in
3232
List.iter ~f:(Option.iter ~f:(fun diff -> Train.set_hosted diff.Tensor.grad)) [ a.diff; b.diff ];
3333
let update = Train.grad_update g in
34-
let step = Train.to_routine (module Backend) ctx IDX.empty update.fwd_bprop in
34+
let step = Train.to_routine (module Backend) ctx IDX.empty update in
3535
Utils.capture_stdout_logs @@ fun () ->
3636
Train.run step;
3737
Tensor.print ~with_code:false ~with_grad:false `Default g;

bin/moons_demo.ml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,13 +55,13 @@ let demo () =
5555
let update = Train.grad_update scalar_loss in
5656
let%op learning_rate = 0.1 *. (!..steps - !@step_n) /. !..steps in
5757
Train.set_hosted learning_rate.value;
58-
let sgd = Train.sgd_update ~learning_rate ~weight_decay update in
58+
let sgd = Train.sgd_update ~learning_rate ~weight_decay scalar_loss in
5959

6060
let module Backend = (val Backends.fresh_backend ~backend_name:"cuda" ()) in
6161
let stream = Backend.(new_stream @@ get_device ~ordinal:0) in
6262
let ctx = Backend.make_context stream in
6363
let routine =
64-
Train.to_routine (module Backend) ctx bindings (Asgns.sequence [ update.fwd_bprop; sgd ])
64+
Train.to_routine (module Backend) ctx bindings (Asgns.sequence [ update; sgd ])
6565
in
6666

6767
let points = Tn.points_2d ~xdim:0 ~ydim:1 moons_flat.value in

bin/primitive_ops.ml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ let%debug_sexp graph_t () : unit =
2727
let xs = Array.init size ~f:Float.(fun i -> (of_int i / 10.) + 0.1) in
2828
let x_flat =
2929
Tensor.term ~grad_spec:Require_grad ~label:[ "x_flat" ]
30-
~fetch_op:(Constant_fill { values = xs; strict = true })
30+
~fetch_op:(fun ~v:_ -> Constant_fill xs)
3131
()
3232
in
3333
let step_sym, bindings = IDX.get_static_symbol ~static_range:size IDX.empty in
@@ -37,7 +37,7 @@ let%debug_sexp graph_t () : unit =
3737
Train.set_hosted x_flat.value;
3838
Train.set_hosted (Option.value_exn ~here:[%here] xkcd.diff).grad;
3939
let update = Train.grad_update fx in
40-
let fx_routine = Train.to_routine (module Backend) ctx bindings update.fwd_bprop in
40+
let fx_routine = Train.to_routine (module Backend) ctx bindings update in
4141
let step_ref = IDX.find_exn fx_routine.bindings step_sym in
4242
Tensor.print_tree ~with_shape:true ~with_grad:true ~depth:9 xkcd;
4343
let ys, dys =

bin/zero2hero_1of7.ml

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,12 @@ let _suspended () =
2222
let%op v = ("w" [ (-3, 1) ] * "x" [ 2; 0 ]) + "b" [ 6.7 ] in
2323
Train.every_non_literal_on_host v;
2424
let code = Train.grad_update v in
25-
let routine = Train.to_routine (module Backend) ctx IDX.empty code.fwd_bprop in
25+
let routine = Train.to_routine (module Backend) ctx IDX.empty code in
2626
Train.run routine;
2727
Stdio.printf "\n%!";
2828
Tensor.print_tree ~with_id:true ~with_grad:true ~depth:9 v;
2929
Stdio.printf "\nHigh-level code:\n%!";
30-
Ir.Assignments.to_doc () code.fwd_bprop.asgns |> PPrint.ToChannel.pretty 0.7 100 Stdio.stdout;
30+
Ir.Assignments.to_doc () code.asgns |> PPrint.ToChannel.pretty 0.7 100 Stdio.stdout;
3131
Stdio.printf "\n%!"
3232

3333
let _suspended () =
@@ -57,7 +57,7 @@ let _suspended () =
5757
let x_flat =
5858
Tensor.term ~grad_spec:Tensor.Require_grad
5959
~label:[ "x_flat" ] (* ~input_dims:[] ~output_dims:[ 1 ] *)
60-
~fetch_op:(Constant_fill { values; strict = true })
60+
~fetch_op:(fun ~v:_ -> Constant_fill values)
6161
()
6262
in
6363
let step_sym, bindings = IDX.get_static_symbol ~static_range:size IDX.empty in
@@ -70,7 +70,7 @@ let _suspended () =
7070
let stream = Backend.(new_stream @@ get_device ~ordinal:0) in
7171
let ctx = Backend.make_context stream in
7272
let update = Train.grad_update fx in
73-
let routine = Train.to_routine (module Backend) ctx bindings update.fwd_bprop in
73+
let routine = Train.to_routine (module Backend) ctx bindings update in
7474
let step_ref = IDX.find_exn routine.bindings step_sym in
7575
let ys = Array.create ~len:size 0. and dys = Array.create ~len:size 0. in
7676
let open Operation.At in
@@ -111,7 +111,7 @@ let _suspended () =
111111
(* Yay, the whole shape gets inferred! *)
112112
let x_flat =
113113
Tensor.term ~grad_spec:Require_grad ~label:[ "x_flat" ]
114-
~fetch_op:(Constant_fill { values = xs; strict = true })
114+
~fetch_op:(fun ~v:_ -> Constant_fill xs)
115115
()
116116
in
117117
let step_sym, bindings = IDX.get_static_symbol ~static_range:size IDX.empty in
@@ -120,7 +120,7 @@ let _suspended () =
120120
Train.set_hosted x.value;
121121
Train.set_hosted (Option.value_exn ~here:[%here] x.diff).grad;
122122
let update = Train.grad_update fx in
123-
let fx_routine = Train.to_routine (module Backend) ctx bindings update.fwd_bprop in
123+
let fx_routine = Train.to_routine (module Backend) ctx bindings update in
124124
let step_ref = IDX.find_exn fx_routine.bindings step_sym in
125125
let%track_sexp () =
126126
let ys, dys =
@@ -155,9 +155,7 @@ let () =
155155
let module Backend = (val Backends.fresh_backend ()) in
156156
let stream = Backend.(new_stream @@ get_device ~ordinal:0) in
157157
let update = Train.grad_update l in
158-
let routine =
159-
Train.to_routine (module Backend) (Backend.make_context stream) IDX.empty update.fwd_bprop
160-
in
158+
let routine = Train.to_routine (module Backend) (Backend.make_context stream) IDX.empty update in
161159
Train.run routine;
162160
(* Tensor.iter_embedded l ~f:(fun a -> ignore (Backend.to_host routine.context a : bool));
163161
Backend.await stream; *)
@@ -168,8 +166,7 @@ let () =
168166
Tensor.print_tree ~with_grad:true ~depth:9 l;
169167
let%op learning_rate = 0.1 in
170168
let routine =
171-
Train.to_routine (module Backend) routine.context IDX.empty
172-
@@ Train.sgd_update ~learning_rate update
169+
Train.to_routine (module Backend) routine.context IDX.empty @@ Train.sgd_update ~learning_rate l
173170
in
174171
(* learning_rate is virtual so this will not print anything. *)
175172
Stdio.print_endline
@@ -185,7 +182,7 @@ let () =
185182
Tensor.print_tree ~with_grad:true ~depth:9 l;
186183
(* We could reuse the jitted code if we did not use `jit_and_run`. *)
187184
let update = Train.grad_update l in
188-
let routine = Train.to_routine (module Backend) routine.context IDX.empty update.fwd_bprop in
185+
let routine = Train.to_routine (module Backend) routine.context IDX.empty update in
189186
Train.run routine;
190187
(* Tensor.iter_embedded l ~f:(fun a -> ignore (Backend.to_host routine.context a : bool));
191188
Backend.await stream; *)

0 commit comments

Comments
 (0)