In [1]:
import random
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from matplotlib.colors import to_rgb
from matplotlib.animation import FuncAnimation, PillowWriter
from IPython.display import HTML

In [2]:
def bubble(array):
    n = len(array)
    for i in range(n-1):
        for j in range(n-i-1):
            yield ([j, j+1], array.copy())
            if array[j] > array[j+1]:
                array[j], array[j+1] = array[j+1], array[j]
                yield ([j, j+1], array.copy())
    yield ([], array.copy())

In [3]:
def insertion(array):
    n = len(array)
    for i in range(1, n):
        key = array[i]
        j = i - 1
        yield ([i, j], array.copy())
        while j >= 0 and key < array[j]:
            array[j+1] = array[j]
            yield ([j, j+1], array.copy())
            j -= 1
        array[j+1] = key
        yield ([i, j+1], array.copy())
    yield ([], array.copy())

In [4]:
def quick(array, low=0, high=None, outer=True):
    if high is None:
        high = len(array) - 1
    if low < high:
        pivot = array[high]
        i = low - 1
        for j in range(low, high):
            yield ([i, j, high], array.copy())
            if array[j] <= pivot:
                i += 1
                array[i], array[j] = array[j], array[i]
                yield ([i, j, high], array.copy())
        array[i+1], array[high] = array[high], array[i+1]
        yield ([i+1, high], array.copy())
        pivot_index = i + 1
        yield from quick(array, low, pivot_index-1, outer=False)
        yield from quick(array, pivot_index+1, high, outer=False)
    if outer:
        yield ([], array.copy())

In [5]:
def cocktail(array):
    n = len(array)
    swapped = True
    start = 0
    end = n - 1
    while swapped:
        swapped = False
        for i in range(start, end):
            yield ([i, i+1], array.copy())
            if array[i] > array[i+1]:
                array[i], array[i+1] = array[i+1], array[i]
                yield ([i, i+1], array.copy())
                swapped = True
        if not swapped:
            break
        swapped = False
        end -= 1
        for i in range(end-1, start-1, -1):
            yield ([i, i+1], array.copy())
            if array[i] > array[i+1]:
                array[i], array[i+1] = array[i+1], array[i]
                yield ([i, i+1], array.copy())
                swapped = True
        start += 1
    yield ([], array.copy())

In [6]:
def merge(array, start=0, end=None, outer=True):
    if end is None:
        end = len(array)
    if end - start <= 1:
        return
    mid = (start + end) // 2
    yield from merge(array, start, mid)
    yield from merge(array, mid, end)
    left = array[start:mid]
    right = array[mid:end]
    i = j = 0
    k = start
    while i < len(left) and j < len(right):
        if left[i] <= right[j]:
            array[k] = left[i]
            yield ([k], array.copy())
            i += 1
        else:
            array[k] = right[j]
            yield ([k], array.copy())
            j += 1
        k += 1
    while i < len(left):
        array[k] = left[i]
        yield ([k], array.copy())
        i += 1
        k += 1
    while j < len(right):
        array[k] = right[j]
        yield ([k], array.copy())
        j += 1
        k += 1
    if outer:
        yield ([], array.copy())

In [7]:
def highlight(colour, factor):
    r, g, b = to_rgb(colour)
    return (
        r + (1 - r) * factor, 
        g + (1 - g) * factor, 
        b + (1 - b) * factor
    )

In [8]:
def anim_sort(sort_func, num_strips, fps=10, save=False):
    
    plt.style.use('dark_background')
    fig, ax = plt.subplots(figsize=(9, 6))
    ax.set_xlim(0, 3 * num_strips)
    ax.set_ylim(0, 2 * num_strips)
    ax.axis('off')

    colours = plt.cm.rainbow(np.linspace(1, 0, num_strips))
    heights = [2*i+2 for i in range(num_strips)]

    colour_dict = dict(zip(heights, colours))
    
    random.shuffle(heights)

    strips = []

    for i, h in enumerate(heights):
        colour = colour_dict[h]
        strip = Rectangle((3*i, 0), 3, h, color=colour)
        ax.add_patch(strip)
        strips.append(strip)

    steps = list(sort_func(heights))

    def update(frame):
        picks, state = frame

        for i, strip in enumerate(strips):
            h = state[i]
            colour = colour_dict[h]
            strip.set_height(h)
            strip.set_color(colour)

        for pick in picks:
            h = state[pick]
            colour = highlight(colour_dict[h], 0.8)
            strips[pick].set_color(colour)

        return strips

    anim = FuncAnimation(
        fig,
        update,
        frames=steps,
        interval=100,
        blit=True
    )

    if save:
        anim.save(
            f'{sort_func.__name__}_{num_strips}.gif', 
            writer=PillowWriter(fps=10)
        )

    plt.close(fig)
    
    return HTML(anim.to_jshtml())

In [9]:
anim_sort(cocktail, 10, save=False)

In [10]:
anim_sort(merge, 20, save=False)