# Garsia-Wachs Algorithm
###### (Algorithm Design with Haskell)

Garsia-Wachs is an algorithm for finding a binary tree of minimum cost with a fringe equal to the input. It works in two stages
- Build a tree from a list of weights that has as it's labels the indices [1..n] of the input. The fringe is not necessarily equal to the input.
- Using the depths of the nodes in the first tree, build a second tree by bracketing the input weights such that their corresponding depth in the first tree must match the depth of it's pair. This first tree is constructed in such a way that the resulting tree is the optimal one with the input as it's fringe.


In [2]:
type Weight = Int
type Label = Int
type Depth = Int
data Tree a = Leaf a | Fork (Tree a) (Tree a)

gwa :: [Weight] -> Tree Weight
gwa ws = rebuild ws (build ws)

build :: [Weight] -> Tree Label
build = undefined

rebuild :: [Weight] -> Tree Label -> Tree Weight
rebuild = undefined

Once we have the output of `build`, the next step is to break it into a list of pairs where the first elem is a leaf with the weight as it's label and the second elem being it's depth in the first tree. This list is arranged in fringe order.

We want to repeatedly combine adjacent pairs who's depths are the same until a single tree remains. When two trees are combined, the depth is reduced by 1.

In [3]:
depth = snd

reduce :: [(Tree Weight, Depth)] -> Tree Weight
reduce = extract . until single step where
  extract [(t,_)] = t
  single [x] = True
  single _ = False
  step (x:y:xs) = if depth x == depth y
                      then join x y : xs
                      else x : step (y : xs)
join (t1, d) (t2, _) = (Fork t1 t2, d - 1)

This implementation is inefficient because at each step it starts from the beginning of the list even though we may have already discovered that the front of the list cannot be reduced further. We can fix this using a left fold and recursive definition of step.

In [4]:
:m Data.List

reduce :: [(Tree Weight, Depth)] -> Tree Weight
reduce = extract . foldl' step [] where
  extract [(t,_)] = t
  step [] y = [y]
  step (x:xs) y = if depth x == depth y
                    then step xs (join x y)
                    else y:x:xs

Now to implement `rebuild`

In [5]:
import qualified Data.Array as A

rebuild :: [Weight] -> Tree Label -> Tree Weight
rebuild ws = reduce . zip (map Leaf ws) . sortDepths (length ws)

fringeDepth :: Tree a -> [(a, Depth)]
fringeDepth = ($ []) . from 0 where
  from d (Leaf a) = ((a, d) :)
  from d (Fork t1 t2) = from (d + 1) t1 . from (d + 1) t2

sortDepths :: Int -> Tree Label -> [Depth]
sortDepths size t = A.elems $ A.array (1,  size) (fringeDepth t)

Now we need the first stage of the algorithm, `build`, which is the complicated part. We start with a quadratic implementation and then move to a linearithmic one.

Begin with a list of `[(0, w0), (1, w1), (2, w2)...]` where the first elem of each pair is a Leaf and the second is a weight. The first pair is a sentinel and has infinite weigth.

At each step we find the pair with the largest index such that the weight of it's left neighbor is >= the weight of it's right neighbor. Combine this pair with it's right neighbor so that the trees are joined and the weights are added, then move this new pair to the right of all pairs which have a lesser weight.

At the end of this process there are just two pairs left, the sentinel and the pair whose tree is the desired result.

Once a pair has been selected, merged, and moved into place, the possible choices for the next pair is limited to
- the pair to the right of where the first pair got moved, if it's right neighbor has weight less than the moved pair's.
- the pair to the right of where the moved pair used to be if it's neighbors fit the condition.
- any of the pairs to the left of the left neighbor of where the moved pair used to be.

In [6]:
build :: [Weight] -> Tree Label
build = extract
      . foldr step []
      . (zip (Leaf <$> [0..]))
      . (maxBound :) -- could use sum of weights instead of maxBound
  where
    extract [(t,_), _] = t

    step x (y:z:zs)
      | weight x >= weight z = step x $ move (join y z) zs
    step x xs = x : xs

    move x [] = [x]
    move x (y:ys)
      | weight x > weight y = y : move x ys
      | otherwise = step x (y : ys)
    
    weight = snd
join (t1, w1) (t2, w2) = (Fork t1 t2, w1 + w2)

In the worst case `build` is quadratic. To improve this we will exploit the fact that the second argument to both step and move are "2-sorted", meaning A1,A3,A5 and A2,A4,A6 are sorted lists in terms of weight. This will allow for a linear implementation and bring the overall runtime of the algorithm to linearithmic.

