<a href="https://colab.research.google.com/github/CT-LU/Notes-of-Clean-Code-in-Python/blob/main/Clean_Code_in_Python.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# unittest 

In [None]:
import unittest

def division(a, b): 
    return a / b

class MyTest(unittest.TestCase):

    def test_upper(self):
        self.assertEqual('foo'.upper(), 'FOO')

    def test_isupper(self):
        self.assertTrue('FOO'.isupper())
        self.assertFalse('Foo'.isupper())

    def test_split(self):
        s = 'hello world'
        self.assertEqual(s.split(), ['hello', 'world'])
        # check that s.split fails when the separator is not a string
        #s.split(2)
        with self.assertRaises(TypeError):
            s.split(2) # TypeError: must be str or None, not int
    
    def test_raise(self): # 通過regex 'by zero' 匹配除零異常 
        #division(1, 0) # ZeroDivisionError: division by zero
        self.assertRaises(Exception, division, 1, 0)
        self.assertRaisesRegex(Exception, "by zero", division, 1, 0)


if __name__ == '__main__':
    unittest.main(argv=['first-arg-is-ignored'], exit=False)

In [None]:
import unittest
import time
from unittest import TestCase
from unittest.mock import patch


def sum(a, b):
    time.sleep(100)  # 測試sum要花狠久 
    return a + b

class ClassName1: pass
class ClassName2: pass

class TestCalculator(TestCase):
    @patch('__main__.sum', return_value=5) #直接patch sum函數return 5
    def test_sum(self, sum):
        self.assertEqual(sum(2, 3), 5) 
        self.assertEqual(sum(333, 3), 5) #無論sum傳什麼都會是return 5

    @patch('__main__.ClassName2')
    @patch('__main__.ClassName1')
    def test_patch(self, MockClass1, MockClass2): #patch 兩個Mock
        MockClass1.return_value = 1
        MockClass2.return_value = 2
        print(ClassName1())
        print(ClassName2())
        assert MockClass1 is ClassName1
        assert MockClass2 is ClassName2
        assert MockClass1.called
        assert MockClass2.called

if __name__ == '__main__':
    unittest.main(argv=['first-arg-is-ignored'], exit=False)

In [None]:
import unittest
import time
from unittest import TestCase
from unittest.mock import patch
from unittest.mock import Mock 

class TestCalculator(TestCase):
    m = Mock()
    m.return_value = 3 # 等同 m = Mock(return_value=42) 
    m.foo = 42
    m.configure_mock(bar='baz') # 等同 m.bar = 'baz'

    print(m()) # 3
    print(m.foo) # 42 
    print(m.bar) # baz

if __name__ == '__main__':
    unittest.main(argv=['first-arg-is-ignored'], exit=False)

In [None]:
from unittest.mock import Mock

myMethod = Mock()   # mock 一個函數
myMethod.return_value = 3 # 函數return 3
print(myMethod(1, 'a', foo='bar'))  # 呼叫函數傳入三個參數 會得到return 3

myMethod.assert_called_with(1, 'a', foo='bar')  # 傳入三個參數(1, 'a', foo='bar') 被呼叫過會是true
print(myMethod())  # 再呼叫一次 
print(myMethod.call_count)  # 2 被呼叫過2次 

In [None]:

myMethod.side_effect = KeyError("Hi Hi Key Error")
#myMethod("abc") # KeyError: 'Hi Hi Key Error'

new_mock = Mock(side_effect=KeyboardInterrupt("Error by ctrl + c"))
#new_mock() # KeyboardInterrupt: Error by ctrl + c

def for_side(*args, **kwargs):
    print('args: ', args)
    print('kwargs: ', kwargs)

myMethod.side_effect = for_side
myMethod('dsf', **{"a": 1, "bn": 2})
# args:  ('dsf',)
# kwargs:  {'a': 1, 'bn': 2}

myMethod('dsf', {"a": 1, "bn": 2})
#args:  ('dsf', {'a': 1, 'bn': 2})
#kwargs:  {}

In [None]:
import unittest

class Person:
    def __init__(self):
        self.__age = 10
    def get_fullname(self, first_name, last_name):
        return first_name + ' ' + last_name
    def get_age(self):
        return self.__age
    @staticmethod
    def get_class_name():
        return Person.__name__

class PersonTest(unittest.TestCase):
    def setUp(self):
        self.p = Person()
    def test_should_get_age(self):
        self.p.get_age = Mock(side_effect=[10, 11, 12]) # 摸擬每次呼叫get_age會得到的結果
        self.assertEqual(self.p.get_age(), 10)
        self.assertEqual(self.p.get_age(), 11)
        self.assertEqual(self.p.get_age(), 12)

    def test_should_get_fullname(self):
        '''
        side_effect摸擬get_fullname，用lambda吃兩個參數，return回dict中的value
        '''
        values = {('James', 'Harden'): 'James Harden', ('Tracy', 'Grady'): 'Tracy Grady'}
        self.p.get_fullname = Mock(side_effect=lambda x, y: values[(x, y)])
        self.assertEqual(self.p.get_fullname('James', 'Harden'), 'James Harden')
        self.assertEqual(self.p.get_fullname('Tracy', 'Grady'), 'Tracy Grady')

