Skip to content

Commit bd6c5ef

Browse files
committed
Row provenance TODOs
1 parent 902a2d5 commit bd6c5ef

File tree

3 files changed

+12
-8
lines changed

3 files changed

+12
-8
lines changed

tensor/row.ml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ type provenance_origin = { sh_id : int; kind : kind } [@@deriving sexp, compare,
160160
(* List of origins, maintained as deduplicated and sorted *)
161161
type provenance = provenance_origin list [@@deriving sexp, compare, equal, hash]
162162

163+
let empty_provenance = []
163164
let provenance ~sh_id ~kind = [ { sh_id; kind } ]
164165

165166
(* Merge two provenances by combining and deduplicating their origins *)
@@ -171,7 +172,7 @@ type t = { dims : dim list; bcast : bcast; prov : provenance }
171172

172173
type row = t [@@deriving equal, sexp]
173174

174-
let get_row_for_var ?(prov = []) v = { dims = []; bcast = Row_var { v; beg_dims = [] }; prov }
175+
let get_row_for_var prov v = { dims = []; bcast = Row_var { v; beg_dims = [] }; prov }
175176

176177
let dims_label_assoc dims =
177178
let f = function Var { label = Some l; _ } as d -> Some (l, d) | _ -> None in
@@ -400,7 +401,7 @@ let collect_factors dims =
400401

401402
let known_dims_product dims = match collect_factors dims with Some (_, []) -> true | _ -> false
402403

403-
let rec row_conjunction ?(prov = []) ~origin stage constr1 constr2 =
404+
let rec row_conjunction ~prov ~origin stage constr1 constr2 =
404405
let elems_mismatch n1 n2 =
405406
raise
406407
@@ Shape_error

tensor/row.mli

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ val dim_to_string : print_style -> dim -> string
4545
type provenance [@@deriving sexp, compare, equal, hash]
4646

4747
val provenance : sh_id:int -> kind:kind -> provenance
48+
val empty_provenance : provenance
4849
val merge_provenance : provenance -> provenance -> provenance
4950

5051
type row_var [@@deriving sexp, compare, equal, hash]
@@ -59,10 +60,11 @@ type bcast =
5960
| Broadcastable (** The shape does not have more axes of this kind, but is "polymorphic". *)
6061
[@@deriving equal, hash, compare, sexp, variants]
6162

62-
type t = { dims : dim list; bcast : bcast; prov : provenance } [@@deriving equal, hash, compare, sexp]
63+
type t = { dims : dim list; bcast : bcast; prov : provenance }
64+
[@@deriving equal, hash, compare, sexp]
6365

6466
val dims_label_assoc : t -> (string * dim) list
65-
val get_row_for_var : ?prov:provenance -> row_var -> t
67+
val get_row_for_var : provenance -> row_var -> t
6668

6769
type environment [@@deriving sexp_of]
6870

tensor/shape.ml

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1267,7 +1267,7 @@ let set_dim delayed_var_ref dim =
12671267
{
12681268
(* TODO: actually, the Row.provenance should be the one of the shape that the row variable
12691269
is in, should be stored in `Row and in env_row_var. *)
1270-
r = [ Row.get_row_for_var row_var ];
1270+
r = [ Row.get_row_for_var Row.empty_provenance row_var ];
12711271
constr = Total_elems { numerator = Num_elems dim; divided_by = [] };
12721272
origin =
12731273
[
@@ -1283,6 +1283,7 @@ let set_dim delayed_var_ref dim =
12831283
:: !active_constraints
12841284

12851285
let set_equal delayed_ref1 delayed_ref2 =
1286+
(* TODO: use provenance from the row variables once we have it there. *)
12861287
match (delayed_ref1, delayed_ref2) with
12871288
| { var_ref = { solved_dim = Some dim1; _ }; _ }, { var_ref = { solved_dim = Some dim2; _ }; _ }
12881289
->
@@ -1331,8 +1332,8 @@ let set_equal delayed_ref1 delayed_ref2 =
13311332
active_constraints :=
13321333
Row.Row_eq
13331334
{
1334-
r1 = Row.get_row_for_var row_var1;
1335-
r2 = Row.get_row_for_var row_var2;
1335+
r1 = Row.get_row_for_var Row.empty_provenance row_var1;
1336+
r2 = Row.get_row_for_var Row.empty_provenance row_var2;
13361337
origin =
13371338
[
13381339
{
@@ -1351,7 +1352,7 @@ let set_equal delayed_ref1 delayed_ref2 =
13511352
active_constraints :=
13521353
Row.Rows_constr
13531354
{
1354-
r = [ Row.get_row_for_var row_var ];
1355+
r = [ Row.get_row_for_var Row.empty_provenance row_var ];
13551356
constr =
13561357
Total_elems
13571358
{

0 commit comments

Comments
 (0)