Skip to content

Commit 29a8d50

Browse files
committed
Update the moons_demo test and signal shortcoming in consume_forward_code
1 parent 6f9d38d commit 29a8d50

File tree

2 files changed

+22
-12
lines changed

2 files changed

+22
-12
lines changed

lib/tensor.ml

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -205,8 +205,8 @@ let raw_unop ~initialize_neutral ~accum ~(t : t) ~(lhs_is_grad : bool) ~op ~(t1
205205
type grad_spec = Require_grad | Prohibit_grad | If_needed [@@deriving sexp, equal, variants]
206206

207207
let op ~(label : string list) ?(ternary_op = Shape.Pointwise_tern)
208-
?(compose_op = Shape.Pointwise_bin) ?(transpose_op = Shape.Pointwise_un) ?terminal_op
209-
~op_asn ~grad_asn ?(grad_spec = If_needed) make_shape (orig_ts : t list) : t =
208+
?(compose_op = Shape.Pointwise_bin) ?(transpose_op = Shape.Pointwise_un) ?terminal_op ~op_asn
209+
~grad_asn ?(grad_spec = If_needed) make_shape (orig_ts : t list) : t =
210210
(* The code needs to be included in the order it was computed due to potential non-tree DAGs. *)
211211
let ordered_ts = List.dedup_and_sort orig_ts ~compare:(fun t1 t2 -> Int.ascending t1.id t2.id) in
212212
let id = session_state.next_id in
@@ -250,8 +250,7 @@ let op ~(label : string list) ?(ternary_op = Shape.Pointwise_tern)
250250
| Some (Shape.Data (Asgns.Padded { data; padding = padding_spec; padded_value })) ->
251251
let padding = Some (padding_spec, padded_value) in
252252
Tn.create_from_padded ~id ~label ~ndarray:data ~padding ()
253-
| Some (Shape.Fetch _) | None ->
254-
Tn.create ~default_prec ~id ~label ~dims ~padding ()
253+
| Some (Shape.Fetch _) | None -> Tn.create ~default_prec ~id ~label ~dims ~padding ()
255254
in
256255
let embedded_nodes = ref @@ Set.singleton (module Tn) v in
257256
let children =
@@ -358,7 +357,7 @@ let unop ~label ?transpose_op ~op_asn ~grad_asn ?grad_spec t1 =
358357
let term ~label ~grad_spec ?batch_dims ?input_dims ?output_dims ?batch_axes ?input_axes ?output_axes
359358
?deduced ?init_data ?fetch_op () =
360359
let terminal_op =
361-
match init_data, fetch_op with
360+
match (init_data, fetch_op) with
362361
| Some _, Some _ -> invalid_arg "Tensor.term: both init_data and fetch_op are provided"
363362
| Some init_data, None -> Some (Shape.Data init_data)
364363
| None, Some fetch_op -> Some (Shape.Fetch fetch_op)
@@ -369,16 +368,18 @@ let term ~label ~grad_spec ?batch_dims ?input_dims ?output_dims ?batch_axes ?inp
369368
let dims = lazy (Lazy.force projections).Idx.lhs_dims in
370369
match fetch_op with
371370
| None -> Asgns.empty_comp
372-
| Some (( Constant _ | Slice _ | Embed_symbol _ | Range_over_offsets | Constant_fill _ ) as fetch_op) ->
371+
| Some
372+
((Constant _ | Slice _ | Embed_symbol _ | Range_over_offsets | Constant_fill _) as fetch_op)
373+
->
373374
Asgns.to_comp @@ Fetch { array = v; fetch_op; dims }
374375
in
375376
let grad_asn ~t:_ ~g:_ ~projections:_ = Asgns.empty_comp in
376377
let make_shape =
377378
Shape.make ?batch_dims ?input_dims ?output_dims ?batch_axes ?input_axes ?output_axes ?deduced ()
378379
in
379380
(* Note: terminal_op is used for both tensor creation and shape inference. *)
380-
op ~label ?compose_op:None ?transpose_op:None ?terminal_op ~op_asn ~grad_asn ~grad_spec
381-
make_shape []
381+
op ~label ?compose_op:None ?transpose_op:None ?terminal_op ~op_asn ~grad_asn ~grad_spec make_shape
382+
[]
382383

383384
let float_to_label v = Float.to_string v
384385

@@ -467,6 +468,8 @@ let consume_forward_code t =
467468
@@ Session_error
468469
( "Tensor.consume_forward_code: tensor is not a root for tnode: " ^ Tn.debug_name t.value,
469470
Some t );
471+
(* FIXME(#321): this is too aggressive, instead we should check if the code contains any
472+
non-embedded nodes that are embedded nodes of the other roots. *)
470473
let unsafe_roots =
471474
Map.data session_state.forward_roots
472475
|> List.filter ~f:(fun r -> not (List.is_empty r.children || r.id = t.id))

test/training/moons_demo.ml

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,15 +42,21 @@ let main () =
4242
computation. *)
4343
let weight_decay = 0.0001 in
4444
let%op scalar_loss = (margin_loss ++ "...|... => 0") /. !..batch_size in
45+
let init_params = Tensor.init_params scalar_loss in
4546
let update = Train.grad_update scalar_loss in
47+
(* TODO(#321): Define learning_rate above the call to grad_update to test the consume_forward_code
48+
fix *)
4649
let%op learning_rate = 0.1 *. ((2 *. !..steps) - !@step_n) /. !..steps in
50+
(* TODO: is set_hosted needed? *)
4751
Train.set_hosted learning_rate.value;
4852
let sgd = Train.sgd_update ~learning_rate ~weight_decay scalar_loss in
53+
let init_routine = Train.to_routine (module Backend) ctx bindings init_params in
4954
let sgd_routine =
50-
Train.to_routine (module Backend) ctx bindings (Asgns.sequence [ update; sgd ])
55+
Train.to_routine (module Backend) init_routine.context bindings (Asgns.sequence [ update; sgd ])
5156
in
5257
let step_ref = IDX.find_exn sgd_routine.bindings step_n in
5358
step_ref := 0;
59+
Train.run init_routine;
5460
for _epoch = 1 to epochs do
5561
Train.sequential_loop sgd_routine.bindings ~f:(fun () ->
5662
Train.run sgd_routine;
@@ -65,7 +71,8 @@ let main () =
6571
let points = Tn.points_2d ~xdim:0 ~ydim:1 moons_flat.value in
6672
let classes = Tn.points_1d ~xdim:0 moons_classes.value in
6773
let points1, points2 = Array.partitioni_tf points ~f:Float.(fun i _ -> classes.(i) > 0.) in
68-
let%op mlp_result = mlp "point" in
74+
(* %cd instead of %op to not get complaints about uninitialized point tensor node. *)
75+
let%cd mlp_result = mlp "point" in
6976
Train.set_on_host mlp_result.value;
7077
let result_routine =
7178
Train.to_routine
@@ -114,7 +121,7 @@ let main () =
114121
Stdio.printf "mlp_result's name: %s\n%!" @@ Tensor.debug_name mlp_result;
115122
(* Note: mlp_result is not included in the resulting tensor's label, because the identifier label
116123
does not propagate across function calls. *)
117-
(Stdio.printf "(mlp moons_input) name: %s\n%!"
124+
Stdio.printf "(mlp moons_input) name: %s\n%!"
118125
@@ Tensor.debug_name
119126
@@
120127
match margin_loss.children with
@@ -126,6 +133,6 @@ let main () =
126133
};
127134
] ->
128135
subtensor
129-
| _ -> assert false)
136+
| _ -> assert false
130137

131138
let () = main ()

0 commit comments

Comments
 (0)