@@ -291,7 +291,9 @@ type source = Direct | Equation | Cur | Subr [@@deriving equal, sexp]
291291let rec interleave l1 l2 =
292292 match (l1, l2) with [] , l | l , [] -> l | h1 :: t1 , h2 :: t2 -> h1 :: h2 :: interleave t1 t2
293293
294- let merge_origins o1 o2 = List. take (interleave o1 o2) 6
294+ let merge_origins o1 o2 =
295+ let o = List. dedup_and_sort ~compare: compare_constraint_origin @@ interleave o1 o2 in
296+ List. take o 10
295297
296298let dim_to_int_exn = function
297299 | Dim { d; _ } -> d
@@ -382,7 +384,7 @@ let collect_factors dims =
382384
383385let known_dims_product dims = match collect_factors dims with Some (_ , [] ) -> true | _ -> false
384386
385- let rec row_conjunction ?(id = phantom_row_id) stage constr1 constr2 =
387+ let rec row_conjunction ?(id = phantom_row_id) ~ origin stage constr1 constr2 =
386388 let elems_mismatch n1 n2 =
387389 raise
388390 @@ Shape_error
@@ -415,7 +417,7 @@ let rec row_conjunction ?(id = phantom_row_id) stage constr1 constr2 =
415417 (* So var = n * denom / coeff *)
416418 if n * denom % coeff_val = 0 then
417419 Some
418- ( [ Dim_eq { d1 = Var var; d2 = get_dim ~d: (n * denom / coeff_val) () ; origin = [] } ],
420+ ( [ Dim_eq { d1 = Var var; d2 = get_dim ~d: (n * denom / coeff_val) () ; origin } ],
419421 constr1 )
420422 else
421423 (* n * denom is not divisible by coeff - this is a mismatch *)
@@ -448,7 +450,7 @@ let rec row_conjunction ?(id = phantom_row_id) stage constr1 constr2 =
448450 d2 =
449451 Conv_input
450452 { stride = k; output = Var v2; dilation = 0 ; kernel = get_dim ~d: 0 () };
451- origin = [] ;
453+ origin;
452454 };
453455 ],
454456 constr2 )
@@ -464,7 +466,7 @@ let rec row_conjunction ?(id = phantom_row_id) stage constr1 constr2 =
464466 d2 =
465467 Conv_input
466468 { stride = k; output = Var v1; dilation = 0 ; kernel = get_dim ~d: 0 () };
467- origin = [] ;
469+ origin;
468470 };
469471 ],
470472 constr1 )
@@ -481,15 +483,15 @@ let rec row_conjunction ?(id = phantom_row_id) stage constr1 constr2 =
481483 (* Variable appears in both numerator and denominator, they cancel out *)
482484 (* (c1 * v1 / d1) / (... * v1 * ...) = c1 / (d1 * ... * ...) *)
483485 let vars1' = remove_var v1 vars1 in
484- row_conjunction ~id stage
486+ row_conjunction ~id ~origin stage
485487 (Total_elems { numerator = Num_elems (Utils. safe_force c1); divided_by = vars1' })
486488 constr2
487489 | ( constr2,
488490 Total_elems
489491 { numerator = Strided_var { coeff = c1; var = v1; denom = _ }; divided_by = vars1 } )
490492 when List. mem vars1 v1 ~equal: equal_dim_var && late ->
491493 let vars1' = remove_var v1 vars1 in
492- row_conjunction ~id stage
494+ row_conjunction ~id ~origin stage
493495 (Total_elems { numerator = Num_elems (Utils. safe_force c1); divided_by = vars1' })
494496 constr2
495497 | ( Total_elems { numerator = n1; divided_by = vars1 },
@@ -522,12 +524,12 @@ let rec row_conjunction ?(id = phantom_row_id) stage constr1 constr2 =
522524 | None -> elems_mismatch n1 n2
523525 | Some v ->
524526 if quotient < = 0 then elems_mismatch n1 n2;
525- [ Dim_eq { d1 = Var v; d2 = get_dim ~d: quotient () ; origin = [] } ])
527+ [ Dim_eq { d1 = Var v; d2 = get_dim ~d: quotient () ; origin } ])
526528 else if quotient < = 0 && Option. is_none num_var then elems_mismatch n1 n2
527529 else if quotient = 1 && Option. is_none num_var then
528530 (* The difference variables must all be 1 *)
529531 List. map diff_vars ~f: (fun v ->
530- Dim_eq { d1 = Var v; d2 = get_dim ~d: 1 () ; origin = [] })
532+ Dim_eq { d1 = Var v; d2 = get_dim ~d: 1 () ; origin })
531533 else
532534 (* The product of difference variables equals the quotient *)
533535 let r = { dims = List. map diff_vars ~f: (fun v -> Var v); bcast = Broadcastable ; id } in
@@ -540,7 +542,7 @@ let rec row_conjunction ?(id = phantom_row_id) stage constr1 constr2 =
540542 in
541543 [
542544 Rows_constr
543- { r = [ r ]; constr = Total_elems { numerator; divided_by = [] }; origin = [] };
545+ { r = [ r ]; constr = Total_elems { numerator; divided_by = [] }; origin };
544546 ]
545547 in
546548 let lazy_extras ~keep_constr1 ~num_var ?(extra_var = [] ) ~coeff ~denom () =
@@ -554,7 +556,7 @@ let rec row_conjunction ?(id = phantom_row_id) stage constr1 constr2 =
554556 let constr =
555557 Total_elems { numerator = Strided_var { coeff; var = num_var; denom }; divided_by = [] }
556558 in
557- [ Rows_constr { r = [ r ]; constr; origin = [] } ]
559+ [ Rows_constr { r = [ r ]; constr; origin } ]
558560 in
559561 let extras ~keep_constr1 : _ option =
560562 match (n1, n2) with
@@ -612,7 +614,7 @@ let rec row_conjunction ?(id = phantom_row_id) stage constr1 constr2 =
612614 ];
613615 ] )
614616 else
615- let eqs = List. map2_exn dims1 dims2 ~f: (fun d1 d2 -> Dim_eq { d1; d2; origin = [] }) in
617+ let eqs = List. map2_exn dims1 dims2 ~f: (fun d1 d2 -> Dim_eq { d1; d2; origin }) in
616618 Some (eqs, constr1)
617619 | Total_elems { numerator; divided_by }, Exact dims
618620 | Exact dims , Total_elems { numerator; divided_by } -> (
@@ -638,24 +640,24 @@ let rec row_conjunction ?(id = phantom_row_id) stage constr1 constr2 =
638640 (* reminder is 1: equate all variables on both sides to 1 *)
639641 let divided_by_eqs =
640642 List. map divided_by ~f: (fun v ->
641- Dim_eq { d1 = Var v; d2 = get_dim ~d: 1 () ; origin = [] })
643+ Dim_eq { d1 = Var v; d2 = get_dim ~d: 1 () ; origin })
642644 in
643645 let exact_vars_eqs =
644646 List. map vars ~f: (fun v ->
645- Dim_eq { d1 = Var v; d2 = get_dim ~d: 1 () ; origin = [] })
647+ Dim_eq { d1 = Var v; d2 = get_dim ~d: 1 () ; origin })
646648 in
647649 Some (divided_by_eqs @ exact_vars_eqs, Exact dims)
648650 else if List. is_empty divided_by && List. length vars = 1 && reminder > 0 then
649651 (* divided_by is empty and there is only one dim variable in Exact dims *)
650652 let v = List. hd_exn vars in
651653 Some
652- ([ Dim_eq { d1 = Var v; d2 = get_dim ~d: reminder () ; origin = [] } ], Exact dims)
654+ ([ Dim_eq { d1 = Var v; d2 = get_dim ~d: reminder () ; origin } ], Exact dims)
653655 else if List. is_empty vars && List. length divided_by = 1 && reminder > 0 then
654656 (* Exact dims contain only known dimensions and divided_by has exactly one
655657 variable *)
656658 let v = List. hd_exn divided_by in
657659 Some
658- ([ Dim_eq { d1 = Var v; d2 = get_dim ~d: reminder () ; origin = [] } ], Exact dims)
660+ ([ Dim_eq { d1 = Var v; d2 = get_dim ~d: reminder () ; origin } ], Exact dims)
659661 else None
660662 | Strided_var { coeff; var; denom } ->
661663 if known_product = 0 then
@@ -666,7 +668,7 @@ let rec row_conjunction ?(id = phantom_row_id) stage constr1 constr2 =
666668 let coeff_val = Utils. safe_force coeff in
667669 if known_product * denom % coeff_val = 0 then
668670 let d = known_product * denom / coeff_val in
669- Some ([ Dim_eq { d1 = Var var; d2 = get_dim ~d () ; origin = [] } ], Exact dims)
671+ Some ([ Dim_eq { d1 = Var var; d2 = get_dim ~d () ; origin } ], Exact dims)
670672 else elems_mismatch numerator (Num_elems known_product)
671673 else if
672674 late
@@ -695,7 +697,7 @@ let rec row_conjunction ?(id = phantom_row_id) stage constr1 constr2 =
695697 let coefficient = known_product * denom / coeff_val in
696698 if coefficient = 1 then
697699 (* Simple equality *)
698- Some ([ Dim_eq { d1 = Var var; d2 = Var single_var; origin = [] } ], Exact dims)
700+ Some ([ Dim_eq { d1 = Var var; d2 = Var single_var; origin } ], Exact dims)
699701 else if coefficient > 1 then
700702 (* Use Conv_input with stride and dilation=0 *)
701703 Some
@@ -711,7 +713,7 @@ let rec row_conjunction ?(id = phantom_row_id) stage constr1 constr2 =
711713 dilation = 0 ;
712714 kernel = get_dim ~d: 0 () ;
713715 };
714- origin = [] ;
716+ origin;
715717 };
716718 ],
717719 Exact dims )
@@ -758,7 +760,7 @@ let%track5_sexp rec apply_dim_constraint ~(source : source) ~(stage : stage) (di
758760 (* FIXME: *)
759761 (* match (dim, constr, stage) with
760762 | Var _, At_least_dim d, Stage4 ->
761- (Dim_eq { d1 = dim; d2 = get_dim ~d () ; origin = [] } :: extras, Unconstrained_dim)
763+ (Dim_eq { d1 = dim; d2 = get_dim ~d () ; origin } :: extras, Unconstrained_dim)
762764 | _ -> *)
763765 (extras, constr)
764766
@@ -886,7 +888,7 @@ let rows_to_row_or_vars (rows : row list) : (row, dim list * (row_var * row_id)
886888
887889let row_of_var v id = { dims = [] ; bcast = Row_var { v; beg_dims = [] }; id }
888890
889- let check_empty_row r =
891+ let check_empty_row ~ origin r =
890892 if not (List. is_empty r.dims) then
891893 raise @@ Shape_error (" check_empty_row: row is not empty" , [ Row_mismatch [ r ] ]);
892894 match r.bcast with
@@ -898,7 +900,7 @@ let check_empty_row r =
898900 {
899901 r1 = row_of_var v r.id;
900902 r2 = { dims = [] ; bcast = Broadcastable ; id = r.id };
901- origin = [] ;
903+ origin;
902904 };
903905 ]
904906 else raise @@ Shape_error (" check_empty_row: row is not empty" , [ Row_mismatch [ r ] ])
@@ -1203,15 +1205,15 @@ let%track5_sexp rec apply_rows_constraint ~depth ~stage origin (rows : row list)
12031205 r2 = { dims = [ single_dim ]; bcast = Broadcastable ; id };
12041206 origin;
12051207 }
1206- :: List. concat_map ~f: check_empty_row more_rows,
1208+ :: List. concat_map ~f: ( check_empty_row ~origin ) more_rows,
12071209 env )
12081210 | { dims = _; bcast = Row_var { v = _; beg_dims = _ }; id = { kind = `Output ; _ } } :: _
12091211 ->
12101212 assert false
12111213 | _ ->
12121214 raise @@ Shape_error (" apply_rows_constraint: shape too big" , [ Row_mismatch rows ])
12131215 )
1214- | _ -> ([ Rows_constr { r = rows; constr; origin = [] } ], env))
1216+ | _ -> ([ Rows_constr { r = rows; constr; origin } ], env))
12151217
12161218and apply_row_constraint ~depth stage origin (r : row ) (constr : row_constraint ) env :
12171219 constraint_ list * _ =
@@ -1261,8 +1263,9 @@ and apply_row_constraint ~depth stage origin (r : row) (constr : row_constraint)
12611263 true ,
12621264 false )
12631265 | Some (Bounds_row bounds ) -> (
1266+ let origin = merge_origins origin bounds.origin in
12641267 match
1265- row_conjunction ~id: r.id stage (reduce constr ~beg_dims ~dims ) bounds.constr
1268+ row_conjunction ~id: r.id ~origin stage (reduce constr ~beg_dims ~dims ) bounds.constr
12661269 with
12671270 | None -> ([] , constr, env, false , false )
12681271 | Some (extras , constr ) ->
@@ -1325,7 +1328,7 @@ and apply_row_constraint ~depth stage origin (r : row) (constr : row_constraint)
13251328 (Dim_eq { d1 = Var v; d2 = get_dim ~d: n () ; origin } :: extras, env)
13261329 | Num_elems 1 , vs1 , vs2 ->
13271330 ( List. map
1328- ~f: (fun v -> Dim_eq { d1 = Var v; d2 = get_dim ~d: 1 () ; origin = [] })
1331+ ~f: (fun v -> Dim_eq { d1 = Var v; d2 = get_dim ~d: 1 () ; origin })
13291332 (vs1 @ vs2)
13301333 @ extras,
13311334 env )
@@ -1334,7 +1337,7 @@ and apply_row_constraint ~depth stage origin (r : row) (constr : row_constraint)
13341337 (* Total = (coeff * v / denom) / v = coeff / denom *)
13351338 if Utils. safe_force coeff % denom = 0 then
13361339 ( Dim_eq
1337- { d1 = Var v; d2 = get_dim ~d: (Utils. safe_force coeff / denom) () ; origin = [] }
1340+ { d1 = Var v; d2 = get_dim ~d: (Utils. safe_force coeff / denom) () ; origin }
13381341 :: extras,
13391342 env )
13401343 else if
@@ -1380,7 +1383,7 @@ and apply_row_constraint ~depth stage origin (r : row) (constr : row_constraint)
13801383 List. filter lub_dims ~f: (function Dim { d; _ } -> d > 1 | _ -> false )
13811384 in
13821385 if List. length greater_than_one < = 1 then
1383- (Row_eq { r1 = row_of_var v r.id; r2 = lub; origin = [] } :: extras, env)
1386+ (Row_eq { r1 = row_of_var v r.id; r2 = lub; origin } :: extras, env)
13841387 else if stored then (extras, env)
13851388 else (Rows_constr { r = [ r ]; constr; origin } :: extras, env)
13861389 | _ ->
@@ -1614,7 +1617,7 @@ let%debug5_sexp rec unify_row ~stage origin (eq : t * t) (env : environment) :
16141617 id;
16151618 };
16161619 r2;
1617- origin = [] ;
1620+ origin;
16181621 };
16191622 ],
16201623 env )
@@ -1632,9 +1635,9 @@ let%debug5_sexp rec unify_row ~stage origin (eq : t * t) (env : environment) :
16321635 ineqs := Row_ineq { cur = row_of_var cur value.id; subr = r2; origin } :: ! ineqs);
16331636 List. iter subr ~f: (fun subr ->
16341637 ineqs :=
1635- Row_ineq { subr = row_of_var subr value.id; cur = r2; origin = [] } :: ! ineqs);
1638+ Row_ineq { subr = row_of_var subr value.id; cur = r2; origin } :: ! ineqs);
16361639 Option. iter lub ~f: (fun lub ->
1637- ineqs := Row_ineq { cur = lub; subr = r2; origin = [] } :: ! ineqs);
1640+ ineqs := Row_ineq { cur = lub; subr = r2; origin } :: ! ineqs);
16381641 let extras, env = apply_row_constraint ~depth: 0 stage origin value constr env in
16391642 ineqs := extras @ ! ineqs;
16401643 env)
@@ -1649,7 +1652,7 @@ let%debug5_sexp rec unify_row ~stage origin (eq : t * t) (env : environment) :
16491652 raise @@ Shape_error (" Mismatching number of axes" , [ Row_mismatch [ r1; r2 ] ])
16501653 | Ok eqs ->
16511654 List. fold ~init: ([] , env)
1652- ~f: (fun acc (d1 , d2 ) -> solve acc (Dim_eq { d1; d2; origin = [] }))
1655+ ~f: (fun acc (d1 , d2 ) -> solve acc (Dim_eq { d1; d2; origin }))
16531656 eqs)
16541657
16551658let % track5_sexp solve_dim_ineq ~(stage : stage ) origin ~(cur : dim ) ~(subr : dim )
@@ -1679,14 +1682,14 @@ let%track5_sexp solve_dim_ineq ~(stage : stage) origin ~(cur : dim) ~(subr : dim
16791682 (" dimension comparison for axis: different labels" , [ Dim_mismatch [ cur; subr ] ])
16801683 | Dim { d = d1 ; _ } , Dim { d = d2 ; _ } when d1 = d2 -> ([] , env)
16811684 | _ , Dim { d = 1 ; _ } -> ([] , env)
1682- | (Dim { d = 1 ; _ } as cur ), _ -> ([ Dim_eq { d1 = subr; d2 = cur; origin = [] } ], env)
1683- | Conv_input _ , _ | _ , Conv_input _ -> ([ Dim_eq { d1 = subr; d2 = cur; origin = [] } ], env)
1685+ | (Dim { d = 1 ; _ } as cur ), _ -> ([ Dim_eq { d1 = subr; d2 = cur; origin } ], env)
1686+ | Conv_input _ , _ | _ , Conv_input _ -> ([ Dim_eq { d1 = subr; d2 = cur; origin } ], env)
16841687 | Var cur_v , Var subr_v -> (
16851688 match (Map. find env.dim_env cur_v, Map. find env.dim_env subr_v) with
16861689 | Some (Bounds_dim { cur = cur1 ; _ } ), _ when List. mem ~equal: equal_dim_var cur1 subr_v ->
1687- ([ Dim_eq { d1 = cur; d2 = subr; origin = [] } ], env)
1690+ ([ Dim_eq { d1 = cur; d2 = subr; origin } ], env)
16881691 | _ , Some (Bounds_dim { subr = subr2 ; _ } ) when List. mem ~equal: equal_dim_var subr2 cur_v ->
1689- ([ Dim_eq { d1 = cur; d2 = subr; origin = [] } ], env)
1692+ ([ Dim_eq { d1 = cur; d2 = subr; origin } ], env)
16901693 | None , None ->
16911694 ( [] ,
16921695 {
@@ -1915,7 +1918,7 @@ let%debug5_sexp solve_row_ineq ~(stage : stage) origin ~(cur : t) ~(subr : t) (e
19151918 when is_stage6_up stage ->
19161919 ( Row_ineq { cur; subr; origin }
19171920 :: Row_eq
1918- { r1 = row_of_var v id; r2 = { dims = [] ; bcast = Broadcastable ; id }; origin = [] }
1921+ { r1 = row_of_var v id; r2 = { dims = [] ; bcast = Broadcastable ; id }; origin }
19191922 :: ineqs,
19201923 env )
19211924 | cur , subr when equal_row cur subr -> ([] , env)
@@ -2086,8 +2089,8 @@ let%debug5_sexp solve_row_ineq ~(stage : stage) origin ~(cur : t) ~(subr : t) (e
20862089 (* We don't need to add any dimension inequalities, because they'll be captured by the extra
20872090 row inequalities. *)
20882091 ( [
2089- Row_eq { r1 = cur; r2 = template; origin = [] };
2090- Row_ineq { cur = template; subr; origin = [] };
2092+ Row_eq { r1 = cur; r2 = template; origin };
2093+ Row_ineq { cur = template; subr; origin };
20912094 ],
20922095 env )
20932096 | { bcast = Broadcastable ; _ }, _ when cur_dims_l + cur_beg_dims_l < subr_dims_l + subr_beg_dims_l
@@ -2293,7 +2296,7 @@ and eliminate_row_constraint ~depth stage origin ~terminal ~(lub : row option) (
22932296 ( no_further_axes
22942297 :: List. map vs ~f: (fun v ->
22952298 let d2 = get_dim ~d: 1 () in
2296- Dim_eq { d1 = Var v; d2; origin = [] }),
2299+ Dim_eq { d1 = Var v; d2; origin }),
22972300 env )
22982301 | Num_elems d , [] , None when d <> 1 && is_stage3_up stage ->
22992302 let dim = get_dim ~d () in
@@ -2373,7 +2376,7 @@ and eliminate_row_constraint ~depth stage origin ~terminal ~(lub : row option) (
23732376 let _denom : int = denom in
23742377 keep_constr ()
23752378 | _ -> keep_constr () )
2376- | Exact dims -> ([ Row_eq { r1; r2 = { dims; bcast = Broadcastable ; id }; origin = [] } ], env)
2379+ | Exact dims -> ([ Row_eq { r1; r2 = { dims; bcast = Broadcastable ; id }; origin } ], env)
23772380 | Unconstrained -> ([] , env))
23782381
23792382let % track5_sexp close_row_terminal ~(stage : stage ) origin (env : environment )
0 commit comments