if __name__ == '__main__':
    unittest.main(argv=['first-arg-is-ignored'], exit=False)

# **Chapter 2 Pythonic Code**

## Indexes and slices 

In [None]:
"""Indexes and slices
Getting elements by an index or range
"""
import doctest


def index_last():
    """
    >>> my_numbers = (4, 5, 3, 9)
    >>> my_numbers[-1]
    9
    >>> my_numbers[-3]
    5
    """


def get_slices():
    """
    >>> my_numbers = (1, 1, 2, 3, 5, 8, 13, 21)
    >>> my_numbers[2:5]
    (2, 3, 5)
    >>> my_numbers[:3]
    (1, 1, 2)
    >>> my_numbers[3:]
    (3, 5, 8, 13, 21)
    >>> my_numbers[::]
    (1, 1, 2, 3, 5, 8, 13, 21)
    >>> my_numbers[1:7:2]
    (1, 3, 8)

    >>> interval = slice(1, 7, 2)
    >>> my_numbers[interval]
    (1, 3, 8)

    >>> interval = slice(None, 3)
    >>> my_numbers[interval] == my_numbers[:3]
    True
    """


def main():
    index_last()
    get_slices()
    fail_count, _ = doctest.testmod(verbose=True)
    #raise SystemExit(fail_count)


if __name__ == "__main__":
    main()

## Creating your own sequences 

In [None]:
"""Clean Code in Python - Chapter 2: Pythonic Code
可以使用pythonic的統一存取方式(magic function), 
實作自己的class也應該要先思考，實作magic function來存取
"""


class Items:
    def __init__(self, *values):
        self._values = list(values)

    def __len__(self):
        return len(self._values)

    def __getitem__(self, item):
        return self._values.__getitem__(item)

def main():
    a_item = Items(10, 1, 'hello item !!')
    print(a_item[-1])

if __name__ == "__main__":
    main()

## Context Managers 
三種with obj: ... 的實作方式

In [None]:
import contextlib


run = print


def stop_database():
    run("systemctl stop postgresql.service")


def start_database():
    run("systemctl start postgresql.service")

'''
method 1: 利用__magic__ ，__enter__會return給as，__exit__是結束with
'''
class DBHandler:
    def __enter__(self):
        stop_database()
        return self

    def __exit__(self, exc_type, ex_value, ex_traceback):
        start_database()


def db_backup():
    run("pg_dump database")


'''
method 2: contextlib.contextmanager,裝飾yield產生器函式
'''
@contextlib.contextmanager
def db_handler():
    stop_database()
    yield
    start_database()


'''
method 3: 實作contextlib.ContextDecorator裝飾器class, 就可以不用with
但是，無法在環境管理器中就拿不到 as obj，原則上裝飾器不曉得發生什麼事
'''
class dbhandler_decorator(contextlib.ContextDecorator):
    def __enter__(self):
        stop_database()

    def __exit__(self, ext_type, ex_value, ex_traceback):
        start_database()


@dbhandler_decorator()
def offline_backup():
    run("pg_dump database")


def main():
    with DBHandler():
        db_backup()

    with db_handler():
        db_backup()

    offline_backup() #第三種沒有with obj

    '''
    想乎略錯誤, 例如檔案不存在就乎略
    '''
    with contextlib.suppress(FileNotFoundError):
        with open("1.txt") as f:
            for line in f:
                print(line)


if __name__ == "__main__":
    main()


## Properties 
約定成俗的private，想要存取它，要實作property

In [None]:
import re

EMAIL_FORMAT = re.compile(r"[^@]+@[^@]+\.[^@]+")


def is_valid_email(potentially_valid_email: str):
    return re.match(EMAIL_FORMAT, potentially_valid_email) is not None


class User:
    def __init__(self, username):
        self.username = username
        self._email = None #private

    @property
    def email(self): # read _email
        return self._email

    @email.setter # set _email
    def email(self, new_email):
        if not is_valid_email(new_email):
            raise ValueError(
                f"Can't set {new_email} as it's not a valid email"
            )
        self._email = new_email

def main():
    u1 = User('jsmith')
    u1.email = 'jsmith@g.co'
    print(u1.email)
    #u1.email = 'jsmith@' #raise error

if __name__ == "__main__":
    main()

## Iterable objects

In [None]:
'''
想要for in，有兩種可能，
1 __len__ and __getitem__
2 __next__ or __iter__

'''
from datetime import timedelta, date

'''
這種作法只能使用一次loop
'''
class DateRangeIterable:
    """An iterable that contains its own iterator object."""

    def __init__(self, start_date, end_date):
        self.start_date = start_date
        self.end_date = end_date
        self._present_day = start_date

    def __iter__(self):
        return self #傳回自己成為iter物件
    
    def __next__(self): # 執行一次for loop後， _present_day會更新成最後一天
        if self._present_day >= self.end_date:
            raise StopIteration
        today = self._present_day
        self._present_day += timedelta(days=1)
        return today

