@@ -304,17 +304,72 @@ let row_conjunction ?(id = phantom_row_id) constr1 constr2 =
304304 let r = { dims = List. map shared ~f: (fun v -> Var v); bcast = Broadcastable ; id } in
305305 [
306306 Rows_constr
307- { r = [r]; constr = Total_elems { nominator; divided_by = Set. empty (module Dim_var ) } };
307+ {
308+ r = [ r ];
309+ constr = Total_elems { nominator; divided_by = Set. empty (module Dim_var ) };
310+ };
308311 ]
309312 in
310313 let subsum = Set. symmetric_diff vars1 vars2 in
311314 if Sequence. for_all ~f: Either. is_first subsum then Some (extras ~keep_constr1: false , constr2)
312315 else if Sequence. for_all ~f: Either. is_second subsum then
313316 Some (extras ~keep_constr1: true , constr1)
314317 else None
315- | Exact _ , _ | _ , Exact _ ->
316- (* FIXME: NOT IMPLEMENTED YET *)
317- None
318+ | Exact dims1 , Exact dims2 ->
319+ if List. length dims1 <> List. length dims2 then
320+ raise @@ Shape_error (" Exact constraint length mismatch" , [] )
321+ else
322+ let eqs = List. map2_exn dims1 dims2 ~f: (fun d1 d2 -> Dim_eq { d1; d2 }) in
323+ Some (eqs, constr1)
324+ | Total_elems { nominator; divided_by }, Exact dims
325+ | Exact dims , Total_elems { nominator; divided_by } -> (
326+ (* Simple collect_factors logic - handle only basic Dim and Var cases *)
327+ let rec collect_dim_factors (ds , vars ) = function
328+ | Dim { d; _ } -> Some (d :: ds, vars)
329+ | Var v -> Some (ds, v :: vars)
330+ | Conv_input _ -> None (* Too complex, give up *)
331+ in
332+ match
333+ List. fold dims
334+ ~init: (Some ([] , [] ))
335+ ~f: (fun acc dim ->
336+ match acc with None -> None | Some (ds , vars ) -> collect_dim_factors (ds, vars) dim)
337+ with
338+ | None -> None (* Give up on complex cases *)
339+ | Some (ds , vars ) ->
340+ let known_product = List. fold ds ~init: 1 ~f: ( * ) in
341+ if nominator < = 0 then
342+ raise @@ Shape_error ([% string " Invalid Total_elems nominator: %{nominator#Int}" ], [] )
343+ else if known_product = 0 then
344+ raise @@ Shape_error (" Exact constraint has zero dimension" , [] )
345+ else if nominator % known_product <> 0 then
346+ raise
347+ @@ Shape_error
348+ ( [% string
349+ " Total_elems nominator %{nominator#Int} not divisible by Exact dimensions \
350+ product %{known_product#Int}" ],
351+ [] )
352+ else
353+ let reminder = nominator / known_product in
354+ if reminder = 1 then
355+ (* reminder is 1: equate all variables on both sides to 1 *)
356+ let divided_by_eqs =
357+ Set. to_list divided_by
358+ |> List. map ~f: (fun v -> Dim_eq { d1 = Var v; d2 = get_dim ~d: 1 () })
359+ in
360+ let exact_vars_eqs =
361+ List. map vars ~f: (fun v -> Dim_eq { d1 = Var v; d2 = get_dim ~d: 1 () })
362+ in
363+ Some (divided_by_eqs @ exact_vars_eqs, Exact dims)
364+ else if Set. is_empty divided_by && List. length vars = 1 && reminder > 0 then
365+ (* divided_by is empty and there is only one dim variable in Exact dims *)
366+ let v = List. hd_exn vars in
367+ Some ([ Dim_eq { d1 = Var v; d2 = get_dim ~d: reminder () } ], Exact dims)
368+ else if List. is_empty vars && Set. length divided_by = 1 && reminder > 0 then
369+ (* Exact dims contain only known dimensions and divided_by has exactly one variable *)
370+ let v = Set. choose_exn divided_by in
371+ Some ([ Dim_eq { d1 = Var v; d2 = get_dim ~d: reminder () } ], Exact dims)
372+ else None (* Cannot handle this case *) )
318373
319374let rec apply_dim_constraint ~(source : source ) ~(stage : stage ) (dim : dim )
320375 (constr : dim_constraint ) (env : environment ) : constraint_ list * dim_constraint =
@@ -506,17 +561,19 @@ let apply_row_constraint ~stage:_ (r : row) (constr : row_constraint) env : cons
506561 (* | v :: _, [] | [], v :: _ when (is_stage4_up stage) -> (Dim_eq { d1 = Var v; d2 = get_dim
507562 ~d:nominator () } :: extras, env) *)
508563 | _ :: _ , _ when stored -> (extras, env)
509- | _ , _ -> (Rows_constr { r = [r]; constr } :: extras, env (* Wait for more shape inference. *) )
564+ | _ , _ ->
565+ (Rows_constr { r = [ r ]; constr } :: extras, env (* Wait for more shape inference. *) )
510566 with Given_up ->
511567 if stored then (extras, env)
512- else (Rows_constr { r = [r]; constr } :: extras, env (* Wait for more shape inference. *) ))
568+ else
569+ (Rows_constr { r = [ r ]; constr } :: extras, env (* Wait for more shape inference. *) ))
513570 | { bcast = Row_var _ ; _ } , _ | _ , Total_elems { nominator = _ ; divided_by = _ } ->
514571 if stored then (extras, env)
515- else (Rows_constr { r = [r ]; constr } :: extras, env (* Wait for more shape inference. *) )
572+ else (Rows_constr { r = [ r ]; constr } :: extras, env (* Wait for more shape inference. *) )
516573 | _ , Exact _ ->
517574 (* FIXME: NOT IMPLEMENTED YET *)
518575 if stored then (extras, env)
519- else (Rows_constr { r = [r ]; constr } :: extras, env (* Wait for more shape inference. *) )
576+ else (Rows_constr { r = [ r ]; constr } :: extras, env (* Wait for more shape inference. *) )
520577
521578let s_dim_one_in_entry v ~value (in_ : dim_entry ) : _ * dim_entry =
522579 match in_ with
@@ -598,12 +655,13 @@ let s_row_one v ~value:{ dims = more_dims; bcast; id = _ } ~in_ =
598655 })
599656 | _ -> in_
600657
601- let s_row_one_in_row_constr _v ~value :_ ~in_ =
602- match in_ with
658+ let s_row_one_in_row_constr _v ~value :_ ~in_ =
659+ match in_ with
603660 | Unconstrained | Total_elems _ -> in_
604- | Exact _ ->
661+ | Exact _ ->
605662 (* FIXME: NOT IMPLEMENTED YET *)
606663 in_
664+
607665let row_of_var v id = { dims = [] ; bcast = Row_var { v; beg_dims = [] }; id }
608666
609667let s_row_one_in_entry (v : row_var ) ~(value : row ) ~(in_ : row_entry ) :
@@ -1408,9 +1466,8 @@ let%debug5_sexp rec eliminate_row_constraint ~lub (r : row) (constr : row_constr
14081466 | ineq -> [ ineq ])
14091467 | _ , [ v ], _ -> no_further_axes :: [ Dim_eq { d1 = Var v; d2 = get_dim ~d () } ]
14101468 | _ -> [] )
1411- | Exact _ ->
1412- (* FIXME: NOT IMPLEMENTED YET *)
1413- []
1469+ | Exact dims ->
1470+ [ Row_eq { r1; r2 = { dims; bcast = Broadcastable ; id } } ]
14141471 | _ -> [] )
14151472
14161473let % debug5_sexp close_row_terminal ~(stage : stage ) (env : environment )
@@ -1523,8 +1580,7 @@ let%debug4_sexp solve_inequalities ~(stage : stage) (ineqs : constraint_ list) (
15231580 let r = subst_row env (List. hd_exn rows) in
15241581 let more_ineqs, env = apply_row_constraint ~stage r constr env in
15251582 (more_ineqs @ ineqs, env)
1526- else
1527- (ineqs, env)
1583+ else (ineqs, env)
15281584 | Terminal_dim d ->
15291585 let more_ineqs = close_dim_terminal ~stage env @@ subst_dim env d in
15301586 (more_ineqs @ ineqs, env)
0 commit comments