Skip to content

Commit 5077531

Browse files
committed
Use a fixed bias 0.5 in the half-moons examples
(1) current randomness doesn't work with sizes not divisible by 2 / 4 / 8 / 16 (double / single / half / fp8); (2) learnable bias is redundant as the other weights can adapt (3) 0 bias does not train well with ReLU activations
1 parent 1218577 commit 5077531

File tree

6 files changed

+10
-12
lines changed

6 files changed

+10
-12
lines changed

bin/micrograd_demo.ml

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ let experiment seed ~no_batch_shape_inference ~use_builtin_weight_decay () =
2525
let moons_classes = TDSL.rebatch ~l:"moons_classes" moons_classes_ndarray () in
2626
let batch_n, bindings = IDX.get_static_symbol ~static_range:n_batches IDX.empty in
2727
let step_n, bindings = IDX.get_static_symbol bindings in
28-
let%op mlp x = "b3" + ("w3" * relu ("b2" hid_dim + ("w2" * relu ("b1" hid_dim + ("w1" * x))))) in
28+
let%op mlp x = "w3" * relu ("b2" hid_dim + ("w2" * relu ("b1" hid_dim + ("w1" * x)))) in
2929
let%op moons_input = moons_flat @| batch_n in
3030
(* Tell shape inference to make a minibatch axis. *)
3131
let () =
@@ -51,9 +51,7 @@ let experiment seed ~no_batch_shape_inference ~use_builtin_weight_decay () =
5151
(scalar_loss, 0.0002)
5252
else
5353
let%op ssq w = (w **. 2) ++ "...|...->... => 0" in
54-
let reg_loss =
55-
List.map ~f:ssq [ w1; w2; w3; b1; b2; b3 ] |> List.reduce_exn ~f:TDSL.O.( + )
56-
in
54+
let reg_loss = List.map ~f:ssq [ w1; w2; w3; b1; b2 ] |> List.reduce_exn ~f:TDSL.O.( + ) in
5755
let%op scalar_loss =
5856
((margin_loss ++ "...|... => 0") /. !..batch_size) + (0.0001 *. reg_loss)
5957
in

bin/moons_demo.ml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ let demo () =
2323
let steps = epochs * n_batches in
2424
let weight_decay = 0.0002 in
2525

26-
let%op mlp x = "b3" + ("w3" * relu ("b2" hid_dim + ("w2" * relu ("b1" hid_dim + ("w1" * x))))) in
26+
let%op mlp x = "w3" * relu ("b2" hid_dim + ("w2" * relu ("b1" hid_dim + ("w1" * x)))) in
2727

2828
let config = Datasets.Half_moons.Config.{ noise_range = 0.1; seed = Some seed } in
2929
let moons_coordinates, moons_labels = Datasets.Half_moons.generate_single_prec ~config ~len () in

bin/moons_demo_parallel.ml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ let experiment ~seed ~backend_name ~config () =
2323
let moons_classes_ndarray = Ir.Ndarray.as_array Ir.Ops.Double moons_labels in
2424
let moons_flat = TDSL.rebatch ~l:"moons_flat" moons_flat_ndarray () in
2525
let moons_classes = TDSL.rebatch ~l:"moons_classes" moons_classes_ndarray () in
26-
let%op mlp x = "b3" + ("w3" * relu ("b2" hid_dim + ("w2" * relu ("b1" hid_dim + ("w1" * x))))) in
27-
(* let%op mlp x = "b" + ("w" * x) in *)
26+
let%op mlp x = "w3" * relu ("b2" hid_dim + ("w2" * relu ("b1" hid_dim + ("w1" * x)))) in
27+
(* let%op mlp x = ("w" * x) in *)
2828
let%op loss_fn ~output ~expectation = relu (!..1 - (expectation *. output)) in
2929
(* We don't need a regression loss formula thanks to weight_decay built into the sgd_update
3030
computation. *)