'''
在__iter__中使用yeild生出generator, 可以重新loop
'''
class DateRangeContainerIterable:
    """An range that builds its iteration through a generator."""

    def __init__(self, start_date, end_date):
        self.start_date = start_date
        self.end_date = end_date

    def __iter__(self):
        current_day = self.start_date
        while current_day < self.end_date:
            yield current_day #使用yield替代__next__
            current_day += timedelta(days=1)

'''
就最直觀的create your own sequences
'''
class DateRangeSequence:
    """An range created by wrapping a sequence."""

    def __init__(self, start_date, end_date):
        self.start_date = start_date
        self.end_date = end_date
        self._range = self._create_range()

    def _create_range(self):
        days = []
        current_day = self.start_date
        while current_day < self.end_date:
            days.append(current_day)
            current_day += timedelta(days=1)
        return days

    def __getitem__(self, day_no):
        return self._range[day_no]

    def __len__(self):
        return len(self._range)

def main():
    for day in DateRangeIterable(date(2018, 1, 1), date(2018, 1, 5)):
        print(day)
    
    r = DateRangeIterable(date(2018, 1, 1), date(2018, 1, 5))
    next(r)
    next(r)
    next(r)
    next(r)
    #next(r) #should raise stop
    
    r1 = DateRangeIterable(date(2018, 1, 1), date(2018, 1, 5))
    print(" , ".join(map(str, r1)))
    #max(r1) #should raise stop 

    '''
    前面這樣用iter物件執行一圈完就走到底不會從頭來,
    讓我們使用下一個implementation
    '''
    r1 = DateRangeContainerIterable(date(2018, 1, 1), date(2018, 1, 5))
    print(" , ".join(map(str, r1)))
    print(max(r1))
    '''
    每個for都會藉__iter__建立新的generator
    '''
    s1 = DateRangeSequence(date(2018, 1, 1), date(2018, 1, 5))
    for day in s1:
        print(day)
    '''
    用iter,generator都是時間換空間，用list就是O(1)，空間換時間
    '''

    
if __name__ == "__main__":
    main()
    

## Container objects 

In [None]:
'''
作一個mask地圖
'''
import numpy as np


class Boundaries:
    def __init__(self, width, height):
        self.width = width
        self.height = height

    def __contains__(self, coord):
        x, y = coord
        return 0 <= x < self.width and 0 <= y < self.height


class Grid:
    def __init__(self, width, height):
        self.width = width
        self.height = height
        self.map = np.zeros([width, height])
        self.limits = Boundaries(width, height)

    def __contains__(self, coord):
        return coord in self.limits #pythonic易讀因為實作了__contains__

    def __setitem__(self, coord, value):
        self.map[coord] = value

        
def mark_coordinate(grid, coord):
    x, y = coord
    if 0 <= x < grid.width and 0 <= y < grid.height:
        grid[coord] = 1

    if coord in grid: #pythonic易讀因為實作了__contains__
        grid[coord] = 1

def main():
    '''
    element in container會是另一種magic的方式container.__contains__(element)
    實作magic __contains__讓程式一致性易讀
    '''
    grid = Grid(640, 480)
    mark_coordinate(grid, (1,2))

    
if __name__ == "__main__":
    main()

## Dynamic attributes for objects 

In [None]:
class DynamicAttributes:

    def __init__(self, attribute):
        self.attribute = attribute

    def __getattr__(self, attr): #沒有這個attribute就會呼叫它
        if attr.startswith("fallback_"):
            name = attr.replace("fallback_", "")
            return f"[fallback resolved] {name}"
        raise AttributeError(
            f"{self.__class__.__name__} has no attribute {attr}"
        )

def main():
    
    dyn = DynamicAttributes("value")
    print(dyn.attribute)
    #'value'
    print(dyn.fallback_test)
    #'[fallback resolved] test'
    
    dyn.__dict__["fallback_new"] = "new value"
    '''
    this call would be the same as running dyn.fallback_new = "new value"
    '''
    print(dyn.fallback_new)
    #'new value'
    '''
    The syntax of getattr() method is:
        getattr(object, name[, default])
        
    The above syntax is equivalent to:
        object.name
    ''' 
    print(getattr(dyn, "something", "default"))
    #'default'
    
if __name__ == "__main__":
    main()

## Callable objects 

In [None]:
from collections import defaultdict, namedtuple


class CallCount:

    def __init__(self):
        self._counts = defaultdict(int) #dict的key可以是int

    def __call__(self, argument): #functor呼叫加入新的key, 同時value + 1
        self._counts[argument] += 1
        return self._counts[argument]


def main():
    cc = CallCount()
    print(cc(1))
    #1
    print(cc(2))
    #1
    print(cc(1))
    #2
    print(cc(1))
    #3
    print(cc("something"))
    #1
    print(cc("something"))
    #2

    print(callable(cc))
    #True

    
    
if __name__ == "__main__":
    main()

## Caveats in Python

In [None]:
import unittest
from datetime import datetime


class LoginEventSerializer:
    def __init__(self, event):
        self.event = event

    def serialize(self) -> dict:
        return {
            "username": self.event.username,
            "password": "**redacted**",
            "ip": self.event.ip,
            "timestamp": self.event.timestamp.strftime("%Y-%m-%d %H:%M"),
        }


