@@ -355,42 +355,43 @@ let%expect_test "Simple gradients virtual" =
355355 Tensor. print_tree ~with_grad: true ~depth: 9 l;
356356 [% expect
357357 {|
358- #123 *. _l < (Hosted Changed_on_devices ) 41 >
358+ #119 *. _l < (Hosted Changed_on_devices ) 41 >
359359 < not - in - yet>
360- #124 grad_*. _l < waiting>
360+ #120 grad_*. _l < waiting>
361361 < not - in - yet>
362- #119 + _d < waiting> │#121 f < (Hosted Nonconstant ) 24 >
362+ #115 + _d < waiting> │#117 f < (Hosted Nonconstant ) 24 >
363363 < not - in - yet> │< not - in - yet>
364- #120 grad_+ _d < waiting> │#122 grad_f < Materialized 28 >
364+ #116 grad_+ _d < waiting> │#118 grad_f < Materialized 28 >
365365 < not - in - yet> │< not - in - yet>
366- #115 *. _e < waiting> │#117 c < (Hosted Nonconstant ) 24 > │
366+ #111 *. _e < waiting> │#113 c < (Hosted Nonconstant ) 24 > │
367367 < not - in - yet> │< not - in - yet> │
368- #116 grad_*. _e < waiting> │#118 grad_c < Materialized 28 > │
368+ #112 grad_*. _e < waiting> │#114 grad_c < Materialized 28 > │
369369 < not - in - yet> │< not - in - yet> │
370- #111 a < (Hosted Nonconstant ) 24 > │#113 b < (Hosted Nonconstant ) 24 > │ │
370+ #107 a < (Hosted Nonconstant ) 24 > │#109 b < (Hosted Nonconstant ) 24 > │ │
371371 < not - in - yet> │< not - in - yet> │ │
372- #112 grad_a < Materialized 28 > │#114 grad_b < Materialized 28 > │ │
373- < not - in - yet> │< not - in - yet> │ │ | }];
372+ #108 grad_a < Materialized 28 > │#110 grad_b < Materialized 28 > │ │
373+ < not - in - yet> │< not - in - yet> │ │
374+ | }];
374375 let grad_routine = Backend. (link ctx @@ compile IDX. empty grad.fwd_bprop) in
375376 (* Check out the state without running a forward pass or compiling the SGD update. *)
376377 Tensor. print_tree ~with_grad: true ~depth: 9 l;
377378 [% expect
378379 {|
379- #123 *. _l < (Hosted Changed_on_devices ) 41 >
380+ #119 *. _l < (Hosted Changed_on_devices ) 41 >
380381 < not - in - yet>
381- #124 grad_*. _l < Virtual 40 >
382+ #120 grad_*. _l < Virtual 40 >
382383 < not - in - yet>
383- #119 + _d < Local 33 > │#121 f < (Hosted Nonconstant ) 24 >
384+ #115 + _d < Local 33 > │#117 f < (Hosted Nonconstant ) 24 >
384385 < not - in - yet> │< not - in - yet>
385- #120 grad_+ _d < Virtual 40 > │#122 grad_f < On_device 33 >
386+ #116 grad_+ _d < Virtual 40 > │#118 grad_f < On_device 33 >
386387 < not - in - yet> │< not - in - yet>
387- #115 *. _e < Virtual 152 > │#117 c < (Hosted Nonconstant ) 24 > │
388+ #111 *. _e < Virtual 152 > │#113 c < (Hosted Nonconstant ) 24 > │
388389 < not - in - yet> │< not - in - yet> │
389- #116 grad_*. _e < Virtual 40 > │#118 grad_c < On_device 33 > │
390+ #112 grad_*. _e < Virtual 40 > │#114 grad_c < On_device 33 > │
390391 < not - in - yet> │< not - in - yet> │
391- #111 a < (Hosted Nonconstant ) 24 > │#113 b < (Hosted Nonconstant ) 24 > │ │
392+ #107 a < (Hosted Nonconstant ) 24 > │#109 b < (Hosted Nonconstant ) 24 > │ │
392393 < not - in - yet> │< not - in - yet> │ │
393- #112 grad_a < On_device 33 > │#114 grad_b < On_device 33 > │ │
394+ #108 grad_a < On_device 33 > │#110 grad_b < On_device 33 > │ │
394395 < not - in - yet> │< not - in - yet> │ │
395396 | }];
396397 (* Do not update the params: all values and gradients will be at initial points, which are
@@ -399,21 +400,21 @@ let%expect_test "Simple gradients virtual" =
399400 Tensor. print_tree ~with_grad: true ~depth: 9 l;
400401 [% expect
401402 {|
402- #123 *. _l
403+ #119 *. _l
403404 - 8.00e+0
404- #124 grad_*. _l < Virtual 40 >
405+ #120 grad_*. _l < Virtual 40 >
405406 < not - in - yet>
406- #119 + _d < Local 33 > │#121 f
407+ #115 + _d < Local 33 > │#117 f
407408 < not - in - yet> │ - 2.00e+0
408- #120 grad_+ _d < Virtual 40 > │#122 grad_f < On_device 33 >
409+ #116 grad_+ _d < Virtual 40 > │#118 grad_f < On_device 33 >
409410 < not - in - yet> │< void>
410- #115 *. _e < Virtual 152 > │#117 c │
411+ #111 *. _e < Virtual 152 > │#113 c │
411412 < not - in - yet> │ 1.00e+1 │
412- #116 grad_*. _e < Virtual 40 > │#118 grad_c < On_device 33 > │
413+ #112 grad_*. _e < Virtual 40 > │#114 grad_c < On_device 33 > │
413414 < not - in - yet> │< void> │
414- #111 a │#113 b │ │
415+ #107 a │#109 b │ │
415416 2.00e+0 │ - 3.00e+0 │ │
416- #112 grad_a < On_device 33 > │#114 grad_b < On_device 33 > │ │
417+ #108 grad_a < On_device 33 > │#110 grad_b < On_device 33 > │ │
417418 < void> │< void> │ │
418419 | }];
419420 (* Only now compile the SGD update. *)
@@ -425,21 +426,21 @@ let%expect_test "Simple gradients virtual" =
425426 Tensor. print_tree ~with_grad: true ~depth: 9 l;
426427 [% expect
427428 {|
428- #123 *. _l
429+ #119 *. _l
429430 - 8.00e+0
430- #124 grad_*. _l < Virtual 40 >
431+ #120 grad_*. _l < Virtual 40 >
431432 < not - in - yet>
432- #119 + _d < Local 33 > │#121 f
433+ #115 + _d < Local 33 > │#117 f
433434 < not - in - yet> │ - 2.40e+0
434- #120 grad_+ _d < Virtual 40 > │#122 grad_f < On_device 33 >
435+ #116 grad_+ _d < Virtual 40 > │#118 grad_f < On_device 33 >
435436 < not - in - yet> │< void>
436- #115 *. _e < Virtual 152 > │#117 c │
437+ #111 *. _e < Virtual 152 > │#113 c │
437438 < not - in - yet> │ 1.02e+1 │
438- #116 grad_*. _e < Virtual 40 > │#118 grad_c < On_device 33 > │
439+ #112 grad_*. _e < Virtual 40 > │#114 grad_c < On_device 33 > │
439440 < not - in - yet> │< void> │
440- #111 a │#113 b │ │
441+ #107 a │#109 b │ │
441442 1.40e+0 │ - 2.60e+0 │ │
442- #112 grad_a < On_device 33 > │#114 grad_b < On_device 33 > │ │
443+ #108 grad_a < On_device 33 > │#110 grad_b < On_device 33 > │ │
443444 < void> │< void> │ │
444445 | }];
445446 (* Now the params will remain as above, but both param gradients and the values and gradients of
@@ -448,21 +449,21 @@ let%expect_test "Simple gradients virtual" =
448449 Tensor. print_tree ~with_grad: true ~depth: 9 l;
449450 [% expect
450451 {|
451- #123 *. _l
452+ #119 *. _l
452453 - 1.57e+1
453- #124 grad_*. _l < Virtual 40 >
454+ #120 grad_*. _l < Virtual 40 >
454455 < not - in - yet>
455- #119 + _d < Local 33 > │#121 f
456+ #115 + _d < Local 33 > │#117 f
456457 < not - in - yet> │ - 2.40e+0
457- #120 grad_+ _d < Virtual 40 > │#122 grad_f < On_device 33 >
458+ #116 grad_+ _d < Virtual 40 > │#118 grad_f < On_device 33 >
458459 < not - in - yet> │< void>
459- #115 *. _e < Virtual 152 > │#117 c │
460+ #111 *. _e < Virtual 152 > │#113 c │
460461 < not - in - yet> │ 1.02e+1 │
461- #116 grad_*. _e < Virtual 40 > │#118 grad_c < On_device 33 > │
462+ #112 grad_*. _e < Virtual 40 > │#114 grad_c < On_device 33 > │
462463 < not - in - yet> │< void> │
463- #111 a │#113 b │ │
464+ #107 a │#109 b │ │
464465 1.40e+0 │ - 2.60e+0 │ │
465- #112 grad_a < On_device 33 > │#114 grad_b < On_device 33 > │ │
466+ #108 grad_a < On_device 33 > │#110 grad_b < On_device 33 > │ │
466467 < void> │< void> │ │
467468 | }]
468469
@@ -484,18 +485,19 @@ let%expect_test "2D neuron hosted" =
484485 Tensor. print_tree ~with_grad: true ~depth: 9 v;
485486 [% expect
486487 {|
487- #155 + _v
488+ #147 + _v
488489 7.00e-1
489- #156 grad_+ _v
490+ #148 grad_+ _v
490491 1.00e+0
491- #153 * │#147 b
492+ #145 * │#139 b
492493 - 6.00e+0 │ 6.70e+0
493- #154 grad_* │#148 grad_b
494+ #146 grad_* │#140 grad_b
494495 1.00e+0 │ 1.00e+0
495- #149 w │#151 x │
496+ #141 w │#143 x │
496497 - 3.00e+0 1.00e+0 │ 2.00e+0 0.00e+0 │
497- #150 grad_w │#152 grad_x │
498- 2.00e+0 0.00e+0 │ - 3.00e+0 1.00e+0 │ | }]
498+ #142 grad_w │#144 grad_x │
499+ 2.00e+0 0.00e+0 │ - 3.00e+0 1.00e+0 │
500+ | }]
499501
500502let % expect_test " 2D neuron virtual" =
501503 Rand. init 0 ;
@@ -510,16 +512,16 @@ let%expect_test "2D neuron virtual" =
510512 Tensor. print_tree ~with_grad: true ~depth: 9 v;
511513 [% expect
512514 {|
513- #166 + _v
515+ #158 + _v
514516 7.00e-1
515- #167 grad_+ _v < Virtual 40 >
517+ #159 grad_+ _v < Virtual 40 >
516518 < not - in - yet>
517- #164 * < Local 33 > │#158 b
519+ #156 * < Local 33 > │#150 b
518520 < not - in - yet> │ 6.70e+0
519- #165 grad_* < Virtual 40 > │#159 grad_b < Local 33 >
521+ #157 grad_* < Virtual 40 > │#151 grad_b < Local 33 >
520522 < not - in - yet> │< not - in - yet>
521- #160 w │#162 x │
523+ #152 w │#154 x │
522524 - 3.00e+0 1.00e+0 │ 2.00e+0 0.00e+0 │
523- #161 grad_w < Local 33 > │#163 grad_x < Local 33 > │
525+ #153 grad_w < Local 33 > │#155 grad_x < Local 33 > │
524526 < not - in - yet> │< not - in - yet> │
525527 | }]
0 commit comments