Skip to content

Commit

Permalink
Merge pull request #213 from FissoreD/discr-tree-multiple-args
Browse files Browse the repository at this point in the history
Discrimination tree on multiple args
  • Loading branch information
gares committed Dec 7, 2023
2 parents 3ebfce2 + 1350da7 commit e151123
Show file tree
Hide file tree
Showing 15 changed files with 300 additions and 149 deletions.
36 changes: 24 additions & 12 deletions src/compiler.ml
Original file line number Diff line number Diff line change
Expand Up @@ -632,9 +632,16 @@ end = struct (* {{{ *)
| Some Structured.External -> duplicate_err "external"
| Some _ -> error ~loc "external predicates cannot be indexed"
end
| Index i :: rest ->
| Index(i,index_type) :: rest ->
let it =
match index_type with
| None -> None
| Some "Map" -> Some Map
| Some "Hash" -> Some HashMap
| Some "DTree" -> Some DiscriminationTree
| Some s -> error ~loc ("unknown indexing directive " ^ s) in
begin match r with
| None -> aux (Some (Structured.Index i)) rest
| None -> aux (Some (Structured.Index(i,it))) rest
| Some (Structured.Index _) -> duplicate_err "index"
| Some _ -> error ~loc "external predicates cannot be indexed"
end
Expand All @@ -643,7 +650,7 @@ end = struct (* {{{ *)
let attributes = aux None attributes in
let attributes =
match attributes with
| None -> Structured.Index [1]
| None -> Structured.Index([1],None)
| Some x -> x in
{ Type.attributes; loc; name; ty }

Expand Down Expand Up @@ -2388,17 +2395,22 @@ let compile_clause modes initial_depth state
if morelcs <> 0 then error ~loc "sigma in a toplevel clause is not supported";
state, cl

let chose_indexing state predicate l =
let chose_indexing state predicate l k =
let all_zero = List.for_all ((=) 0) in
let rec aux argno = function
let rec check_map default argno = function
(* TODO: @FissoreD here we should raise an error if n > arity of the predicate? *)
| [] -> error ("Wrong indexing for " ^ Symbols.show state predicate)
| 0 :: l -> aux (argno+1) l
| [] -> error ("Wrong indexing for " ^ Symbols.show state predicate ^ ": no argument selected.")
| 0 :: l -> check_map default (argno+1) l
| 1 :: l when all_zero l -> MapOn argno
| path_depth :: l when all_zero l -> Trie { argno ; path_depth }
| _ -> Hash l
| _ -> default ()
in
aux 0 l
match k with
| Some Ast.Structured.DiscriminationTree -> DiscriminationTree l
| Some HashMap -> Hash l
| None -> check_map (fun () -> DiscriminationTree l) 0 l
| Some Map -> check_map (fun () ->
error ("Wrong indexing for " ^ Symbols.show state predicate ^
": Map indexes exactly one argument at depth 1")) 0 l

let check_rule_pattern_in_clique state clique { D.CHR.pattern; rule_name } =
try
Expand Down Expand Up @@ -2436,8 +2448,8 @@ let run
let mode = try C.Map.find name modes with Not_found -> [] in
let declare_index, index =
match tindex with
| Some (Ast.Structured.Index l) -> true, chose_indexing state name l
| _ -> false, chose_indexing state name [1] in
| Some (Ast.Structured.Index(l,k)) -> true, chose_indexing state name l k
| _ -> false, chose_indexing state name [1] None in
try
let _, old_tindex = C.Map.find name map in
if old_tindex <> index then
Expand Down
11 changes: 6 additions & 5 deletions src/data.ml
Original file line number Diff line number Diff line change
Expand Up @@ -168,10 +168,9 @@ and second_lvl_idx =
time : int; (* time is used to recover the total order *)
args_idx : (clause * int) list Ptmap.t; (* clause, insertion time *)
}
| IndexWithTrie of {
| IndexWithDiscriminationTree of {
mode : mode;
argno : int; (* position of argument on which the trie is built *)
path_depth : int; (* depth bound at which the term is inspected *)
arg_depths : int list; (* the list of args on which the trie is built *)
time : int; (* time is used to recover the total order *)
args_idx : clause DT.t;
}
Expand All @@ -192,12 +191,14 @@ type suspended_goal = {
P. Indexing is done by hashing all the parameters with a non
zero depth and comparing it with the hashing of the parameters
of the query
- [IndexWithTrie N D] -> N-th argument at D depth
- [DiscriminationTree L] ->
we use the same logic of Hash, except we use DiscriminationTree to discriminate
clauses
*)
type indexing =
| MapOn of int
| Hash of int list
| Trie of { argno : int; path_depth : int }
| DiscriminationTree of int list
[@@deriving show]

let mkLam x = Lam x [@@inline]
Expand Down
70 changes: 39 additions & 31 deletions src/discrimination_tree.ml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ let kConstant = 0
let kPrimitive = 1
let kVariable = 2
let kOther = 3

let k_lshift = Sys.int_size - k_bits
let ka_lshift = Sys.int_size - k_bits - arity_bits
let k_mask = ((1 lsl k_bits) - 1) lsl k_lshift
Expand All @@ -25,33 +24,40 @@ let arity_of n =
let mkConstant ~safe c a =
let rc = encode kConstant c a in
if safe && (abs c > data_mask || a >= 1 lsl arity_bits) then
Elpi_util.Util.anomaly (Printf.sprintf "Indexing at depth > 1 is unsupported since constant %d/%d is too large or wide" c a);
Elpi_util.Util.anomaly
(Printf.sprintf "Indexing at depth > 1 is unsupported since constant %d/%d is too large or wide" c a);
rc

let mkVariable = encode kVariable 0 0
let mkOther = encode kOther 0 0
let mkPrimitive c = encode kPrimitive (Elpi_util.Util.CData.hash c lsl k_bits) 0

let () = assert(k_of (mkConstant ~safe:false ~-17 0) == kConstant)
let () = assert(k_of mkVariable == kVariable)
let () = assert(k_of mkOther == kOther)

let isVariable x = k_of x == kVariable
let isOther x = k_of x == kOther
let mkInputMode = encode kOther 1 0
let mkOutputMode = encode kOther 2 0
let () = assert (k_of (mkConstant ~safe:false ~-17 0) == kConstant)
let () = assert (k_of mkVariable == kVariable)
let () = assert (k_of mkOther == kOther)
let isVariable x = x == mkVariable
let isOther x = x == mkOther
let isInput x = x == mkInputMode
let isOutput x = x == mkOutputMode

type cell = int

let pp_cell fmt n =
let k = k_of n in
if k == kConstant then
let data = data_mask land n in
let arity = (arity_mask land n) lsr ka_lshift in
Format.fprintf fmt "Constant(%d,%d)" data arity
else if k == kVariable then Format.fprintf fmt "Variable"
else if k == kOther then Format.fprintf fmt "Other"
else if k == kOther then
if isInput n then Format.fprintf fmt "Input"
else if isOutput n then Format.fprintf fmt "Output"
else Format.fprintf fmt "Other"
else if k == kPrimitive then Format.fprintf fmt "Primitive"
else Format.fprintf fmt "%o" k

let show_cell n =
Format.asprintf "%a" pp_cell n
let show_cell n = Format.asprintf "%a" pp_cell n

module Trie = struct
(*
Expand Down Expand Up @@ -84,9 +90,7 @@ module Trie = struct
['a t Ptmap.t]. The empty trie is just the empty map. *)

type key = int list

type 'a t =
| Node of { data : 'a list; other : 'a t option; map : 'a t Ptmap.t }
type 'a t = Node of { data : 'a list; other : 'a t option; map : 'a t Ptmap.t }

let empty = Node { data = []; other = None; map = Ptmap.empty }

Expand Down Expand Up @@ -131,7 +135,12 @@ module Trie = struct
Format.fprintf fmt "} other:{";
(match other with None -> () | Some m -> pp ppelem fmt m);
Format.fprintf fmt "} key:{";
Ptmap.to_list map |> Elpi_util.Util.pplist (fun fmt (k,v) -> pp_cell fmt k; pp ppelem fmt v) "; " fmt;
Ptmap.to_list map
|> Elpi_util.Util.pplist
(fun fmt (k, v) ->
pp_cell fmt k;
pp ppelem fmt v)
"; " fmt;
Format.fprintf fmt "}]"

let show (fmt : Format.formatter -> 'a -> unit) (n : 'a t) : string =
Expand All @@ -146,11 +155,7 @@ let compare x y = x - y

let skip (path : path) : path =
let rec aux arity path =
if arity = 0 then path
else
match path with
| [] -> assert false
| m :: tl -> aux (arity - 1 + arity_of m) tl
if arity = 0 then path else match path with [] -> assert false | m :: tl -> aux (arity - 1 + arity_of m) tl
in
match path with
| [] -> Elpi_util.Util.anomaly "Skipping empty path is not possible"
Expand Down Expand Up @@ -197,31 +202,34 @@ let rec merge (l1 : ('a * int) list) l2 =
| ((_, tx) as x) :: xs, (_, ty) :: _ when tx > ty -> x :: merge xs l2
| _, y :: ys -> y :: merge l1 ys

let get_all_children v mode = isOther v || (isVariable v && Elpi_util.Util.Output == mode)
let get_all_children v mode = isOther v || (isVariable v && isOutput mode)

(* get_all_children returns if a key should be unified with all the values of
the current sub-tree. This key should be either K.to_unfy or K.variable.
In the latter case, the mode boolean to be true (we are in output mode). *)
let rec retrieve_aux mode path = function
let rec retrieve_aux (mode : cell) path = function
| [] -> []
| hd :: tl -> merge (retrieve mode path hd) (retrieve_aux mode path tl)

and retrieve mode path tree =
match (tree, path) with
| Trie.Node { data }, [] -> data
| Trie.Node { other; map }, v :: path when get_all_children v mode ->
retrieve_aux mode path (all_children other map)
| node, hd :: tl when isInput hd || isOutput hd -> retrieve hd tl tree
| Trie.Node { other; map }, v :: path when get_all_children v mode -> retrieve_aux mode path (all_children other map)
| Trie.Node { other = None; map }, node :: sub_path ->
if mode == Elpi_util.Util.Input && isVariable node then []
if isInput mode && isVariable node then []
else
let subtree = try Ptmap.find node map with Not_found -> Trie.empty in
retrieve mode sub_path subtree
| Trie.Node { other = Some other; map }, (node :: sub_path as path) ->
merge
(if mode == Elpi_util.Util.Input && isVariable node then []
else
let subtree = try Ptmap.find node map with Not_found -> Trie.empty in
retrieve mode sub_path subtree)
(if isInput mode && isVariable node then []
else
let subtree = try Ptmap.find node map with Not_found -> Trie.empty in
retrieve mode sub_path subtree)
(retrieve mode (skip path) other)

let retrieve mode path index = retrieve mode path index |> List.map fst
let retrieve path index =
match path with
| [] -> Elpi_util.Util.anomaly "A path should at least of length 2"
| mode :: tl -> retrieve mode tl index |> List.map fst
5 changes: 3 additions & 2 deletions src/parser/ast.ml
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ type raw_attribute =
| Before of string
| Replace of string
| External
| Index of int list
| Index of int list * string option
[@@deriving show]

module Clause = struct
Expand Down Expand Up @@ -307,7 +307,8 @@ and cattribute = {
}
and tattribute =
| External
| Index of int list
| Index of int list * tindex option
and tindex = Map | HashMap | DiscriminationTree
and 'a shorthand = {
iloc : Loc.t;
full_name : 'a;
Expand Down
5 changes: 3 additions & 2 deletions src/parser/ast.mli
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ type raw_attribute =
| Before of string
| Replace of string
| External
| Index of int list
| Index of int list * string option
[@@ deriving show]

module Clause : sig
Expand Down Expand Up @@ -213,7 +213,8 @@ and cattribute = {
}
and tattribute =
| External
| Index of int list
| Index of int list * tindex option
and tindex = Map | HashMap | DiscriminationTree
and 'a shorthand = {
iloc : Loc.t;
full_name : 'a;
Expand Down
2 changes: 1 addition & 1 deletion src/parser/grammar.mly
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ attribute:
| BEFORE; s = STRING { Before s }
| REPLACE; s = STRING { Replace s }
| EXTERNAL { External }
| INDEX; LPAREN; l = nonempty_list(indexing) ; RPAREN { Index l }
| INDEX; LPAREN; l = nonempty_list(indexing) ; RPAREN; o = option(STRING) { Index (l,o) }

indexing:
| FRESHUV { 0 }
Expand Down
31 changes: 24 additions & 7 deletions src/parser/test_parser.ml
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ let error s a1 a2 =
let f2 = Filename.temp_file "parser_out" "txt" in
let oc1 = open_out f1 in
let oc2 = open_out f2 in
output_string oc1 "new:\n";
output_string oc1 "\nnew:\n";
output_string oc1 (Program.show a1);
output_string oc2 "reference:\n";
output_string oc2 "\nreference:\n";
output_string oc2 (Program.show a2);
flush_all ();
close_out oc1;
Expand Down Expand Up @@ -43,9 +43,9 @@ let test s x y w z att b =
let p = Parser.program_from ~loc lexbuf in
if p <> exp then
error s p exp
with Parse.ParseError(loc,message) ->
Printf.eprintf "error parsing '%s' at %s\n%s%!" s (Loc.show loc) message;
exit 1
with Parse.ParseError(loc,message) ->
Printf.eprintf "error parsing '%s' at %s\n%s%!" s (Loc.show loc) message;
exit 1

let testR s x y w z attributes to_match to_remove guard new_goal =
let exp = [Program.(Chr { Chr.to_match; to_remove; guard; new_goal; loc=(mkLoc x y w z); attributes })] in
Expand All @@ -55,10 +55,25 @@ let testR s x y w z attributes to_match to_remove guard new_goal =
let p = Parser.program_from ~loc lexbuf in
if p <> exp then
error s p exp
with Parse.ParseError(loc,message) ->
with Parse.ParseError(loc,message) ->
Printf.eprintf "error parsing '%s' at %s\n%s%!" s (Loc.show loc) message;
exit 1

let testT s x y w z attributes () =
let lexbuf = Lexing.from_string s in
let loc = Loc.initial "(input)" in
try
let p = Parser.program_from ~loc lexbuf in
match p with
| [Program.Pred _] -> ()
| [Program.Type _] -> ()
| _ ->
Printf.eprintf "error parsing '%s' at %s\n%s%!" s (Loc.show loc) "not a type declaration";
exit 1
with Parse.ParseError(loc,message) ->
Printf.eprintf "error parsing '%s' at %s\n%s%!" s (Loc.show loc) message;
exit 1

let testF s i msg =
let lexbuf = Lexing.from_string s in
let loc = Loc.initial "(input)" in
Expand Down Expand Up @@ -136,6 +151,8 @@ let _ =
testF "x. x]" 5 "unexpected keyword";
testF "x. +" 4 "unexpected start";
test ":name \"x\" x." 0 11 1 0 [Name "x"] (c"x");
testT ":index (1) \"foobar\" pred x." 0 11 1 0 [Index ([1],Some "foobar")] ();
testT ":index (1) pred x." 0 11 1 0 [Index ([1], None)] ();
testF "p :- g (f x) \\ y." 14 ".*bind.*must follow.*name.*";
testF "foo i:term, o:term. foo A B :- A = [B]." 6 "unexpected keyword";
(* 01234567890123456789012345 *)
Expand Down

0 comments on commit e151123

Please sign in to comment.