class LoginEvent:
    SERIALIZER = LoginEventSerializer

    def __init__(self, username, password, ip, timestamp):
        self.username = username
        self.password = password
        self.ip = ip
        self.timestamp = timestamp

    def serialize(self) -> dict:
        return self.SERIALIZER(self).serialize()


class TestLoginEventSerialized(unittest.TestCase):
    def test_serializetion(self):
        event = LoginEvent(
            "username", "password", "127.0.0.1", datetime(2016, 7, 20, 15, 45)
        )
        expected = {
            "username": "username",
            "password": "**redacted**",
            "ip": "127.0.0.1",
            "timestamp": "2016-07-20 15:45",
        }
        self.assertEqual(event.serialize(), expected)


if __name__ == "__main__":
    unittest.main(argv=['first-arg-is-ignored'], exit=False)

In [None]:
from collections import UserList


def wrong_user_display(user_metadata: dict = {"name": "John", "age": 30}):
    name = user_metadata.pop("name")
    age = user_metadata.pop("age")

    return f"{name} ({age})"


def user_display(user_metadata: dict = None):
    user_metadata = user_metadata or {"name": "John", "age": 30}

    name = user_metadata.pop("name")
    age = user_metadata.pop("age")

    return f"{name} ({age})"


class BadList(list): #list是CPython的實作, 直接繼承會把某些方法蓋掉
    def __getitem__(self, index):
        value = super().__getitem__(index)
        if index % 2 == 0:
            prefix = "even"
        else:
            prefix = "odd"
        return f"[{prefix}] {value}"


class GoodList(UserList): #UserList按自己定制
    def __getitem__(self, index):
        value = super().__getitem__(index)
        if index % 2 == 0:
            prefix = "even"
        else:
            prefix = "odd"
        return f"[{prefix}] {value}"

def main():
    ### mutable default arguments
    print(wrong_user_display())
    print(wrong_user_display({"name": "Jane", "age": 25}))
    #print(wrong_user_display()) #keyError
    print(user_display())
    print(user_display({"name": "Jane", "age": 25}))
    print(user_display()) #it works

    ### Extending built-in types 
    b1 = BadList((0, 1, 2, 3, 4, 5))
    print(b1[0])
    print(b1[1])
    #print("".join(b1)) #TypeError
    '''
    join function 會試著iterate(run a for loop over)這個list，
    但是預期是string，但我們已經改成output string
    '''
    g1 = GoodList((0, 1, 2))
    print(g1[0])
    print(g1[1])
    print(";".join(g1)) #

if __name__ == "__main__":
    main()

# Chapter 3  General Traits of Good Code

## Handle exceptions at the right level of abstraction
* 例外不要當成處理商業邏輯的go to
* 函式只應做一件事情，這條原則也包括例外

In [None]:
import logging
import unittest
from unittest.mock import Mock, patch
import time

logger = logging.getLogger(__name__)


class Connector:
    """Abstract the connection to a database."""

    def connect(self):
        """Connect to a data source."""
        return self

    @staticmethod
    def send(data):
        return data


class Event:
    def __init__(self, payload):
        self._payload = payload

    def decode(self):
        return f"decoded {self._payload}"

#主要是deliver_event, 看它
class DataTransport:
    """An example of an object badly handling exceptions of different levels."""

    retry_threshold: int = 5
    retry_n_times: int = 3

    def __init__(self, connector):
        self._connector = connector
        self.connection = None
    '''
    將event解碼後，傳輸data，它有兩種exception
    '''
    def deliver_event(self, event):
        try:
            self.connect()
            data = event.decode()
            self.send(data) # send f"decoded {self._payload}"
        except ConnectionError as e: #ConnectionError是處理connect()沒有問題
            logger.info("connection error detected: %s", e)
            raise
        except ValueError as e: #Value應該是decode要處理的，不該放在這裡
            logger.error("%r contains incorrect data: %s", event, e)
            raise

    def connect(self):
        for _ in range(self.retry_n_times):
            try:
                self.connection = self._connector.connect() #組合關係的_connector負責connect
            except ConnectionError as e:
                logger.info(
                    "%s: attempting new connection in %is",
                    e,
                    self.retry_threshold,
                )
                time.sleep(self.retry_threshold) # 這裡用time.sleep後再重連
            else:
                return self.connection
        raise ConnectionError(
            f"Couldn't connect after {self.retry_n_times} times"
        )

    def send(self, data):
        return self.connection.send(data)


class FailsAfterNTimes:
    '''
    init可指定次數，跟例外處理
    '''
    def __init__(self, n_times: int, with_exception) -> None:
        self._remaining_failures = n_times
        self._exception = with_exception

    def connect(self):
        self._remaining_failures -= 1
        if self._remaining_failures >= 0: # 可以重連的次數用光
            raise self._exception
        return self

    def send(self, data):
        return data


