Skip to content

Commit 81a2d5e

Browse files
committed
New no-op shape inference spec Defined_by_cd_logic
1 parent 504f15a commit 81a2d5e

File tree

5 files changed

+93
-27
lines changed

5 files changed

+93
-27
lines changed

tensor/operation.ml

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -287,11 +287,12 @@ let interleave =
287287
let%cd op_asn ~t ~t1 ~t2 ~projections:_ =
288288
t =:+ id t1 ~logic:"... | ... -> ..., i => ... | ... -> ..., 2*i";
289289
t =+ id t2 ~logic:"... | ... -> ..., i => ... | ... -> ..., 2*i + 1"
290-
in
290+
in
291291
let%cd grad_asn ~t ~g:_ ~t1 ~t2 ~projections:_ =
292-
t1.grad =+ id t.grad ~logic:"... | ... -> ..., 2*i => ... | ... -> ..., i";
293-
t2.grad =+ id t.grad ~logic:"... | ... -> ..., 2*i + 1 => ... | ... -> ..., i" in
294-
Tensor.binop ~op_label:"<>" ~compose_op:Pointwise_bin ~op_asn ~grad_asn
292+
t1.grad =+ id t.grad ~logic:"... | ... -> ..., 2*i => ... | ... -> ..., i";
293+
t2.grad =+ id t.grad ~logic:"... | ... -> ..., 2*i + 1 => ... | ... -> ..., i"
294+
in
295+
Tensor.binop ~op_label:"interleave" ~compose_op:Defined_by_cd_logic ~op_asn ~grad_asn
295296

296297
let threefry4x32_crypto =
297298
let module NTDSL = Initial_NTDSL in

tensor/shape.ml

