File tree Expand file tree Collapse file tree 4 files changed +96
-48
lines changed Expand file tree Collapse file tree 4 files changed +96
-48
lines changed Original file line number Diff line number Diff line change @@ -25,9 +25,9 @@ let%debug_sexp graph_t () : unit =
2525 let ctx = Backend. make_context stream in
2626 let open Operation.At in
2727 CDSL. virtualize_settings.enable_device_only < - false ;
28- let % op f x = sin x in
29- let size = 100 in
30- let xs = Array. init size ~f: Float. (fun i -> (of_int i / 10. ) - 5. ) in
28+ let % op f x = recip x in
29+ let size = 50 in
30+ let xs = Array. init size ~f: Float. (fun i -> (of_int i / 10. ) + 0.1 ) in
3131 let x_flat =
3232 Tensor. term ~grad_spec: Require_grad ~label: [ " x_flat" ]
3333 ~init_op: (Constant_fill { values = xs; strict = true })
Original file line number Diff line number Diff line change @@ -247,7 +247,7 @@ let sqrt ?(label = []) =
247247let recip ?(label = [] ) =
248248 let module NTDSL = NTDSL_before_div in
249249 let % cd op_asn ~v ~t1 ~projections = v =: recip v1 in
250- let % cd grad_asn ~t ~g ~t1 ~projections = g1 =+ g * (- 1 * (t **. 2 )) in
250+ let % cd grad_asn ~t ~g ~t1 ~projections = g1 =+ g * (- 1 *. (t **. 2 )) in
251251 Tensor. unop ~label: (" recip" :: label) ~transpose_op: Pointwise_un ~op_asn ~grad_asn
252252
253253let recip_sqrt ?(label = [] ) =
Original file line number Diff line number Diff line change @@ -1084,7 +1084,10 @@ let%debug5_sexp solve_row_ineq ~(stage : stage) ~(cur : t) ~(subr : t) (env : en
10841084 ([ Row_eq { r1 = cur; r2 = template }; Row_ineq { cur = template; subr } ], env)
10851085 | { bcast = Broadcastable ; _ }, _ when cur_dims_l + cur_beg_dims_l < subr_dims_l + subr_beg_dims_l
10861086 ->
1087- raise @@ Shape_error (" Too many axes in a subtensor" , [ Row_mismatch [ cur; subr ] ])
1087+ raise
1088+ @@ Shape_error
1089+ ( " Too many axes in a subtensor; maybe using * instead of *.?" ,
1090+ [ Row_mismatch [ cur; subr ] ] )
10881091 | { bcast; dims; id }, { bcast = Row_var { v = subr_v; _ }; _ }
10891092 when subr_dims_l < = cur_dims_l && subr_beg_dims_l < = cur_beg_dims_l -> (
10901093 let bcast =
Original file line number Diff line number Diff line change @@ -534,48 +534,48 @@ let%expect_test "sqrt(x)" =
534534 [% expect
535535 {|
536536 ┌─────┬────────────────────────────────────────────────────────────────────────────────────────────────────┐
537- │ inf │ │
538- │ │ │
539- │ │ │
540- │ │ │
541- │ │ │
542- │ │ │
543- │ │ │
544- │ │ │
545- │ │ │
546- │ │ │
547- │ │ │
548- │ │ │
549- │ │ │
550- │ │ │
551- │ │ │
552- │ │ │
553- │ │ │
554- │ │ │
555- │f │ │
556- │( │ │
557- │x │ │
558- │) │ │
559- │ │ │
560- │ │ │
561- │ │ │
562- │ │ │
563- │ │ │
564- │ │ │
565- │ │ │
566- │ │ │
567- │ │ │
568- │ │ │
569- │ │ │
570- │ │ │
571- │ │ │
572- │ │ │
573- │ │ │
574- │ │ │
575- │ │ │
576- │ 0.00 │*************************-*************** ********-******************************* **** *******-**** │
537+ │ 2.23 │ # │
538+ │ │ ##### │
539+ │ │ ##### │
540+ │ │ #### │
541+ │ │ ## ## │
542+ │ │ ##### │
543+ │ │ #### │
544+ │ │ ##### │
545+ │ │ #### │
546+ │ │ ## # │
547+ │ │ # ## │
548+ │ │ ### │
549+ │ │* #### │
550+ │ │ #### │
551+ │ │ # # │
552+ │ │ ### │
553+ │ │ ### │
554+ │ │ * ### │
555+ │f │ ## │
556+ │( │ ### │
557+ │x │ * ### │
558+ │) │ ## │
559+ │ │ * ## │
560+ │ │ # ## │
561+ │ │ * ## │
562+ │ │ * ## │
563+ │ │ ** # │
564+ │ │ * # │
565+ │ │ # ** │
566+ │ │ ## *** │
567+ │ │ # * * │
568+ │ │ # ***** │
569+ │ │ # ******** │
570+ │ │ # * ********** │
571+ │ │# * ************* **** │
572+ │ │ * ********************* ************ │
573+ │ │ ***** │
574+ │ │ │
575+ │ │ │
576+ │ 0.00 │- - - - - - - - - - - - - - - - - - - - │
577577 ├─────┼────────────────────────────────────────────────────────────────────────────────────────────────────┤
578- │ │0.00 5.00 │
578+ │ │1.00e-1 5.00 │
579579 │ │ x │
580580 └─────┴────────────────────────────────────────────────────────────────────────────────────────────────────┘
581581 | }]
@@ -584,8 +584,53 @@ let%expect_test "recip(x)" =
584584 let % op f x = recip x in
585585 let plot_box = plot_unop ~f ~x_min: 0.1 ~x_max: 5.0 () in
586586 PrintBox_text. output Stdio. stdout plot_box;
587- (* FIXME: There's is a bug here in the shape inference. *)
588- [% expect {|| }]
587+ [% expect {|
588+ ┌─────────┬────────────────────────────────────────────────────────────────────────────────────────────────────┐
589+ │ 1.00e+1 │# │
590+ │ │ │
591+ │ │ ## │
592+ │ │ ######### │
593+ │ │- - - ###-**************-********** *************-****-********************* ***************** │
594+ │ │ ******* │
595+ │ │ ** │
596+ │ │ * │
597+ │ │ * │
598+ │ │ │
599+ │ │ * │
600+ │ │ │
601+ │ │ │
602+ │ │ * │
603+ │ │ │
604+ │ │ │
605+ │ │ │
606+ │ │ │
607+ │f │ │
608+ │( │ │
609+ │x │ * │
610+ │) │ │
611+ │ │ │
612+ │ │ │
613+ │ │ │
614+ │ │ │
615+ │ │ │
616+ │ │ │
617+ │ │ │
618+ │ │ │
619+ │ │ │
620+ │ │ │
621+ │ │ │
622+ │ │ │
623+ │ │ │
624+ │ │ │
625+ │ │ │
626+ │ │ │
627+ │ │ │
628+ │ - 1.00e+2 │* │
629+ ├─────────┼────────────────────────────────────────────────────────────────────────────────────────────────────┤
630+ │ │1.00e-1 5.00 │
631+ │ │ x │
632+ └─────────┴────────────────────────────────────────────────────────────────────────────────────────────────────┘
633+ | }]
589634
590635let % expect_test " recip_sqrt(x)" =
591636 let % op f x = recip_sqrt x in
You can’t perform that action at this time.
0 commit comments