@patch("time.sleep", return_value=0) #patch time.sleep 成為sleep參數這個mock，且直接設定回傳0
class TestTransport(unittest.TestCase):
    def test_connects_after_retries(self, sleep):
        data_transport = DataTransport(
            FailsAfterNTimes(2, with_exception=ConnectionError)
        )
        data_transport.send = Mock() # Mock一個send函數
        data_transport.deliver_event(Event("test"))

        data_transport.send.assert_called_once_with("decoded test") # 驗証否send呼叫過一次且參數是'decoded test'

        assert (
            sleep.call_count == DataTransport.retry_n_times - 1 # if false, assertionError is raised
        ), sleep.call_count # if 0, assertionError 

    def test_connects_directly(self, sleep):
        connector = Mock()
        data_transport = DataTransport(connector)
        data_transport.send = Mock()
        data_transport.deliver_event(Event("test"))

        connector.connect.assert_called_once() # 驗証組合關係的connector.connect
        assert sleep.call_count == 0

    def test_connection_error(self, sleep):
        data_transport = DataTransport(
            Mock(connect=Mock(side_effect=ConnectionError)) 
        ) #組合關係connector用一個Mock摸擬, 且connector.connect也用一個Mock的side_effect摸擬Exception
        
        self.assertRaisesRegex(
            ConnectionError,
            "Couldn't connect after \d+ times", #驗証這個expression的Exception
            data_transport.deliver_event,
            Event("connection error"),
        )
        assert sleep.call_count == DataTransport.retry_n_times

    def test_error_in_event(self, sleep):
        data_transport = DataTransport(Mock())
        event = Mock(decode=Mock(side_effect=ValueError)) #decode摸擬一個ValueError
        with patch("__main__.logger.error"): #patch
            self.assertRaises(ValueError, data_transport.deliver_event, event)

        assert not sleep.called


if __name__ == "__main__":
    unittest.main(argv=['first-arg-is-ignored'], exit=False)

# Chapter 5 Using Decorators to Improve Our Code 

## Decorate functions 

In [None]:
"""
Creating a decorator to be applied over a function.
"""

from functools import wraps
from unittest import TestCase, main, mock

import logging


class ControlledException(Exception):
    """A generic exception on the program's domain."""


def retry(operation):
    @wraps(operation) #先別管@wraps
    def wrapped(*args, **kwargs):
        last_raised = None
        RETRIES_LIMIT = 3
        for _ in range(RETRIES_LIMIT):
            try:
                return operation(*args, **kwargs)
            except ControlledException as e:
                logging.info("retrying %s", operation.__qualname__)
                last_raised = e
        raise last_raised

    return wrapped


class OperationObject:
    """A helper object to test the decorator."""

    def __init__(self):
        self._times_called: int = 0

    def run(self) -> int:
        """Base operation for a particular action"""
        self._times_called += 1
        return self._times_called

    def __str__(self):
        return f"{self.__class__.__name__}()"

    __repr__ = __str__


class RunWithFailure:
    def __init__(
        self,
        task: "OperationObject",
        fail_n_times: int = 0,
        exception_cls=ControlledException,
    ):
        self._task = task
        self._fail_n_times = fail_n_times
        self._times_failed = 0
        self._exception_cls = exception_cls

    def run(self):
        called = self._task.run()
        if self._times_failed < self._fail_n_times:
            self._times_failed += 1
            raise self._exception_cls(f"{self._task!s} failed!")
        return called


@retry #語法糖，實際是執行 run_operation = retry(run_operation)
def run_operation(task):
    """Run a particular task, simulating some failures on its execution."""
    return task.run()


class RetryDecoratorTest(TestCase):
    def setUp(self):
        self.info = mock.patch("logging.info").start()

    def tearDown(self):
        self.info.stop()

    def test_fail_less_than_retry_limit(self):
        """Retry = 3, fail = 2, should work"""
        task = OperationObject()
        failing_task = RunWithFailure(task, fail_n_times=2)
        times_run = run_operation(failing_task)
        self.assertEqual(times_run, 3)
        self.assertEqual(task._times_called, 3)

    def test_fail_equal_retry_limit(self):
        """Retry = fail = 3, will fail"""
        task = OperationObject()
        failing_task = RunWithFailure(task, fail_n_times=3)
        with self.assertRaises(ControlledException):
            run_operation(failing_task)

    def test_no_failures(self):
        task = OperationObject()
        failing_task = RunWithFailure(task, fail_n_times=0)
        times_run = run_operation(failing_task)

        self.assertEqual(times_run, 1)
        self.assertEqual(task._times_called, 1)


if __name__ == "__main__":
    main(argv=['first-arg-is-ignored'], exit=False)

## Decorate classes
當要開始擴展下面的例子時，就會有三個缺點
* 過多class，事件數量增加，就要不同的serialize來對應
* 無法reuse，屁如說，有一個新的event也需要hide password
* boilerplate(指許多地方重覆出現只改少量code，冗)，serialize呼叫會在不同類出現

In [None]:
import unittest
from datetime import datetime


class LoginEventSerializer:
    def __init__(self, event):
        self.event = event

    def serialize(self) -> dict:
        return {
            "username": self.event.username,
            "password": "**redacted**",
            "ip": self.event.ip,
            "timestamp": self.event.timestamp.strftime("%Y-%m-%d %H:%M"),
        }


