Skip to content

Commit 8a26b97

Browse files
committed
Changed %cd syntax ~~ to allow detailed structuring.
Rewrote `Train.grad_update` to use the `%cd` syntax.
1 parent 72d12ba commit 8a26b97

File tree

4 files changed

+74
-62
lines changed

4 files changed

+74
-62
lines changed

CHANGES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
### Changed
1414

1515
- Removed the `pipes_cc, pipes_gccjit` backends (`Pipes_multicore_backend`) -- I had fixed `Pipes_multicore_backend` by using the `poll` library instead of `Unix.select`, but it turns out to be very very slow.
16+
- Changed the `%cd` block comment syntax `~~` to allow detailed structuring. Rewrote `Train.grad_update` to use the `%cd` syntax.
1617

1718
### Fixed
1819

lib/ppx_cd.ml

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -738,6 +738,26 @@ let translate (expr : expression) : result =
738738
Ast_builder.Default.pexp_extension ~loc
739739
@@ Location.error_extensionf ~loc "ppx_ocannl %%cd: repeated .merge not allowed";
740740
})
741+
| [%expr
742+
~~([%e? { pexp_desc = Pexp_apply (expr, exprs); pexp_loc; _ }];
743+
[%e? expr2])] ->
744+
let elements =
745+
expr :: List.map ~f:snd exprs
746+
|> List.map ~f:(function
747+
| { pexp_desc = Pexp_constant (Pconst_string _); _ } as s -> s
748+
| [%expr [%e? t].value] -> [%expr Arrayjit.Tnode.debug_name [%e t].value]
749+
| [%expr [%e? t].grad] -> [%expr Arrayjit.Tnode.debug_name [%e t].value ^ ".grad"]
750+
| t -> [%expr Arrayjit.Tnode.debug_name [%e t].value])
751+
in
752+
let res2 = loop ~proj_in_scope expr2 in
753+
{
754+
res2 with
755+
expr =
756+
[%expr
757+
Arrayjit.Assignments.Block_comment
758+
( String.concat_array ~sep:" " [%e Ast_helper.Exp.array ~loc:pexp_loc elements],
759+
[%e res2.expr] )];
760+
}
741761
| [%expr
742762
[%e? accu_op]
743763
[%e? lhs]
@@ -916,26 +936,6 @@ let translate (expr : expression) : result =
916936
@@ Location.error_extensionf ~loc
917937
"ppx_ocannl %%cd: for-downto: low-level code embeddings not supported yet";
918938
}
919-
| [%expr
920-
~~[%e? { pexp_desc = Pexp_apply (expr, exprs); pexp_loc; _ }];
921-
[%e? expr2]] ->
922-
let elements =
923-
expr :: List.map ~f:snd exprs
924-
|> List.map ~f:(function
925-
| { pexp_desc = Pexp_constant (Pconst_string _); _ } as s -> s
926-
| [%expr [%e? t].value] -> [%expr Arrayjit.Tnode.debug_name [%e t].value]
927-
| [%expr [%e? t].grad] -> [%expr Arrayjit.Tnode.debug_name [%e t].value ^ ".grad"]
928-
| t -> [%expr Arrayjit.Tnode.debug_name [%e t].value])
929-
in
930-
let res2 = loop ~proj_in_scope expr2 in
931-
{
932-
res2 with
933-
expr =
934-
[%expr
935-
Arrayjit.Assignments.Block_comment
936-
( String.concat_array ~sep:" " [%e Ast_helper.Exp.array ~loc:pexp_loc elements],
937-
[%e res2.expr] )];
938-
}
939939
| [%expr
940940
[%e? expr1];
941941
[%e? expr2]] ->

