Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Discrimination tree on multiple args #213

Merged
merged 5 commits into from
Dec 7, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/compiler.ml
Original file line number Diff line number Diff line change
Expand Up @@ -2395,10 +2395,10 @@ let chose_indexing state predicate l =
| [] -> error ("Wrong indexing for " ^ Symbols.show state predicate)
| 0 :: l -> aux (argno+1) l
| 1 :: l when all_zero l -> MapOn argno
| path_depth :: l when all_zero l -> Trie { argno ; path_depth }
| _ -> Hash l
in
aux 0 l
| _ -> Trie l
(* TODO: @FissoreD we should add some syntax if we don't want to lose the indexing with Hash *)
(* | _ -> Hash l *)
in aux 0 l

let check_rule_pattern_in_clique state clique { D.CHR.pattern; rule_name } =
try
Expand Down
8 changes: 4 additions & 4 deletions src/data.ml
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,7 @@ and second_lvl_idx =
}
| IndexWithTrie of {
mode : mode;
argno : int; (* position of argument on which the trie is built *)
path_depth : int; (* depth bound at which the term is inspected *)
arg_depths : int list; (* the list of args on which the trie is built *)
time : int; (* time is used to recover the total order *)
args_idx : clause DT.t;
}
Expand All @@ -192,12 +191,13 @@ type suspended_goal = {
P. Indexing is done by hashing all the parameters with a non
zero depth and comparing it with the hashing of the parameters
of the query
- [IndexWithTrie N D] -> N-th argument at D depth
- [Trie L] -> we use the same logic of Hash, except we use Trie to discriminate
clauses
*)
type indexing =
| MapOn of int
| Hash of int list
| Trie of { argno : int; path_depth : int }
| Trie of int list
[@@deriving show]

let mkLam x = Lam x [@@inline]
Expand Down
70 changes: 39 additions & 31 deletions src/discrimination_tree.ml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ let kConstant = 0
let kPrimitive = 1
let kVariable = 2
let kOther = 3

let k_lshift = Sys.int_size - k_bits
let ka_lshift = Sys.int_size - k_bits - arity_bits
let k_mask = ((1 lsl k_bits) - 1) lsl k_lshift
Expand All @@ -25,33 +24,40 @@ let arity_of n =
let mkConstant ~safe c a =
let rc = encode kConstant c a in
if safe && (abs c > data_mask || a >= 1 lsl arity_bits) then
Elpi_util.Util.anomaly (Printf.sprintf "Indexing at depth > 1 is unsupported since constant %d/%d is too large or wide" c a);
Elpi_util.Util.anomaly
(Printf.sprintf "Indexing at depth > 1 is unsupported since constant %d/%d is too large or wide" c a);
rc

let mkVariable = encode kVariable 0 0
let mkOther = encode kOther 0 0
let mkPrimitive c = encode kPrimitive (Elpi_util.Util.CData.hash c lsl k_bits) 0

let () = assert(k_of (mkConstant ~safe:false ~-17 0) == kConstant)
let () = assert(k_of mkVariable == kVariable)
let () = assert(k_of mkOther == kOther)

let isVariable x = k_of x == kVariable
let isOther x = k_of x == kOther
let mkInputMode = encode kOther 1 0
let mkOutputMode = encode kOther 2 0
let () = assert (k_of (mkConstant ~safe:false ~-17 0) == kConstant)
let () = assert (k_of mkVariable == kVariable)
let () = assert (k_of mkOther == kOther)
let isVariable x = x == mkVariable
let isOther x = x == mkOther
let isInput x = x == mkInputMode
let isOutput x = x == mkOutputMode

type cell = int

let pp_cell fmt n =
let k = k_of n in
if k == kConstant then
let data = data_mask land n in
let arity = (arity_mask land n) lsr ka_lshift in
Format.fprintf fmt "Constant(%d,%d)" data arity
else if k == kVariable then Format.fprintf fmt "Variable"
else if k == kOther then Format.fprintf fmt "Other"
else if k == kOther then
if isInput n then Format.fprintf fmt "Input"
else if isOutput n then Format.fprintf fmt "Output"
else Format.fprintf fmt "Other"
else if k == kPrimitive then Format.fprintf fmt "Primitive"
else Format.fprintf fmt "%o" k

let show_cell n =
Format.asprintf "%a" pp_cell n
let show_cell n = Format.asprintf "%a" pp_cell n

