-
Notifications
You must be signed in to change notification settings - Fork 33
/
MergeSort.hs
95 lines (82 loc) · 4.06 KB
/
MergeSort.hs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
-----------------------------------------------------------------------------
-- |
-- Module : Data.SBV.Examples.BitPrecise.MergeSort
-- Copyright : (c) Levent Erkok
-- License : BSD3
-- Maintainer : erkokl@gmail.com
-- Stability : experimental
-- Portability : portable
--
-- Symbolic implementation of merge-sort and its correctness.
-----------------------------------------------------------------------------
module Data.SBV.Examples.BitPrecise.MergeSort where
import Data.SBV
-----------------------------------------------------------------------------
-- * Implementing Merge-Sort
-----------------------------------------------------------------------------
-- | Element type of lists we'd like to sort. For simplicity, we'll just
-- use 'SWord8' here, but we can pick any symbolic type.
type E = SWord8
-- | Merging two given sorted lists, preserving the order.
merge :: [E] -> [E] -> [E]
merge [] ys = ys
merge xs [] = xs
merge xs@(x:xr) ys@(y:yr) = ite (x .< y) (x : merge xr ys) (y : merge xs yr)
-- | Simple merge-sort implementation. We simply divide the input list
-- in two two halves so long as it has at least two elements, sort
-- each half on its own, and then merge.
mergeSort :: [E] -> [E]
mergeSort [] = []
mergeSort [x] = [x]
mergeSort xs = merge (mergeSort th) (mergeSort bh)
where (th, bh) = splitAt (length xs `div` 2) xs
-----------------------------------------------------------------------------
-- * Proving correctness
-- ${props}
-----------------------------------------------------------------------------
{- $props
There are two main parts to proving that a sorting algorithm is correct:
* Prove that the output is non-decreasing
* Prove that the output is a permutation of the input
-}
-- | Check whether a given sequence is non-decreasing.
nonDecreasing :: [E] -> SBool
nonDecreasing [] = true
nonDecreasing [_] = true
nonDecreasing (a:b:xs) = a .<= b &&& nonDecreasing (b:xs)
-- | Check whether two given sequences are permutations. We simply check that each sequence
-- is a subset of the other, when considered as a set. The check is slightly complicated
-- for the need to account for possibly duplicated elements.
isPermutationOf :: [E] -> [E] -> SBool
isPermutationOf as bs = go as (zip bs (repeat true)) &&& go bs (zip as (repeat true))
where go [] _ = true
go (x:xs) ys = let (found, ys') = mark x ys in found &&& go xs ys'
-- Go and mark off an instance of 'x' in the list, if possible. We keep track
-- of unmarked elements by associating a boolean bit. Note that we have to
-- keep the lists equal size for the recursive result to merge properly.
mark _ [] = (false, [])
mark x ((y,v):ys) = ite (v &&& x .== y)
(true, (y, bnot v):ys)
(let (r, ys') = mark x ys in (r, (y,v):ys'))
-- | Asserting correctness of merge-sort for a list of the given size. Note that we can
-- only check correctness for fixed-size lists. Also, the proof will get more and more
-- complicated for the backend SMT solver as 'n' increases. A value around 5 or 6 should
-- be fairly easy to prove. For instance, we have:
--
-- >>> correctness 5
-- Q.E.D.
correctness :: Int -> IO ThmResult
correctness n = prove $ do xs <- mkFreeVars n
let ys = mergeSort xs
return $ nonDecreasing ys &&& isPermutationOf xs ys
-----------------------------------------------------------------------------
-- * Generating C code
-----------------------------------------------------------------------------
-- | Generate C code for merge-sorting an array of size 'n'. Again, we're restricted
-- to fixed size inputs. While the output is not how one would code merge sort in C
-- by hand, it's a faithful rendering of all the operations merge-sort would do as
-- described by it's Haskell counterpart.
codeGen :: Int -> IO ()
codeGen n = compileToC (Just ("mergeSort" ++ show n)) "mergeSort" $ do
xs <- cgInputArr n "xs"
cgOutputArr "ys" (mergeSort xs)