Skip to content

Commit c8d0142

Browse files
committed
Major refactor of the Tensor API to share parameter signatures and reduce boilerplate in configurable operation definitions; some cleanup
1 parent 43cc8d2 commit c8d0142

26 files changed

+382
-330
lines changed

CLAUDE.md

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,13 @@ opam install cudajit # for CUDA backend
8080
- Tests are implemented either as inline expectations using `ppx_expect`; or as cram-style tests where an `.ml` file is compiled, executed, and its output compared against an `.expected` file
8181
- Tutorial files in `test/` serve as both documentation and integration tests
8282
- Use `dune promote` to accept test output changes
83+
- **Test Placement Guidelines**:
84+
* Always add tests under one of the test subdirectories
85+
* Default location is `test/operations`
86+
* Use `test/einsum` for tests involving complex einsum specifications
87+
* Use `test/training` for tests involving training loops
88+
* When adding a test, update the corresponding test stanza
89+
* Add an `.expected` file for test results (can initially be empty)
8390

8491
### Configuration
8592

@@ -134,4 +141,4 @@ opam install cudajit # for CUDA backend
134141
- Virtual nodes are inlined automatically (controlled by `virtualize_max_visits`)
135142
- Scalar constants can be inlined via `inline_scalar_constexprs=true`
136143
- Memory sharing optimizations through cross-stream tensor nodes
137-
- Backend-specific optimization levels configurable per backend
144+
- Backend-specific optimization levels configurable per backend

bin/hello_world.ml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ let hello2 () =
3333
let hello3 () =
3434
let module Backend = (val Backends.fresh_backend ()) in
3535
(* Hey is inferred to be a matrix. *)
36-
let hey = TDSL.param "hey" in
36+
let hey = TDSL.param "hey" () in
3737
let zero_to_twenty = TDSL.range 20 in
3838
let y = TDSL.O.(( + ) ~label:[ "y" ] (hey * zero_to_twenty) zero_to_twenty) in
3939
Train.set_hosted hey.value;
@@ -66,7 +66,7 @@ let hello4 () =
6666
let%op tj = rj ++ "j=>j1" in
6767
let rk = TDSL.range 5 in
6868
let%op tk = rk ++ "k=>k2" in
69-
let positions = TDSL.outer_sum "ijl;kl=>ijkl" (TDSL.outer_sum "il;jl=>ijl" ti tj) tk in
69+
let positions = TDSL.outer_sum "ijl;kl=>ijkl" (TDSL.outer_sum "il;jl=>ijl" ti tj ()) tk () in
7070
Train.set_hosted ti.value;
7171
Train.set_hosted tk.value;
7272
ignore (Train.forward_once backend positions);

bin/hello_world_op.ml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ let%track2_sexp _Big_matrix (() : unit) : unit =
188188
in
189189

190190
(* Hey is inferred to be a matrix. *)
191-
let hey = TDSL.param ~value:0.5 "hey" in
191+
let hey = TDSL.param ~value:0.5 "hey" () in
192192
let zero_to_twenty = TDSL.range 20 in
193193
let%op yd = (hey * zero_to_twenty) + zero_to_twenty in
194194
ignore (Train.forward_once backend yd);

bin/micrograd_demo.ml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ let experiment seed ~no_batch_shape_inference ~use_builtin_weight_decay () =
2121
let moons_coordinates, moons_labels = Datasets.Half_moons.generate ~config:moons_config ~len () in
2222
let moons_flat_ndarray = Ir.Ndarray.as_array Ir.Ops.Double moons_coordinates in
2323
let moons_classes_ndarray = Ir.Ndarray.as_array Ir.Ops.Double moons_labels in
24-
let moons_flat = TDSL.rebatch ~l:"moons_flat" moons_flat_ndarray in
25-
let moons_classes = TDSL.rebatch ~l:"moons_classes" moons_classes_ndarray in
24+
let moons_flat = TDSL.rebatch ~l:"moons_flat" moons_flat_ndarray () in
25+
let moons_classes = TDSL.rebatch ~l:"moons_classes" moons_classes_ndarray () in
2626
let batch_n, bindings = IDX.get_static_symbol ~static_range:n_batches IDX.empty in
2727
let step_n, bindings = IDX.get_static_symbol bindings in
2828
let%op mlp x = "b3" + ("w3" * relu ("b2" hid_dim + ("w2" * relu ("b1" hid_dim + ("w1" * x))))) in

