Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse files

[enhance] qmlPatternAnalysis: reduces the complexity of class_merge

the class_merge function (which can be called a number of time proportional to the sum type size),
was for each field of the class :
-flattening the main type,
-then linear searching the right sum case,
-then searching the field type,

now each operation is done once when the information (ty,class,field) is available and the linear search has been replace by index intersection

reduction is at least proportional to the size of the sum type
  • Loading branch information...
commit df9bea6e07c34475614a889e868df0a49d57898d 1 parent f758244
@OpaOnWindowsNow OpaOnWindowsNow authored
Showing with 115 additions and 44 deletions.
  1. +1 −0  libqmlcompil/_tags
  2. +114 −44 libqmlcompil/qmlPatternAnalysis.ml
View
1  libqmlcompil/_tags
@@ -27,5 +27,6 @@
<typer_w/w_Unify.ml> : with_mlstate_debug
<qmlEffects.ml>: with_mlstate_debug
<qmlMoreTypes.ml>: with_mlstate_debug
+<qmlPatternAnalysis.ml>: with_mlstate_debug
<qml2opa.{byte,native}>: thread, use_unix, use_ulex
View
158 libqmlcompil/qmlPatternAnalysis.ml
@@ -79,6 +79,60 @@ let (|>) = InfixOperator.(|>)
module Format = Base.Format
module List = Base.List
+module RecordIndex =
+struct
+ open Loop.Deprecated
+
+ module Array = BaseArray
+ type ('a,'b) index =
+ {
+ keys : 'a array;
+ values : 'b array;
+ map : IntSet.t StringMap.t;
+ size : IntSet.t IntMap.t;
+ }
+
+ let (|?) a b = match a with
+ | Some a -> a
+ | None -> b
+
+ let create entities key value =
+ let entities = Array.of_list entities in
+ let keys = Array.map key entities in
+ let values = Array.map value entities in
+ let map,size = Array.fold_left_i (fun (map,size) fields id ->
+ let size = IntMap.update_default (List.length fields) (IntSet.add id) (IntSet.singleton id) size in
+ let map = l_fold(fields,map)(fun field map ->
+ StringMap.update_default field (IntSet.add id) (IntSet.singleton id) map
+ ) in map,size
+ ) (StringMap.empty,IntMap.empty) keys
+ in
+ (*
+ StringMap.iter (fun field v ->
+ Format.printf "Field %s -> {%a}\n" field (Format.pp_list "," Format.pp_print_int) (IntSet.elements v)
+ ) map;
+ IntMap.iter (fun d v ->
+ Format.printf "Size %d : {%a}\n" d (Format.pp_list "," Format.pp_print_int) (IntSet.elements v)
+ ) size;*)
+ { keys; values; map; size}
+
+ (* is ok with slow and stupid intersection but could benefit from better intersection than a fold *)
+ let get_case strict fields index =
+ let map = index.map in
+ let size = index.size in
+ let to_list index set = IntSet.fold (fun id list -> index.values.(id)::list) set [] in
+ try match fields with
+ | [] -> if strict then to_list index (IntMap.find 0 size) else Array.to_list index.values
+ | hd :: tl ->
+ let len = List.length tl + 1 in
+ let set = l_fold(tl,StringMap.find hd map)(fun field set->
+ IntSet.inter set (StringMap.find field map)
+ ) in
+ let set = if strict then IntSet.inter set (IntMap.find len size) else set in
+ to_list index set
+ with Not_found -> []
+end
+
(* refactoring in progress *)
(* shorthands *)
@@ -699,7 +753,14 @@ module Normalize = struct
exception Local_no_recur
(* TODO Document *)
- let class_merge new_ident _is_joint strict ty (c,l) =
+ let class_merge new_ident _is_joint strict ty =
+ let get_field_type = match ty with
+ | None -> (fun _ _ -> None)
+ | Some ty ->
+ let get = L.strict_get_field_type ty in
+ (fun fields -> let get = get fields in fun field -> Some(get field))
+ in
+ fun (c,l) ->
let strip_names_1 o =
match o with
| Term t -> Term { t with bind = {t.bind with ident =[]}}
@@ -721,17 +782,14 @@ module Normalize = struct
| _ -> debug "ASSERT get_recur: %a\n" (Format.pp_list " |@\n " print) initl; raise Local_no_recur
) in Recur(Option.get id,ty, if List.length l > 1 then Or { cases = List.rev l ; default=None ; ty=ty} else List.hd l)
in
- let get_field_type fields = match ty with
- | None -> (fun _ -> None)
- | Some ty -> fun field -> Some (L.strict_get_field_type ty fields field)
- in
debug "class_merge %a\n" print_class c;
let recurs = try Some(get_recur l) with Local_no_recur -> None in
let ext_ty tys = match ty with None -> tys | Some ty -> ty::tys in
match c with
| CRECORD fields when fields<>[] ->
+ let get_field_type = get_field_type fields in
let id = new_ident "record_to_recurse" in
- let fields_idents_types = l_map(fields)(fun field -> field, new_ident field, get_field_type fields field) in
+ let fields_idents_types = l_map(fields)(fun field -> field, new_ident field, get_field_type field) in
let fields = l_map(fields_idents_types)(fun (f,name,ty)->
let names,tys = get_names_and_types l f in
let tys = match ty with None -> tys | Some ty -> ty::tys in
@@ -804,7 +862,7 @@ let rec class_layer ol =
with Not_found -> None
in
let class_o = l_map(ol)(fun o->PatternClass.from_pattern o,o) in
- let classes = Hashtbl.create 6 in
+ let classes = Hashtbl.create (List.length class_o) in
l_iter(class_o)(hash_add classes);
let classes_patterns = l_filter_map(class_o)(hash_pop classes) in
classes_patterns
@@ -910,7 +968,7 @@ and or_cases ~path ?(recurse_todo=[]) ty new_ident cases =
let any =
debug "TODO : I should do completion verification here\n";
(* completion verification *)
- let missing = if is_const_cases then
+ let missing() = if is_const_cases then
let missings= if strict = [] then (assert (any<>[]); [] ) else L.get_missing (List.map (function _,Term {pat = Cst c} -> c | _ -> assert false) strict) in
if missings <> [] || strict = [] then (
let exc = Missing_const_case(path,missings) in
@@ -937,10 +995,10 @@ and or_cases ~path ?(recurse_todo=[]) ty new_ident cases =
Some exc)
else None
) in
- match missing with
+ if any <> [] then any
+ else match missing() with
| Some exc ->
- debug "Appart 'any', incomplete pattern\n";
- if any = [] then (
+ debug "Appart 'any', incomplete pattern\n";
raise_public exc;
debug "Adding failure branch\n";
let buf = Buffer.create 10 in
@@ -950,7 +1008,6 @@ and or_cases ~path ?(recurse_todo=[]) ty new_ident cases =
Format.pp_print_flush fmt ();
let str = Buffer.contents buf in
[max_int,term ~e:(Failure str) ()]
- ) else any
| None -> any
in
(* when any is alone, it should not be strictified, otherwise you can make infinite pattern for recursives types *)
@@ -1183,25 +1240,13 @@ struct
in aux first (List.tl l) |> List.rev
+ let add_opt opt list =
+ match opt with
+ | Some e -> e :: list
+ | None -> list
- let rec get_type_cases ty fields =
- let row _unstrict (fields:string list) (field_ty:(string*ty) list) :ty list =
- (* SLOW : see if it matters *)
-
- assert(ordered2 (List.sort compare field_ty));
- assert(ordered (List.sort compare fields));
- try
- let fields_of_type = l_map_sort(field_ty)(fst) in
- let field_type_from_fields = l_map_sort(fields)(fun f->f,Private.raw_get_field_type field_ty f) in
- match fields_completion (Private.raw_get_field_type field_ty) fields_of_type field_type_from_fields [] with
- | Some fields ->
- let fields = List.sort compare fields in
- assert(ordered2 fields);
- [Q.TypeRecord(Q.TyRow(fields,_unstrict) )]
- | None -> []
- with Failure _ -> []
- in
- let rec self ty fields =
+ let get_filtered_type_record ty filtered =
+ let rec self ty acc=
match ty with
(* trivial cases *)
| Q.TypeVar _
@@ -1211,25 +1256,44 @@ struct
-> []
(* border line cases *)
- | Q.TypeForall(_,_,_,ty) -> self ty fields
+ | Q.TypeForall(_,_,_,ty) -> self ty acc
(* gamma aware cases *)
| Q.TypeName(params, ident) ->
let gamma = (typer_env ()).gamma in
let ty = QmlTypesUtils.Inspect.find_and_specialize gamma ident params in
- self ty fields
+ self ty acc
(* interesting cases *)
| Q.TypeRecord( Q.TyRow(field_ty_l,_unstrict) ) ->
(* assert( _unstrict=None ); *)
- row _unstrict fields field_ty_l
+ add_opt (filtered _unstrict field_ty_l) acc
+
+ | Q.TypeSum( Q.TyCol(row_l,_unstrict) ) -> l_fold(row_l,acc)(fun field_ty_l acc->add_opt (filtered None field_ty_l) acc )
- | Q.TypeSum( Q.TyCol(row_l,_unstrict) ) -> l_map_flat(row_l)(row None fields)
+ | Q.TypeSumSugar ty_l -> l_fold(ty_l,acc)(self)
- | Q.TypeSumSugar ty_l -> l_map_flat(ty_l)(fun ty->self ty fields)
+ in self ty []
- in self ty fields
+ let rec get_type_record ty fields =
+ let row (fields:string list) _unstrict (field_ty:(string*ty) list) :ty option =
+ (* SLOW : see if it matters *)
+
+ assert(ordered2 (List.sort compare field_ty));
+ assert(ordered (List.sort compare fields));
+ try
+ let fields_of_type = l_map_sort(field_ty)(fst) in
+ let field_type_from_fields = l_map_sort(fields)(fun f->f,Private.raw_get_field_type field_ty f) in
+ match fields_completion (Private.raw_get_field_type field_ty) fields_of_type field_type_from_fields [] with
+ | Some fields ->
+ let fields = List.sort compare fields in
+ assert(ordered2 fields);
+ Some(Q.TypeRecord(Q.TyRow(fields,_unstrict) ))
+ | None -> None
+ with Failure _ -> None
+ in
+ get_filtered_type_record ty (row fields)
let strictify_record_ty (ty:ty) (fields:string list) =
debug "Strictify record %a {%a} @\n%!" print_ty ty (Format.pp_list "," Format.pp_print_string) fields;
@@ -1243,14 +1307,20 @@ struct
in
l_map_sort(field_ty_l)(fun (f,_)->f), strict
| _ -> assert false
- ) (get_type_cases ty fields)
-
- (* SLOW TODO : add a strict get_type_cases for better perfs *)
- let strict_get_field_type (ty:ty) (fields:string list) (field:string):ty=
- match Private.filter_ambiguous fields (get_type_cases ty fields) with
- | [ Q.TypeRecord( Q.TyRow(field_ty_l,_)) ] -> Private.raw_get_field_type field_ty_l field
- | tys ->
- debug "get_field_type %a {%a} %s in <<%a>>@\n%!" print_ty ty (Format.pp_list "/" Format.pp_print_string) fields field (Format.pp_list "," print_ty) tys ;
+ ) (get_type_record ty fields)
+
+
+ let case_selector (ty:ty) =
+ let records :(string*ty) list list = get_filtered_type_record ty (fun _ x-> Some x) in
+ let index = RecordIndex.create records ((List.map fst):(string*ty) list->string list) (fun x->x) in
+ fun (fields:string list) -> RecordIndex.get_case true fields index
+
+ let strict_get_field_type (ty:ty) =
+ let case_selector = case_selector ty in
+ fun fields ->
+ match case_selector fields with
+ | [field_ty_l] -> (Private.raw_get_field_type field_ty_l) (* field *)
+ | tys -> Format.printf "strict_get_field_type ambiguous selection (%d) %a {%a}" (List.length tys) print_ty ty (Format.pp_list "/" Format.pp_print_string) fields ;
(* typer specification make this case an assert false case *)
assert false
Please sign in to comment.
Something went wrong with that request. Please try again.