@@ -192,6 +192,16 @@ let sat01 ?(label = []) =
192192 let % cd grad_asn ~v: _ ~g ~t1 ~projections = g1 =+ sat01_gate (v1, g) in
193193 Tensor. unop ~label: (" sat01" :: label) ~transpose_op: Pointwise_un ~op_asn ~grad_asn
194194
195+ let fma ?(label = [] ) ~grad_spec t1 t2 t3 =
196+ let module NTDSL = Initial_NTDSL in
197+ let % cd op_asn ~v ~t1 ~t2 ~t3 ~projections = v =: fma v1 v2 v3 in
198+ let % cd grad_asn ~v: _ ~g ~t1 ~t2 ~t3 ~projections =
199+ g1 =+ g * v2;
200+ g2 =+ g * v1;
201+ g3 =+ g
202+ in
203+ Tensor. ternop ~label: (" fma" :: label) ~ternary_op: Pointwise_tern ~op_asn ~grad_asn ~grad_spec t1 t2 t3
204+
195205let range ?(label = [] ) ?(grad_spec = Tensor. Prohibit_grad ) ?axis_label upto =
196206 let result =
197207 Tensor. term
@@ -268,6 +278,7 @@ module DO = struct
268278 let ( **. ) ?label base exp = pointpow ?label exp base ~grad_spec: If_needed
269279 let relu = relu ~grad_spec: If_needed
270280 let sat01 = sat01 ~grad_spec: If_needed
281+ let fma = fma ~grad_spec: If_needed
271282 let ( ! . ) = Tensor. number ~grad_spec: If_needed
272283 let ( !.. ) ?label i = Tensor. number ?label ~grad_spec: If_needed @@ Float. of_int i
273284 let ( ! @ ) = embed_symbol
@@ -282,7 +293,8 @@ module NDO = struct
282293
283294 let ( /. ) = pointdiv ~grad_spec: Prohibit_grad
284295 let ( @| ) ?label t1 idx = slice ?label ~grad_spec: Prohibit_grad idx t1
285- let sat01 = sat01 ~grad_spec: If_needed
296+ let sat01 = sat01 ~grad_spec: Prohibit_grad
297+ let fma = fma ~grad_spec: Prohibit_grad
286298end
287299
288300module TDSL = struct
0 commit comments