module Trie = struct
(*
Expand Down Expand Up @@ -84,9 +90,7 @@ module Trie = struct
['a t Ptmap.t]. The empty trie is just the empty map. *)

type key = int list

type 'a t =
| Node of { data : 'a list; other : 'a t option; map : 'a t Ptmap.t }
type 'a t = Node of { data : 'a list; other : 'a t option; map : 'a t Ptmap.t }

let empty = Node { data = []; other = None; map = Ptmap.empty }

Expand Down Expand Up @@ -131,7 +135,12 @@ module Trie = struct
Format.fprintf fmt "} other:{";
(match other with None -> () | Some m -> pp ppelem fmt m);
Format.fprintf fmt "} key:{";
Ptmap.to_list map |> Elpi_util.Util.pplist (fun fmt (k,v) -> pp_cell fmt k; pp ppelem fmt v) "; " fmt;
Ptmap.to_list map
|> Elpi_util.Util.pplist
(fun fmt (k, v) ->
pp_cell fmt k;
pp ppelem fmt v)
"; " fmt;
Format.fprintf fmt "}]"

let show (fmt : Format.formatter -> 'a -> unit) (n : 'a t) : string =
Expand All @@ -146,11 +155,7 @@ let compare x y = x - y

let skip (path : path) : path =
let rec aux arity path =
if arity = 0 then path
else
match path with
| [] -> assert false
| m :: tl -> aux (arity - 1 + arity_of m) tl
if arity = 0 then path else match path with [] -> assert false | m :: tl -> aux (arity - 1 + arity_of m) tl
in
match path with
| [] -> Elpi_util.Util.anomaly "Skipping empty path is not possible"
Expand Down Expand Up @@ -197,31 +202,34 @@ let rec merge (l1 : ('a * int) list) l2 =
| ((_, tx) as x) :: xs, (_, ty) :: _ when tx > ty -> x :: merge xs l2
| _, y :: ys -> y :: merge l1 ys

let get_all_children v mode = isOther v || (isVariable v && Elpi_util.Util.Output == mode)
let get_all_children v mode = isOther v || (isVariable v && isOutput mode)

(* get_all_children returns if a key should be unified with all the values of
the current sub-tree. This key should be either K.to_unfy or K.variable.
In the latter case, the mode boolean to be true (we are in output mode). *)
let rec retrieve_aux mode path = function
let rec retrieve_aux (mode : cell) path = function
| [] -> []
| hd :: tl -> merge (retrieve mode path hd) (retrieve_aux mode path tl)

and retrieve mode path tree =
match (tree, path) with
| Trie.Node { data }, [] -> data
| Trie.Node { other; map }, v :: path when get_all_children v mode ->
retrieve_aux mode path (all_children other map)
| node, hd :: tl when isInput hd || isOutput hd -> retrieve hd tl tree
| Trie.Node { other; map }, v :: path when get_all_children v mode -> retrieve_aux mode path (all_children other map)
| Trie.Node { other = None; map }, node :: sub_path ->
if mode == Elpi_util.Util.Input && isVariable node then []
if isInput mode && isVariable node then []
else
let subtree = try Ptmap.find node map with Not_found -> Trie.empty in
retrieve mode sub_path subtree
| Trie.Node { other = Some other; map }, (node :: sub_path as path) ->
merge
(if mode == Elpi_util.Util.Input && isVariable node then []
else
let subtree = try Ptmap.find node map with Not_found -> Trie.empty in
retrieve mode sub_path subtree)
(if isInput mode && isVariable node then []
else
let subtree = try Ptmap.find node map with Not_found -> Trie.empty in
retrieve mode sub_path subtree)
(retrieve mode (skip path) other)

let retrieve mode path index = retrieve mode path index |> List.map fst
let retrieve path index =
match path with
| [] -> Elpi_util.Util.anomaly "A path should at least of length 2"
| mode :: tl -> retrieve mode tl index |> List.map fst
141 changes: 78 additions & 63 deletions src/runtime.ml
Original file line number Diff line number Diff line change
Expand Up @@ -2429,57 +2429,74 @@ let hash_clause_arg_list = hash_arg_list false
let hash_goal_arg_list = hash_arg_list true

(**
[arg_to_trie_path_aux ~depth t_list path_depth]
Takes a list of terms and builds the path representing this list with
height limited to [depth].
[arg_to_trie_path ~safe ~depth is_goal args arg_depths mode]
returns the path represetation of a term to be used in indexing with trie.
args, args_depths and mode are the lists of respectively the arguments, the
depths and the modes of the current term to be indexed.
is_goal is used to know if we are encoding the path for instance retriaval or
for clause insertion in the trie.
In the former case, each argument we add a special mkInputMode/mkOutputMode
node before each argument to be indexed. This special node is used during
instance retrival to deal with the input/output mode of the considere argument
*)
let rec arg_to_trie_path_aux ~safe ~depth t_list path_depth : Discrimination_tree.path =
if path_depth = 0 then []
else
match t_list with
| [] -> []
| hd :: tl ->
let hd_path = arg_to_trie_path ~safe ~depth hd path_depth in
let tl_path = arg_to_trie_path_aux ~safe ~depth tl path_depth in
hd_path @ tl_path
(**
[arg_to_trie_path ~depth t path_depth]
Takes a [term] and returns it path representation with height bound by [path_depth]
*)
and arg_to_trie_path ~safe ~depth t path_depth : Discrimination_tree.path =
let arg_to_trie_path ~safe ~depth is_goal args arg_depths mode : Discrimination_tree.path =
let open Discrimination_tree in
if path_depth = 0 then []
else
let path_depth = path_depth - 1 in
match deref_head ~depth t with
| Const k when k == Global_symbols.uvarc -> [mkVariable]
| Const k when safe -> [mkConstant ~safe k 0]
| Const k -> [mkConstant ~safe k 0]
| CData d -> [mkPrimitive d]
| App (k,_,_) when k == Global_symbols.uvarc -> [mkVariable]
| App (k,a,_) when k == Global_symbols.asc -> arg_to_trie_path ~safe ~depth a (path_depth+1)
| Nil -> [mkConstant ~safe Global_symbols.nilc 0]
| Lam _ -> [mkOther] (* loose indexing to enable eta *)
| Arg _ | UVar _ | AppArg _ | AppUVar _ | Discard -> [mkVariable]
| Builtin (k,tl) ->
let path = arg_to_trie_path_aux ~safe ~depth tl path_depth in
mkConstant ~safe k (if path_depth = 0 then 0 else List.length tl) :: path
| App (k, x, xs) ->
let arg_length = if path_depth = 0 then 0 else List.length xs + 1 in
let hd_path = arg_to_trie_path ~safe ~depth x path_depth in
let tl_path = arg_to_trie_path_aux ~safe ~depth xs path_depth in
mkConstant ~safe k arg_length :: hd_path @ tl_path
| Cons (x,xs) ->
let hd_path = arg_to_trie_path ~safe ~depth x path_depth in
let tl_path = arg_to_trie_path ~safe ~depth xs path_depth in
mkConstant ~safe Global_symbols.consc (if path_depth = 0 then 0 else 2) :: hd_path @ tl_path

(**
[arg_to_trie_path ~path_depth ~depth t]
Take a term and returns its path representation up to path_depth
*)
let arg_to_trie_path ~safe ~path_depth ~depth t =
arg_to_trie_path ~safe ~depth t path_depth
(** prepend the mode of the current argument if we are "pathifing" a goal *)
let prepend_mode is_goal mode tl = if is_goal then mode :: tl else tl in
(** gives the path representation of a list of sub-terms *)
let rec arg_to_trie_path_aux ~safe ~depth t_list path_depth : Discrimination_tree.path =
if path_depth = 0 then []
else
match t_list with
| [] -> []
| hd :: tl ->
let hd_path = arg_to_trie_path ~safe ~depth hd path_depth in
let tl_path = arg_to_trie_path_aux ~safe ~depth tl path_depth in
hd_path @ tl_path
(** gives the path representation of a term *)
and arg_to_trie_path ~safe ~depth t path_depth : Discrimination_tree.path =
let open Discrimination_tree in
if path_depth = 0 then []
else
let path_depth = path_depth - 1 in
match deref_head ~depth t with
| Const k when k == Global_symbols.uvarc -> [mkVariable]
| Const k when safe -> [mkConstant ~safe k 0]
| Const k -> [mkConstant ~safe k 0]
| CData d -> [mkPrimitive d]
| App (k,_,_) when k == Global_symbols.uvarc -> [mkVariable]
| App (k,a,_) when k == Global_symbols.asc -> arg_to_trie_path ~safe ~depth a (path_depth+1)
| Nil -> [mkConstant ~safe Global_symbols.nilc 0]
| Lam _ -> [mkOther] (* loose indexing to enable eta *)
| Arg _ | UVar _ | AppArg _ | AppUVar _ | Discard -> [mkVariable]
| Builtin (k,tl) ->
let path = arg_to_trie_path_aux ~safe ~depth tl path_depth in
mkConstant ~safe k (if path_depth = 0 then 0 else List.length tl) :: path
| App (k, x, xs) ->
let arg_length = if path_depth = 0 then 0 else List.length xs + 1 in
let hd_path = arg_to_trie_path ~safe ~depth x path_depth in
let tl_path = arg_to_trie_path_aux ~safe ~depth xs path_depth in
mkConstant ~safe k arg_length :: hd_path @ tl_path
| Cons (x,xs) ->
let hd_path = arg_to_trie_path ~safe ~depth x path_depth in
let tl_path = arg_to_trie_path ~safe ~depth xs path_depth in
mkConstant ~safe Global_symbols.consc (if path_depth = 0 then 0 else 2) :: hd_path @ tl_path
(** builds the sub-path of a sublist of arguments of the current clause *)
and make_sub_path arg_hd arg_tl arg_depth_hd arg_depth_tl mode_hd mode_tl =
let tl = arg_to_trie_path ~safe ~depth arg_hd arg_depth_hd @
aux ~safe ~depth is_goal arg_tl arg_depth_tl mode_tl in
prepend_mode is_goal (match mode_hd with Input -> mkInputMode | _ -> mkOutputMode) tl
(** main function: build the path of the arguments received in entry *)
and aux ~safe ~depth is_goal args arg_depths mode : Discrimination_tree.path =
match args, arg_depths, mode with
| _, [], _ -> []
| arg_hd :: arg_tl, arg_depth_hd :: arg_depth_tl, [] ->
make_sub_path arg_hd arg_tl arg_depth_hd arg_depth_tl Output []
| arg_hd :: arg_tl, arg_depth_hd :: arg_depth_tl, mode_hd :: mode_tl ->
make_sub_path arg_hd arg_tl arg_depth_hd arg_depth_tl mode_hd mode_tl
| _, _ :: _,_ -> anomaly "Invalid Index length" in
if args == [] then prepend_mode is_goal mkOutputMode []
else aux ~safe ~depth is_goal args arg_depths mode

let add1clause ~depth m (predicate,clause) =
match Ptmap.find predicate m with
Expand Down Expand Up @@ -2528,11 +2545,11 @@ let add1clause ~depth m (predicate,clause) =
time = time + 1;
args_idx = Ptmap.add hash ((clause,time) :: clauses) args_idx
}) m
| IndexWithTrie {mode; argno; args_idx; time; path_depth } ->
let path = arg_to_trie_path ~safe:true ~depth ~path_depth (match clause.args with [] -> Discard | l -> List.nth l argno) in
| IndexWithTrie {mode; arg_depths; args_idx; time } ->
let path = arg_to_trie_path ~depth ~safe:true false clause.args arg_depths mode in
let dt = DT.index args_idx path clause ~time in
Ptmap.add predicate (IndexWithTrie {
mode; argno; path_depth;
mode; arg_depths;
time = time+1;
args_idx = dt
}) m
Expand Down Expand Up @@ -2583,8 +2600,8 @@ let make_index ~depth ~indexing ~clauses_rev:p =
flex_arg_clauses = [];
arg_idx = Ptmap.empty;
}
| Trie { argno; path_depth } -> IndexWithTrie {
argno; path_depth; mode;
| Trie arg_depths -> IndexWithTrie {
arg_depths; mode;
args_idx = DT.empty;
time = min_int;
}
Expand Down Expand Up @@ -2641,10 +2658,9 @@ let rec nth_not_bool_default l n = match l with
| x :: _ when n = 0 -> x
| _ :: l -> nth_not_bool_default l (n - 1)

let trie_goal_args goal argno : term = match goal with
| Const a when argno = 0 -> goal
| App(k, x, _) when argno = 0 -> x
| App (_, _, xs) -> nth_not_found xs (argno - 1)
let trie_goal_args goal : term list = match goal with
| Const _ -> []
| App(_, x, xs) -> x :: xs
| _ -> assert false

let get_clauses ~depth predicate goal { index = m } =
Expand All @@ -2662,14 +2678,13 @@ let get_clauses ~depth predicate goal { index = m } =
let hash = hash_goal_args ~depth mode args goal in
let cl = List.flatten (Ptmap.find_unifiables hash args_idx) in
List.(map fst (sort (fun (_,cl1) (_,cl2) -> cl2 - cl1) cl))
| IndexWithTrie {argno; path_depth; mode; args_idx} ->
let mode_arg = nth_not_bool_default mode argno in
let path = arg_to_trie_path ~safe:false ~depth ~path_depth (trie_goal_args goal argno) in
| IndexWithTrie {arg_depths; mode; args_idx} ->
let path = arg_to_trie_path ~safe:false ~depth true (trie_goal_args goal) arg_depths mode in
[%spy "dev:disc-tree:path" ~rid
Discrimination_tree.pp_path path
pp_int path_depth
(pplist pp_int ";") arg_depths
(*Discrimination_tree.(pp pp_clause) args_idx*)];
let candidates = DT.retrieve mode_arg path args_idx in
let candidates = DT.retrieve path args_idx in
[%spy "dev:disc-tree:candidates" ~rid
pp_int (List.length candidates)];
candidates
Expand Down