# OOP Design for implementation

In [49]:
import time
import sys
import numpy as np
import torch
from torch import nn
from d2l import torch as d2l
import collections


### Using Decorator for better code implementation

## First:

Utility function allows us to register function as method in class after the class has been created.we can do so even after we have created instances of the class! It allows us to split the implementation of a class into multiple code blocks.

In [50]:
def add_to_class(Class):
    """Register function as method in created class."""
    def wrapper(obj):
        setattr(Class, obj.__name__, obj)
    return wrapper

Let see what `add_to_class()` will do.

In [51]:
class A:
    def __init__(self) -> None:
        self.b = 1

a = A()

Class A is created and instantiated, now we code method for the class and add it via decorator `add_to_class()`.

In [52]:
def do(self):
    print("Class attribute 'b' is", self.b)

a.do()

AttributeError: 'A' object has no attribute 'do'

we can see that it give am AttributeError: `'A' object has no attribute 'do'`

that means there is no Attribute name `do`

Now we decorate the function `do()` with `add_to_class()` which will add `do()` as a method to the `class A`

lets do it!!

In [None]:
@add_to_class(A)
def do(self):
    print("Class attribute 'b' is", self.b)

a.do()

Class attribute 'b' is 1


Voila!! 😍

## Second: 
Utility class that saves all arguments is a class's `__init__()` method as class attribute.This allows us to extend constructor call signatures implicitly without additional code

In [None]:
import inspect
class HyperParameters:
    """ The base class of Hyperparameters"""
    def save_hypers(self, ignore=[]):
        raise NotImplemented

In [None]:
@add_to_class(HyperParameters)
def save_hyperparameters(self, ignore=[]):
        """Save function arguments into class attributes.
    
        Defined in :numref:`sec_utils`"""
        frame = inspect.currentframe().f_back
        _, _, _, local_vars = inspect.getargvalues(frame)
        self.hparams = {k:v for k, v in local_vars.items()
                        if k not in set(ignore+['self']) and not k.startswith('_')}
        for k, v in self.hparams.items():
            setattr(self, k, v)

In [None]:
class B(HyperParameters):
    def __init__(self, a,b,c) -> None:
        self.save_hyperparameters(ignore=['c'])
        print('self.a =', self.a, 'self.b =', self.b)
        print("There is no self.c =", not hasattr(self, 'c'))

b = B(a=1, b=2, c=3)

self.a = 1 self.b = 2
There is no self.c = True


In [None]:
class ProgressBoard(HyperParameters):
    """The board that plots data points in animation.

    Defined in :numref:`sec_oo-design`"""
    def __init__(self, xlabel=None, ylabel=None, xlim=None,
                 ylim=None, xscale='linear', yscale='linear',
                 ls=['-', '--', '-.', ':'], colors=['C0', 'C1', 'C2', 'C3'],
                 fig=None, axes=None, figsize=(3.5, 2.5), display=True):
        self.save_hyperparameters()

    def draw(self, x, y, label, every_n=1):
        raise NotImplemented

    def draw(self, x, y, label, every_n=1):
        """Defined in :numref:`sec_utils`"""
        Point = collections.namedtuple('Point', ['x', 'y'])
        if not hasattr(self, 'raw_points'):
            self.raw_points = collections.OrderedDict()
            self.data = collections.OrderedDict()
        if label not in self.raw_points:
            self.raw_points[label] = []
            self.data[label] = []
        points = self.raw_points[label]
        line = self.data[label]
        points.append(Point(x, y))
        if len(points) != every_n:
            return
        mean = lambda x: sum(x) / len(x)
        line.append(Point(mean([p.x for p in points]),
                          mean([p.y for p in points])))
        points.clear()
        if not self.display:
            return
        d2l.use_svg_display()
        if self.fig is None:
            self.fig = d2l.plt.figure(figsize=self.figsize)
        plt_lines, labels = [], []
        for (k, v), ls, color in zip(self.data.items(), self.ls, self.colors):
            plt_lines.append(d2l.plt.plot([p.x for p in v], [p.y for p in v],
                                          linestyle=ls, color=color)[0])
            labels.append(k)
        axes = self.axes if self.axes else d2l.plt.gca()
        if self.xlim: axes.set_xlim(self.xlim)
        if self.ylim: axes.set_ylim(self.ylim)
        if not self.xlabel: self.xlabel = self.x
        axes.set_xlabel(self.xlabel)
        axes.set_ylabel(self.ylabel)
        axes.set_xscale(self.xscale)
        axes.set_yscale(self.yscale)
        axes.legend(plt_lines, labels)
        display.display(self.fig)
        display.clear_output(wait=True)

In [None]:
class ProgressBoard(HyperParameters):  #@save
    """The board that plots data points in animation."""
    def __init__(self, xlabel=None, ylabel=None, xlim=None,
                 ylim=None, xscale='linear', yscale='linear',
                 ls=['-', '--', '-.', ':'], colors=['C0', 'C1', 'C2', 'C3'],
                 fig=None, axes=None, figsize=(3.5, 2.5), display=True):
        self.save_hyperparameters()

    def draw(self, x, y, label, every_n=1):
        raise NotImplemented

In [54]:
board = d2l.ProgressBoard('x')
for x in np.arange(0, 10, 0.01):
    board.draw(x, np.sin(x), 'sin', every_n=2)
    board.draw(x, np.cos(x), 'cos', every_n=10)

AttributeError: module 'd2l.torch' has no attribute 'ProgressBoard'