File tree Expand file tree Collapse file tree 6 files changed +10
-12
lines changed Expand file tree Collapse file tree 6 files changed +10
-12
lines changed Original file line number Diff line number Diff line change @@ -25,7 +25,7 @@ let experiment seed ~no_batch_shape_inference ~use_builtin_weight_decay () =
2525 let moons_classes = TDSL. rebatch ~l: " moons_classes" moons_classes_ndarray () in
2626 let batch_n, bindings = IDX. get_static_symbol ~static_range: n_batches IDX. empty in
2727 let step_n, bindings = IDX. get_static_symbol bindings in
28- let % op mlp x = " b3 " + ( " w3" * relu (" b2" hid_dim + (" w2" * relu (" b1" hid_dim + (" w1" * x) )))) in
28+ let % op mlp x = " w3" * relu (" b2" hid_dim + (" w2" * relu (" b1" hid_dim + (" w1" * x)))) in
2929 let % op moons_input = moons_flat @| batch_n in
3030 (* Tell shape inference to make a minibatch axis. *)
3131 let () =
@@ -51,9 +51,7 @@ let experiment seed ~no_batch_shape_inference ~use_builtin_weight_decay () =
5151 (scalar_loss, 0.0002 )
5252 else
5353 let % op ssq w = (w **. 2 ) ++ " ...|...->... => 0" in
54- let reg_loss =
55- List. map ~f: ssq [ w1; w2; w3; b1; b2; b3 ] |> List. reduce_exn ~f: TDSL.O. ( + )
56- in
54+ let reg_loss = List. map ~f: ssq [ w1; w2; w3; b1; b2 ] |> List. reduce_exn ~f: TDSL.O. ( + ) in
5755 let % op scalar_loss =
5856 ((margin_loss ++ " ...|... => 0" ) /. ! ..batch_size) + (0.0001 *. reg_loss)
5957 in
Original file line number Diff line number Diff line change @@ -23,7 +23,7 @@ let demo () =
2323 let steps = epochs * n_batches in
2424 let weight_decay = 0.0002 in
2525
26- let % op mlp x = " b3 " + ( " w3" * relu (" b2" hid_dim + (" w2" * relu (" b1" hid_dim + (" w1" * x) )))) in
26+ let % op mlp x = " w3" * relu (" b2" hid_dim + (" w2" * relu (" b1" hid_dim + (" w1" * x)))) in
2727
2828 let config = Datasets.Half_moons.Config. { noise_range = 0.1 ; seed = Some seed } in
2929 let moons_coordinates, moons_labels = Datasets.Half_moons. generate_single_prec ~config ~len () in
Original file line number Diff line number Diff line change @@ -23,8 +23,8 @@ let experiment ~seed ~backend_name ~config () =
2323 let moons_classes_ndarray = Ir.Ndarray. as_array Ir.Ops. Double moons_labels in
2424 let moons_flat = TDSL. rebatch ~l: " moons_flat" moons_flat_ndarray () in
2525 let moons_classes = TDSL. rebatch ~l: " moons_classes" moons_classes_ndarray () in
26- let % op mlp x = " b3 " + ( " w3" * relu (" b2" hid_dim + (" w2" * relu (" b1" hid_dim + (" w1" * x) )))) in
27- (* let%op mlp x = "b" + ("w" * x) in *)
26+ let % op mlp x = " w3" * relu (" b2" hid_dim + (" w2" * relu (" b1" hid_dim + (" w1" * x)))) in
27+ (* let%op mlp x = ("w" * x) in *)
2828 let % op loss_fn ~output ~expectation = relu (! ..1 - (expectation *. output)) in
2929 (* We don't need a regression loss formula thanks to weight_decay built into the sgd_update
3030 computation. *)
Original file line number Diff line number Diff line change @@ -27,7 +27,7 @@ let () =
2727 let step_n, bindings = IDX. get_static_symbol bindings in
2828 let moons_flat = TDSL. rebatch ~l: " moons_flat" moons_flat_ndarray () in
2929 let moons_classes = TDSL. rebatch ~l: " moons_classes" moons_classes_ndarray () in
30- let % op mlp x = " b3 " + (" w3" * relu (" b2" 16 + (" w2" * relu (" b1" 16 + (" w1" * x))))) in
30+ let % op mlp x = 0.5 + (" w3" * relu (" b2" 16 + (" w2" * relu (" b1" 16 + (" w1" * x))))) in
3131 (* Don't decay the learning rate too quickly, it behaves better than in the original. *)
3232 let % op moons_input = moons_flat @| batch_n in
3333 (* THIS IS THE SPECIFIC SHAPE INFERENCE ASPECT OF THE TEST. *)
Original file line number Diff line number Diff line change @@ -16,7 +16,7 @@ let main () =
1616 let len = 200 in
1717 let batch_size = 10 in
1818 let n_batches = 2 * len / batch_size in
19- let epochs = 10 in
19+ let epochs = 50 in
2020 let steps = epochs * 2 * len / batch_size in
2121 let moons_config = Datasets.Half_moons.Config. { noise_range = 0.1 ; seed = Some 5 } in
2222 let moons_coordinates, moons_labels = Datasets.Half_moons. generate ~config: moons_config ~len () in
@@ -26,7 +26,7 @@ let main () =
2626 let step_n, bindings = IDX. get_static_symbol bindings in
2727 let moons_flat = TDSL. rebatch ~l: " moons_flat" moons_flat_ndarray () in
2828 let moons_classes = TDSL. rebatch ~l: " moons_classes" moons_classes_ndarray () in
29- let % op mlp x = " b3 " + (" w3" * relu (" b2" 16 + (" w2" * relu (" b1" 16 + (" w1" * x))))) in
29+ let % op mlp x = 0.5 + (" w3" * relu (" b2" 16 + (" w2" * relu (" b1" 16 + (" w1" * x))))) in
3030 (* Don't decay the learning rate too quickly, it behaves better than in the original. *)
3131 let % op moons_input = moons_flat @| batch_n in
3232 let % op moons_class = moons_classes @| batch_n in
Original file line number Diff line number Diff line change @@ -18,15 +18,15 @@ let main () =
1818 let len = batch_size * 20 in
1919 let init_lr = 0.1 in
2020 (* let epochs = 10 in *)
21- let epochs = 20 in
21+ let epochs = 100 in
2222 (* let epochs = 1 in *)
2323 let moons_config = Datasets.Half_moons.Config. { noise_range = 0.1 ; seed = Some seed } in
2424 let moons_coordinates, moons_labels = Datasets.Half_moons. generate ~config: moons_config ~len () in
2525 let moons_flat_ndarray = Ir.Ndarray. as_array Ir.Ops. Double moons_coordinates in
2626 let moons_classes_ndarray = Ir.Ndarray. as_array Ir.Ops. Double moons_labels in
2727 let moons_flat = TDSL. rebatch ~l: " moons_flat" moons_flat_ndarray () in
2828 let moons_classes = TDSL. rebatch ~l: " moons_classes" moons_classes_ndarray () in
29- let % op mlp x = " b3 " + (" w3" * relu (" b2" hid_dim + (" w2" * relu (" b1" hid_dim + (" w1" * x))))) in
29+ let % op mlp x = 0.5 + (" w3" * relu (" b2" hid_dim + (" w2" * relu (" b1" hid_dim + (" w1" * x))))) in
3030 (* let%op mlp x = "b" + ("w" * x) in *)
3131 let % op loss_fn ~output ~expectation = relu (! ..1 - (expectation *. output)) in
3232 (* We don't need a regression loss formula thanks to weight_decay built into the sgd_update
You can’t perform that action at this time.
0 commit comments