# 什么是单元测试

In [1]:
import unittest

In [4]:
# 将要被测试的排序函数
def sort(arr):
    l = len(arr)
    for i in range(0, l):
        for j in range(i + 1, l):
            if arr[i] >= arr[j]:
                tmp = arr[i]
                arr[i] = arr[j]
                arr[j] = tmp

In [5]:
# 编写子类继承unittest.TestCase
class TestSort(unittest.TestCase):
    
    # 以test开头的函数会被测试
    def test_sort(self):
        arr = [3, 4, 1, 5, 6]
        sort(arr)
        # assert 结果根我们期待的一样
        self.assertEqual(arr, [1, 3, 4, 5, 6])

In [6]:
if __name__ == '__main__':
    ## 如果在Jupyter下，请用如下方式运行单元测试
    unittest.main(argv=['first-arg-is-ignored'], exit=False)
    ## 如果是命令行下运行，则：
    ## unittest.main()

.
----------------------------------------------------------------------
Ran 1 test in 0.012s

OK


# 单元测试的几个技巧

## mock

In [7]:
import unittest
from unittest.mock import MagicMock

class A(unittest.TestCase):
    def m1(self):
        val = self.m2()
        self.m3(val)
        
    def m2(self):
        pass
    
    def m3(self, val):
        pass
    
    def test_m1(self):
        a = A()
        a.m2 = MagicMock(return_value='custom_val')
        a.m3 = MagicMock()
        a.m1()
        self.assertTrue(a.m2.called) # 验证m2被call过
        a.m3.assert_called_with("custom_val") # 验证m3被指定参数call过
        
        
if __name__ == '__main__':
    unittest.main(argv=['first-arg-is-ignored'], exit=False)

..
----------------------------------------------------------------------
Ran 2 tests in 0.004s

OK


## Mock Side Effect

In [8]:
from unittest.mock import MagicMock

def side_effect(arg):
    if arg < 0:
        return 1
    else:
        return 2
    

mock = MagicMock()
mock.side_effect = side_effect

In [9]:
mock(-1)

1

In [10]:
mock(1)

2

## patch

In [12]:
from unittest.mock import patch

@patch('%s.sort')
def test_sort(self, mock_sort):
    ...
    ...

# 高质量单元测试的关键

## Test Coverage

## 模块化

In [13]:
def work(arr):
    # pre process
    ...
    ...
    # sort
    l = len(arr)
    for i in range(0, l):
        for j in range(i+1, l):
            if arr[i] >= arr[j]:
                tmp = arr[i]
                arr[i] = arr[j]
                arr[j] = tmp
                
    # post process
    ...
    ...
    return arr

In [14]:
def preprocess(arr):
    ...
    ...
    return arr

def sort(arr):
    ...
    ...
    return arr

def postprocess(arr):
    ...
    return arr

def work(self):
    arr = preprocess(arr)
    arr = sort(arr)
    arr = postprocess(arr)
    return arr

In [15]:
from unittest.mock import patch

def test_preprocess(self):
    ...
    
def test_sort(self):
    ...
    
def test_postprocess(self):
    ...
    
@patch('%s.preprocess')
@patch('%s.sort')
@patch('%s.postprocess')
def test_work(self, mock_post_process, mock_sort, mock_preprocess):
    work()
    self.assertTrue(mock_post_process.called)
    self.assertTrue(mock_sort.called)
    self.assertTrue(mock_preprocess.called)
    
    
if __name__ == '__main__':
    unittest.main(argv=['fisrt-arg-is-ignored'], exit=False)

.F
FAIL: test_sort (__main__.TestSort)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "<ipython-input-5-bf8afc76cdf1>", line 9, in test_sort
    self.assertEqual(arr, [1, 3, 4, 5, 6])
AssertionError: Lists differ: [3, 4, 1, 5, 6] != [1, 3, 4, 5, 6]

First differing element 0:
3
1

- [3, 4, 1, 5, 6]
?        ---

+ [1, 3, 4, 5, 6]
?  +++


----------------------------------------------------------------------
Ran 2 tests in 0.019s

FAILED (failures=1)
