Skip to content

Commit 12e1d15

Browse files
committed
Operation corresponding to the primitive fma
1 parent 57efb5c commit 12e1d15

File tree

2 files changed

+65
-1
lines changed

2 files changed

+65
-1
lines changed

lib/operation.ml

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,16 @@ let sat01 ?(label = []) =
192192
let%cd grad_asn ~v:_ ~g ~t1 ~projections = g1 =+ sat01_gate (v1, g) in
193193
Tensor.unop ~label:("sat01" :: label) ~transpose_op:Pointwise_un ~op_asn ~grad_asn
194194

195+
let fma ?(label = []) ~grad_spec t1 t2 t3 =
196+
let module NTDSL = Initial_NTDSL in
197+
let%cd op_asn ~v ~t1 ~t2 ~t3 ~projections = v =: fma v1 v2 v3 in
198+
let%cd grad_asn ~v:_ ~g ~t1 ~t2 ~t3 ~projections =
199+
g1 =+ g * v2;
200+
g2 =+ g * v1;
201+
g3 =+ g
202+
in
203+
Tensor.ternop ~label:("fma" :: label) ~ternary_op:Pointwise_tern ~op_asn ~grad_asn ~grad_spec t1 t2 t3
204+
195205
let range ?(label = []) ?(grad_spec = Tensor.Prohibit_grad) ?axis_label upto =
196206
let result =
197207
Tensor.term
@@ -268,6 +278,7 @@ module DO = struct
268278
let ( **. ) ?label base exp = pointpow ?label exp base ~grad_spec:If_needed
269279
let relu = relu ~grad_spec:If_needed
270280
let sat01 = sat01 ~grad_spec:If_needed
281+
let fma = fma ~grad_spec:If_needed
271282
let ( !. ) = Tensor.number ~grad_spec:If_needed
272283
let ( !.. ) ?label i = Tensor.number ?label ~grad_spec:If_needed @@ Float.of_int i
273284
let ( !@ ) = embed_symbol
@@ -282,7 +293,8 @@ module NDO = struct
282293

283294
let ( /. ) = pointdiv ~grad_spec:Prohibit_grad
284295
let ( @| ) ?label t1 idx = slice ?label ~grad_spec:Prohibit_grad idx t1
285-
let sat01 = sat01 ~grad_spec:If_needed
296+
let sat01 = sat01 ~grad_spec:Prohibit_grad
297+
let fma = fma ~grad_spec:Prohibit_grad
286298
end
287299

288300
module TDSL = struct

test/primitive_ops.ml

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,3 +153,55 @@ let%expect_test "sat01" =
153153
│ │ x │
154154
└──┴────────────────────────────────────────────────────────────────────────────────────────────────────┘
155155
|}]
156+
157+
let%expect_test "fma(x, 2, 1)" =
158+
let%op f x = fma x !.2. !.1. in
159+
let plot_box = plot_unop ~f in
160+
PrintBox_text.output Stdio.stdout plot_box;
161+
[%expect {|
162+
┌───┬────────────────────────────────────────────────────────────────────────────────────────────────────┐
163+
11│ #│
164+
│ │ # # │
165+
│ │ ### │
166+
│ │ # # │
167+
│ │ ### │
168+
│ │ # # │
169+
│ │ ### │
170+
│ │ # # │
171+
│ │ ### │
172+
│ │ # │
173+
│ │ ### │
174+
│ │ # │
175+
│ │ ### │
176+
│ │ # │
177+
│ │ ### │
178+
│ │ # # │
179+
│ │ ## │
180+
│ │ # # │
181+
│f │* * ** * ***** *** **** ***** **** **** * **** **** * ** * ** * **** **** **** *** **** **** **** **
182+
│( │ # # │
183+
│x │ ## │
184+
│) │ # ## │
185+
│ │- - - - - - - - - ## - - - - - - - - - -
186+
│ │ # ## │
187+
│ │ # # │
188+
│ │ ### │
189+
│ │ # # │
190+
│ │ ### │
191+
│ │ ## # │
192+
│ │ ## │
193+
│ │ # ## │
194+
│ │ ## │
195+
│ │ # ## │
196+
│ │ ## │
197+
│ │ ## # │
198+
│ │ ## │
199+
│ │ # ## │
200+
│ │ # # │
201+
│ │ # ## │
202+
-9│# # │
203+
├───┼────────────────────────────────────────────────────────────────────────────────────────────────────┤
204+
│ │-5 4.9
205+
│ │ x │
206+
└───┴────────────────────────────────────────────────────────────────────────────────────────────────────┘
207+
|}]

0 commit comments

Comments
 (0)