Skip to content

Commit 58c7187

Browse files
authored
Merge pull request #369 from ahrefs/feature/record-syntax
Replace string-based inline tensor definitions with record syntax
2 parents 7fe3406 + 96f754c commit 58c7187

29 files changed

+494
-208
lines changed

CLAUDE.md

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,9 +151,19 @@ opam install cudajit # for CUDA backend
151151

152152
- `%cd` requires `NTDSL` module in scope (from `Operation.NTDSL`)
153153
- `%op` requires `TDSL` module in scope (from `Operation.TDSL`)
154-
- Inline tensor declarations using string literals
154+
- Record syntax for inline tensor declarations: `{ tensor_name }` or `{ tensor_name = init_expr }`
155155
- Generalized einsum notation for complex tensor operations
156156

157+
**Key differences between %op and %cd**:
158+
- `%op` allows initialization expressions (`{ x = uniform () }`), used for model parameters
159+
- `%cd` is self-referential only (`{ x }`), used in computation graphs where tensors are defined by operations
160+
161+
**Record syntax features**:
162+
- OCaml punning: `{ x }` expands to default initialization (uniform() for parameters in %op)
163+
- Shorthand field names: `o``output_dims`, `i``input_dims`, `b``batch_dims`
164+
- Additional fields map to labeled arguments of tensor creation functions
165+
- Dimension specification: lists `[...]` for output, tuples `(...)` for input, arrays `[|...|]` for batch
166+
157167
## Common Development Tasks
158168

159169
### Adding New Operations

bin/compilation_speed.ml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ let benchmark_overhead backend () =
1616
CDSL.disable_all_debugs ();
1717
Stdio.prerr_endline @@ "\n\n****** Benchmarking " ^ Backend.name ^ " ******";
1818
let init_time = Time_now.nanoseconds_since_unix_epoch () in
19-
let%op f = (3 *. ("x" [ 5 ] **. 2)) - (4 *. x) + 5 in
19+
let%op f = (3 *. ({ x; o = [ 5 ] } **. 2)) - (4 *. x) + 5 in
2020
Train.set_hosted f.value;
2121

2222
(* Train.every_non_literal_on_host f; *)

bin/hello_world.ml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ let hello1 () =
2323
let hello2 () =
2424
let module Backend = (val Backends.fresh_backend ()) in
2525
(* Hey is inferred to be a matrix. *)
26-
let%op y = ("hey" * 'q' 2.0) + 'p' 1.0 in
26+
let%op y = ({ hey } * 'q' 2.0) + 'p' 1.0 in
2727
(* Punning for ["hey"] above introduced the [hey] identifier. *)
2828
Train.every_non_literal_on_host y;
2929
ignore (Train.forward_once (module Backend) y);
@@ -106,7 +106,7 @@ let hello6 () =
106106
in
107107

108108
(* "Hey" is inferred to be a scalar. *)
109-
let%op y = 2 *. "hey" in
109+
let%op y = 2 *. { hey } in
110110
ignore (Train.forward_once backend y);
111111
(* Train.printf ~here:[%here] ~with_code:false ~with_grad:false hey; *)
112112
Train.printf ~here:[%here] ~with_code:false ~with_grad:false y

bin/hello_world_op.ml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ let%track2_sexp _Pointwise_multiplication_dims_1 (() : unit) : unit =
2727
in
2828

2929
(* "Hey" is inferred to be a scalar. *)
30-
let%op ya = 2 *. "hey" 7.0 in
30+
let%op ya = 2 *. { hey = 7.0 } in
3131
ignore (Train.forward_once backend ya);
3232
Train.printf ~here:[%here] ~with_code:false ~with_grad:false ya
3333

@@ -44,7 +44,7 @@ let%track2_sexp _Matrix_multiplication_dims_1x1 (() : unit) : unit =
4444
in
4545

4646
(* Hey is inferred to be a matrix because of matrix multiplication [*]. *)
47-
let%op yb = ("hey" 7.0 * 'q' 2.0) + 'p' 1.0 in
47+
let%op yb = ({ hey = 7.0 } * 'q' 2.0) + 'p' 1.0 in
4848
ignore (Train.forward_once backend yb);
4949
(* Punning for ["hey"] above introduced the [hey] identifier. *)
5050
Train.printf ~here:[%here] ~with_code:false ~with_grad:false hey;
@@ -172,7 +172,7 @@ let%track2_sexp _Matrix_multiplication_dims_2x3 (() : unit) : unit =
172172
in
173173

174174
(* Hey is inferred to be a matrix. *)
175-
let%op yc = ("hey" 7.0 * [ 2; 3 ]) + [ 4; 5; 6 ] in
175+
let%op yc = ({ hey = 7.0 } * [ 2; 3 ]) + [ 4; 5; 6 ] in
176176
ignore (Train.forward_once backend yc);
177177
Train.printf ~here:[%here] ~with_code:false ~with_grad:false hey;
178178
Train.printf ~here:[%here] ~with_code:false ~with_grad:false yc

bin/micrograd_basic.ml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ let _get_local_debug_runtime = Utils.get_local_debug_runtime
99

