@@ -215,6 +215,9 @@ type environment = { dim_env : dim_env; row_env : row_env } [@@deriving sexp_of]
215215 appear elsewhere in the environment. In particular, per-dim and per-row constraints might not
216216 have been applied. *)
217217
218+ let get_dim_val env var =
219+ match Map. find env.dim_env var with Some (Solved_dim (Dim { d; _ } )) -> Some d | _ -> None
220+
218221type constraint_ =
219222 | Dim_eq of { d1 : dim ; d2 : dim }
220223 | Row_eq of { r1 : t ; r2 : t }
@@ -796,7 +799,15 @@ let rows_to_row_or_vars (rows : row list) : (row, dim list * (row_var * row_id)
796799 match row_vars with
797800 | [] ->
798801 (* No row variables found *)
799- let first_id = match rows with [] -> phantom_row_id | first_row :: _ -> first_row.id in
802+ let first_id =
803+ match
804+ List. find_map rows ~f: (function
805+ | { id = { kind = `Output ; _ } as id ; _ } -> Some id
806+ | _ -> None )
807+ with
808+ | None -> phantom_row_id
809+ | Some id -> id
810+ in
800811 Either. First { dims = all_dims; bcast = Broadcastable ; id = first_id }
801812 | [ (v, id) ] ->
802813 (* Exactly one row variable - reconstruct the proper row structure *)
@@ -972,6 +983,13 @@ and apply_row_constraint stage (r : row) (constr : row_constraint) env : constra
972983 match (r, constr) with
973984 | _ when stored && not updated -> (extras, env)
974985 | _ , Unconstrained -> assert false
986+ | _, Total_elems { numerator = Strided_var { coeff; var; denom }; divided_by = [] }
987+ when is_stage2_up stage && Option. is_some (get_dim_val env var) ->
988+ let tot = Option. value_exn (get_dim_val env var) in
989+ let tot = Utils. safe_force coeff * tot / denom in
990+ apply_row_constraint stage r
991+ (Total_elems { numerator = Num_elems tot; divided_by = [] })
992+ env
975993 | ( { dims; bcast = Broadcastable ; _ },
976994 Total_elems { numerator = Strided_var { coeff; var; denom }; divided_by = [] } )
977995 when is_stage2_up stage && known_dims_product dims ->
0 commit comments