Lines changed: 38 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,11 @@ type delayed_var_ref = {
102102
let get_variable_ref ref_label =
103103
{ var_ref = Ir.Indexing.{ ref_label; solved_dim = None }; var = `Not_set_yet }
104104

105-
type compose_type = Pointwise_bin | Compose | Einsum of string * delayed_var_ref list
105+
type compose_type =
106+
| Pointwise_bin
107+
| Compose
108+
| Einsum of string * delayed_var_ref list
109+
| Defined_by_cd_logic
106110
[@@deriving sexp_of, equal]
107111

108112
type transpose_type =
@@ -111,12 +115,14 @@ type transpose_type =
111115
| Permute of string * delayed_var_ref list
112116
| Batch_slice of Idx.static_symbol
113117
| Uint4x32_to_prec of Ir.Ops.prec Lazy.t
118+
| Defined_by_cd_logic
114119
[@@deriving equal, sexp_of]
115120

116121
type terminal_type = Data of Ir.Assignments.init_data | Fetch of Ir.Assignments.fetch_op
117122
[@@deriving equal, sexp_of]
118123

119-
type ternary_type = Pointwise_tern | Compose_accumulate [@@deriving sexp, equal]
124+
type ternary_type = Pointwise_tern | Compose_accumulate | Defined_by_cd_logic
125+
[@@deriving sexp, equal]
120126

121127
let identifier ~multichar =
122128
let open Angstrom in
@@ -277,6 +283,10 @@ let logic_to_spec = function
277283
| Transpose (Transpose, _) -> "T"
278284
| Transpose (Batch_slice _, _) -> "@|"
279285
| Transpose (Uint4x32_to_prec _, _) -> "U4x32"
286+
| Broadcast (Defined_by_cd_logic, _, _)
287+
| Transpose (Defined_by_cd_logic, _)
288+
| Broadcast_tern (Defined_by_cd_logic, _, _, _) ->
289+
"<cd_logic>"
280290
| Terminal _ -> "<terminal>"
281291

282292
module Update_id = struct
@@ -371,8 +381,7 @@ let axes_spec_to_dims_bio ~sh_id ~row_var_env ~dim_var_env:_ ~f labels =
371381
let output = to_row `Output labels.bcast_output o_dims beg_o_dims in
372382
(batch, input, output)
373383

374-
let einsum_slot_spec_to_dims_bio ~original_spec ~sh_id ~row_var_env ~dim_var_env labels
375-
=
384+
let einsum_slot_spec_to_dims_bio ~original_spec ~sh_id ~row_var_env ~dim_var_env labels =
376385
let proj_env_update = ref @@ Row.dim_map_empty in
377386
let extras = ref [] in
378387
let f kind = function
@@ -840,6 +849,10 @@ let%debug4_sexp get_inequalities ({ shape = cur_sh; logic; id = _ } as _upd : up
840849
subr = sh3.output;
841850
};
842851
] )
852+
| Broadcast (Defined_by_cd_logic, _, _)
853+
| Transpose (Defined_by_cd_logic, _)
854+
| Broadcast_tern (Defined_by_cd_logic, _, _, _) ->
855+
(Row.dim_map_empty, [])
843856
| Transpose (Batch_slice { static_range; static_symbol }, sh) ->
844857
Hash_set.remove unused_shapes sh.id;
845858
let slice_v = get_var () in
@@ -915,12 +928,12 @@ let%debug4_sexp get_inequalities ({ shape = cur_sh; logic; id = _ } as _upd : up
915928
let dim_var_env = Hashtbl.create (module String) in
916929

917930
let extras_rhs, proj_env_rhs, (b_rhs, i_rhs, o_rhs) =
918-
einsum_slot_spec_to_dims_bio ~original_spec:spec ~sh_id:sh.id ~row_var_env
919-
~dim_var_env ls_rhs
931+
einsum_slot_spec_to_dims_bio ~original_spec:spec ~sh_id:sh.id ~row_var_env ~dim_var_env
932+
ls_rhs
920933
in
921934
let extras_lhs, proj_env_lhs, (b_lhs, i_lhs, o_lhs) =
922-
einsum_slot_spec_to_dims_bio ~original_spec:spec ~sh_id:cur_sh.id ~row_var_env
923-
~dim_var_env ls_lhs
935+
einsum_slot_spec_to_dims_bio ~original_spec:spec ~sh_id:cur_sh.id ~row_var_env ~dim_var_env
936+
ls_lhs
924937
in
925938
(* Bind delayed_var_refs to the variables after they are created *)
926939
let extras_dim_refs =
@@ -1134,16 +1147,16 @@ let%debug4_sexp get_inequalities ({ shape = cur_sh; logic; id = _ } as _upd : up
11341147
let row_var_env = Hashtbl.create (module String) in
11351148
let dim_var_env = Hashtbl.create (module String) in
11361149
let extras_rhs1, proj_env_rhs1, (b_rhs1, i_rhs1, o_rhs1) =
1137-
einsum_slot_spec_to_dims_bio ~original_spec:spec ~sh_id:sh1.id ~row_var_env
1138-
~dim_var_env ls_rhs1
1150+
einsum_slot_spec_to_dims_bio ~original_spec:spec ~sh_id:sh1.id ~row_var_env ~dim_var_env
1151+
ls_rhs1
11391152
in
11401153
let extras_rhs2, proj_env_rhs2, (b_rhs2, i_rhs2, o_rhs2) =
1141-
einsum_slot_spec_to_dims_bio ~original_spec:spec ~sh_id:sh2.id ~row_var_env
1142-
~dim_var_env ls_rhs2
1154+
einsum_slot_spec_to_dims_bio ~original_spec:spec ~sh_id:sh2.id ~row_var_env ~dim_var_env
1155+
ls_rhs2
11431156
in
11441157
let extras_lhs, proj_env_lhs, (b_lhs, i_lhs, o_lhs) =
1145-
einsum_slot_spec_to_dims_bio ~original_spec:spec ~sh_id:cur_sh.id ~row_var_env
1146-
~dim_var_env ls_lhs
1158+
einsum_slot_spec_to_dims_bio ~original_spec:spec ~sh_id:cur_sh.id ~row_var_env ~dim_var_env
1159+
ls_lhs
11471160
in
11481161
(* Bind delayed_var_refs to the variables after they are created *)
11491162
(* TODO: refactor to avoid duplication with the one for unary einsum *)
@@ -1735,6 +1748,10 @@ let fresh_proj_ids update =
17351748
in
17361749
fresh_shape update.shape;
17371750
(match update.logic with
1751+
| Broadcast (Defined_by_cd_logic, _, _)
1752+
| Transpose (Defined_by_cd_logic, _)
1753+
| Broadcast_tern (Defined_by_cd_logic, _, _, _) ->
1754+
()
17381755
| Terminal _ -> ()
17391756
| Transpose (_, sh) -> fresh_shape sh
17401757
| Broadcast (_, sh1, sh2) ->
@@ -1777,6 +1794,12 @@ let%debug4_sexp derive_projections (update_step : update_step) : Idx.projections
17771794
let lhs = update_step.shape in
17781795
let rhs =
17791796
match update_step.logic with
1797+
| Broadcast (Defined_by_cd_logic, _, _)
1798+
| Transpose (Defined_by_cd_logic, _)
1799+
| Broadcast_tern (Defined_by_cd_logic, _, _, _) ->
1800+
raise
1801+
@@ Utils.User_error
1802+
"Defined_by_cd_logic: use explicit ~logic annotations when defining this operation"
17801803
| Terminal _ -> []
17811804
| Transpose (_, sh) -> [ sh ]
17821805
| Broadcast (_, sh1, sh2) -> [ sh1; sh2 ]
@@ -1830,7 +1853,7 @@ let make ?batch_dims ?input_dims ?output_dims ?batch_axes ?input_axes ?output_ax
18301853
match kind with
18311854
| `Batch | `Input -> get_dim ~d ()
18321855
| `Output ->
1833-
if not known_no_batch && num_dim1_output = 1 && d = 1 then
1856+
if (not known_no_batch) && num_dim1_output = 1 && d = 1 then
18341857
let label = debug_name ^ "_output" in
18351858
get_dim ~d ~label ()
18361859
else get_dim ~d ()

tensor/shape.mli

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -113,20 +113,22 @@ type compose_type =
113113
multiply). *)
114114
| Einsum of string * delayed_var_ref list
115115
(** The binary "einsum" syntax: RHS1;RHS2=>LHS, where RHSi, LHS are labels specifications.
116-
OCANNL's extended einsum notation supports both axis variables and row variables.
117-
The [axis_labels] use pseudo-labels local to the notation, to line up the axes and row
116+
OCANNL's extended einsum notation supports both axis variables and row variables. The
117+
[axis_labels] use pseudo-labels local to the notation, to line up the axes and row
118118
variables. The symmetric difference / disjunctive union of RHS1 and RHS2's pseudo-labels
119119
should be equal to LHS pseudo-labels.
120120
121121
Unlike [Pointwise_bin] and [Compose], einsum operations use equations only (not
122-
inequalities), so they do NOT permit broadcasting. This makes einsum more restrictive
123-
but also more precise for operations where exact shape matching is required.
122+
inequalities), so they do NOT permit broadcasting. This makes einsum more restrictive but
123+
also more precise for operations where exact shape matching is required.
124124
125125
The optional {!Ir.Indexing.variable_ref}s will capture the solutions of the dimensions
126126
corresponding to the specification labels equal to [ref_label] of a reference.
127127
128128
Note: The "right-hand-side" is on the left! I.e. the syntax is "rhs=>lhs",
129129
"rhs1;rhs2=>lhs". *)
130+
| Defined_by_cd_logic
131+
(** A placeholder for operations where the shape logic is defined by the %cd extension. *)
130132
[@@deriving sexp_of, equal]
131133

132134
type transpose_type =
@@ -136,8 +138,8 @@ type transpose_type =
136138
(** The unary "einsum" syntax: RHS1=>LHS.
137139
138140
Unlike [Pointwise_un], permute operations use equations only (not inequalities), so they
139-
do NOT permit broadcasting. This makes permute more restrictive but also more precise
140-
for operations where exact shape matching is required.
141+
do NOT permit broadcasting. This makes permute more restrictive but also more precise for
142+
operations where exact shape matching is required.
141143
142144
The optional {!Ir.Indexing.variable_ref}s will capture the solutions of the dimensions
143145
corresponding to the specification labels equal to [ref_label] of a reference. *)
@@ -146,13 +148,17 @@ type transpose_type =
146148
(** Converts precision in a bit-effient way, with a corresponding conversion in total number
147149
of elements. Currently, assumes the incoming tensor (RHS) has just a single axis to not
148150
force unnecessary minimum sizes on output axes. *)
151+
| Defined_by_cd_logic
152+
(** A placeholder for operations where the shape logic is defined by the %cd extension. *)
149153
[@@deriving equal, sexp_of]
150154

151155
(** If you miss expressivity here, leave a note on
152156
{{:https://github.com/ahrefs/ocannl/issues/305}issue 305}. *)
153157
type ternary_type =
154158
| Pointwise_tern (** As in the operation [Where]. *)
155159
| Compose_accumulate (** As in the operation [FMA]. *)
160+
| Defined_by_cd_logic
161+
(** A placeholder for operations where the shape logic is defined by the %cd extension. *)
156162
[@@deriving equal, sexp_of]
157163

158164
(** Extracts any available shape information from the initialization or fetch. *)
@@ -203,8 +209,8 @@ type logic =
203209
[s1], hence the name. *)
204210
| Broadcast_tern of ternary_type * t * t * t (** Matches the shapes for a ternary operation. *)
205211
| Terminal of { is_param : bool; logic : terminal_type }
206-
(** Extracts any available shape information from the initialization.
207-
The [is_param] field indicates if this is a parameter tensor that requires gradients. *)
212+
(** Extracts any available shape information from the initialization. The [is_param] field
213+
indicates if this is a parameter tensor that requires gradients. *)
208214
[@@deriving equal, sexp_of]
209215

210216
type update_id [@@deriving equal, compare, hash, sexp]

test/operations/dune

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -389,3 +389,14 @@
389389
(preprocess
390390
(pps ppx_here ppx_expect ppx_inline_test ppx_ocannl))
391391
(modes best))
392+
393+
(test
394+
(name test_interleave)
395+
(package neural_nets_lib)
396+
(deps
397+
ocannl_config
398+
(env_var OCANNL_BACKEND))
399+
(modules test_interleave)
400+
(libraries base ocannl stdio)
401+
(preprocess
402+
(pps ppx_here ppx_ocannl ppx_sexp_conv)))

test/operations/test_interleave.ml

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
open Base
2+
open Ocannl
3+
open Nn_blocks.DSL_modules
4+
5+
let () =
6+
Tensor.unsafe_reinitialize ();
7+
let%op t1 = [ 1.0; 2.0; 3.0 ] in
8+
let%op t2 = [ 4.0; 5.0; 6.0 ] in
9+
let t3 = Operation.interleave t1 t2 () in
10+
11+
(* t3 should be [1.0; 4.0; 2.0; 5.0; 3.0; 6.0] *)
12+
let ctx = Context.auto () in
13+
14+
try
15+
let _ctx = Train.forward_once ctx t3 in
16+
Stdio.printf "Test failed! Expected error was not raised.\n";
17+
Stdlib.exit 1
18+
with Utils.User_error msg ->
19+
if
20+
String.equal msg
21+
"Defined_by_cd_logic: use explicit ~logic annotations when defining this operation"
22+
then Stdio.printf "Test passed! Caught expected error: %s\n" msg
23+
else (
24+
Stdio.printf "Test failed! Caught unexpected error: %s\n" msg;
25+
Stdlib.exit 1)

0 commit comments

Comments
 (0)