lib/syntax_extensions.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ type Assignments.t =
204204
...
205205
```
206206

207-
Schematic example: `~~("space" "separated" "comment" "tensor p debug_name:" p); <scope of the comment>`. The content of the comment uses application syntax, must be composed of strings, `<tensor>`, `<tensor>.value` (equivalent to `<tensor>`), `<tensor>.grad` components, where `<tensor>` is any tensor expression or tensor identifier.
207+
Schematic example: `~~("space" "separated" "comment" "tensor p debug_name:" p; <scope of the comment>)`. The content of the comment uses application syntax, must be composed of strings, `<tensor>`, `<tensor>.value` (equivalent to `<tensor>`), `<tensor>.grad` components, where `<tensor>` is any tensor expression or tensor identifier.
208208

209209
## Further features of the syntax extension `%op` {#features-of-syntax-op}
210210

lib/train.ml

Lines changed: 52 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -111,27 +111,41 @@ let set_hosted (a : Tn.t) =
111111
if Tn.known_constant a then Tn.update_memory_mode a (Hosted Constant) 41
112112
else Tn.update_memory_mode a (Hosted Changed_on_devices) 41
113113

114-
let label_suffix label =
115-
(* FIXME: this should be label prefix, as most valuable label components come first. *)
116-
Option.value ~default:"unknown"
117-
@@ List.find ~f:(String.for_all ~f:(fun c -> Char.is_alphanum c || equal_char '_' c))
118-
@@ List.rev label
119-
120114
(** Sets the tensor's value as "fully on host", returns the tensor's forward code with a
121115
label-derived comment. *)
122116
let forward ?(disable_rootness_check = false) t =
123117
let fwd = if disable_rootness_check then t.Tensor.forward else Tensor.consume_forward_code t in
124118
set_hosted t.Tensor.value;
125-
let label = label_suffix t.Tensor.value.label in
119+
let label = Tn.debug_name t.value in
126120
Asgns.Block_comment (label ^ " fwd", fwd)
127121

128122
type updaten = {
129123
loss : Tensor.t;
130-
label : string;
131124
params : (Tensor.t, Tensor.comparator_witness) Base.Set.t;
132125
fwd_bprop : Asgns.t;
133126
}
134127

128+
let diff_or_error t provenance =
129+
Option.value_or_thunk t.Tensor.diff ~default:(fun () ->
130+
raise @@ Tensor.Session_error (provenance ^ ": tensor is not differentiable", Some t))
131+
132+
let grad_update_nochecks loss =
133+
let params = get_params loss in
134+
let diff = diff_or_error loss "Train.grad_update_nochecks" in
135+
let fwd_bprop =
136+
let%cd init_grad = loss.grad =: 1 in
137+
[%cd
138+
~~(loss "gradient update";
139+
~~(loss "fwd";
140+
loss.forward);
141+
~~(loss "zero grads";
142+
diff.zero_grads);
143+
init_grad;
144+
~~(loss "bprop";
145+
diff.backprop))]
146+
in
147+
{ loss; params; fwd_bprop }
148+
135149
(** Returns the tensor's forward, zeroing gradients, and backprop code wrapped with label-derived
136150
comments. Sets the tensor's value as "fully on host". If [setup_for_parallel] is true (false by
137151
default), sets the parameters and their gradients as "non-local" (on-device). *)
@@ -140,52 +154,49 @@ let grad_update ?(disable_rootness_check = false) ?(setup_for_parallel = false)
140154
let params = get_params loss in
141155
if setup_for_parallel then
142156
Set.iter params ~f:(fun p -> set_materialized (Option.value_exn ~here:[%here] p.diff).grad);
143-
let label = label_suffix loss.value.label in
144157
let fwd =
145158
if disable_rootness_check then loss.Tensor.forward else Tensor.consume_forward_code loss
146159
in
160+
let diff = diff_or_error loss "Train.grad_update" in
147161
let fwd_bprop =
148-
match loss.Tensor.diff with
149-
| Some diff ->
150-
let zero_grads, bprop =
151-
if disable_rootness_check then (diff.zero_grads, diff.backprop)
152-
else Tensor.consume_backprop_code loss
153-
in
154-
(* Note: the %cd syntax for [loss.grad] does not modify roots. *)
155-
let%cd init_grad = loss.grad =: 1 in
156-
Asgns.(
157-
Block_comment
158-
( label ^ " gradient update",
159-
sequential
160-
[
161-
Block_comment (label ^ " fwd", fwd);
162-
Block_comment (label ^ " zero grads", zero_grads);
163-
init_grad;
164-
Block_comment (label ^ " bprop", bprop);
165-
] ))
166-
| None ->
167-
raise @@ Tensor.Session_error ("Train.grad_update: tensor is not differentiable", Some loss)
162+
let zero_grads, bprop =
163+
if disable_rootness_check then (diff.zero_grads, diff.backprop)
164+
else Tensor.consume_backprop_code loss
165+
in
166+
(* Note: the %cd syntax for [loss.grad] does not modify roots. *)
167+
let%cd init_grad = loss.grad =: 1 in
168+
[%cd
169+
~~(loss "gradient update";
170+
~~(loss "fwd";
171+
fwd);
172+
~~(loss "zero grads";
173+
zero_grads);
174+
init_grad;
175+
~~(loss "bprop";
176+
bprop))]
168177
in
169-
{ loss; label; params; fwd_bprop }
178+
{ loss; params; fwd_bprop }
170179

171180
(** See: https://github.com/tinygrad/tinygrad/blob/master/tinygrad/nn/optim.py *)
172181
let sgd_one ~learning_rate ?(momentum = 0.0) ?(weight_decay = 0.0) ?(nesterov = false) p =
173182
if not @@ is_param p then raise @@ Tensor.Session_error ("Train.sgd_one: not a parameter", Some p);
174183
[%cd
175-
~~(p "param sgd step");
176-
"sgd_delta" =: p.grad + (!.weight_decay *. p);
177-
if Float.(momentum > 0.0) then (
178-
"sgd_momentum" =: (!.momentum *. sgd_momentum) + sgd_delta;
179-
if nesterov then sgd_delta =+ !.momentum *. sgd_momentum else sgd_delta =: sgd_momentum);
180-
p =- learning_rate *. sgd_delta]
184+
~~(p "param sgd step";
185+
"sgd_delta" =: p.grad + (!.weight_decay *. p);
186+
if Float.(momentum > 0.0) then (
187+
"sgd_momentum" =: (!.momentum *. sgd_momentum) + sgd_delta;
188+
if nesterov then sgd_delta =+ !.momentum *. sgd_momentum else sgd_delta =: sgd_momentum);
189+
p =- learning_rate *. sgd_delta)]
181190

182191
let sgd_update ~learning_rate ?momentum ?weight_decay ?nesterov l =
183192
let code =
184193
l.params |> Set.to_list
185194
|> List.map ~f:(sgd_one ~learning_rate ?momentum ?weight_decay ?nesterov)
186195
|> Asgns.sequential
187196
in
188-
Asgns.Block_comment (l.label ^ " sgd update", code)
197+
[%cd
198+
~~(l.loss "sgd update";
199+
code)]
189200

190201
(** All and only bindings with associated ranges are iterated, with the binding's initial value
191202
lost. Bindings without ranges remain at their initial values. *)
@@ -328,8 +339,8 @@ let%track3_sexp parallel_update (type context)
328339
let grad_merges : Asgns.t array =
329340
Array.map all_params ~f:(fun p ->
330341
[%cd
331-
~~("merging gradient of" p);
332-
p.grad =+ p.grad.merge])
342+
~~("merging gradient of" p;
343+
p.grad =+ p.grad.merge)])
333344
in
334345
let grad_merges_to : Backend.routine option array array =
335346
(* For now, we need all params on all devices. *)
@@ -346,8 +357,8 @@ let%track3_sexp parallel_update (type context)
346357
link ~from_prior_context:(needs_prior_context updaten.loss) sgd_update.context
347358
@@ compile Idx.Empty
348359
[%cd
349-
~~("merging" updaten.loss);
350-
updaten.loss.value =+ updaten.loss.value.merge])
360+
~~("merging" updaten.loss;
361+
updaten.loss.value =+ updaten.loss.value.merge)])
351362
in
352363
let into_merge_buffer = if copy_to_merge then BT.Copy else BT.Streaming in
353364
(* Since each device has its own queue, we can iterate over devices in the outer loop. *)

0 commit comments

Comments
 (0)