Skip to content

Commit 53fec9b

Browse files
committed
Made Train.sgd_one slightly more thrifty
1 parent 417733d commit 53fec9b

File tree

3 files changed

+58
-55
lines changed

3 files changed

+58
-55
lines changed

CHANGES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
- Removed the `pipes_cc, pipes_gccjit` backends (`Pipes_multicore_backend`) -- I had fixed `Pipes_multicore_backend` by using the `poll` library instead of `Unix.select`, but it turns out to be very very slow.
1616
- Changed the `%cd` block comment syntax `~~` to allow detailed structuring. Rewrote `Train.grad_update` to use the `%cd` syntax.
17+
- Made `Train.sgd_one` slightly more thrifty: `p =- learning_rate *. sgd_delta` --> `p =- learning_rate * sgd_delta ~logic:"."` without the inline tensor expression.
1718

1819
### Fixed
1920

lib/train.ml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ let sgd_one ~learning_rate ?(momentum = 0.0) ?(weight_decay = 0.0) ?(nesterov =
184184
if Float.(momentum > 0.0) then (
185185
"sgd_momentum" =: (!.momentum *. sgd_momentum) + sgd_delta;
186186
if nesterov then sgd_delta =+ !.momentum *. sgd_momentum else sgd_delta =: sgd_momentum);
187-
p =- learning_rate *. sgd_delta)]
187+
p =- learning_rate * sgd_delta ~logic:".")]
188188

189189
let sgd_update ~learning_rate ?momentum ?weight_decay ?nesterov l =
190190
let code =

test/zero2hero_1of7.ml