class LoginEvent:
    SERIALIZER = LoginEventSerializer

    def __init__(self, username, password, ip, timestamp):
        self.username = username
        self.password = password
        self.ip = ip
        self.timestamp = timestamp

    def serialize(self) -> dict:
        return self.SERIALIZER(self).serialize()


class TestLoginEventSerialized(unittest.TestCase):
    def test_serializetion(self):
        event = LoginEvent(
            "username", "password", "127.0.0.1", datetime(2016, 7, 20, 15, 45)
        )
        expected = {
            "username": "username",
            "password": "**redacted**",
            "ip": "127.0.0.1",
            "timestamp": "2016-07-20 15:45",
        }
        self.assertEqual(event.serialize(), expected)


if __name__ == "__main__":
    unittest.main(argv=['first-arg-is-ignored'], exit=False)

讓我們用類別裝飾器來改善上面的例子

In [None]:
'''
Reimplement the serialization of the events by applying a class decorator.
'''
import unittest
from datetime import datetime

'''
reuse呼叫func抽出來, 給EventSerializer專門轉Event的 attributes
'''
def hide_field(field) -> str:
    return "**redacted**"


def format_time(field_timestamp: datetime) -> str:
    return field_timestamp.strftime("%Y-%m-%d %H:%M")


def show_original(event_field):
    return event_field


class EventSerializer:
    """Apply the transformations to an Event object based on its properties and
    the definition of the function to apply to each field.
    """

    def __init__(self, serialization_fields: dict) -> None:
        """Created with a mapping of fields to functions.

        Example::

        >>> serialization_fields = {
        ...    "username": str.upper,
        ...    "name": str.title,
        ... }

        Means that then this object is called with::

        >>> from types import SimpleNamespace
        >>> event = SimpleNamespace(username="usr", name="name")
        >>> result = EventSerializer(serialization_fields).serialize(event)

        Will return a dictionary where::

        >>> result == {
        ...     "username": event.username.upper(),
        ...     "name": event.name.title(),
        ... }
        True

        """
        self.serialization_fields = serialization_fields

    def serialize(self, event) -> dict:
        """Get all the attributes from ``event``, apply the transformations to
        each attribute, and place it in a dictionary to be returned.
        """
        return { 
            field: transformation(getattr(event, field)) for field, transformation in self.serialization_fields.items()
        }


class Serialization:
    """A class decorator created with transformation functions to be applied
    over the fields of the class instance.
    """

    def __init__(self, **transformations):
        """The ``transformations`` dictionary contains the definition of how to
        map the attributes of the instance of the class, at serialization time.
        """
        self.serializer = EventSerializer(transformations)

    def __call__(self, event_class):
        """Called when being applied to ``event_class``, will replace the
        ``serialize`` method of this one by a new version that uses the
        serializer instance.
        """

        def serialize_method(event_instance): #event_instance就是class被裝飾產生的instance
            return self.serializer.serialize(event_instance)

        event_class.serialize = serialize_method #裝飾event_class新增serialize函式`
        return event_class #return裝飾的class

'''
Serialization裝飾器帶dict參數, 裝飾event_class
'''
@Serialization(
    username=str.lower,
    password=hide_field,
    ip=show_original,
    timestamp=format_time,
)
class LoginEvent:
    def __init__(self, username, password, ip, timestamp):
        self.username = username
        self.password = password
        self.ip = ip
        self.timestamp = timestamp


class TestLoginEventSerialized(unittest.TestCase):
    def test_serialization(self):
        event = LoginEvent(
            "UserName", "password", "127.0.0.1", datetime(2016, 7, 20, 15, 45)
        )
        expected = {
            "username": "username",
            "password": "**redacted**",
            "ip": "127.0.0.1",
            "timestamp": "2016-07-20 15:45",
        }
        self.assertEqual(event.serialize(), expected)
        '''
        event被裝飾才有serialize function
        '''

if __name__ == "__main__":
    unittest.main(argv=['first-arg-is-ignored'], exit=False)

用dataclass省去__init__ boilerplate

In [None]:
"""
Class decorators.

Reimplement the serialization of the events by applying a class decorator.
Use the @dataclass decorator.

This code only works in Python 3.7+
"""
import sys
import unittest
from datetime import datetime

#from decorator_class_2 import (
#    Serialization,
#    format_time,
#    hide_field,
#    show_original,
#)

try:
    from dataclasses import dataclass
except ImportError:

    def dataclass(cls):
        return cls


@Serialization(
    username=show_original,
    password=hide_field,
    ip=show_original,
    timestamp=format_time,
)
@dataclass #用它省去__init__ boilerplate
class LoginEvent:
    username: str
    password: str
    ip: str
    timestamp: datetime


class TestLoginEventSerialized(unittest.TestCase):
    @unittest.skipIf(
        sys.version_info[:3] < (3, 7, 0), reason="Requires Python 3.7+ to run"
    )
    def test_serializetion(self):
        event = LoginEvent(
            "username", "password", "127.0.0.1", datetime(2016, 7, 20, 15, 45)
        )
        expected = {
            "username": "username",
            "password": "**redacted**",
            "ip": "127.0.0.1",
            "timestamp": "2016-07-20 15:45",
        }
        self.assertEqual(event.serialize(), expected)

