In [None]:
import numpy as np
import torch
from torch import nn
from d2l import torch as d2l

## OOP （面向对象） API

这一节记录一些用于 jupyter notebook 面向对象编程的 API。

### 添加属性 / 方法

允许我们将函数添加进一个类作为方法。

可以将 `class` 的定义拆散成小块。

In [None]:
def add_to_class(Class):
    def wrapper(obj):
        setattr(Class, obj.__name__, obj)
    return wrapper

In [None]:
class Cat:
    def __init__(self, cat_name: str):
        self.name = cat_name
        self.age = 0
    
    def grow(self):
        self.age += 1
    
    def __repr__(self):
        return "Cat {}, age {}".format(self.name, self.age)

neko = Cat('Neko')
neko.grow()
neko

定义函数并将函数添加到类中，相当于：

```python
add_to_cat = add_to_class(Cat)
self_introduction = add_to_cat(self_introduction)
```

In [None]:
@add_to_class(Cat)
def self_introduction(self) -> None:
    print("Hello! My name is {} nya~ I am {} years old nya~".format(self.name, self.age))

# 定义完成后，self_introduction 变为 wrapper 的返回值，即 None
self_introduction, neko.self_introduction()

### 保存超参数

方法：调用栈，查看栈内的局部变量，将局部变量写为类的属性

该方法可以将 `__init__` 函数传入的形参（本地变量）全都保存为类的属性

In [None]:
import inspect

class HyperParameters:
    def save_hyperparameters(self, ignore=[]):
        # currentframe 获得栈帧
        frame = inspect.currentframe().f_back
        _, _, _, local_vars = inspect.getargvalues(frame)
        self.hparams = {
            k: v for k, v in local_vars.items()
            if k not in set(ignore + ['self']) and not k.startswith('_')
        }
        for k, v in self.hparams.items():
            setattr(self, k, v)

In [None]:
class B(HyperParameters):
    def __init__(self, a, b, c):
        self.save_hyperparameters(ignore=['c']) # 将 a, b 保存为属性，不包括 c

b = B(a=1, b='qwq', c=3)
b.a, b.b, hasattr(b, 'c')

### 动态绘图

可以根据输出结果实时绘图（慎用，占用内存较大）

In [None]:
import collections
from matplotlib_inline import backend_inline
import matplotlib.pyplot as plt
from IPython import display

class ProgressBoard(HyperParameters):
    def __init__(
        self, xlabel=None, ylabel=None, xlim=None, ylim=None,
        xscale='linear', yscale='linear',
        ls=['-', '--', '-.', ':'],
        colors=['C0', 'C1', 'C2', 'C3'],
        fig=None, axes=None, figsize=(3.5, 2.5), display=True
    ):
        self.save_hyperparameters()
    
    def draw(self, x, y, label, every_n=1):
        Point = collections.namedtuple('Point', ['x', 'y']) # 创建命名元组
        if not hasattr(self, 'raw_points'):
            self.raw_points = collections.OrderedDict() # 有序映射
            self.data = collections.OrderedDict()
        if label not in self.raw_points:
            self.raw_points[label] = [] # 空列表
            self.data[label] = []
        points: list = self.raw_points[label]
        line: list = self.data[label]
        points.append(Point(x, y))
        if len(points) != every_n:
            return
        mean = lambda x: sum(x) / len(x)
        line.append(Point(
            mean([p.x for p in points]),
            mean([p.y for p in points])
        ))
        points.clear()
        if not self.display:
            return
        backend_inline.set_matplotlib_formats('svg') # 设置矢量图格式
        if self.fig is None:
            self.fig = plt.figure(figsize=self.figsize) # 创建画布
        plt_lines, labels = [], []
        for (k, v), ls ,color in zip(self.data.items(), self.ls, self.colors):
            plt_lines.append(plt.plot(
                [p.x for p in v], [p.y for p in v],
                linestyle=ls, color=color
            )[0])
            labels.append(k)
        axes = self.axes if self.axes else plt.gca() # get current axes 获得当前子图
        if self.xlim: axes.set_xlim(self.xlim)
        if self.ylim: axes.set_ylim(self.ylim)
        if not self.xlabel: self.xlabel = self.x
        axes.set_xlabel(self.xlabel)
        axes.set_ylabel(self.ylabel)
        axes.set_xscale(self.xscale)
        axes.set_yscale(self.yscale)
        axes.legend(plt_lines, labels) # 设置图例
        display.display(self.fig) # 显示图像
        display.clear_output(wait=True) # 让新显示的图覆盖原图

In [None]:
board = ProgressBoard('x')
for x in np.arange(0, 10, 0.1):
    board.draw(x, np.sin(x), 'sin')
    board.draw(x, np.cos(x), 'cos', 5)