Skip to content

Commit 842daaa

Browse files
committed
Migrate the syntax away from operators for unary primitive ops and relu
1 parent e126bd2 commit 842daaa

19 files changed

+124
-79
lines changed

arrayjit/lib/ops.ml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -273,9 +273,9 @@ let binop_cd_fallback_syntax = function
273273
| Relu_gate -> "relu_gate"
274274
| Cmplt -> "lt"
275275
| Cmpne -> "le"
276-
| Or -> "orf"
277-
| And -> "andf"
278-
| Mod -> "modf"
276+
| Or -> "or_"
277+
| And -> "and_"
278+
| Mod -> "mod_"
279279
| Max -> "max"
280280
| Min -> "min"
281281
(* | Shl -> "shlf" *)

bin/hello_world.ml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ let hello3 () =
5151
let stream = Backend.(new_stream @@ get_device ~ordinal:0) in
5252
let ctx = Backend.make_context stream in
5353
(* Hey is inferred to be a matrix. *)
54-
let hey = TDSL.O.(!~"hey") in
54+
let hey = Tensor.param "hey" in
5555
let zero_to_twenty = TDSL.range 20 in
5656
let y = TDSL.O.(( + ) ~label:[ "y" ] (hey * zero_to_twenty) zero_to_twenty) in
5757
Train.set_hosted hey.value;

bin/micrograd_basic.ml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@ let%diagn_sexp _suspended () : unit =
4343
let%op d = (a *. b) + (b **. 3) in
4444
let%op c = c + c + 1 in
4545
let%op c = c + 1 + c + ~-a in
46-
let%op d = d + (d *. 2) + ?/(b + a) in
47-
let%op d = d + (3 *. d) + ?/(b - a) in
46+
let%op d = d + (d *. 2) + relu (b + a) in
47+
let%op d = d + (3 *. d) + relu (b - a) in
4848
let%op e = c - d in
4949
let%op f = e *. e in
5050
let%op g = f /. 2 in

bin/micrograd_demo.ml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ let experiment seed ~no_batch_shape_inference ~use_builtin_weight_decay () =
3939
let moons_classes = TDSL.init_const ~l:"moons_classes" ?b ~o:[ 1 ] moons_classes in
4040
let batch_n, bindings = IDX.get_static_symbol ~static_range:n_batches IDX.empty in
4141
let step_n, bindings = IDX.get_static_symbol bindings in
42-
let%op mlp x = "b3" + ("w3" * ?/("b2" hid_dim + ("w2" * ?/("b1" hid_dim + ("w1" * x))))) in
42+
let%op mlp x = "b3" + ("w3" * relu ("b2" hid_dim + ("w2" * relu ("b1" hid_dim + ("w1" * x))))) in
4343
let%op moons_input = moons_flat @| batch_n in
4444
(* Tell shape inference to make a minibatch axis. *)
4545
let () =
@@ -56,7 +56,7 @@ let experiment seed ~no_batch_shape_inference ~use_builtin_weight_decay () =
5656
let losses = ref [] in
5757
let log_losses = ref [] in
5858
let learning_rates = ref [] in
59-
let%op margin_loss = ?/(1 - (moons_class *. mlp moons_input)) in
59+
let%op margin_loss = relu (1 - (moons_class *. mlp moons_input)) in
6060
(* We don't need a regression loss formula thanks to weight_decay built into the sgd_update
6161
computation. *)
6262
let scalar_loss, weight_decay =

bin/moons_benchmark.ml

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,11 +76,13 @@ let classify_moons ~seed ~on_device ~inlining_cutoff ~num_streams ~batch_size ~b
7676
let init_time = Time_now.nanoseconds_since_unix_epoch () in
7777
let%op mlp x =
7878
"w4"
79-
* ?/("b3" hid_dim_3 + ("w3" * ?/("b2" hid_dim_2 + ("w2" * ?/("b1" hid_dim_1 + ("w1" * x))))))
79+
* relu
80+
("b3" hid_dim_3
81+
+ ("w3" * relu ("b2" hid_dim_2 + ("w2" * relu ("b1" hid_dim_1 + ("w1" * x))))))
8082
in
8183
(* TINY for debugging: *)
82-
(* let%op mlp x = "w2" * ?/("b1" hid_dim + ("w1" * x)) in *)
83-
let%op loss_fn ~output ~expectation = ?/(!..1 - (expectation *. output)) in
84+
(* let%op mlp x = "w2" * relu("b1" hid_dim + ("w1" * x)) in *)
85+
let%op loss_fn ~output ~expectation = relu (!..1 - (expectation *. output)) in
8486
let start_time = ref None in
8587
let weight_decay = 0.0002 in
8688
Arrayjit.Schedulers.sync_suggested_num_streams := num_streams;

bin/moons_demo.ml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ let demo () =
2727
let steps = epochs * n_batches in
2828
let weight_decay = 0.0002 in
2929

30-
let%op mlp x = "b3" + ("w3" * ?/("b2" hid_dim + ("w2" * ?/("b1" hid_dim + ("w1" * x))))) in
30+
let%op mlp x = "b3" + ("w3" * relu ("b2" hid_dim + ("w2" * relu ("b1" hid_dim + ("w1" * x))))) in
3131

3232
let noise () = Rand.float_range (-0.1) 0.1 in
3333
let moons_flat =
@@ -49,7 +49,7 @@ let demo () =
4949
let%op moons_input = moons_flat @| batch_n in
5050
let%op moons_class = moons_classes @| batch_n in
5151

52-
let%op margin_loss = ?/(1 - (moons_class *. mlp moons_input)) in
52+
let%op margin_loss = relu (1 - (moons_class *. mlp moons_input)) in
5353
let%op scalar_loss = (margin_loss ++ "...|... => 0") /. !..batch_size in
5454

