@@ -92,13 +92,19 @@ let row_of_kind = function `Batch -> batch | `Input -> input | `Output -> output
9292type deduce_within_shape = Not_constrained | Input_equals_output
9393[@@ deriving compare , sexp , variants ]
9494
95- type compose_type = Pointwise_bin | Compose | Einsum of string * Idx .variable_ref list
95+ type delayed_var_ref = {
96+ var_ref : Ir.Indexing .variable_ref ;
97+ mutable var : [ `Row of Row .row_var | `Dim of Row .dim_var | `Not_set_yet ];
98+ }
99+ [@@ deriving equal , sexp_of ]
100+
101+ type compose_type = Pointwise_bin | Compose | Einsum of string * delayed_var_ref list
96102[@@ deriving sexp_of , equal ]
97103
98104type transpose_type =
99105 | Transpose
100106 | Pointwise_un
101- | Permute of string * Idx .variable_ref list
107+ | Permute of string * delayed_var_ref list
102108 | Batch_slice of Idx .static_symbol
103109 | Uint4x32_to_prec of Ir.Ops .prec Lazy .t
104110[@@ deriving equal , sexp_of ]
@@ -705,6 +711,7 @@ let%debug4_sexp propagate_shapes (update_step : update_step) : unit =
705711 active_constraints := ineqs @ ! active_constraints;
706712 let ineqs', env = Row. solve_inequalities ~stage: Row. Stage1 ineqs ! state in
707713 let _debug_remaining_constraints : Row.constraint_ list = ineqs' in
714+ (* FIXME: call apply_env_step instead *)
708715 iter_shapes update_step ~f: (apply_env_t env);
709716 state := env
710717
@@ -725,6 +732,7 @@ let%debug4_sexp finish_inference (() : unit) : unit =
725732 let unsolved, env = Row. solve_inequalities ~stage: Stage7 unsolved env in
726733 assert (List. is_empty unsolved);
727734 let _active_update_steps : update_step list = ! active_update_steps in
735+ (* FIXME: call apply_env_step instead *)
728736 List. iter ~f: (iter_shapes ~f: (apply_env_t env)) ! active_update_steps;
729737 let _applied_update_steps : update_step list = ! active_update_steps in
730738 active_constraints := [] ;
0 commit comments