# 相关依赖库
* opencv （实时渲染）
* numpy （矩阵与数组操作）
* random （生成随机数）
* time （schedule定时相关）

In [2]:
import cv2
import numpy as np
import time
import random

# 基础流程渲染类DataSeq

In [3]:
class DataSeq:
    WHITE = (255,255,255)
    RED = (0,0,255)
    BLACK = (0,0,0)
    YELLOW = (0,127,255)
    def __init__(self, Length, time_interval=1, sort_title="Figure", repeatition=False):
        self.data = [x for x in range(Length)]
        if repeatition:
            self.data = random.choices(self.data, k=Length)
        else:
            self.Shuffle()
        self.length = Length

        self.SetTimeInterval(time_interval)
        self.SetSortType(sort_title)
        self.Getfigure()
        self.InitTime()

        self.Visualize()

    def InitTime(self):
        self.start=time.time()
        self.time=0
        self.StopTimer()

    def StartTimer(self):
        self.start_flag=True
        self.start = time.time()

    def StopTimer(self):
        self.start_flag=False

    def GetTime(self):
        if self.start_flag:
            self.time = time.time()-self.start

    def SetTimeInterval(self, time_interval):
        self.time_interval=time_interval

    def SetSortType(self, sort_title):
        self.sort_title=sort_title

    def Shuffle(self):
        random.shuffle(self.data)

    def Getfigure(self):
        _bar_width = 5
        figure = np.full((self.length*_bar_width,self.length*_bar_width,3), 255,dtype=np.uint8)
        for i in range(self.length):
            val = self.data[i]
            figure[-1-val*_bar_width:, i*_bar_width:i*_bar_width+_bar_width] = self.GetColor(val, self.length)
        self._bar_width = _bar_width
        self.figure = figure

    @staticmethod
    def GetColor(val, TOTAL):
        return (120+val*255//(2*TOTAL), 255-val*255//(2*TOTAL), 0)

    def _set_figure(self, idx, val):
        min_col = idx*self._bar_width
        max_col = min_col+self._bar_width
        min_row = -1-val*self._bar_width
        self.figure[ : , min_col:max_col] = self.WHITE
        self.figure[ min_row: , min_col:max_col] = self.GetColor(val, self.length)

    def SetColor(self, img, marks, color):
        for idx in marks:
            min_col = idx*self._bar_width
            max_col = min_col+self._bar_width
            min_row = -1-self.data[idx]*self._bar_width
            img[min_row:, min_col:max_col] = color
    def Mark(self, img, marks, color):
        self.SetColor(img, marks, color)

    def SetVal(self, idx, val):
        self.data[idx] = val
        self._set_figure(idx, val)

        self.Visualize((idx,))

    def Swap(self, idx1, idx2):
        self.data[idx1], self.data[idx2] = self.data[idx2], self.data[idx1]
        self._set_figure(idx1, self.data[idx1])
        self._set_figure(idx2, self.data[idx2])

        self.Visualize((idx1, idx2))

    def Visualize(self, mark1=None, mark2=None):
        img = self.figure.copy()
        if mark2:
            self.Mark( img, mark2, self.YELLOW)
        if mark1:
            self.Mark( img, mark1, self.RED)
        if img.shape[1] > 500:
            img = cv2.resize(img, (500,500))
        
        self.GetTime()
        cv2.putText(img, self.sort_title+" Time:%02.2fs"%self.time, (20,20), cv2.FONT_HERSHEY_PLAIN, 1, self.YELLOW, 1)
        cv2.imshow(self.sort_title, img)
        cv2.waitKey(self.time_interval)

# 基本入口方法

In [6]:
def execute_sort(sortFuncName):
    ds=DataSeq(64,1, sort_title)
    ds.Visualize()
    ds.StartTimer()
    sortFuncName(ds)
    ds.StopTimer()
    ds.SetTimeInterval(0)
    ds.Visualize()

# 排序

## 冒泡排序

In [19]:
def bubble_sort(ds):
    assert isinstance(ds, DataSeq), 'Type Error'
    Len = ds.length
    for i in range(Len-1, -1,-1):
        for j in range(0,i,1):
            if ds.data[j] > ds.data[j+1]:
                ds.Swap(j, j+1)

## 桶排序
依赖：
* copy

In [21]:
def bucket_sort(ds):
    """
        桶排序只适用于整数排序，且最大元素不能比数组元素大太多的情况
    """
    assert isinstance(ds, DataSeq), "Type Error"

    Length = ds.length
    bucket = [0 for _ in range(Length)]
    for i in range(Length):
        bucket[ds.data[i]] += 1
    j=0
    for i in range(Length):
        tmp = bucket[i]
        while tmp>0:
            ds.SetVal(j, i)
            tmp-=1
            j+=1

## 环排序

In [24]:
def cycle_sort(ds):
    """
        环排序只适用于整数排序，且数正好范围在[0,N-1]内，且只有少量重复元素，不稳定
    """
    assert isinstance(ds, DataSeq), "Type Error"

    Length = ds.length
    # 重复元素的列表
    repeatIdxs = []
    for i in range(Length):
        currIdx=i
        nextIdx=ds.data[currIdx]
        while ds.data[nextIdx] != nextIdx:
            ds.Swap(currIdx, nextIdx)
            nextIdx=ds.data[currIdx]
        if ds.data[i] != i:
            repeatIdxs.append(i)
    # 剩下少数重复元素，整个数组基本有序，使用插入排序
    # print(repeatIdxs)
    for p in range(Length):
        tmp = ds.data[p]
        i=p
        while i>=1 and ds.data[i-1]>tmp:
            ds.SetVal(i, ds.data[i-1])
            i-=1
        ds.SetVal(i, tmp)

## 堆排序

In [26]:
def CheckMaxHeap(data, size, child):
    if child <= size:
        father = child//2
        if data[child]>data[father]:
            print("error found")
        CheckMaxHeap(data, size, child*2)
        CheckMaxHeap(data, size, child*2+1)


def heap_sort(ds, time_interval=1):
    assert isinstance(ds, DataSeq), "Type Error"

    Length = ds.length
    # 先构建最大堆
    for i in range(1,Length):
        X = ds.data[i]
        child = i
        father = (child+1)//2-1
        while X>ds.data[father] and father>=0:
            ds.SetVal(child, ds.data[father])
            ds.Visualize((child,father))
            child=father
            father = (child+1)//2-1
        ds.SetVal(child, X)
        ds.Visualize((child,))

    # 检查最大堆是否正确
    # data = [100000]+ds.data
    # print(data)
    # CheckMaxHeap(data, ds.length, 1)

    # 再反向弹出
    p=ds.length-1
    while(p>0):
        maxval = ds.data[0]
        last = ds.data[p]
        father = 0
        child = (father+1)*2-1
        while child<p:
            if child!=(p-1) and ds.data[child]<ds.data[child+1]:
                child += 1
            if ds.data[child]<last:
                break
            else:
                ds.SetVal(father, ds.data[child])
                ds.Visualize((father, child))
                father = child
                child = (father+1)*2-1
        ds.SetVal(father, last)
        ds.SetVal(p,maxval)
        ds.Visualize((p,))
        p-=1

## 插入排序

In [28]:
def insertion_sort(ds, time_interval=1):
    assert isinstance(ds, DataSeq), "Type Error"

    Length = ds.length
    for i in range(Length):
        tmp = ds.data[i]
        j=i
        while j>=1 and ds.data[j-1]>tmp:
            ds.SetVal(j, ds.data[j-1])
            j-=1
        ds.SetVal(j, tmp)

## 归并排序

In [29]:
def Merge(ds, L, R, RightEnd, time_interval):
    tmpData = copy.copy(ds.data)
    LeftEnd = R-1
    i=L
    j=R
    k=L
    # import ipdb; ipdb.set_trace()
    while i<=LeftEnd and j<=RightEnd:
        if tmpData[i] < tmpData[j]:
            ds.SetVal(k, tmpData[i]) 
            i+=1
        else:
            ds.SetVal(k, tmpData[j]) 
            j+=1
        k+=1
    while i<=LeftEnd:
        ds.SetVal(k, tmpData[i]) 
        k+=1
        i+=1
    while j<=RightEnd:
        ds.SetVal(k, tmpData[j]) 
        k+=1
        j+=1

def Sort(ds, L, RightEnd, time_interval):
    # import ipdb; ipdb.set_trace()
    if RightEnd>L:
        mid = (L+RightEnd)//2
        Sort(ds,L,mid, time_interval)
        Sort(ds,mid+1,RightEnd, time_interval)
        Merge(ds,L,mid+1,RightEnd, time_interval)



def merge_sort(ds, time_interval=1):
    assert isinstance(ds, DataSeq), "Type Error"

    Length = ds.length
    Sort(ds, 0,Length-1, time_interval)

## 快排

In [30]:
def GetPivot(ds, Left, Right):
    Mid = (Left+Right)//2
    if ds.data[Left]>ds.data[Right]:
        ds.Swap(Left, Right)
    if ds.data[Left]>ds.data[Mid]:
        ds.Swap(Left, Mid)
    if ds.data[Mid]>ds.data[Right]:
        ds.Swap(Mid, Right)
    ds.Swap(Mid, Right-1)
    return ds.data[Right-1]

def Qsort(ds, Left, Right):
    Cutoff = 10
    if Cutoff <= Right-Left:
        Pivot = GetPivot(ds, Left, Right)
        low=Left+1
        high=Right-2
        while True:
            while ds.data[low]<Pivot :
                low+=1
            while ds.data[high]>Pivot:
                high-=1
            if low<high:
                ds.Swap(low, high)
                low +=1
                high-=1
            else:
                break
        ds.Swap(low, Right-1)
        Qsort(ds, Left, low-1)
        Qsort(ds, low+1, Right)

    else:
        # 元素太少， 用插入排序
        for p in range(Left,Right+1):
            tmp = ds.data[p]
            i=p
            while i>=1 and ds.data[i-1]>tmp:
                ds.SetVal(i, ds.data[i-1])
                i-=1
            ds.SetVal(i, tmp)

def quick_sort(ds):
    assert isinstance(ds, DataSeq), "Type Error"

    Length = ds.length
    Qsort(ds, 0, Length-1)

## 选择排序

In [31]:
def selection_sort(ds):
    assert isinstance(ds, DataSeq), "Type Error"

    Length = ds.length
    for i in range(Length):
        for j in range(i, Length):
            if ds.data[j] < ds.data[i]:
                ds.Swap(i,j)

## 希尔排序

In [33]:
def shell_sort(ds):
    assert isinstance(ds, DataSeq), "Type Error"

    Length = ds.length
    D = Length//2
    while D>0:
        i=0
        while i<Length:
            tmp = ds.data[i]

            j=i
            while j>=1 and ds.data[j-D]>tmp:
                ds.SetVal(j, ds.data[j-D])
                j-=D
            ds.SetVal(j, tmp)

            i+=D
        D//=2

In [36]:
execute_sort(quick_sort)