test/einsum/moons_demo_variant.ml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ let () =
2727
let step_n, bindings = IDX.get_static_symbol bindings in
2828
let moons_flat = TDSL.rebatch ~l:"moons_flat" moons_flat_ndarray () in
2929
let moons_classes = TDSL.rebatch ~l:"moons_classes" moons_classes_ndarray () in
30-
let%op mlp x = "b3" + ("w3" * relu ("b2" 16 + ("w2" * relu ("b1" 16 + ("w1" * x))))) in
30+
let%op mlp x = 0.5 + ("w3" * relu ("b2" 16 + ("w2" * relu ("b1" 16 + ("w1" * x))))) in
3131
(* Don't decay the learning rate too quickly, it behaves better than in the original. *)
3232
let%op moons_input = moons_flat @| batch_n in
3333
(* THIS IS THE SPECIFIC SHAPE INFERENCE ASPECT OF THE TEST. *)

test/training/moons_demo.ml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ let main () =
1616
let len = 200 in
1717
let batch_size = 10 in
1818
let n_batches = 2 * len / batch_size in
19-
let epochs = 10 in
19+
let epochs = 50 in
2020
let steps = epochs * 2 * len / batch_size in
2121
let moons_config = Datasets.Half_moons.Config.{ noise_range = 0.1; seed = Some 5 } in
2222
let moons_coordinates, moons_labels = Datasets.Half_moons.generate ~config:moons_config ~len () in
@@ -26,7 +26,7 @@ let main () =
2626
let step_n, bindings = IDX.get_static_symbol bindings in
2727
let moons_flat = TDSL.rebatch ~l:"moons_flat" moons_flat_ndarray () in
2828
let moons_classes = TDSL.rebatch ~l:"moons_classes" moons_classes_ndarray () in
29-
let%op mlp x = "b3" + ("w3" * relu ("b2" 16 + ("w2" * relu ("b1" 16 + ("w1" * x))))) in
29+
let%op mlp x = 0.5 + ("w3" * relu ("b2" 16 + ("w2" * relu ("b1" 16 + ("w1" * x))))) in
3030
(* Don't decay the learning rate too quickly, it behaves better than in the original. *)
3131
let%op moons_input = moons_flat @| batch_n in
3232
let%op moons_class = moons_classes @| batch_n in

test/training/moons_demo_parallel.ml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,15 @@ let main () =
1818
let len = batch_size * 20 in
1919
let init_lr = 0.1 in
2020
(* let epochs = 10 in *)
21-
let epochs = 20 in
21+
let epochs = 100 in
2222
(* let epochs = 1 in *)
2323
let moons_config = Datasets.Half_moons.Config.{ noise_range = 0.1; seed = Some seed } in
2424
let moons_coordinates, moons_labels = Datasets.Half_moons.generate ~config:moons_config ~len () in
2525
let moons_flat_ndarray = Ir.Ndarray.as_array Ir.Ops.Double moons_coordinates in
2626
let moons_classes_ndarray = Ir.Ndarray.as_array Ir.Ops.Double moons_labels in
2727
let moons_flat = TDSL.rebatch ~l:"moons_flat" moons_flat_ndarray () in
2828
let moons_classes = TDSL.rebatch ~l:"moons_classes" moons_classes_ndarray () in
29-
let%op mlp x = "b3" + ("w3" * relu ("b2" hid_dim + ("w2" * relu ("b1" hid_dim + ("w1" * x))))) in
29+
let%op mlp x = 0.5 + ("w3" * relu ("b2" hid_dim + ("w2" * relu ("b1" hid_dim + ("w1" * x))))) in
3030
(* let%op mlp x = "b" + ("w" * x) in *)
3131
let%op loss_fn ~output ~expectation = relu (!..1 - (expectation *. output)) in
3232
(* We don't need a regression loss formula thanks to weight_decay built into the sgd_update

0 commit comments

Comments
 (0)