bin/moons_benchmark.ml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,8 @@ let classify_moons ~seed ~on_device ~inlining_cutoff ~num_streams ~batch_size ~b
5757
in
5858
let moons_flat_ndarray = Ir.Ndarray.as_array Ir.Ops.Double moons_coordinates in
5959
let moons_classes_ndarray = Ir.Ndarray.as_array Ir.Ops.Double moons_labels in
60-
let moons_flat ~b:_ = TDSL.rebatch ~l:"moons_flat" moons_flat_ndarray in
61-
let moons_classes ~b:_ = TDSL.rebatch ~l:"moons_classes" moons_classes_ndarray in
60+
let moons_flat = TDSL.rebatch ~l:"moons_flat" moons_flat_ndarray () in
61+
let moons_classes = TDSL.rebatch ~l:"moons_classes" moons_classes_ndarray () in
6262

6363
let init_time = Time_now.nanoseconds_since_unix_epoch () in
6464
let%op mlp x =

bin/moons_demo.ml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,10 @@ let demo () =
2828
let config = Datasets.Half_moons.Config.{ noise_range = 0.1; seed = Some seed } in
2929
let moons_coordinates, moons_labels = Datasets.Half_moons.generate_single_prec ~config ~len () in
3030
let moons_flat =
31-
TDSL.rebatch ~l:"moons_flat" (Ir.Ndarray.as_array Ir.Ops.Single moons_coordinates)
31+
TDSL.rebatch ~l:"moons_flat" (Ir.Ndarray.as_array Ir.Ops.Single moons_coordinates) ()
3232
in
3333
let moons_classes =
34-
TDSL.rebatch ~l:"moons_classes" (Ir.Ndarray.as_array Ir.Ops.Single moons_labels)
34+
TDSL.rebatch ~l:"moons_classes" (Ir.Ndarray.as_array Ir.Ops.Single moons_labels) ()
3535
in
3636

3737
let batch_n, bindings = IDX.get_static_symbol ~static_range:n_batches IDX.empty in

bin/moons_demo_parallel.ml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ let experiment ~seed ~backend_name ~config () =
2121
let moons_coordinates, moons_labels = Datasets.Half_moons.generate ~config:moons_config ~len () in
2222
let moons_flat_ndarray = Ir.Ndarray.as_array Ir.Ops.Double moons_coordinates in
2323
let moons_classes_ndarray = Ir.Ndarray.as_array Ir.Ops.Double moons_labels in
24-
let moons_flat ~b:_ = TDSL.rebatch ~l:"moons_flat" moons_flat_ndarray in
25-
let moons_classes ~b:_ = TDSL.rebatch ~l:"moons_classes" moons_classes_ndarray in
24+
let moons_flat = TDSL.rebatch ~l:"moons_flat" moons_flat_ndarray () in
25+
let moons_classes = TDSL.rebatch ~l:"moons_classes" moons_classes_ndarray () in
2626
let%op mlp x = "b3" + ("w3" * relu ("b2" hid_dim + ("w2" * relu ("b1" hid_dim + ("w1" * x))))) in
2727
(* let%op mlp x = "b" + ("w" * x) in *)
2828
let%op loss_fn ~output ~expectation = relu (!..1 - (expectation *. output)) in

bin/primitive_ops.ml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ let graph_t () : unit =
1919
let open Operation.At in
2020
CDSL.virtualize_settings.enable_device_only <- false;
2121
(* let%op f x = sin x in *)
22-
let%op f x = uint4x32_to_prec_uniform x in
22+
let%op f x = sin x in
2323
let size = 50 in
2424
let xs = Array.init size ~f:Float.(fun i -> (of_int i / 10.) + 0.1) in
2525
let x_flat = Tensor.term_init xs ~label:[ "x_flat" ] ~grad_spec:Require_grad () in

0 commit comments

Comments
 (0)