@@ -80,8 +80,6 @@ let matmul ?(label = []) =
8080 let % cd op_asn ~v ~t1 ~t2 ~projections = v =:+ v1 * v2 in
8181 mul Compose ~op_asn ~label: (" *" :: label)
8282
83- let capture_dims_to_refs = List. map ~f: (fun var_ref -> { Shape. var_ref; var = `Not_set_yet })
84-
8583(* * Similar to the explicit mode of [numpy.einsum], the binary variant. Can compute various forms of
8684 matrix multiplication, inner and outer products, etc.
8785
@@ -94,9 +92,7 @@ let einsum ?(label = []) ?(capture_dims = []) spec =
9492 g1 =+ g * v2;
9593 g2 =+ v1 * g
9694 in
97- Tensor. binop ~label: (" ;=>" :: label)
98- ~compose_op: (Einsum (spec, capture_dims_to_refs capture_dims))
99- ~op_asn ~grad_asn
95+ Tensor. binop ~label: (" ;=>" :: label) ~compose_op: (Einsum (spec, capture_dims)) ~op_asn ~grad_asn
10096
10197(* * Like [einsum], but adds instead than multiplying the resulting values. *)
10298let outer_sum ?(label = [] ) ?(capture_dims = [] ) spec =
@@ -106,9 +102,7 @@ let outer_sum ?(label = []) ?(capture_dims = []) spec =
106102 g1 =+ g;
107103 g2 =+ g
108104 in
109- Tensor. binop ~label: (" ;=>+" :: label)
110- ~compose_op: (Einsum (spec, capture_dims_to_refs capture_dims))
111- ~op_asn ~grad_asn
105+ Tensor. binop ~label: (" ;=>+" :: label) ~compose_op: (Einsum (spec, capture_dims)) ~op_asn ~grad_asn
112106
113107(* * Similar to the explicit mode of [numpy.einsum], the unary variant. Can permute axes, extract
114108 diagonals, compute traces etc.
@@ -120,7 +114,7 @@ let einsum1 ?(label = []) ?(capture_dims = []) spec =
120114 let % cd op_asn ~v ~t1 ~projections = v =:+ v1 in
121115 let % cd grad_asn ~t: _ ~g ~t1 ~projections = g1 =+ g in
122116 Tensor. unop
123- ~transpose_op: (Shape. Permute (spec, capture_dims_to_refs capture_dims))
117+ ~transpose_op: (Shape. Permute (spec, capture_dims))
124118 ~op_asn ~grad_asn ~label: (" =>" :: label)
125119
126120module NDO_before_pow = struct
@@ -471,8 +465,8 @@ let embed_self_id ?grad_spec ?(label = []) () =
471465 ~input_dims: [] ~output_dims: [ 1 ] ()
472466
473467let embed_dim ?grad_spec ?(label = [] ) variable_ref =
474- Tensor. term ~fetch_op: (Embed_dim variable_ref) ?grad_spec ~label: ( " !@self_id " :: label)
475- ~batch_dims: [] ~input_dims: [] ~output_dims: [ 1 ] ()
468+ Tensor. term ~fetch_op: (Embed_dim variable_ref. Shape. var_ref ) ?grad_spec
469+ ~label: ( " !@self_id " :: label) ~ batch_dims:[] ~input_dims: [] ~output_dims: [ 1 ] ()
476470
477471let uniform ?grad_spec () =
478472 uint4x32_to_prec_uniform ?grad_spec
674668module DSL_modules = struct
675669 module Shape = Shape
676670 module Tensor = Tensor
671+
677672 module TDSL = Make_DSL (struct
678673 let grad_spec = Tensor. If_needed
679674 end )
0 commit comments