@@ -274,6 +274,24 @@ let not ?(label = []) =
274274 let % cd grad_asn ~t: _ ~g: _ ~t1: _ ~projections: _ = Asgns. empty_comp in
275275 Tensor. unop ~label: (" not" :: label) ~transpose_op: Pointwise_un ~op_asn ~grad_asn
276276
277+ let lt ?(label = [] ) =
278+ let module NTDSL = Initial_NTDSL in
279+ let % cd op_asn ~v ~t1 ~t2 ~projections = v =: (v1 < v2) in
280+ let % cd grad_asn ~t: _ ~g: _ ~t1: _ ~t2: _ ~projections: _ = Asgns. empty_comp in
281+ Tensor. binop ~label: (" <" :: label) ~compose_op: Pointwise_bin ~op_asn ~grad_asn
282+
283+ let eq ?(label = [] ) =
284+ let module NTDSL = Initial_NTDSL in
285+ let % cd op_asn ~v ~t1 ~t2 ~projections = v =: (v1 = v2) in
286+ let % cd grad_asn ~t: _ ~g: _ ~t1: _ ~t2: _ ~projections: _ = Asgns. empty_comp in
287+ Tensor. binop ~label: (" =" :: label) ~compose_op: Pointwise_bin ~op_asn ~grad_asn
288+
289+ let ne ?(label = [] ) =
290+ let module NTDSL = Initial_NTDSL in
291+ let % cd op_asn ~v ~t1 ~t2 ~projections = v =: (v1 <> v2) in
292+ let % cd grad_asn ~t: _ ~g: _ ~t1: _ ~t2: _ ~projections: _ = Asgns. empty_comp in
293+ Tensor. binop ~label: (" <>" :: label) ~compose_op: Pointwise_bin ~op_asn ~grad_asn
294+
277295let fma ?(label = [] ) ~grad_spec t1 t2 t3 =
278296 let module NTDSL = Initial_NTDSL in
279297 let % cd op_asn ~v ~t1 ~t2 ~t3 ~projections = v =: fma v1 v2 v3 in
@@ -285,6 +303,17 @@ let fma ?(label = []) ~grad_spec t1 t2 t3 =
285303 Tensor. ternop ~label: (" fma" :: label) ~ternary_op: Pointwise_tern ~op_asn ~grad_asn ~grad_spec t1
286304 t2 t3
287305
306+ let where ?(label = [] ) ~grad_spec t1 t2 t3 =
307+ let module NTDSL = NTDSL_before_div in
308+ let % cd op_asn ~v ~t1 ~t2 ~t3 ~projections = v =: where v1 v2 v3 in
309+ (* TODO: introduce a special-case projection for constants *)
310+ let % cd grad_asn ~t: _ ~g ~t1 ~t2 ~t3 ~projections =
311+ g2 =+ where v1 g (t3 - t3);
312+ g3 =+ where v1 (t2 - t2) g
313+ in
314+ Tensor. ternop ~label: (" where" :: label) ~ternary_op: Pointwise_tern ~op_asn ~grad_asn ~grad_spec t1
315+ t2 t3
316+
288317let range ?(label = [] ) ?(grad_spec = Tensor. Prohibit_grad ) ?axis_label upto =
289318 let result =
290319 Tensor. term
@@ -375,11 +404,15 @@ module DO = struct
375404 let sin = sin ~grad_spec: If_needed
376405 let cos = cos ~grad_spec: If_needed
377406 let neg = neg ~grad_spec: If_needed
378- let not = not ~grad_spec: If_needed
407+ let not = not ~grad_spec: Prohibit_grad
379408 let sqrt = sqrt ~grad_spec: If_needed
380409 let recip = recip ~grad_spec: If_needed
381410 let recip_sqrt = recip_sqrt ~grad_spec: If_needed
382411 let tanh = tanh ~grad_spec: If_needed
412+ let where = where ~grad_spec: If_needed
413+ let (< ) = lt ~grad_spec: Prohibit_grad
414+ let (= ) = eq ~grad_spec: Prohibit_grad
415+ let (<> ) = ne ~grad_spec: Prohibit_grad
383416end
384417
385418module NDO = struct
@@ -401,6 +434,10 @@ module NDO = struct
401434 let recip = recip ~grad_spec: Prohibit_grad
402435 let recip_sqrt = recip_sqrt ~grad_spec: Prohibit_grad
403436 let tanh = tanh ~grad_spec: Prohibit_grad
437+ let where = where ~grad_spec: Prohibit_grad
438+ let (< ) = lt ~grad_spec: Prohibit_grad
439+ let (= ) = eq ~grad_spec: Prohibit_grad
440+ let (<> ) = ne ~grad_spec: Prohibit_grad
404441end
405442
406443module TDSL = struct
0 commit comments