5555
let update = Train.grad_update scalar_loss in

bin/moons_demo_parallel.ml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,9 @@ let experiment ~seed ~backend_name ~config () =
3535
let moons_flat ~b = TDSL.init_const ~l:"moons_flat" ~b ~o:[ 2 ] moons_flat in
3636
let moons_classes = Array.init (len * 2) ~f:(fun i -> if i % 2 = 0 then 1. else -1.) in
3737
let moons_classes ~b = TDSL.init_const ~l:"moons_classes" ~b ~o:[ 1 ] moons_classes in
38-
let%op mlp x = "b3" + ("w3" * ?/("b2" hid_dim + ("w2" * ?/("b1" hid_dim + ("w1" * x))))) in
38+
let%op mlp x = "b3" + ("w3" * relu ("b2" hid_dim + ("w2" * relu ("b1" hid_dim + ("w1" * x))))) in
3939
(* let%op mlp x = "b" + ("w" * x) in *)
40-
let%op loss_fn ~output ~expectation = ?/(!..1 - (expectation *. output)) in
40+
let%op loss_fn ~output ~expectation = relu (!..1 - (expectation *. output)) in
4141
(* We don't need a regression loss formula thanks to weight_decay built into the sgd_update
4242
computation. *)
4343
let weight_decay = 0.0002 in

docs/OCANNL-ocaml_workshop_2024.tm

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@
55
<\body>
66
<doc-data|<doc-title|OCANNL optimization framework>|<doc-subtitle|Tensor
77
shape inference, concise notation, multidevice
8-
runtime>|<doc-author|<author-data|<author-name|Šukasz
8+
runtime>|<doc-author|<author-data|<author-name|�ukasz
99
Stafiniak>|<\author-note>
1010
Since April 2024, <hlink|<with|font-family|tt|<with|color|orange|a><with|color|blue|hrefs>>|https://ahrefs.com/>
11-
sponsors Šukasz's work on OCANNL.
11+
sponsors �ukasz's work on OCANNL.
1212
</author-note>>>>
1313

1414
<abstract-data|<abstract|OCANNL is a Deep Learning framework with
@@ -98,14 +98,14 @@
9898
fetaures <with|font-shape|italic|parameter punning> (strings become
9999
let-bindings of tensors) and inline output dimensions specification. Full
100100
example of a Multi Layer Perceptron with 2 hidden layers and Rectified
101-
Linear Unit non-linearity <verbatim|(?/)>, defining tensors <verbatim|b1>,
101+
Linear Unit non-linearity <verbatim|(relu)>, defining tensors <verbatim|b1>,
102102
<verbatim|w1>, <verbatim|b2>, <verbatim|w2>, <verbatim|b3>, <verbatim|w3>,
103103
and a tensor-returning function <verbatim|mlp>:
104104

105105
<\verbatim-code>
106106
let%op mlp x =
107107

108-
\ \ "b3" + ("w3" * ?/("b2" hid_dim + ("w2" * ?/("b1" hid_dim + ("w1" *
108+
\ \ "b3" + ("w3" * relu("b2" hid_dim + ("w2" * relu("b1" hid_dim + ("w1" *
109109
x)))))
110110
</verbatim-code>
111111

lib/nn_blocks.ml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ module NTDSL = Operation.NTDSL
77

88
type mlp_layer_config = { label : string list; hid_dim : int }
99

10-
let%op mlp_layer ~config x = ?/(("w" * x) + "b" config.hid_dim)
10+
let%op mlp_layer ~config x = relu (("w" * x) + "b" config.hid_dim)
1111

1212
type mlp_config = { label : string list; hid_dims : int list }
1313

lib/operation.ml

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -127,13 +127,13 @@ let relu ?(label = []) =
127127
let module NTDSL = Initial_NTDSL in
128128
let%cd op_asn ~v ~t1 ~projections = v =: relu v1 ~projections in
129129
let%cd grad_asn ~v ~g ~t1 ~projections = g1 =+ v -?/ g in
130-
Tensor.unop ~label:("?/" :: label) ~transpose_op:Pointwise_un ~op_asn ~grad_asn
130+
Tensor.unop ~label:("relu" :: label) ~transpose_op:Pointwise_un ~op_asn ~grad_asn
131131

132132
module NDO_without_pow = struct
133133
let ( * ) = matmul ~grad_spec:Prohibit_grad
134134
let ( *. ) = pointmul ~grad_spec:Prohibit_grad
135135
let ( + ) = add ~grad_spec:Prohibit_grad
136-
let ( ?/ ) = relu ~grad_spec:Prohibit_grad
136+
let relu = relu ~grad_spec:Prohibit_grad
137137
let ( !. ) = Tensor.number ~grad_spec:Prohibit_grad
138138
let ( !.. ) ?label i = Tensor.number ?label ~grad_spec:Prohibit_grad @@ Float.of_int i
139139
let ( - ) = sub ~grad_spec:Prohibit_grad
@@ -260,8 +260,7 @@ module DO = struct
260260
let ( *. ) = pointmul ~grad_spec:If_needed
261261
let ( + ) = add ~grad_spec:If_needed
262262
let ( **. ) ?label base exp = pointpow ?label exp base ~grad_spec:If_needed
263-
let ( ?/ ) = relu ~grad_spec:If_needed
264-
let ( !~ ) label = Tensor.param label
263+
let relu = relu ~grad_spec:If_needed
265264
let ( !. ) = Tensor.number ~grad_spec:If_needed
266265
let ( !.. ) ?label i = Tensor.number ?label ~grad_spec:If_needed @@ Float.of_int i
267266
let ( !@ ) = embed_symbol

0 commit comments

Comments
 (0)