Skip to content

Commit 884d5ca

Browse files
committed
Fix automatic memory mode for tensors changed on host only
1 parent 2eb8cc1 commit 884d5ca

File tree

4 files changed

+72
-122
lines changed

4 files changed

+72
-122
lines changed

arrayjit/lib/low_level.ml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -356,7 +356,7 @@ let visit_llc traced_store ~merge_node_id reverse_node_map ~max_visits llc =
356356
specified as virtual by another routine. However, if the memory mode is unspecified, we
357357
assume this will be the first computation involving the tensor node. *)
358358
traced.read_only <- true;
359-
if Tn.mode_is_unspecified tn then Tn.update_memory_mode tn (Hosted Constant) 37
359+
if Tn.mode_is_unspecified tn then Tn.update_memory_mode tn (Hosted Unset_hosted) 37
360360
else if Tn.known_not_materialized tn then (
361361
if Tn.known_non_virtual tn then
362362
raise

arrayjit/lib/tnode.ml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -768,6 +768,7 @@ let get_value tn =
768768
raise @@ Utils.User_error "Tnode.get_value: index out of bounds"
769769

770770
let set_values tn values =
771+
update_memory_mode tn (Hosted Nonconstant) 51;
771772
do_write tn;
772773
let padding = Option.map ~f:fst (Lazy.force tn.padding) in
773774
Nd.(set_flat_values ?padding (Option.value_exn ~here:[%here] @@ Lazy.force tn.array) values)
Lines changed: 65 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -1,117 +1,78 @@
11
Retrieving commandline, environment, or config file variable ocannl_log_level
22
Found 0, in the config file
33
Tnode: collecting accessible arrays...
4-
n0 moons_flat as moons_flat: Host-const/37; single prec 40x10x2; mem in bytes: 3_200
5-
n1 moons_classes as moons_classes: Host-const/37; single prec 40x10x1; mem in bytes: 1_600
4+
n0 moons_flat as moons_flat: Host&shared/38039; single prec 40x10x2; mem in bytes: 3_200
5+
n1 moons_classes as moons_classes: Host&shared/38039; single prec 40x10x1; mem in bytes: 1_600
66
n2 range_over_offsets as range_over_offsets: Virt/15; single prec 4; mem in bytes: <not-in-yet>
77
n3 !@self_id as n3: Virt/40; single prec 1; mem in bytes: <not-in-yet>
88
n4 threefry4x32 as threefry4x32: Virt/15; single prec 4; mem in bytes: <not-in-yet>
99
n5 b1 as b1: Host&stream/412410; single prec 16; mem in bytes: <not-in-yet>
1010
n6 grad_b1 as b1.grad: Local/26046; single prec 16; mem in bytes: <not-in-yet>
11-
n7 range_over_offsets as range_over_offsets: Virt/15; single prec 4; mem in bytes: <not-in-yet>
11+
n7 range_over_offsets as range_over_offsets: Virt/15; single prec 8; mem in bytes: <not-in-yet>
1212
n8 !@self_id as n8: Virt/40; single prec 1; mem in bytes: <not-in-yet>
13-
n9 threefry4x32 as threefry4x32: Virt/15; single prec 4; mem in bytes: <not-in-yet>
14-
n10 b2 as b2: Host&stream/412410; single prec 16; mem in bytes: <not-in-yet>
15-
n11 grad_b2 as b2.grad: Local/26046; single prec 16; mem in bytes: <not-in-yet>
16-
n12 range_over_offsets as range_over_offsets: Virt/15; single prec 8; mem in bytes: <not-in-yet>
13+
n9 threefry4x32 as threefry4x32: Virt/15; single prec 8; mem in bytes: <not-in-yet>
14+
n10 w1 as w1: Host&stream/412410; single prec 16x2; mem in bytes: <not-in-yet>
15+
n11 grad_w1 as w1.grad: Local/26046; single prec 16x2; mem in bytes: <not-in-yet>
16+
n12 range_over_offsets as range_over_offsets: Virt/15; single prec 4; mem in bytes: <not-in-yet>
1717
n13 !@self_id as n13: Virt/40; single prec 1; mem in bytes: <not-in-yet>
18-
n14 threefry4x32 as threefry4x32: Virt/15; single prec 8; mem in bytes: <not-in-yet>
19-
n15 w1 as w1: Host&stream/412410; single prec 16x2; mem in bytes: <not-in-yet>
20-
n16 grad_w1 as w1.grad: Local/26046; single prec 16x2; mem in bytes: <not-in-yet>
21-
n17 range_over_offsets as range_over_offsets: Virt/15; single prec 64; mem in bytes: <not-in-yet>
22-
n18 !@self_id as n18: Virt/40; single prec 1; mem in bytes: <not-in-yet>
23-
n19 threefry4x32 as threefry4x32: Virt/15; single prec 64; mem in bytes: <not-in-yet>
24-
n20 w2 as w2: Host&stream/412410; single prec 16x16; mem in bytes: <not-in-yet>
25-
n21 grad_w2 as w2.grad: Local/26046; single prec 16x16; mem in bytes: <not-in-yet>
26-
n22 range_over_offsets as range_over_offsets: Virt/15; single prec 4; mem in bytes: <not-in-yet>
27-
n23 !@self_id as n23: Virt/40; single prec 1; mem in bytes: <not-in-yet>
28-
n24 threefry4x32 as threefry4x32: Virt/15; single prec 4; mem in bytes: <not-in-yet>
29-
n25 w3 as w3: Host&stream/412410; single prec 1x16; mem in bytes: <not-in-yet>
30-
n26 grad_w3 as w3.grad: Local/26046; single prec 1x16; mem in bytes: <not-in-yet>
31-
n27 @|_moons_input as moons_input: Local/1046; single prec 10x2; mem in bytes: <not-in-yet>
32-
n30 @|_moons_class as moons_class: Local/1046; single prec 10x1; mem in bytes: <not-in-yet>
33-
n33 * as n33: Local/1046; single prec 10x16; mem in bytes: <not-in-yet>
34-
n34 grad_* as n33.grad: Local/1046; single prec 10x16; mem in bytes: <not-in-yet>
35-
n35 + as n35: Local/1046; single prec 10x16; mem in bytes: <not-in-yet>
36-
n36 grad_+ as n35.grad: Local/1046; single prec 10x16; mem in bytes: <not-in-yet>
37-
n37 relu as relu: Local/1046; single prec 10x16; mem in bytes: <not-in-yet>
38-
n38 grad_relu as relu.grad: Local/1046; single prec 10x16; mem in bytes: <not-in-yet>
39-
n39 * as n39: Local/1046; single prec 10x16; mem in bytes: <not-in-yet>
40-
n40 grad_* as n39.grad: Local/1046; single prec 10x16; mem in bytes: <not-in-yet>
41-
n41 + as n41: Local/1046; single prec 10x16; mem in bytes: <not-in-yet>
42-
n42 grad_+ as n41.grad: Local/1046; single prec 10x16; mem in bytes: <not-in-yet>
43-
n43 relu as relu: Local/1046; single prec 10x16; mem in bytes: <not-in-yet>
44-
n44 grad_relu as relu.grad: Local/1046; single prec 10x16; mem in bytes: <not-in-yet>
45-
n45 * as n45: Local/1046; single prec 10x1; mem in bytes: <not-in-yet>
46-
n46 grad_* as n45.grad: Local/1046; single prec 10x1; mem in bytes: <not-in-yet>
47-
n47 0.5 as n47: Virt/40; single prec 1; mem in bytes: <not-in-yet>
48-
n48 +_mlp_@|_moons_input as mlp_moons_input: Virt/15; single prec 10x1; mem in bytes: <not-in-yet>
49-
n49 grad_+_mlp_@|_moons_input as mlp_moons_input.grad: Local/1046; single prec 10x1; mem in bytes: <not-in-yet>
50-
n50 *. as n50: Virt/15; single prec 10x1; mem in bytes: <not-in-yet>
51-
n51 grad_*. as n50.grad: Local/1046; single prec 10x1; mem in bytes: <not-in-yet>
52-
n52 1 as 1: Virt/40; single prec 1; mem in bytes: <not-in-yet>
53-
n53 - as n53: Local/1046; single prec 10x1; mem in bytes: <not-in-yet>
54-
n54 grad_- as n53.grad: Local/1046; single prec 10x1; mem in bytes: <not-in-yet>
55-
n55 relu_margin_loss as relu_margin_loss: Virt/15; single prec 10x1; mem in bytes: <not-in-yet>
56-
n56 grad_relu_margin_loss as relu_margin_loss.grad: Local/1046; single prec 10x1; mem in bytes: <not-in-yet>
57-
n57 10 as 10: Virt/40; single prec 1; mem in bytes: <not-in-yet>
58-
n58 => as n58: Local/1046; single prec 1; mem in bytes: <not-in-yet>
59-
n59 grad_=> as n58.grad: Virt/40; single prec 1; mem in bytes: <not-in-yet>
60-
n60 /._scalar_loss as scalar_loss: Host&stream/412410; single prec 1; mem in bytes: 4
61-
n61 grad_/._scalar_loss as scalar_loss.grad: Virt/40; single prec 1; mem in bytes: <not-in-yet>
62-
n62 2 as 2: Virt/40; single prec 1; mem in bytes: <not-in-yet>
63-
n63 **. as n63: Virt/40; single prec 1; mem in bytes: <not-in-yet>
64-
n64 -1 as n64: Virt/40; single prec 1; mem in bytes: <not-in-yet>
65-
n65 *. as n65: Virt/152; single prec 1; mem in bytes: <not-in-yet>
66-
n66 /. as n66: Host&stream/412410; single prec 1; mem in bytes: <not-in-yet>
67-
n67 1 as 1: Virt/40; single prec 1; mem in bytes: <not-in-yet>
68-
n68 80 as 80: Virt/40; single prec 1; mem in bytes: <not-in-yet>
69-
n69 !@ as n69: Virt/152; single prec 1; mem in bytes: <not-in-yet>
70-
n70 80 as 80: Virt/40; single prec 1; mem in bytes: <not-in-yet>
71-
n71 2 as 2: Virt/40; single prec 1; mem in bytes: <not-in-yet>
72-
n72 *. as n72: Virt/40; single prec 1; mem in bytes: <not-in-yet>
73-
n73 - as n73: Virt/152; single prec 1; mem in bytes: <not-in-yet>
74-
n74 0.1 as n74: Virt/40; single prec 1; mem in bytes: <not-in-yet>
75-
n75 *. as n75: Virt/152; single prec 1; mem in bytes: <not-in-yet>
76-
n76 /._learning_rate as learning_rate: Host&stream/412410; single prec 1; mem in bytes: <not-in-yet>
77-
n77 sgd_delta_b1 as sgd_delta_b1: Virt/15; single prec 16; mem in bytes: <not-in-yet>
78-
n78 sgd_momentum_b1 as sgd_momentum_b1: unknown; single prec <not-in-yet>; mem in bytes: <not-in-yet>
79-
n79 0.0001 as n79: Virt/40; single prec 1; mem in bytes: <not-in-yet>
80-
n80 *. as n80: Virt/15; single prec 16; mem in bytes: <not-in-yet>
81-
n81 sgd_delta_b2 as sgd_delta_b2: Virt/15; single prec 16; mem in bytes: <not-in-yet>
82-
n82 sgd_momentum_b2 as sgd_momentum_b2: unknown; single prec <not-in-yet>; mem in bytes: <not-in-yet>
83-
n83 0.0001 as n83: Virt/40; single prec 1; mem in bytes: <not-in-yet>
84-
n84 *. as n84: Virt/15; single prec 16; mem in bytes: <not-in-yet>
85-
n85 sgd_delta_w1 as sgd_delta_w1: Virt/15; single prec 16x2; mem in bytes: <not-in-yet>
86-
n86 sgd_momentum_w1 as sgd_momentum_w1: unknown; single prec <not-in-yet>; mem in bytes: <not-in-yet>
87-
n87 0.0001 as n87: Virt/40; single prec 1; mem in bytes: <not-in-yet>
88-
n88 *. as n88: Virt/15; single prec 16x2; mem in bytes: <not-in-yet>
89-
n89 sgd_delta_w2 as sgd_delta_w2: Virt/15; single prec 16x16; mem in bytes: <not-in-yet>
90-
n90 sgd_momentum_w2 as sgd_momentum_w2: unknown; single prec <not-in-yet>; mem in bytes: <not-in-yet>
91-
n91 0.0001 as n91: Virt/40; single prec 1; mem in bytes: <not-in-yet>
92-
n92 *. as n92: Virt/15; single prec 16x16; mem in bytes: <not-in-yet>
93-
n93 sgd_delta_w3 as sgd_delta_w3: Virt/15; single prec 1x16; mem in bytes: <not-in-yet>
94-
n94 sgd_momentum_w3 as sgd_momentum_w3: unknown; single prec <not-in-yet>; mem in bytes: <not-in-yet>
95-
n95 0.0001 as n95: Virt/40; single prec 1; mem in bytes: <not-in-yet>
96-
n96 *. as n96: Virt/15; single prec 1x16; mem in bytes: <not-in-yet>
97-
n97 point as point: Host-const/37; single prec 2; mem in bytes: <not-in-yet>
98-
n98 * as n98: Local/1046; single prec 16; mem in bytes: <not-in-yet>
99-
n99 grad_* as n98.grad: unknown; single prec 16; mem in bytes: <not-in-yet>
100-
n100 + as n100: Virt/15; single prec 16; mem in bytes: <not-in-yet>
101-
n101 grad_+ as n100.grad: unknown; single prec 16; mem in bytes: <not-in-yet>
102-
n102 relu as relu: Local/1046; single prec 16; mem in bytes: <not-in-yet>
103-
n103 grad_relu as relu.grad: unknown; single prec 16; mem in bytes: <not-in-yet>
104-
n104 * as n104: Local/1046; single prec 16; mem in bytes: <not-in-yet>
105-
n105 grad_* as n104.grad: unknown; single prec 16; mem in bytes: <not-in-yet>
106-
n106 + as n106: Virt/15; single prec 16; mem in bytes: <not-in-yet>
107-
n107 grad_+ as n106.grad: unknown; single prec 16; mem in bytes: <not-in-yet>
108-
n108 relu as relu: Virt/15; single prec 16; mem in bytes: <not-in-yet>
109-
n109 grad_relu as relu.grad: unknown; single prec 16; mem in bytes: <not-in-yet>
110-
n110 * as n110: Local/1046; single prec 1; mem in bytes: <not-in-yet>
111-
n111 grad_* as n110.grad: unknown; single prec 1; mem in bytes: <not-in-yet>
112-
n112 0.5 as n112: Virt/40; single prec 1; mem in bytes: <not-in-yet>
113-
n113 +_mlp_point as mlp_point: Host&stream/412410; single prec 1; mem in bytes: <not-in-yet>
114-
n114 grad_+_mlp_point as mlp_point.grad: unknown; single prec 1; mem in bytes: <not-in-yet>
18+
n14 threefry4x32 as threefry4x32: Virt/15; single prec 4; mem in bytes: <not-in-yet>
19+
n15 w2 as w2: Host&stream/412410; single prec 1x16; mem in bytes: <not-in-yet>
20+
n16 grad_w2 as w2.grad: Local/26046; single prec 1x16; mem in bytes: <not-in-yet>
21+
n17 @|_moons_input as moons_input: Local/1046; single prec 10x2; mem in bytes: <not-in-yet>
22+
n20 @|_moons_class as moons_class: Local/1046; single prec 10x1; mem in bytes: <not-in-yet>
23+
n23 * as n23: Local/1046; single prec 10x16; mem in bytes: <not-in-yet>
24+
n24 grad_* as n23.grad: Local/1046; single prec 10x16; mem in bytes: <not-in-yet>
25+
n25 + as n25: Local/1046; single prec 10x16; mem in bytes: <not-in-yet>
26+
n26 grad_+ as n25.grad: Local/1046; single prec 10x16; mem in bytes: <not-in-yet>
27+
n27 relu as relu: Local/1046; single prec 10x16; mem in bytes: <not-in-yet>
28+
n28 grad_relu as relu.grad: Local/1046; single prec 10x16; mem in bytes: <not-in-yet>
29+
n29 *_mlp_@|_moons_input as mlp_moons_input: Local/1046; single prec 10x1; mem in bytes: <not-in-yet>
30+
n30 grad_*_mlp_@|_moons_input as mlp_moons_input.grad: Local/1046; single prec 10x1; mem in bytes: <not-in-yet>
31+
n31 *. as n31: Virt/15; single prec 10x1; mem in bytes: <not-in-yet>
32+
n32 grad_*. as n31.grad: Local/1046; single prec 10x1; mem in bytes: <not-in-yet>
33+
n33 1 as 1: Virt/40; single prec 1; mem in bytes: <not-in-yet>
34+
n34 - as n34: Local/1046; single prec 10x1; mem in bytes: <not-in-yet>
35+
n35 grad_- as n34.grad: Local/1046; single prec 10x1; mem in bytes: <not-in-yet>
36+
n36 relu_margin_loss as relu_margin_loss: Virt/15; single prec 10x1; mem in bytes: <not-in-yet>
37+
n37 grad_relu_margin_loss as relu_margin_loss.grad: Local/1046; single prec 10x1; mem in bytes: <not-in-yet>
38+
n38 10 as 10: Virt/40; single prec 1; mem in bytes: <not-in-yet>
39+
n39 => as n39: Local/1046; single prec 1; mem in bytes: <not-in-yet>
40+
n40 grad_=> as n39.grad: Virt/40; single prec 1; mem in bytes: <not-in-yet>
41+
n42 grad_/._scalar_loss as scalar_loss.grad: Virt/40; single prec 1; mem in bytes: <not-in-yet>
42+
n43 2 as 2: Virt/40; single prec 1; mem in bytes: <not-in-yet>
43+
n44 **. as n44: Virt/40; single prec 1; mem in bytes: <not-in-yet>
44+
n45 -1 as n45: Virt/40; single prec 1; mem in bytes: <not-in-yet>
45+
n46 *. as n46: Virt/152; single prec 1; mem in bytes: <not-in-yet>
46+
n48 1 as 1: Virt/40; single prec 1; mem in bytes: <not-in-yet>
47+
n49 200 as 200: Virt/40; single prec 1; mem in bytes: <not-in-yet>
48+
n50 !@ as n50: Virt/152; single prec 1; mem in bytes: <not-in-yet>
49+
n51 200 as 200: Virt/40; single prec 1; mem in bytes: <not-in-yet>
50+
n52 2 as 2: Virt/40; single prec 1; mem in bytes: <not-in-yet>
51+
n53 *. as n53: Virt/40; single prec 1; mem in bytes: <not-in-yet>
52+
n54 - as n54: Virt/152; single prec 1; mem in bytes: <not-in-yet>
53+
n55 0.1 as n55: Virt/40; single prec 1; mem in bytes: <not-in-yet>
54+
n56 *. as n56: Virt/152; single prec 1; mem in bytes: <not-in-yet>
55+
n58 sgd_delta_b1 as sgd_delta_b1: Virt/15; single prec 16; mem in bytes: <not-in-yet>
56+
n59 sgd_momentum_b1 as sgd_momentum_b1: unknown; single prec <not-in-yet>; mem in bytes: <not-in-yet>
57+
n60 0.0001 as n60: Virt/40; single prec 1; mem in bytes: <not-in-yet>
58+
n61 *. as n61: Virt/15; single prec 16; mem in bytes: <not-in-yet>
59+
n62 sgd_delta_w1 as sgd_delta_w1: Virt/15; single prec 16x2; mem in bytes: <not-in-yet>
60+
n63 sgd_momentum_w1 as sgd_momentum_w1: unknown; single prec <not-in-yet>; mem in bytes: <not-in-yet>
61+
n64 0.0001 as n64: Virt/40; single prec 1; mem in bytes: <not-in-yet>
62+
n65 *. as n65: Virt/15; single prec 16x2; mem in bytes: <not-in-yet>
63+
n66 sgd_delta_w2 as sgd_delta_w2: Virt/15; single prec 1x16; mem in bytes: <not-in-yet>
64+
n67 sgd_momentum_w2 as sgd_momentum_w2: unknown; single prec <not-in-yet>; mem in bytes: <not-in-yet>
65+
n68 0.0001 as n68: Virt/40; single prec 1; mem in bytes: <not-in-yet>
66+
n69 *. as n69: Virt/15; single prec 1x16; mem in bytes: <not-in-yet>
67+
n70 point as point: Host&shared/38039; single prec 2; mem in bytes: <not-in-yet>
68+
n71 * as n71: Local/1046; single prec 16; mem in bytes: <not-in-yet>
69+
n72 grad_* as n71.grad: unknown; single prec 16; mem in bytes: <not-in-yet>
70+
n73 + as n73: Virt/15; single prec 16; mem in bytes: <not-in-yet>
71+
n74 grad_+ as n73.grad: unknown; single prec 16; mem in bytes: <not-in-yet>
72+
n75 relu as relu: Virt/15; single prec 16; mem in bytes: <not-in-yet>
73+
n76 grad_relu as relu.grad: unknown; single prec 16; mem in bytes: <not-in-yet>
74+
n77 *_mlp_point as mlp_point: Host&stream/412410; single prec 1; mem in bytes: <not-in-yet>
75+
n78 grad_*_mlp_point as mlp_point.grad: unknown; single prec 1; mem in bytes: <not-in-yet>
11576
Tnode: Finished printing headers.
11677
mlp_result's name: mlp_point
11778
(mlp moons_input) name: mlp_moons_input

test/einsum/moons_demo_variant.ml

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@ let () =
1717
let len = 200 in
1818
let batch_size = 10 in
1919
let n_batches = 2 * len / batch_size in
20-
let epochs = 2 in
21-
let steps = epochs * 2 * len / batch_size in
2220
let moons_config = Datasets.Half_moons.Config.{ noise_range = 0.1; seed = Some 5 } in
2321
let moons_coordinates, moons_labels =
2422
Datasets.Half_moons.generate_single_prec ~config:moons_config ~len ()
@@ -29,38 +27,28 @@ let () =
2927
let step_n, bindings = IDX.get_static_symbol bindings in
3028
let moons_flat = TDSL.rebatch ~l:"moons_flat" moons_flat_ndarray () in
3129
let moons_classes = TDSL.rebatch ~l:"moons_classes" moons_classes_ndarray () in
32-
let%op mlp x = 0.5 + ("w3" * relu ("b2" 16 + ("w2" * relu ("b1" 16 + ("w1" * x))))) in
30+
(* let%op mlp x = 0.5 + ("w3" * relu ("b2" 16 + ("w2" * relu ("b1" 16 + ("w1" * x))))) in *)
31+
let%op mlp x = "w2" * relu ("b1" 16 + ("w1" * x)) in
3332
(* Don't decay the learning rate too quickly, it behaves better than in the original. *)
3433
let%op moons_input = moons_flat @| batch_n in
3534
(* THIS IS THE SPECIFIC SHAPE INFERENCE ASPECT OF THE TEST. *)
3635
let%cd _ = moons_input =: 0 ++ "i=>2|i" in
3736
let%op moons_class = moons_classes @| batch_n in
3837
let%cd _ = moons_class =: 0 ++ "i=>2|i" in
39-
let losses = Array.create ~len:epochs 0. in
4038
let%op margin_loss = relu (1 - (moons_class *. mlp moons_input)) in
4139
(* We don't need a regression loss formula thanks to weight_decay built into the sgd_update
4240
computation. *)
4341
let weight_decay = 0.0001 in
4442
let%op scalar_loss = (margin_loss ++ "...|... => 0") /. !..batch_size in
4543
let update = Train.grad_update scalar_loss in
46-
let%op learning_rate = 0.1 *. ((2 *. !..steps) - !@step_n) /. !..steps in
44+
let%op learning_rate = 0.1 *. ((2 *. !..len) - !@step_n) /. !..len in
4745
Train.set_hosted learning_rate.value;
4846
let sgd = Train.sgd_update ~learning_rate ~weight_decay scalar_loss in
4947
let ctx = Train.init_params (module Backend) bindings scalar_loss in
5048
let sgd_routine =
5149
Train.to_routine (module Backend) ctx bindings (Asgns.sequence [ update; sgd ])
5250
in
53-
let step_ref = IDX.find_exn sgd_routine.bindings step_n in
54-
step_ref := 0;
55-
for epoch = 1 to epochs do
56-
Train.sequential_loop sgd_routine.bindings ~f:(fun () ->
57-
Train.run sgd_routine;
58-
(* let batch_ref = IDX.find_exn sgd_jitted.bindings batch_n in Stdio.printf "Epoch=%d,
59-
step=%d, batch=%d, lr=%f, loss=%f\n%!" epoch !step_ref !batch_ref learning_rate.@[0]
60-
scalar_loss.@[0]; *)
61-
losses.(epoch - 1) <- losses.(epoch - 1) +. scalar_loss.@[0];
62-
Int.incr step_ref)
63-
done;
51+
(* Skipping over the training loop, not needed for the test. *)
6452
let points = Tn.points_2d ~xdim:0 ~ydim:1 moons_flat.value in
6553
let classes = Tn.points_1d ~xdim:0 moons_classes.value in
6654
let points1, points2 = Array.partitioni_tf points ~f:Float.(fun i _ -> classes.(i) > 0.) in
@@ -79,7 +67,7 @@ let () =
7967
Float.(mlp_result.@[0] >= 0.)
8068
in
8169
let _plot_moons =
82-
PrintBox_utils.plot ~as_canvas:true
70+
PrintBox_utils.plot ~as_canvas:true ~size:(5, 5)
8371
[
8472
Scatterplot { points = points1; content = PrintBox.line "#" };
8573
Scatterplot { points = points2; content = PrintBox.line "%" };

0 commit comments

Comments
 (0)