# Codensity Transformation
###### (https://www.janis-voigtlaender.eu/papers/AsymptoticImprovementOfComputationsOverFreeMonads.pdf)

The codensity monad the right kan extension of a functor over itself:

In [80]:
:ext RankNTypes
:ext DeriveFunctor
newtype Codensity f a = Codensity { runCodensity :: forall b. (a -> f b) -> f b } deriving Functor

This is useful for optimizing the binding operation of a Monad in a similiar fashion to how Difference Lists optimize the repeated concatenation of Lists.

In [32]:
newtype DList a = DList { runDList :: [a] -> [a] }

fromList :: [a] -> DList a
fromList = DList . (++)

toList :: DList a -> [a]
toList = ($ []) . runDList

instance Semigroup (DList a) where
  DList a <> DList b = DList (a . b)
  
-- (((x ++ x) ++ x) ++ x) ++ x is bad!
-- x ++ (x ++ (x ++ (x ++ x))) is good!

To see how the optimization works we can look at how Monad can be implemented for a Binary Tree. Know that `subst`, an implementation of monadic bind, replaces the leaves of the tree by grafting in a new tree produced from the contained value - classic algebraic monad.

In [23]:
:ext DeriveFunctor
data Tree a = Leaf a | Node (Tree a) (Tree a) deriving (Eq, Show, Functor)

subst :: Tree a -> (a -> Tree b) -> Tree b
subst (Leaf a) f = f a
subst (Node a b) f = Node (subst a f) (subst b f)

instance Applicative Tree where
  pure a = Leaf a
  Leaf f <*> Leaf a = Leaf $ f a
  Node fa fb <*> Leaf a = Node (fa <*> Leaf a) (fb <*> Leaf a)
  Leaf f <*> Node a b = Node (f <$> a) (f <$> b)
  Node fa fb <*> Node a b = Node (fa <*> Node a b) (fb <*> Node a b)
  
instance Monad Tree where
  (>>=) = subst

Now given a function such as

In [26]:
fullTree :: Int -> Tree Int
fullTree 1 = Leaf 1
fullTree n = do
  i <- fullTree (n - 1)
  Node (Leaf $ n - 1 - i) (Leaf $ i + 1)

Under the Monad instance the second equation unfolds to

In [31]:
fullTree' n = subst (fullTree' $ n - 1) (\i -> Node (Leaf $ n - 1 - i) (Leaf $ i + 1))

And so `fullTree 4` ends up looking like `subst (subst (subst (subst (Leaf 1) ..) ..) ..) ..`

Recall that `subst` calls itself recursively until it reaches the leaves and so the execution flow here will be that each `subst` traverses the entire structure produced by the `subst` inside of it - so the prefix fragments of the overall tree structure will be traversed again and again. The generalized runtime complexity of this is 2^n which is to be expected since that is the order of the binary tree structure.

However, if we instead write a function that only traces a linear path through the tree, we get quadratic complexity

In [35]:
zigZag :: Tree Int -> Int
zigZag = zig where
  zig (Leaf n) = n
  zig (Node t _) = zag t
  zag (Leaf n) = n
  zag (Node _ t) = zig t
  
zigZag (fullTree 4)

2

In [48]:
fullTree 4

Node (Node (Node (Leaf 1) (Leaf 3)) (Node (Leaf 2) (Leaf 2))) (Node (Node (Leaf 3) (Leaf 1)) (Node (Leaf 0) (Leaf 4)))

Like how DList abstracts over the end of the list, we can make a representation of Tree that abstracts over the leaves.

In [69]:
import Control.Monad (ap)

newtype CTree a = CTree { runCTree :: forall b. (a -> Tree b) -> Tree b } deriving Functor

rep :: Tree a -> CTree a
rep t = CTree (subst t)

abst :: CTree a -> Tree a
abst (CTree f) = f Leaf

instance Applicative CTree where
 pure a = CTree $ \h -> h a
 (<*>) = ap

instance Monad CTree where
  CTree p >>= f = CTree $ \h -> p $ \a -> runCTree (f a) h
  
-- Allow abstracting over the two Tree representations
class Monad t => TreeLike t where
  node :: t a -> t a -> t a
  
instance TreeLike Tree where
  node = Node
  
instance TreeLike CTree where
  node (CTree a) (CTree b) = CTree $ \h -> Node (a h) (b h)
  
leaf :: TreeLike t => a -> t a
leaf = pure

fullTree' :: TreeLike t => Int -> t Int
fullTree' 1 = leaf 1
fullTree' n = do
  i <- fullTree' (n - 1)
  node (leaf $ n - 1 - i) (leaf $ i + 1)
  
-- We use this function to switch to the asymptotically more efficient version
improve :: (forall t. TreeLike t => t a) -> Tree a
improve = abst

zigZag (improve (fullTree' 4))

2

`zigZag (improve (fullTree' 4))` has linear execution time rather than quadratic. This is because instead of consuming a tree and rebuilding a bigger tree at each step, we are instead building up a function to a tree from a given input which we then call with `1` when the base case is reached.

Now we can introduce the more general version of this method using the FreeMonad

In [76]:
data Free f a = Return a | Wrap (f (Free f a)) deriving Functor

instance Functor f => Applicative (Free f) where
  pure = Return
  (<*>) = ap

instance Functor f => Monad (Free f) where
  Return a >>= f = f a
  Wrap fa >>= f = Wrap $ (>>= f) <$> fa

What was abstraction over the leaves in the previous section is now abstraction over the Return value of the Free monad.

In [91]:
:ext MultiParamTypeClasses
:ext FlexibleInstances

rep :: Monad m => m a -> Codensity m a
rep m = Codensity (m >>=)

abst :: Monad m => Codensity m a -> m a
abst = (`runCodensity` pure)

instance Applicative (Codensity f) where
  pure a = Codensity $ \h -> h a
  (<*>) = ap
  
instance Monad (Codensity f) where
  Codensity p >>= f = Codensity $ \h -> p $ \a -> runCodensity (f a) h
  
-- We need support for constructing the non-return values in both `Free f` and `Codensity (Free f)`
class (Functor f, Monad m) => FreeLike f m where
  wrap :: f (m a) -> m a
  
instance Functor f => FreeLike f (Free f) where
  wrap = Wrap

instance FreeLike f m => FreeLike f (Codensity m) where
  wrap t = Codensity $ \h -> wrap $ fmap (($ h) . runCodensity) t
  
-- Our new magic function
improve :: Functor f => (forall m. FreeLike f m => m a) -> Free f a
improve = abst

Now we can more generally represent our Tree example above as `CTree = Codensity (Free Pair)` and `FreeLike Pair = TreeLike`.