Skip to content
Elton Lin edited this page Sep 11, 2019 · 4 revisions

Recursion

A lot of people have trouble getting into the recursive mindset, if this describes you, no worries, this is quite normal. In this tutorial, we will start by exploring recursive algorithms on linked lists and other similar structures, and then move on to more complex structures, like trees.

What is Recursion?

We call a function recursive when it calls itself, for example, the following C++ function is recursive:

int f(int x) {
    if (x == 0) {
        return 0;
    }
    return 2 + f(x-1);
}

Since I'm a weirdo that likes Haskell, here is an equivalent version in Haskell:

f :: Int -> Int
f 0 = 0
f x = 2 + f (x - 1)

Can you figure out what this function does? (It's quite useless) Lets work through some examples:

f(0) == 0
f(1) == 2 + f(0) == 2
f(2) == 2 + f(1) == 2 + (2 + f(0)) == 4
f(3) == 2 + f(2) == 2 + (2 + f(1)) == 2 + (2 + (2 + f(0))) == 6

Can you spot a pattern? This is just a really stupid way to double the input. Now, for fun lets prove our algorithm correct using induction.

We have already proven the base case that f(0) == 0. Now, lets assume that for all x < k that f(x) == 2*x. All we have to do now is to prove that f(k) = 2*k, where k >0 because 0 is the base case.

f(k) == 2 + f(k - 1) == 2 + 2*(k-1) == 2 + 2*k - 2 == 2*k

Here, we've used the fact that k-1 < k to apply the inductive hypothesis. We've now proven our first recursive algorithm correct, thats pretty neat!

Why Recursion

You might be thinking, why do we even want to use recursion, when we have perfectly good constructs like for and while loops! First, once you get the hang of it, recursive algorithms can be much simpler to reason about than their iterative counterparts. Secondly, a lot of interview questions have to do with manipulating recursive data structures such as linked lists and trees, so it's good to know how they work! Thirdly, some languages don't even have any loops! For example, Haskell doesn't have any loops, only recursion, and you can implement loops using recursion. Lastly, but not leastly, the second 280 project is all about recursion, and is perhaps the most fun of them all.

The Linked List

The linked list is one of the more important data structures to grasp conceptually, if not used in practice much due to bad cache locality. The linked list is actually a recursive data structure! What is a recursive data structure you ask? It is a data structure that is defined using its own definition recursively. A linked list is either empty, or an element followed by another linked list. This recursive nature is best seen in the Haskell definition of a linked list:

data [a] = a:[a] | []

What this code says is that a list is either a single element (of type a) and another list, or empty. In case you don't know Haskell (probably most of you), here is an approximately equivalent C++ construct. I will only be talking about lists of integers in this tutorial, but the exact same algorithms will apply to lists of any type.

struct list {
    int data;
    list* next = nullptr;
};

In the C++ version, nullptr represents an empty list, and a list structure is the recursive part, containing a single element and a pointer to the rest of the list. How could we construct some simple lists?

empty = []
one = 1:[]
twoOnes = 1:1:[]

countDown :: Int -> [Int]
countDown 0 = 0:[]
countDown x = x:(countDown x-1)

countUp :: Int -> [Int]
countUp x = x:(countUp x+1)

naturals = countUp 0

The above are different lists defined in Haskell, some of them are defined recursively. Side note: Haskell is lazily evaluated, which means that it won't compute functions until their results are needed, so it doesn't choke up on infinite lists, such as naturals.

I will define some of the above lists in C++, excluding the infinite ones for obvious reasons, and I will be lazy about copying. Assume that copy makes a deep copy of a linked list, we will define it later.

list* empty = nullptr;

list* one = new list;
one->data = 1;
one->next = nullptr;

list* twoOnes = new list;
twoOnes->data = 1;
twoOnes->next = copy(one);

list* countDown(int x) {
    if (x < 0) {
        return nullptr;
    }
    list* result = new list;
    result->data = x;
    result->next = countDown(x-1);
    return result;
}

Copying a Linked List

Now we will consider a simple recursive algorithm for making a deep copy of a linked list. In Haskell, this algorithm is simple because all data structures are immutable, which means that you can never modify a data structure, only create new ones! Memory efficiency is obtained by sharing unmodified parts of data structures, which is okay because of immutability.

However, in C++ we need to worry about making deep copies because we are allowed to modify any part of the data structure willy-nilly. I will try to write up an implementation of a immutable linked list sometime in the future. Now, onto copying, we could use a boring iterative algorithm to copy our linked list, but this is no fun at all! For your reference, here is a boring iterative copy:

list* copy(list* head) {
    list* result = nullptr;
    while (head != nullptr) {
        list* tmp = new list;
        tmp->data = head->data;
        tmp->next = result;

        result = tmp;
        head = head->next;
    }
    return result;
}

Boring! Before you read on, you should consider how you would implement this using a recursive algorithm. What is the base case? What is the recursive case?

Now, we will convert this droll while loop into a recursive function. Notice how we always check if we are looking at a nullptr, the end of the list. This is our base case, which makes sense because that is also the base case in the definition of a list. In the recursive or inductive case, we copy the first element, and then continue on copying the rest of the list. Putting it all together we get:

list* copy(list* head) {
    if (head == nullptr) {
        return nullptr;
    }

    list* result = new list;
    result->data = head->data;
    result->next = copy(head->next);
}

See how much easier that is to parse! All we do is if the list isn't empty, we copy the first element, then recursively copy the rest of the list! If it is empty, we just return an empty list, easy-peasy!

Reversing a Linked List

I dare you to write an iterative algorithm to reverse a linked list, it is a mess! Not only is it too complicated, it is really hard to reason about what it does, and don't get me started on the edge cases.

Now, we need to think of reversing a linked list recursively. What is the operation of reversing a list, really? Really, it is a recursive operation!

Where does the first element of the list go? Easy! at the end of the list. Where does the second element go? The second to last element. Lastly (or firstly), the last element go? At the beginning of the list.

I see a pattern here, so lets work on the cases. The base case is easy, an empty list reversed is still an empty list. A one element list is an element followed by an empty list, which reversed is an empty list followed by one element, which is the same thing. We could go on, but I will leave it as an exercise for the reader. Can you think of the general recursive algorithm? You should really try to think of it before you read the next paragraph.

You got it! You just put the first element at the end of the list, and reverse the rest of the list. Now lets see a concrete implementation, first in Haskell:

reverse :: [a] -> [a]
reverse []   = []
reverse (x:xs) = reverse xs ++ [x]

Now, the same thing in C++:

list* reverse(list* head) {
    // Empty list
    if (head == nullptr) {
        return nullptr;
    }
    // List only has one node, important because we need to manipulate second to last elt
    if (head->next == nullptr) {
        return head;
    }
    list* result = reverse(head->next); // Reverse the rest of the list.

    // Put the first element at the end of the rest of the list.
    head->next->next = head;

    // Set the end of the linked list to nullptr.
    // Remember that head now points to the last element in the list!
    head->next = nullptr;

    return result;
}

The C++ version is more complicated than the Haskell version because we have to deal with raw pointers.

Proof

We will prove the correctness of our reversal algorithm through induction. If you don't understand the algorithm, I highly recommend you try to follow the proof, because you might find it helpful in understanding the algorithm.

First, a base case is the empty list, and our algorithm returns the empty list, which is correct. Now, we will assume that for any list of length less than k, that our algorithm works correctly. Lastly, assume we have a list x:xs of exactly length k, where k > 0, we will prove our procedure is correct. Note that x:xs represents the element x followed by the list xs, and that we will represent appending two lists as ys ++ zs.

reverse(x:xs) == reverse(xs) ++ (x:[])

This algorithm is pretty trivially true, its just C++ that makes it ugly.

Remove Duplicates

The next task is to remove duplicates from a sorted linked list. This is a classic interview question, so you should show off your recursive prowess by not using the puny iterative version!

What are some base cases? Well any list of one or less elements cannot possibly have duplicates! What is the recursive case? If the first two element are the same, delete the first element, if they aren't do nothing. Then remove duplicates from the rest of the list starting at the second element. See how easy that is, now compare that to the ugly iterative version:

void remove_duplicates(list* head) {
    if (head == nullptr) {
        return;
    }

    while (head->next != nullptr) {
        if (head->data == head->next->data) {
            list* tmp = head->next;
            head->next = head->next->next;
            delete head->next;
        } else {
            head = head->next;
        }
    }
}

How long does it take you to ensure that the edge cases work? Does it work for an empty and one element list? How about two or more?

Now, compare that to the recursive version we described above, both in Haskell and C++:

nub :: [Int] -> [Int]
nub []      = []
nub (x:[])  = x:[]
nub (x:y:zs)
    | x == y    = nub (y:zs)
    | otherwise = x : nub (y:zs)
list* nub(list* head) {
    if (head == nullptr || head->next == nullptr) {
        return head;
    }

    if (head->data == head->next->data) {
        list* rest = head->next;
        delete head;
        return nub(rest);
    } else {
        head->next = nub(head->next);
        return head;
    }
}

Proof

In this algorithm we have two base cases, when the list is empty or contains one element. In both cases our algorithm is just the identity (returns its argument unchanged). Our inductive hypothesis is that for any list of length less than k, where k > 1, that our algorithm works. Now, we need to prove the inductive case, that it works for a list x:y:zs of length k. Note that x:y:zs is a list where the first two elements are x and y respectively, and the rest of the list is zs. There are two cases, when x == y, and x != y.

If x == y, then our algorithm says

nub(x:y:zs) == nub(y:zs)

This is correct, because nub(y:zs) contains no duplicates and contains a y by our inductive hypothesis, and because y == x it must contain exactly one instance of x.

If x != y, then our algorithm says

nub(x:y:zs) == x : nub(y:zs)

This is correct, because nub(y:zs) contains no duplicates and no xs because we have a sorted list and x != y, and therefore the list will contain exactly one copy of x, and one copy of everything else because of our inductive hypothesis.

Removing Duplicates Again

We want to generalize our solution to work for unsorted lists. We can always just sort our list, but that is no fun (and slow)! No, instead we will be threading some state through our recursive function.

We have the same function, except this time we need to call a helper function to keep track of the elements we've seen.

list* nub(list* head) {
    return nub_helper(head, std::unordered_set<int>());
}

list* nub_helper(list* head, std::unordered_set<int>& seen) {
    if (head == nullptr) {
        return nullptr;
    }
    if (seen.count(head->data)) {
        return nub_helper(head->next, seen);
    } else {
        seen.insert(head->data);
        head->next = nub_helper(head->next, seen);
    }
}

Insert discussion here

TODO

  • Add more algorithms with linked lists.
  • Add algorithms for trees.
  • Add other assorted algorithms, like palindrome.
  • Check for typos
  • Actually compile the code, it has only been proven correct not tested.