1010
let%diagn_sexp () =
1111
let module Backend = (val Backends.fresh_backend ~backend_name:"multicore_cc" ()) in
12-
let%op c = "a" [ -4 ] + "b" [ 2 ] in
12+
let%op c = { a = [ -4 ] } + { b = [ 2 ] } in
1313
let%op d = c + c + 1 in
1414
(* let%op c = c + 1 + c + ~-a in *)
1515
(* Uncomment just the first "fully on host" line to see which arrays can be virtual, and just the
@@ -25,7 +25,7 @@ let%diagn_sexp () =
2525
Train.printf ~here:[%here] ~with_code:false ~with_grad:true b
2626

2727
let%diagn_sexp _suspended () : unit =
28-
let%op c = "a" [ -4 ] + "b" [ 2 ] in
28+
let%op c = { a = [ -4 ] } + { b = [ 2 ] } in
2929
let%op d = (a *. b) + (b **. 3) in
3030
let%op c = c + c + 1 in
3131
let%op c = c + 1 + c + ~-a in

bin/micrograd_demo_logging.ml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ module type Backend = Ir.Backend_intf.Backend
1212
let () =
1313
Tensor.unsafe_reinitialize ();
1414
let module Backend = (val Backends.fresh_backend ()) in
15-
let%op c = "a" [ -4 ] + "b" [ 2 ] in
15+
let%op c = { a = [ -4 ] } + { b = [ 2 ] } in
1616
let%op d = (a *. b) + (b **. 3) in
1717
let%op c = c + c + 1 in
1818
let%op c = c + 1 + c + ~-a in
@@ -33,7 +33,7 @@ let () =
3333
let _suspended () =
3434
Tensor.unsafe_reinitialize ();
3535
let module Backend = (val Backends.fresh_backend ()) in
36-
let%op c = "a" [ -4 ] + "b" [ 2 ] in
36+
let%op c = { a = [ -4 ] } + { b = [ 2 ] } in
3737
let%op d = (a *. b) + (b **. 3) in
3838
let%op c = c + c + 1 in
3939
let%op c = c + 1 + c + ~-a in

bin/moons_benchmark.ml

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,13 +64,16 @@ let classify_moons ~seed ~on_device ~inlining_cutoff ~num_streams ~batch_size ~b
6464

6565
let init_time = Time_now.nanoseconds_since_unix_epoch () in
6666
let%op mlp x =
67-
"w4"
67+
{ w4 }
6868
* relu
69-
("b3" hid_dim_3
70-
+ ("w3" * relu ("b2" hid_dim_2 + ("w2" * relu ("b1" hid_dim_1 + ("w1" * x))))))
69+
({ b3; o = [ hid_dim_3 ] }
70+
+ { w3 }
71+
* relu
72+
({ b2; o = [ hid_dim_2 ] }
73+
+ ({ w2 } * relu ({ b1; o = [ hid_dim_1 ] } + ({ w1 } * x)))))
7174
in
7275
(* TINY for debugging: *)
73-
(* let%op mlp x = "w2" * relu("b1" hid_dim + ("w1" * x)) in *)
76+
(* let%op mlp x = { w2 } * relu({ b1; o = [ hid_dim ] } + ({ w1 } * x)) in *)
7477
let%op loss_fn ~output ~expectation = relu (!..1 - (expectation *. output)) in
7578
let start_time = ref None in
7679
let weight_decay = 0.0002 in

bin/zero2hero_1of7.ml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ let _get_local_debug_runtime = Utils.get_local_debug_runtime
1717

1818
let _suspended () =
1919
let module Backend = (val Backends.fresh_backend ()) in
20-
let%op v = ("w" [ (-3, 1) ] * "x" [ 2; 0 ]) + "b" [ 6.7 ] in
20+
let%op v = ({ w = [ (-3, 1) ] } * { x = [ 2; 0 ] }) + { b = [ 6.7 ] } in
2121
Train.every_non_literal_on_host v;
2222
let code = Train.grad_update v in
2323
let stream = Backend.(new_stream @@ get_device ~ordinal:0) in
@@ -131,9 +131,9 @@ let _suspended () =
131131
()
132132

133133
let _suspended () =
134-
let%op e = "a" [ 2 ] *. "b" [ -3 ] in
135-
let%op d = e + "c" [ 10 ] in
136-
let%op l = d *. "f" [ -2 ] in
134+
let%op e = { a = [ 2 ] } *. { b = [ -3 ] } in
135+
let%op d = e + { c = [ 10 ] } in
136+
let%op l = d *. { f = [ -2 ] } in
137137
Train.every_non_literal_on_host l;
138138
let module Backend = (val Backends.fresh_backend ()) in
139139
let ctx = Train.update_once (module Backend) ~hosted:true l in

lib/nn_blocks.ml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ module NTDSL = Operation.NTDSL
77

88
type mlp_layer_config = { label : string list; hid_dim : int }
99

10-
let%op mlp_layer ~config x = relu (("w" * x) + "b" config.hid_dim)
10+
let%op mlp_layer ~config x = relu (({ w = uniform () } * x) + { b = 0.; o = [ config.hid_dim ] })
1111

1212
type mlp_config = { label : string list; hid_dims : int list }
1313

lib/operation.ml

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -639,13 +639,14 @@ module TDSL = struct
639639
(** The default initialization operation for {!param} calls. *)
640640
let default_param_init = ref (uniform ~grad_spec:Require_grad)
641641

642-
let param ?value ?values =
642+
let param ?value ?values ?param_init =
643643
let t =
644-
match (value, values) with
645-
| Some _, Some _ -> invalid_arg "TDSL.param: both value and values are set"
646-
| Some value, None -> Tensor.term_init ~grad_spec:Require_grad [| value |]
647-
| None, Some values -> Tensor.term_init ~grad_spec:Require_grad values
648-
| None, None -> !default_param_init ()
644+
match (value, values, param_init) with
645+
| Some value, None, None -> Tensor.term_init ~grad_spec:Require_grad [| value |]
646+
| None, Some values, None -> Tensor.term_init ~grad_spec:Require_grad values
647+
| None, None, Some param_init -> param_init
648+
| None, None, None -> !default_param_init ()
649+
| _ -> invalid_arg "TDSL.param: at most one of value, values, and param_init can be set"
649650
in
650651
Tensor.param ~t
651652

0 commit comments

Comments
 (0)