Permalink
Browse files

implements some helper functions for creating and manipulating partit… (

#892)

* implements some helper functions for creating and manipulating partitions

* Address various issues wrt performance and clarity

* Base test structure for partitions

* Change merge to union, run ocp-indent on all modified files

* Write a few tests for the partition module

* Write a few tests for the partition module

* Fix parenthesizing bug in tests
  • Loading branch information...
codyroux authored and ivg committed Nov 9, 2018
1 parent d2a18af commit ded29fd17b90e89e70fe34afe5895717b51177ee
Showing with 184 additions and 2 deletions.
  1. +21 −0 lib/graphlib/graphlib.mli
  2. +114 −1 lib/graphlib/graphlib_graph.ml
  3. +4 −0 lib/graphlib/graphlib_graph.mli
  4. +45 −1 lib_test/bap_types/test_graph.ml
@@ -491,6 +491,27 @@ module Std : sig
type 'a t = 'a partition
(** [trivial s] creates the trivial partition with a single
equivalence class containing every member of [s] *)
val trivial : ('a, 'b) Set.t -> 'a t
(** [discrete s] returns the partition with one class per element of [s] *)
val discrete : ('a, 'b) Set.t -> 'a t
(** [refine p ~rel ~comp] takes a partition [p], and refines it
according to the equivalence relation [r], so that the
resulting partition corresponds to the classes of [r], assuming
that those classes are finer that the original [p].
Takes an additional [comp] argument to compare for equality
within the equivalence classes. *)
val refine : 'a t -> equiv:('a -> 'a -> bool) -> cmp:('a -> 'a -> int) -> 'a t
(** [union p x y] returns the partition p with the classes of [x]
and [y] merged. Returns [p] unchanged if either [x] or [y] are
not part of any equivalence class. *)
val union : 'a t -> 'a -> 'a -> 'a t
(** [groups p] returns all partition cells of a partitioning [p] *)
val groups : 'a t -> 'a group seq
@@ -186,10 +186,123 @@ module Partition = struct
let find x = Option.Monad_infix.(Hashtbl.find comps x >>= find_root) in
{roots; groups; find}
let equiv t x y = Option.equal Equiv.equal (t.find x) (t.find y)
(* The trivial partition with a single class, or zero if elts is empty *)
let trivial elts =
let head = Set.choose elts in
match head with
| None ->
let roots = [||] in
let groups = [||] in
let find _ = None in
{roots; groups; find}
| Some h ->
let roots = Array.create ~len:1 h in
let groups = Array.create ~len:1 (create_set elts) in
let find x = if Set.mem elts x then Some 0 else None in
{roots; groups; find}
(* The discrete partition with one class per element *)
let discrete elts =
let comparator = Set.comparator elts in
let {Comparator.compare} = comparator in
(* Produces a sorted array per the spec *)
let roots = Set.to_array elts in
let groups = elts |> Set.to_array |>
Array.map ~f:(fun x ->
object
method enum = Seq.return x
method mem y = compare x y = 0
end)
in
let find x = Array.binary_search roots ~compare `First_equal_to x in
{roots; groups; find}
(* Takes a partition and a congruence and splits each equivalence class into
elements related by the congruence.
Takes in a comparison function to test for membership in each class.
*)
let refine (type elt) t ~equiv ~cmp =
let module T = Comparator.Make(struct
type t = elt
let compare = cmp
let sexp_of_t = sexp_of_opaque
end) in
let comparator = T.comparator in
let refine_group g =
let rec insert elt output input = match input with
| [] -> Set.singleton ~comparator elt :: output
| group :: input ->
if equiv (Set.choose_exn group) elt
then List.rev_append ((Set.add group elt) :: output) input
else insert elt (group::output) input in
Seq.fold g#enum ~init:[] ~f:(fun groups elt ->
insert elt [] groups) |> List.rev_map ~f:create_set in
let groups_list = Array.fold t.groups ~init:[] ~f:(fun seqs g -> refine_group g @ seqs) in
let groups = Array.of_list groups_list in
Array.sort groups ~cmp:(fun s1 s2 ->
let h1 = Seq.hd_exn s1#enum in
let h2 = Seq.hd_exn s2#enum in
cmp h1 h2);
let roots = Array.map ~f:(fun s -> Seq.hd_exn s#enum) groups in
let find x = Array.binary_search roots ~compare:cmp `First_equal_to x in
{roots; groups; find}
(* Take two elements and combine their classes if both have a class,
do nothing otherwise *)
let union t x y =
(* Assuming i < j,
create a new array a', such that
Array.length a' = Array.length a - 1 and
a'[k] = a[k] when k < i
a'[i] = x
a'[j] = a[j+1]
a'[j+1] = a[j+2]
...
*)
let array_replace a i j x =
assert (i < j && Array.length a > 0);
Array.init (Array.length a - 1)
~f:(fun n -> if n < i then a.(n)
else if n = i then x
else if n < j then a.(n)
else a.(n+1))
in
if equiv t x y then t
else
match t.find x, t.find y with
| None, _ | _,None -> t
| Some i_x, Some i_y ->
let g_x, g_y = t.groups.(i_x), t.groups.(i_y) in
let u_g = object
method enum =
let s_x, s_y = g_x#enum, g_y#enum in
Seq.append s_x s_y
method mem x = g_x#mem x || g_y#mem x
end
in
(* min biased root *)
let i = Int.min i_x i_y in
let j = Int.max i_x i_y in
let u_root = t.roots.(i) in
let roots = array_replace t.roots i j u_root in
let groups = array_replace t.groups i j u_g in
let find x = Option.Monad_infix.(
t.find x >>|
(* By cases: if n < i or i < n < j then it is in one one
of the original classes, otherwise n = i, then one
should return i (as the class still contains these
elements) or n = j, in wich case these elements are now
in class i, or n > j, in which case we must left-shift them *)
fun n -> if n < j then n
else if n = j then i
else n - 1) in
{roots; groups; find}
let nth_group t n = Group.create t.roots.(n) t.groups.(n) n
let groups t = Seq.(range 0 (Array.length t.roots) >>| nth_group t)
let group t x = Option.(t.find x >>| nth_group t)
let equiv t x y = Option.equal Equiv.equal (t.find x) (t.find y)
let number_of_groups t = Array.length t.roots
let of_equiv t i =
if i >= 0 && i < Array.length t.roots
@@ -206,6 +206,10 @@ end
module Partition : sig
type 'a t = 'a partition
val trivial : ('a, 'b) Set.t -> 'a t
val discrete : ('a, 'b) Set.t -> 'a t
val refine : 'a t -> equiv:('a -> 'a -> bool) -> cmp:('a -> 'a -> int) -> 'a t
val union : 'a t -> 'a -> 'a -> 'a t
val groups : 'a t -> 'a group Sequence.t
val group : 'a t -> 'a -> 'a group option
val equiv : 'a t -> 'a -> 'a -> bool
@@ -564,7 +564,50 @@ end
module Test_int100 = Construction(Int100)
module Test_partition = struct
module P = Partition
let add x s = Set.add s x
let s = Set.empty Int.comparator
|> add 0
|> add 1
|> add 2
|> add 3
|> add 4
|> add 5
|> add 6
|> add 7
|> add 8
|> add 9
|> add 10
let n = Set.length s
let trivial p _ = assert_bool "failed" (P.number_of_groups p = 1)
let discrete p _ = assert_bool "failed" (P.number_of_groups p = n)
let union p x y _ = assert_bool "failed" (P.equiv p x y)
let refine p equiv _ = assert_bool "failed"
(Seq.for_all (P.groups p)
~f:(fun g ->
let x = Group.top g in
Seq.for_all (Group.enum g) ~f:(fun y -> equiv x y)))
let equiv x y = (x - y) mod 2 = 0
let cmp x y = x - y
let suite () = [
"Trivial invariant" >:: trivial (P.trivial s);
"Discrete invariant" >:: discrete (P.discrete s);
"Union invariant" >:: union (P.union (P.discrete s) 1 2) 1 2;
"Refine invariant" >:: refine (P.refine (P.trivial s) equiv cmp) equiv
]
end
let suite () =
"Graph" >::: [
@@ -573,5 +616,6 @@ let suite () =
let module Test = Test_algo(G) in
Test.suite (sprintf "%d" n));
"Construction" >::: [Test_int100.suite];
"IR" >::: Test_IR.suite ()
"IR" >::: Test_IR.suite ();
"Partition" >::: Test_partition.suite ()
]

0 comments on commit ded29fd

Please sign in to comment.