if __name__ == "__main__":
    unittest.main(argv=['first-arg-is-ignored'], exit=False)

## Passing arguments to decorators
 

In [None]:
from functools import wraps

RETRIES_LIMIT = 3

'''
實作裝飾器函式帶參數，可以設定retry次數
'''
def with_retry(retries_limit=RETRIES_LIMIT, allowed_exceptions=None):
    allowed_exceptions = allowed_exceptions or (ControlledException,)

    def retry(operation):
        @wraps(operation)
        def wrapped(*args, **kwargs):
            last_raised = None
            for _ in range(retries_limit):
                try:
                    return operation(*args, **kwargs)
                except allowed_exceptions as e:
                    logging.warning(
                        "retrying %s due to %s", operation.__qualname__, e
                    )
                    last_raised = e
            raise last_raised

        return wrapped

    return retry


@with_retry()
def run_operation(task):
    return task.run()


@with_retry(retries_limit=5)
def run_with_custom_retries_limit(task):
    return task.run()


@with_retry(allowed_exceptions=(AttributeError,))
def run_with_custom_exceptions(task):
    return task.run()


@with_retry(
    retries_limit=4, allowed_exceptions=(ZeroDivisionError, AttributeError)
)
def run_with_custom_parameters(task):
    return task.run()


In [None]:
RETRIES_LIMIT = 3

'''
同的帶參數retry，用class裝飾器實作
函式實作的裝飾器，層數較多，class較清晰 
__init__負責裝飾器參數, __call__處理被裝飾的函數
'''
class WithRetry:
    def __init__(self, retries_limit=RETRIES_LIMIT, allowed_exceptions=None):
        self.retries_limit = retries_limit
        self.allowed_exceptions = allowed_exceptions or (ControlledException,)

    def __call__(self, operation):
        @wraps(operation)
        def wrapped(*args, **kwargs):
            last_raised = None

            for _ in range(self.retries_limit):
                try:
                    return operation(*args, **kwargs)
                except self.allowed_exceptions as e:
                    logger.info(
                        "retrying %s due to %s", operation.__qualname__, e
                    )
                    last_raised = e
            raise last_raised

        return wrapped


@WithRetry()
def run_operation(task):
    return task.run()


@WithRetry(retries_limit=5)
def run_with_custom_retries_limit(task):
    return task.run()


@WithRetry(allowed_exceptions=(AttributeError,))
def run_with_custom_exceptions(task):
    return task.run()


@WithRetry(
    retries_limit=4, allowed_exceptions=(ZeroDivisionError, AttributeError)
)
def run_with_custom_parameters(task):
    return task.run()

## Good uses for decorators
* Transforming parameters: 對傳入的參數做一些前處理 
* Tracing code: log執行過的函式足跡
* Validate parameters
* Implement retry operations
* Simplify classes by moving some (repetitive) logic into decorators: 抽象出不變的部份給decorator

## Effective decorators – avoiding common mistakes
 

In [None]:
def trace_decorator(function):
    def wrapped(*args, **kwargs):
        logging.info("running %s", function.__qualname__)
        return function(*args, **kwargs)

    return wrapped


@trace_decorator
def process_account(account_id):
    """Process an account by Id."""
    logging.info("processing account %s", account_id)

help(process_account) #想取得process_account, 它會變成是wrapped
process_account.__qualname__ 
#顯示函數的名字、類別、模組等位址
#發現看到都是都不是wrapped，如果要trace函式的足跡狠麻煩

'''
修正它，狠簡單的
'''
def trace_decorator(function):
    @wraps(function)
    def wrapped(*args, **kwargs):
        logging.info("running %s", function.__qualname__)
        return function(*args, **kwargs)

    return wrapped


@trace_decorator
def process_account(account_id):
    """Process an account by Id."""
    logging.info("processing account %s", account_id)

help(process_account)
process_account.__qualname__

In [None]:
import time
from functools import wraps

'''
試範怎麼誤用，量elapsed time, 當decorator被import或是被呼叫展開了
'''
def traced_function_wrong(function):
    """An example of a badly defined decorator."""
    start_time = time.time()

    @wraps(function)
    def wrapped(*args, **kwargs):
        print("started execution of %s" % function)
        result = function(*args, **kwargs) 
        print("function %s took %.2fs"% (function, time.time() - start_time))
        return result

    return wrapped


@traced_function_wrong
def process_with_delay(callback, delay=0):
    print("sleep(%d)"% delay)
    return callback

'''
正確使用的話應該把time函式都放入wraps
'''
def traced_function(function):
    @wraps(function)
    def wrapped(*args, **kwargs):
        print("started execution of %s" % function)
        start_time = time.time()
        result = function(*args, **kwargs)
        print("function %s took %.2fs"% (function, time.time() - start_time))
        return result

    return wrapped


@traced_function
def call_with_delay(callback, delay=0):
    print("sleep(%d)" % delay)
    return callback

