Skip to content

Commit 5286521

Browse files
committed
In Total_elems with stride row constraint, also solve for the row side by substituting the stride variable
1 parent 7633562 commit 5286521

File tree

1 file changed

+19
-1
lines changed

1 file changed

+19
-1
lines changed

lib/row.ml

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,9 @@ type environment = { dim_env : dim_env; row_env : row_env } [@@deriving sexp_of]
215215
appear elsewhere in the environment. In particular, per-dim and per-row constraints might not
216216
have been applied. *)
217217

218+
let get_dim_val env var =
219+
match Map.find env.dim_env var with Some (Solved_dim (Dim { d; _ })) -> Some d | _ -> None
220+
218221
type constraint_ =
219222
| Dim_eq of { d1 : dim; d2 : dim }
220223
| Row_eq of { r1 : t; r2 : t }
@@ -796,7 +799,15 @@ let rows_to_row_or_vars (rows : row list) : (row, dim list * (row_var * row_id)
796799
match row_vars with
797800
| [] ->
798801
(* No row variables found *)
799-
let first_id = match rows with [] -> phantom_row_id | first_row :: _ -> first_row.id in
802+
let first_id =
803+
match
804+
List.find_map rows ~f:(function
805+
| { id = { kind = `Output; _ } as id; _ } -> Some id
806+
| _ -> None)
807+
with
808+
| None -> phantom_row_id
809+
| Some id -> id
810+
in
800811
Either.First { dims = all_dims; bcast = Broadcastable; id = first_id }
801812
| [ (v, id) ] ->
802813
(* Exactly one row variable - reconstruct the proper row structure *)
@@ -972,6 +983,13 @@ and apply_row_constraint stage (r : row) (constr : row_constraint) env : constra
972983
match (r, constr) with
973984
| _ when stored && not updated -> (extras, env)
974985
| _, Unconstrained -> assert false
986+
| _, Total_elems { numerator = Strided_var { coeff; var; denom }; divided_by = [] }
987+
when is_stage2_up stage && Option.is_some (get_dim_val env var) ->
988+
let tot = Option.value_exn (get_dim_val env var) in
989+
let tot = Utils.safe_force coeff * tot / denom in
990+
apply_row_constraint stage r
991+
(Total_elems { numerator = Num_elems tot; divided_by = [] })
992+
env
975993
| ( { dims; bcast = Broadcastable; _ },
976994
Total_elems { numerator = Strided_var { coeff; var; denom }; divided_by = [] } )
977995
when is_stage2_up stage && known_dims_product dims ->

0 commit comments

Comments
 (0)