Lines changed: 56 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -355,42 +355,43 @@ let%expect_test "Simple gradients virtual" =
355355
Tensor.print_tree ~with_grad:true ~depth:9 l;
356356
[%expect
357357
{|
358-
#123 *._l <(Hosted Changed_on_devices) 41>
358+
#119 *._l <(Hosted Changed_on_devices) 41>
359359
<not-in-yet>
360-
#124 grad_*._l <waiting>
360+
#120 grad_*._l <waiting>
361361
<not-in-yet>
362-
#119 +_d <waiting> │#121 f <(Hosted Nonconstant) 24>
362+
#115 +_d <waiting> │#117 f <(Hosted Nonconstant) 24>
363363
<not-in-yet><not-in-yet>
364-
#120 grad_+_d <waiting> │#122 grad_f <Materialized 28>
364+
#116 grad_+_d <waiting> │#118 grad_f <Materialized 28>
365365
<not-in-yet><not-in-yet>
366-
#115 *._e <waiting> │#117 c <(Hosted Nonconstant) 24>
366+
#111 *._e <waiting> │#113 c <(Hosted Nonconstant) 24>
367367
<not-in-yet><not-in-yet>
368-
#116 grad_*._e <waiting> │#118 grad_c <Materialized 28>
368+
#112 grad_*._e <waiting> │#114 grad_c <Materialized 28>
369369
<not-in-yet><not-in-yet>
370-
#111 a <(Hosted Nonconstant) 24>│#113 b <(Hosted Nonconstant) 24>│ │
370+
#107 a <(Hosted Nonconstant) 24>│#109 b <(Hosted Nonconstant) 24>│ │
371371
<not-in-yet><not-in-yet> │ │
372-
#112 grad_a <Materialized 28> │#114 grad_b <Materialized 28> │ │
373-
<not-in-yet><not-in-yet> │ │ |}];
372+
#108 grad_a <Materialized 28> │#110 grad_b <Materialized 28> │ │
373+
<not-in-yet><not-in-yet> │ │
374+
|}];
374375
let grad_routine = Backend.(link ctx @@ compile IDX.empty grad.fwd_bprop) in
375376
(* Check out the state without running a forward pass or compiling the SGD update. *)
376377
Tensor.print_tree ~with_grad:true ~depth:9 l;
377378
[%expect
378379
{|
379-
#123 *._l <(Hosted Changed_on_devices) 41>
380+
#119 *._l <(Hosted Changed_on_devices) 41>
380381
<not-in-yet>
381-
#124 grad_*._l <Virtual 40>
382+
#120 grad_*._l <Virtual 40>
382383
<not-in-yet>
383-
#119 +_d <Local 33> │#121 f <(Hosted Nonconstant) 24>
384+
#115 +_d <Local 33> │#117 f <(Hosted Nonconstant) 24>
384385
<not-in-yet><not-in-yet>
385-
#120 grad_+_d <Virtual 40> │#122 grad_f <On_device 33>
386+
#116 grad_+_d <Virtual 40> │#118 grad_f <On_device 33>
386387
<not-in-yet><not-in-yet>
387-
#115 *._e <Virtual 152> │#117 c <(Hosted Nonconstant) 24>
388+
#111 *._e <Virtual 152> │#113 c <(Hosted Nonconstant) 24>
388389
<not-in-yet><not-in-yet>
389-
#116 grad_*._e <Virtual 40> │#118 grad_c <On_device 33>
390+
#112 grad_*._e <Virtual 40> │#114 grad_c <On_device 33>
390391
<not-in-yet><not-in-yet>
391-
#111 a <(Hosted Nonconstant) 24>│#113 b <(Hosted Nonconstant) 24>│ │
392+
#107 a <(Hosted Nonconstant) 24>│#109 b <(Hosted Nonconstant) 24>│ │
392393
<not-in-yet><not-in-yet> │ │
393-
#112 grad_a <On_device 33> │#114 grad_b <On_device 33> │ │
394+
#108 grad_a <On_device 33> │#110 grad_b <On_device 33> │ │
394395
<not-in-yet><not-in-yet> │ │
395396
|}];
396397
(* Do not update the params: all values and gradients will be at initial points, which are
@@ -399,21 +400,21 @@ let%expect_test "Simple gradients virtual" =
399400
Tensor.print_tree ~with_grad:true ~depth:9 l;
400401
[%expect
401402
{|
402-
#123 *._l
403+
#119 *._l
403404
-8.00e+0
404-
#124 grad_*._l <Virtual 40>
405+
#120 grad_*._l <Virtual 40>
405406
<not-in-yet>
406-
#119 +_d <Local 33> │#121 f
407+
#115 +_d <Local 33> │#117 f
407408
<not-in-yet>-2.00e+0
408-
#120 grad_+_d <Virtual 40> │#122 grad_f <On_device 33>
409+
#116 grad_+_d <Virtual 40> │#118 grad_f <On_device 33>
409410
<not-in-yet><void>
410-
#115 *._e <Virtual 152> │#117 c │
411+
#111 *._e <Virtual 152> │#113 c │
411412
<not-in-yet>1.00e+1
412-
#116 grad_*._e <Virtual 40> │#118 grad_c <On_device 33>
413+
#112 grad_*._e <Virtual 40> │#114 grad_c <On_device 33>
413414
<not-in-yet><void>
414-
#111 a │#113 b │ │
415+
#107 a │#109 b │ │
415416
2.00e+0-3.00e+0 │ │
416-
#112 grad_a <On_device 33>│#114 grad_b <On_device 33>│ │
417+
#108 grad_a <On_device 33>│#110 grad_b <On_device 33>│ │
417418
<void><void> │ │
418419
|}];
419420
(* Only now compile the SGD update. *)
@@ -425,21 +426,21 @@ let%expect_test "Simple gradients virtual" =
425426
Tensor.print_tree ~with_grad:true ~depth:9 l;
426427
[%expect
427428
{|
428-
#123 *._l
429+
#119 *._l
429430
-8.00e+0
430-
#124 grad_*._l <Virtual 40>
431+
#120 grad_*._l <Virtual 40>
431432
<not-in-yet>
432-
#119 +_d <Local 33> │#121 f
433+
#115 +_d <Local 33> │#117 f
433434
<not-in-yet>-2.40e+0
434-
#120 grad_+_d <Virtual 40> │#122 grad_f <On_device 33>
435+
#116 grad_+_d <Virtual 40> │#118 grad_f <On_device 33>
435436
<not-in-yet><void>
436-
#115 *._e <Virtual 152> │#117 c │
437+
#111 *._e <Virtual 152> │#113 c │
437438
<not-in-yet>1.02e+1
438-
#116 grad_*._e <Virtual 40> │#118 grad_c <On_device 33>
439+
#112 grad_*._e <Virtual 40> │#114 grad_c <On_device 33>
439440
<not-in-yet><void>
440-
#111 a │#113 b │ │
441+
#107 a │#109 b │ │
441442
1.40e+0-2.60e+0 │ │
442-
#112 grad_a <On_device 33>│#114 grad_b <On_device 33>│ │
443+
#108 grad_a <On_device 33>│#110 grad_b <On_device 33>│ │
443444
<void><void> │ │
444445
|}];
445446
(* Now the params will remain as above, but both param gradients and the values and gradients of
@@ -448,21 +449,21 @@ let%expect_test "Simple gradients virtual" =
448449
Tensor.print_tree ~with_grad:true ~depth:9 l;
449450
[%expect
450451
{|
451-
#123 *._l
452+
#119 *._l
452453
-1.57e+1
453-
#124 grad_*._l <Virtual 40>
454+
#120 grad_*._l <Virtual 40>
454455
<not-in-yet>
455-
#119 +_d <Local 33> │#121 f
456+
#115 +_d <Local 33> │#117 f
456457
<not-in-yet>-2.40e+0
457-
#120 grad_+_d <Virtual 40> │#122 grad_f <On_device 33>
458+
#116 grad_+_d <Virtual 40> │#118 grad_f <On_device 33>
458459
<not-in-yet><void>
459-
#115 *._e <Virtual 152> │#117 c │
460+
#111 *._e <Virtual 152> │#113 c │
460461
<not-in-yet>1.02e+1
461-
#116 grad_*._e <Virtual 40> │#118 grad_c <On_device 33>
462+
#112 grad_*._e <Virtual 40> │#114 grad_c <On_device 33>
462463
<not-in-yet><void>
463-
#111 a │#113 b │ │
464+
#107 a │#109 b │ │
464465
1.40e+0-2.60e+0 │ │
465-
#112 grad_a <On_device 33>│#114 grad_b <On_device 33>│ │
466+
#108 grad_a <On_device 33>│#110 grad_b <On_device 33>│ │
466467
<void><void> │ │
467468
|}]
468469

@@ -484,18 +485,19 @@ let%expect_test "2D neuron hosted" =
484485
Tensor.print_tree ~with_grad:true ~depth:9 v;
485486
[%expect
486487
{|
487-
#155 +_v
488+
#147 +_v
488489
7.00e-1
489-
#156 grad_+_v
490+
#148 grad_+_v
490491
1.00e+0
491-
#153 * │#147 b
492+
#145 * │#139 b
492493
-6.00e+06.70e+0
493-
#154 grad_* │#148 grad_b
494+
#146 grad_* │#140 grad_b
494495
1.00e+01.00e+0
495-
#149 w │#151 x │
496+
#141 w │#143 x │
496497
-3.00e+0 1.00e+02.00e+0 0.00e+0
497-
#150 grad_w │#152 grad_x │
498-
2.00e+0 0.00e+0-3.00e+0 1.00e+0|}]
498+
#142 grad_w │#144 grad_x │
499+
2.00e+0 0.00e+0-3.00e+0 1.00e+0
500+
|}]
499501

500502
let%expect_test "2D neuron virtual" =
501503
Rand.init 0;
@@ -510,16 +512,16 @@ let%expect_test "2D neuron virtual" =
510512
Tensor.print_tree ~with_grad:true ~depth:9 v;
511513
[%expect
512514
{|
513-
#166 +_v
515+
#158 +_v
514516
7.00e-1
515-
#167 grad_+_v <Virtual 40>
517+
#159 grad_+_v <Virtual 40>
516518
<not-in-yet>
517-
#164 * <Local 33> │#158 b
519+
#156 * <Local 33> │#150 b
518520
<not-in-yet>6.70e+0
519-
#165 grad_* <Virtual 40> │#159 grad_b <Local 33>
521+
#157 grad_* <Virtual 40> │#151 grad_b <Local 33>
520522
<not-in-yet><not-in-yet>
521-
#160 w │#162 x │
523+
#152 w │#154 x │
522524
-3.00e+0 1.00e+02.00e+0 0.00e+0
523-
#161 grad_w <Local 33>│#163 grad_x <Local 33>
525+
#153 grad_w <Local 33>│#155 grad_x <Local 33>
524526
<not-in-yet><not-in-yet>
525527
|}]

0 commit comments

Comments
 (0)