def a_callback():
    f'cb'

def main():
    a_fun = process_with_delay
    b_fun = call_with_delay
    time.sleep(2)
    a_fun(a_callback)
    print("------------------------------")
    b_fun(a_callback)

if __name__ == "__main__":
    main()

我想利用side effect的話,


In [None]:
'''
>>> from decorator_side_effects_2 import EVENTS_REGISTRY
>>> EVENTS_REGISTRY
{'UserLoginEvent': decorator_side_effects_2.UserLoginEvent,
 'UserLogoutEvent': decorator_side_effects_2.UserLogoutEvent}
 
 import完後，我就會知道event table裡共有哪些events
'''
EVENTS_REGISTRY = {}


def register_event(event_cls):
    """Place the class for the event into the registry to make it accessible in
    the module.
    """
    EVENTS_REGISTRY[event_cls.__name__] = event_cls
    return event_cls


class Event:
    """A base event object"""


class UserEvent:
    TYPE = "user"


@register_event
class UserLoginEvent(UserEvent):
    """Represents the event of a user when it has just accessed the system."""


@register_event
class UserLogoutEvent(UserEvent):
    """Event triggered right after a user abandoned the system."""


def test():
    """
    >>> sorted(EVENTS_REGISTRY.keys()) == sorted(('UserLoginEvent', 'UserLogoutEvent'))
    True
    """

## Creating decorators that will always work
先舉一個例子是要query db，給一個字串

In [None]:
from functools import wraps


class DBDriver:
    def __init__(self, dbstring):
        self.dbstring = dbstring

    def execute(self, query):
        return f"query {query} at {self.dbstring}"


def inject_db_driver(function):
    """This decorator converts the parameter by creating a ``DBDriver``
    instance from the database dsn string.
    """

    @wraps(function)
    def wrapped(dbstring):
        return function(DBDriver(dbstring))

    return wrapped


@inject_db_driver
def run_query(driver):
    return driver.execute("test_function")


class DataHandler:
    """The decorator will not work for methods as it is defined."""
    #沒法重利用inject_db_driver，因為class參數多了self
    @inject_db_driver
    def run_query(self, driver):
        return driver.execute(self.__class__.__name__)

def main():
    print(run_query('test okay'))
    #print(DataHandler().run_query("test fails")) #TypeError

if __name__ == "__main__":
    main()

In [None]:
from functools import wraps
from types import MethodType


class DBDriver:
    def __init__(self, dbstring):
        self.dbstring = dbstring

    def execute(self, query):
        return f"query {query} at {self.dbstring}"


class inject_db_driver:
    """Convert a string to a DBDriver instance and pass this to the wrapped
    function.
    """

    def __init__(self, function):
        self.function = function
        wraps(self.function)(self)

    def __call__(self, dbstring):
        return self.function(DBDriver(dbstring))
    
    def __get__(self, instance, owner): 
        if instance is None: #這裡先借助chapter06的描述器, 執行class method
            return self
        return self.__class__(MethodType(self.function, instance))


@inject_db_driver
def run_query(driver):
    return driver.execute("test_function_2")


class DataHandler:
    @inject_db_driver
    def run_query(self, driver):
        return driver.execute("test_method_2")

def main():
    print(run_query('test okay'))
    print(DataHandler().run_query("fix test fails")) 

if __name__ == "__main__":
    main()

## The DRY principle with decorators
DRY Don't repeat yourself 
* Do not create the decorator in the first place from scratch. Wait until the pattern emerges and the abstraction for the decorator becomes clear, and then refactor. 不要從零開始就亂做decorator，等設計模式跟可以抽象再來重構decorator
* Consider that the decorator has to be applied several times (at least three times)
before implementing it. 可以重覆利用decorator三次以上
* Keep the code in the decorators to a minimum. 實持decorator的程式輕量簡潔

## Decorators and separation of concerns 

In [None]:
'''
這個例子的decorator要log也要量時間
'''
import functools
import time


def traced_function(function):
    @functools.wraps(function)
    def wrapped(*args, **kwargs):
        logging.info("started execution of %s", function.__qualname__)
        start_time = time.time()
        result = function(*args, **kwargs)
        logging.info(
            "function %s took %.2fs",
            function.__qualname__,
            time.time() - start_time,
        )
        return result

    return wrapped


@traced_function
def operation1():
    time.sleep(2)
    logging.info("running operation 1")
    return 2

In [None]:
import time
from functools import wraps
import logging

'''
把log和量時間分開成兩個decorator
'''
def log_execution(function):
    @wraps(function)
    def wrapped(*args, **kwargs):
        logging.info("started execution of %s", function.__qualname__)
        return function(*kwargs, **kwargs)

    return wrapped


def measure_time(function):
    @wraps(function)
    def wrapped(*args, **kwargs):
        start_time = time.time()
        result = function(*args, **kwargs)

        logging.info(
            "function %s took %.2f",
            function.__qualname__,
            time.time() - start_time,
        )
        return result

    return wrapped


@measure_time
@log_execution
def operation():
    time.sleep(3)
    logging.info("running operation...")
    return 33