To do this we will use a data structure that supports efficiently splitting at some value, which is what we need for `move`. We'll use a binary search tree where a flatenning of the tree produces a 2-sorted list. In order to insert we need to look at both a nodes weight and the weight of the preceding node to determine whether to go right or left. To this end, we will install the previous weight on each node.

In [22]:
data List a = Null
            | Node Int (List a) (a, a) (List a)
            
height :: List a -> Int
height Null = 0
height (Node h _ _ _) = h
            
emptyL :: List a
emptyL = Null
nullL :: List a -> Bool
nullL Null = True
nullL _ = False

node :: List a -> (a,a) -> List a -> List a
node t1 x t2 = Node (max (height t1) (height t2) + 1) t1 x t2

bias :: List a -> Int
bias Null = 0
bias (Node _ l _ r) = height l - height r

balance :: List a -> (a,a) -> List a -> List a
balance t1 x t2
  | abs (h1 - h2) <= 1 = node t1 x t2
  | h1 == h2 + 2 = rotateR t1 x t2
  | h2 == h1 + 2 = rotateL t1 x t2
  where
    h1 = height t1; h2 = height t2
rotateR t1 x t2
  | 0 <= bias t1 = rotr (node t1 x t2)
  | otherwise = rotr (node (rotl t1) x t2)
rotateL t1 x t2
  | bias t2 <= 0 = rotl (node t1 x t2)
  | otherwise = rotl (node t1 x (rotr t2))
rotr (Node _ (Node _ l y lr) x r) = node l y (node lr x r)
rotl (Node _ l x (Node _ rl y r)) = node (node l x rl) y r

balanceR :: List a -> (a,a) -> List a -> List a
balanceR (Node _ l y r) x t2
  | height r >= height t2 = balance l y (balanceR r x t2)
  | otherwise = balance l y (node r x t2)
  
balanceL :: List a -> (a,a) -> List a -> List a
balanceL t2 x (Node _ l y r)
  | height l >= height t2 = balance (balanceL l x t2) y r
  | otherwise = balance (node l x t2) y r

gbalance :: List a -> (a, a) -> List a -> List a
gbalance l x r
  | abs (hl - hr) <= 2 = balance l x r
  | hl > hr + 2 = balanceR l x r
  | hl + 2 < hr = balanceL l x r
  where hl = height l; hr = height r

consL :: a -> List a -> List a
consL x Null = node Null (x, x) Null
consL x (Node _ t1 (y, z) t2)
  | nullL t1 = balance (consL x t1) (x, z) t2
  | otherwise = balance (consL x t1) (y, z) t2
  
deconsL :: List a -> (a, List a)
deconsL (Node _ l (x, y) r)
  | nullL l = (y, r)
  | otherwise = (z, balance t (x, y) r)
  where
    (z, t) = deconsL l
    
lastL :: List a -> a
lastL (Node _ t1 (x,y) t2) = if nullL t2 then y else lastL t2
    
concatL :: List a -> List a -> List a
concatL t Null = t
concatL Null t = t
concatL t1 t2 = gbalance t1 (x, y) t2
  where x = lastL t1
        (y, t3) = deconsL t2
        
type Pair = (Tree Label, Weight)
weight :: Pair -> Weight
weight (t, w) = w
        
splitL :: Pair -> List Pair -> (List Pair, List Pair)
splitL x t = sew (pieces x t)

data Piece a = LP (List a) (a, a)
             | RP (a, a) (List a)
             
pieces :: Pair -> List Pair -> [Piece Pair]
pieces x t = addPiece t [] where
  addPiece Null ps = ps
  addPiece (Node _ t1 (y,z) t2) ps
    | weight x > max (weight y) (weight z)
    = addPiece t2 (LP t1 (y,z) : ps)
    | otherwise = addPiece t1 (RP (y,z) t2 : ps)
    
sew :: [Piece Pair] -> (List Pair, List Pair)
sew = foldl' go (emptyL, emptyL) where
  go (t1, t2) (LP t x) = (gbalance t x t1, t2)
  go (t1, t2) (RP x t) = (t1, gbalance t2 x t)
  
-- now we can define it!
buildL :: [Weight] -> Tree Label
buildL ws = extractL (foldr stepL emptyL (start ws)) where
  start ws = zip (map Leaf [0..]) (maxBound : ws)
  
  extractL xs = t where
    (_, ys) = deconsL xs
    ((t, _), _) = deconsL ys
    
  stepL :: Pair -> List Pair -> List Pair
  stepL x xs
    | nullL xs || nullL ys || weight x < weight z
    = consL x xs
    | otherwise = stepL x (insertL (join y z) zs)
    where
      (y, ys) = deconsL xs
      (z, zs) = deconsL ys
      
  insertL :: Pair -> List Pair -> List Pair
  insertL x xs = concatL ys (stepL x zs) where
    (ys, zs) = splitL x xs