@@ -570,8 +570,7 @@ let%debug4_sexp get_inequalities ({ shape = cur_sh; logic; id = _ } as _upd : up
570570 Row_eq { r1 = cur_sh.input; r2 = sh.input };
571571 Row_eq { r1 = cur_sh.output; r2 = sh.output };
572572 ] )
573- | Transpose (Permute (spec , _dim_refs ), sh ) ->
574- (* FIXME: support dim_refs *)
573+ | Transpose (Permute (spec , dim_refs ), sh ) ->
575574 let ls_rhs, ls_lhs =
576575 match einsum_of_spec spec with
577576 | ls_rhs , None , ls_lhs -> (ls_rhs, ls_lhs)
@@ -590,6 +589,18 @@ let%debug4_sexp get_inequalities ({ shape = cur_sh; logic; id = _ } as _upd : up
590589 let extras_lhs, proj_env_lhs, (b_lhs, i_lhs, o_lhs) =
591590 einsum_slot_spec_to_dims_bio ~generative ~sh_id: cur_sh.id ~row_var_env ~dim_var_env ls_lhs
592591 in
592+ (* Bind delayed_var_refs to the variables after they are created *)
593+ List. iter dim_refs ~f: (fun delayed_ref ->
594+ let label = delayed_ref.var_ref.ref_label in
595+ (* Check if it's in one of the environments *)
596+ match Hashtbl. find dim_var_env label with
597+ | Some var -> delayed_ref.var < - `Dim var
598+ | None -> (
599+ match Hashtbl. find row_var_env label with
600+ | Some var -> delayed_ref.var < - `Row var
601+ | None -> ()
602+ )
603+ );
593604 let proj_env =
594605 let combine ~key :_ _ _ = assert false in
595606 Map. merge_skewed ~combine proj_env_rhs proj_env_lhs
@@ -621,8 +632,7 @@ let%debug4_sexp get_inequalities ({ shape = cur_sh; logic; id = _ } as _upd : up
621632 { numerator = Row. Strided_var { coeff; var; denom = 1 }; divided_by = [] };
622633 };
623634 ] )
624- | Broadcast (Einsum (spec , _dim_refs ), sh1 , sh2 ) ->
625- (* FIXME: support dim_refs *)
635+ | Broadcast (Einsum (spec , dim_refs ), sh1 , sh2 ) ->
626636 let ls_rhs1, ls_rhs2, ls_lhs =
627637 match einsum_of_spec spec with
628638 | ls_rhs1 , Some ls_rhs2 , ls_lhs -> (ls_rhs1, ls_rhs2, ls_lhs)
@@ -643,6 +653,18 @@ let%debug4_sexp get_inequalities ({ shape = cur_sh; logic; id = _ } as _upd : up
643653 let extras_lhs, proj_env_lhs, (b_lhs, i_lhs, o_lhs) =
644654 einsum_slot_spec_to_dims_bio ~generative ~sh_id: cur_sh.id ~row_var_env ~dim_var_env ls_lhs
645655 in
656+ (* Bind delayed_var_refs to the variables after they are created *)
657+ List. iter dim_refs ~f: (fun delayed_ref ->
658+ let label = delayed_ref.var_ref.ref_label in
659+ (* Check if it's in one of the environments *)
660+ match Hashtbl. find dim_var_env label with
661+ | Some var -> delayed_ref.var < - `Dim var
662+ | None -> (
663+ match Hashtbl. find row_var_env label with
664+ | Some var -> delayed_ref.var < - `Row var
665+ | None -> ()
666+ )
667+ );
646668 let proj_env =
647669 let combine ~key :_ _ _ = assert false in
648670 Map. merge_skewed ~combine proj_env_rhs1
@@ -701,6 +723,52 @@ let apply_env_t env sh =
701723 sh.input < - Row. subst_row env sh.input;
702724 sh.output < - Row. subst_row env sh.output
703725
726+ let rec compute_row_product env (row : Row.t ) : int =
727+ match row.dims with
728+ | [] -> 1
729+ | dim :: rest ->
730+ let dim_val =
731+ match dim with
732+ | Row. Dim { d; _ } -> d
733+ | Row. Var v -> (
734+ match Row. get_dim_from_env env v with
735+ | Some d -> d
736+ | None -> 1 (* Variable not yet resolved *)
737+ )
738+ | Row. Conv_input _ -> 1 (* TODO: handle convolution input dimensions *)
739+ in
740+ dim_val * compute_row_product env { row with dims = rest }
741+
742+ let update_delayed_var_refs env update_step =
743+ let update_var_ref_list var_refs =
744+ List. iter var_refs ~f: (fun delayed_ref ->
745+ match delayed_ref.var with
746+ | `Not_set_yet -> () (* Variable not bound yet, will be set later *)
747+ | `Dim dim_var -> (
748+ match Row. get_dim_from_env env dim_var with
749+ | Some d -> delayed_ref.var_ref.solved_dim < - Some d
750+ | None -> () (* Not yet resolved *)
751+ )
752+ | `Row row_var -> (
753+ match Row. get_row_from_env env row_var with
754+ | Some row ->
755+ let product = compute_row_product env row in
756+ delayed_ref.var_ref.solved_dim < - Some product
757+ | None -> () (* Not yet resolved *)
758+ )
759+ )
760+ in
761+ match update_step.logic with
762+ | Transpose (Permute (_ , var_refs ), _ ) ->
763+ update_var_ref_list var_refs
764+ | Broadcast (Einsum (_ , var_refs ), _ , _ ) ->
765+ update_var_ref_list var_refs
766+ | _ -> ()
767+
768+ let apply_env_step env update_step =
769+ iter_shapes update_step ~f: (apply_env_t env);
770+ update_delayed_var_refs env update_step
771+
704772let % debug4_sexp propagate_shapes (update_step : update_step ) : unit =
705773 (* Allow the derivation of constraints to depend on the shapes (currently, only Batch_slice
706774 does). *)
@@ -711,8 +779,7 @@ let%debug4_sexp propagate_shapes (update_step : update_step) : unit =
711779 active_constraints := ineqs @ ! active_constraints;
712780 let ineqs', env = Row. solve_inequalities ~stage: Row. Stage1 ineqs ! state in
713781 let _debug_remaining_constraints : Row.constraint_ list = ineqs' in
714- (* FIXME: call apply_env_step instead *)
715- iter_shapes update_step ~f: (apply_env_t env);
782+ apply_env_step env update_step;
716783 state := env
717784
718785let % debug4_sexp finish_inference (() : unit ) : unit =
@@ -732,8 +799,7 @@ let%debug4_sexp finish_inference (() : unit) : unit =
732799 let unsolved, env = Row. solve_inequalities ~stage: Stage7 unsolved env in
733800 assert (List. is_empty unsolved);
734801 let _active_update_steps : update_step list = ! active_update_steps in
735- (* FIXME: call apply_env_step instead *)
736- List. iter ~f: (iter_shapes ~f: (apply_env_t env)) ! active_update_steps;
802+ List. iter ~f: (apply_env_step env) ! active_update_steps;
737803 let _applied_update_steps : update_step list = ! active_update_steps in
738804 active_constraints := [] ;
739805 active_update_steps := [] ;
0 commit comments