Skip to content

Commit 1511626

Browse files
committed
Differentiable conditional -> piecewise-defined functions
1 parent f6ea375 commit 1511626

File tree

2 files changed

+92
-2
lines changed

2 files changed

+92
-2
lines changed

lib/operation.ml

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,24 @@ let not ?(label = []) =
274274
let%cd grad_asn ~t:_ ~g:_ ~t1:_ ~projections:_ = Asgns.empty_comp in
275275
Tensor.unop ~label:("not" :: label) ~transpose_op:Pointwise_un ~op_asn ~grad_asn
276276

277+
let lt ?(label = []) =
278+
let module NTDSL = Initial_NTDSL in
279+
let%cd op_asn ~v ~t1 ~t2 ~projections = v =: (v1 < v2) in
280+
let%cd grad_asn ~t:_ ~g:_ ~t1:_ ~t2:_ ~projections:_ = Asgns.empty_comp in
281+
Tensor.binop ~label:("<" :: label) ~compose_op:Pointwise_bin ~op_asn ~grad_asn
282+
283+
let eq ?(label = []) =
284+
let module NTDSL = Initial_NTDSL in
285+
let%cd op_asn ~v ~t1 ~t2 ~projections = v =: (v1 = v2) in
286+
let%cd grad_asn ~t:_ ~g:_ ~t1:_ ~t2:_ ~projections:_ = Asgns.empty_comp in
287+
Tensor.binop ~label:("=" :: label) ~compose_op:Pointwise_bin ~op_asn ~grad_asn
288+
289+
let ne ?(label = []) =
290+
let module NTDSL = Initial_NTDSL in
291+
let%cd op_asn ~v ~t1 ~t2 ~projections = v =: (v1 <> v2) in
292+
let%cd grad_asn ~t:_ ~g:_ ~t1:_ ~t2:_ ~projections:_ = Asgns.empty_comp in
293+
Tensor.binop ~label:("<>" :: label) ~compose_op:Pointwise_bin ~op_asn ~grad_asn
294+
277295
let fma ?(label = []) ~grad_spec t1 t2 t3 =
278296
let module NTDSL = Initial_NTDSL in
279297
let%cd op_asn ~v ~t1 ~t2 ~t3 ~projections = v =: fma v1 v2 v3 in
@@ -285,6 +303,17 @@ let fma ?(label = []) ~grad_spec t1 t2 t3 =
285303
Tensor.ternop ~label:("fma" :: label) ~ternary_op:Pointwise_tern ~op_asn ~grad_asn ~grad_spec t1
286304
t2 t3
287305

306+
let where ?(label = []) ~grad_spec t1 t2 t3 =
307+
let module NTDSL = NTDSL_before_div in
308+
let%cd op_asn ~v ~t1 ~t2 ~t3 ~projections = v =: where v1 v2 v3 in
309+
(* TODO: introduce a special-case projection for constants *)
310+
let%cd grad_asn ~t:_ ~g ~t1 ~t2 ~t3 ~projections =
311+
g2 =+ where v1 g (t3 - t3);
312+
g3 =+ where v1 (t2 - t2) g
313+
in
314+
Tensor.ternop ~label:("where" :: label) ~ternary_op:Pointwise_tern ~op_asn ~grad_asn ~grad_spec t1
315+
t2 t3
316+
288317
let range ?(label = []) ?(grad_spec = Tensor.Prohibit_grad) ?axis_label upto =
289318
let result =
290319
Tensor.term
@@ -375,11 +404,15 @@ module DO = struct
375404
let sin = sin ~grad_spec:If_needed
376405
let cos = cos ~grad_spec:If_needed
377406
let neg = neg ~grad_spec:If_needed
378-
let not = not ~grad_spec:If_needed
407+
let not = not ~grad_spec:Prohibit_grad
379408
let sqrt = sqrt ~grad_spec:If_needed
380409
let recip = recip ~grad_spec:If_needed
381410
let recip_sqrt = recip_sqrt ~grad_spec:If_needed
382411
let tanh = tanh ~grad_spec:If_needed
412+
let where = where ~grad_spec:If_needed
413+
let (<) = lt ~grad_spec:Prohibit_grad
414+
let (=) = eq ~grad_spec:Prohibit_grad
415+
let (<>) = ne ~grad_spec:Prohibit_grad
383416
end
384417

