Skip to content

Commit 8b6a6fa

Browse files
committed
Fix bug in grad formula for recip, update tests
1 parent d63fdf0 commit 8b6a6fa

File tree

4 files changed

+96
-48
lines changed

4 files changed

+96
-48
lines changed

bin/primitive_ops.ml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@ let%debug_sexp graph_t () : unit =
2525
let ctx = Backend.make_context stream in
2626
let open Operation.At in
2727
CDSL.virtualize_settings.enable_device_only <- false;
28-
let%op f x = sin x in
29-
let size = 100 in
30-
let xs = Array.init size ~f:Float.(fun i -> (of_int i / 10.) - 5.) in
28+
let%op f x = recip x in
29+
let size = 50 in
30+
let xs = Array.init size ~f:Float.(fun i -> (of_int i / 10.) + 0.1) in
3131
let x_flat =
3232
Tensor.term ~grad_spec:Require_grad ~label:[ "x_flat" ]
3333
~init_op:(Constant_fill { values = xs; strict = true })

lib/operation.ml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ let sqrt ?(label = []) =
247247
let recip ?(label = []) =
248248
let module NTDSL = NTDSL_before_div in
249249
let%cd op_asn ~v ~t1 ~projections = v =: recip v1 in
250-
let%cd grad_asn ~t ~g ~t1 ~projections = g1 =+ g * (-1 * (t **. 2)) in
250+
let%cd grad_asn ~t ~g ~t1 ~projections = g1 =+ g * (-1 *. (t **. 2)) in
251251
Tensor.unop ~label:("recip" :: label) ~transpose_op:Pointwise_un ~op_asn ~grad_asn
252252

253253
let recip_sqrt ?(label = []) =

lib/row.ml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1084,7 +1084,10 @@ let%debug5_sexp solve_row_ineq ~(stage : stage) ~(cur : t) ~(subr : t) (env : en
10841084
([ Row_eq { r1 = cur; r2 = template }; Row_ineq { cur = template; subr } ], env)
10851085
| { bcast = Broadcastable; _ }, _ when cur_dims_l + cur_beg_dims_l < subr_dims_l + subr_beg_dims_l
10861086
->
1087-
raise @@ Shape_error ("Too many axes in a subtensor", [ Row_mismatch [ cur; subr ] ])
1087+
raise
1088+
@@ Shape_error
1089+
( "Too many axes in a subtensor; maybe using * instead of *.?",
1090+
[ Row_mismatch [ cur; subr ] ] )
10881091
| { bcast; dims; id }, { bcast = Row_var { v = subr_v; _ }; _ }
10891092
when subr_dims_l <= cur_dims_l && subr_beg_dims_l <= cur_beg_dims_l -> (
10901093
let bcast =

test/primitive_ops.ml

Lines changed: 88 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -534,48 +534,48 @@ let%expect_test "sqrt(x)" =
534534
[%expect
535535
{|
536536
┌─────┬────────────────────────────────────────────────────────────────────────────────────────────────────┐
537-
inf
538-
│ │
539-
│ │
540-
│ │
541-
│ │
542-
│ │
543-
│ │
544-
│ │
545-
│ │
546-
│ │
547-
│ │
548-
│ │
549-
│ │
550-
│ │
551-
│ │
552-
│ │
553-
│ │
554-
│ │
555-
│f │
556-
│( │
557-
│x │
558-
│) │
559-
│ │
560-
│ │
561-
│ │
562-
│ │
563-
│ │
564-
│ │
565-
│ │
566-
│ │
567-
│ │
568-
│ │
569-
│ │
570-
│ │
571-
│ │
572-
│ │
573-
│ │
574-
│ │ │
575-
│ │ │
576-
0.00*************************-*************** ********-******************************* **** *******-****
537+
2.23#
538+
│ │ #####
539+
│ │ #####
540+
│ │ ####
541+
│ │ ## ##
542+
│ │ #####
543+
│ │ ####
544+
│ │ #####
545+
│ │ ####
546+
│ │ ## #
547+
│ │ # ##
548+
│ │ ###
549+
│ │* ####
550+
│ │ ####
551+
│ │ # #
552+
│ │ ###
553+
│ │ ###
554+
│ │ * ###
555+
│f │ ##
556+
│( │ ###
557+
│x │ * ###
558+
│) │ ##
559+
│ │ * ##
560+
│ │ # ##
561+
│ │ * ##
562+
│ │ * ##
563+
│ │ ** #
564+
│ │ *#
565+
│ │ # **
566+
│ │ ## ***
567+
│ │ # * *
568+
│ │ # *****
569+
│ │ # ********
570+
│ │ # * **********
571+
│ │# * ************* ****
572+
│ │ * ********************* ************
573+
│ │ *****
574+
│ │ │
575+
│ │ │
576+
0.00- - - - - - - - - - - - - - - - - - - -
577577
├─────┼────────────────────────────────────────────────────────────────────────────────────────────────────┤
578-
│ │0.00 5.00
578+
│ │1.00e-1 5.00
579579
│ │ x │
580580
└─────┴────────────────────────────────────────────────────────────────────────────────────────────────────┘
581581
|}]
@@ -584,8 +584,53 @@ let%expect_test "recip(x)" =
584584
let%op f x = recip x in
585585
let plot_box = plot_unop ~f ~x_min:0.1 ~x_max:5.0 () in
586586
PrintBox_text.output Stdio.stdout plot_box;
587-
(* FIXME: There's is a bug here in the shape inference. *)
588-
[%expect {||}]
587+
[%expect {|
588+
┌─────────┬────────────────────────────────────────────────────────────────────────────────────────────────────┐
589+
1.00e+1 │# │
590+
│ │ │
591+
│ │ ## │
592+
│ │ ######### │
593+
│ │- - - ###-**************-********** *************-****-********************* *****************
594+
│ │ *******
595+
│ │ **
596+
│ │ *
597+
│ │ *
598+
│ │ │
599+
│ │ *
600+
│ │ │
601+
│ │ │
602+
│ │ *
603+
│ │ │
604+
│ │ │
605+
│ │ │
606+
│ │ │
607+
│f │ │
608+
│( │ │
609+
│x │ *
610+
│) │ │
611+
│ │ │
612+
│ │ │
613+
│ │ │
614+
│ │ │
615+
│ │ │
616+
│ │ │
617+
│ │ │
618+
│ │ │
619+
│ │ │
620+
│ │ │
621+
│ │ │
622+
│ │ │
623+
│ │ │
624+
│ │ │
625+
│ │ │
626+
│ │ │
627+
│ │ │
628+
-1.00e+2*
629+
├─────────┼────────────────────────────────────────────────────────────────────────────────────────────────────┤
630+
│ │1.00e-1 5.00
631+
│ │ x │
632+
└─────────┴────────────────────────────────────────────────────────────────────────────────────────────────────┘
633+
|}]
589634

590635
let%expect_test "recip_sqrt(x)" =
591636
let%op f x = recip_sqrt x in

0 commit comments

Comments
 (0)