385418
module NDO = struct
@@ -401,6 +434,10 @@ module NDO = struct
401434
let recip = recip ~grad_spec:Prohibit_grad
402435
let recip_sqrt = recip_sqrt ~grad_spec:Prohibit_grad
403436
let tanh = tanh ~grad_spec:Prohibit_grad
437+
let where = where ~grad_spec:Prohibit_grad
438+
let (<) = lt ~grad_spec:Prohibit_grad
439+
let (=) = eq ~grad_spec:Prohibit_grad
440+
let (<>) = ne ~grad_spec:Prohibit_grad
404441
end
405442

406443
module TDSL = struct

test/primitive_ops.ml

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -584,7 +584,8 @@ let%expect_test "recip(x)" =
584584
let%op f x = recip x in
585585
let plot_box = plot_unop ~f ~x_min:0.1 ~x_max:5.0 () in
586586
PrintBox_text.output Stdio.stdout plot_box;
587-
[%expect {|
587+
[%expect
588+
{|
588589
┌─────────┬────────────────────────────────────────────────────────────────────────────────────────────────────┐
589590
1.00e+1 │# │
590591
│ │ │
@@ -737,3 +738,55 @@ let%expect_test "tanh(x)" =
737738
│ │ x │
738739
└─────────┴────────────────────────────────────────────────────────────────────────────────────────────────────┘
739740
|}]
741+
742+
let%expect_test "where(x < 0, sin(x), cos(x))" =
743+
let%op f x = where (x < !.0.) (sin x) (cos x) in
744+
let plot_box = plot_unop ~f () in
745+
PrintBox_text.output Stdio.stdout plot_box;
746+
[%expect {|
747+
┌─────────┬────────────────────────────────────────────────────────────────────────────────────────────────────┐
748+
9.99e-1 │ # *
749+
│ │#### ## *** ## ** ***
750+
│ │ # ** ## *
751+
│ │ # * # *
752+
│ │ # *
753+
│ │ # * # *
754+
│ │ * # │
755+
│ │ # * # *
756+
│ │ # *
757+
│ │ * # │
758+
│ │ # * # *
759+
│ │ │
760+
│ │ # * # *
761+
│ │ # *
762+
│ │* * # #│
763+
│ │ # *
764+
│ │ * * # # │
765+
│ │ # *
766+
│f │ * * # # │
767+
│( │ # *
768+
│x │- * - - - - - - *- - - - - - # - - - - - -# │
769+
│) │ # * *
770+
│ │ * * # # │
771+
│ │ # # * *
772+
│ │ * * # # │
773+
│ │ # # * *
774+
│ │ * * # # │
775+
│ │ # # * *
776+
│ │ * * # # │
777+
│ │ # * # * # *
778+
│ │ * # # * * # │
779+
│ │ * * # # │
780+
│ │ # # * *
781+
│ │ * * # * # # │
782+
│ │ * * * # │
783+
│ │ * # # * * # │
784+
│ │ * * # # * * # # │
785+
│ │ * * # # * * # # │
786+
│ │ ** * # # * * # ## │
787+
-9.99e-1* ***** ####### ******* #### # │
788+
├─────────┼────────────────────────────────────────────────────────────────────────────────────────────────────┤
789+
│ │-5.00 5.00
790+
│ │ x │
791+
└─────────┴────────────────────────────────────────────────────────────────────────────────────────────────────┘
792+
|}]

0 commit comments

Comments
 (0)