In [14]:
%matplotlib inline
import math
import time
import numpy as np
import torch

We need some utility functions to simplify object-oriented programming in jupyter notebooks.

In [15]:
def add_to_class(Class): #@save
    """ Register functions as methods in created class."""
    def wrapper(obj):
        setattr(Class, obj.__name__, obj)
    return wrapper

We also utilize another utility class that saved the constructor params when passed in.

In [16]:
import inspect


class HyperParameters: #@save
    def save_hyperparameters(self, ignore=[]):
        frame = inspect.currentframe().f_back
        _,_,_, local_vars = inspect.getargvalues(frame)
        self.hparams

This final utility allows us to plot experiment progress interactively while it is going on.

In [17]:
class ProgressBoard(HyperParameters): #@save
    def __init__(self, xlabel=None, ylabel=None, xlim=None, ylim=None, 
                 xscale='linear', yscale='linear', ls=['-', '--', '-.', ':'], colors=['C0', 'C1', 'C2', 'C3'],
                 fig=None, axes=None, figsize=(3.5, 2.5), display=True):
        self.save_hyperparameters()

    def draw(self, x, y, label, every_n=1):
        raise NotImplemented

Now we need a way to define an abstract represenation of all our machine leanring models.
The module class is the base class of all modles we will implement. We need this class to:
*   store learnable parameters using the `__init__` method
*   the `training_step` method return the loss value
*   lastly, a `configure_optimizers` method that returns the optimization method or a list of them that is used to update the learnable parameters of the model

In [18]:
class Module(torch.nn.Module, HyperParameters):
    def __init__(self, plot_train_per_epoch=2, plot_valid_per_epoch=1):
        super().__init__()
        self.save_hyperparameters()
        self.board = ProgressBoard()

    def loss(self, y_hat, y):
        raise NotImplementedError
    
    def forward(self, X):
        assert hasattr(self, 'net'), 'Neural network is defined' 
        return self.net(x)
    
    def plot(self, key, value, train: bool):
        """Plot a point in animation"""
        assert hasattr(self, 'trainer'), 'Trainer is not defined'
        self.board.xlable = 'epoch'
        if train:
            x = self.trainer.train_batch_idx / \
                self.trainer.num_train_batches
            n = self.trainer.num_train_batches / \
                self.plot_train_per_epoch
        else:
            x = self.trainer.epoch + 1
            n = self.trainer.num_val_batches / \
                self.plot_valid_per_epoch
        self.board.draw(x, value.to(torch.cpu()).detach().numpy(),
                        ('train_' if train else 'val_') + key,
                        every_n=int(n))
        
        def training_step(self, batch):
            l = self.loss(self(*batch[:-1]), batch[-1])
            self.plot('loss', l, train=False)
            return l

        def validation_step(self, batch):
            l = self.loss(self(*batch[:-1]), batch[-1])
            self.plot('loss', l, train=False)

        def configure_optimizers(self):
            raise NotImplementedError

In [19]:
class DataModule(HyperParameters):
    def __init__(self, root='../data', num_workers=4):
        self.save_hyperparameters()
    
    def get_dataloader(self, train: bool):
        raise NotImplementedError
    
    def train_dataloader(self):
        return self.get_dataloader(self, True)
    
    def val_dataloader(self):
        return self.get_dataloader(self, False)

The `Trainer` class trains the learnable parameters in the Module Class with data specified in DataModule. The key method to this interface is the `fit` method which accepts the model of type `Module` and the data of type `DataModule`. 
The interface iterates over the entire dataset `max_epochs` times to train the model.

In [21]:
class Trainer(HyperParameters):
    def __init__(self, max_epochs, num_gpus=0, gradient_clip_val=0):
        self.save_hyperparameters()
        assert num_gpus == 0, 'No GPU support yet'

    def prepare_data(self, data: DataModule) -> None:
        self.train_dataloader = data.train_dataloader()
        self.val_dataloader = data.val_dataloader()
        self.num_train_batches = len(self.train_dataloader)
        self.num_val_batches = (len(self.val_dataloader) if self.val_dataloader is not None else 0)

    def prepare_model(self, model: Module) -> None:
        model.trainer = self
        model.board.xlim = [0, self.max_epochs]
        self.model = model

    def fit(self, model: Module, data: DataModule) -> None:
        self.prepare_data(data)
        self.prepare_model(model)
        self.optim = model.configure_optimizers()
        self.epoch = 0
        self.train_batch_idx = 0
        self.val_batch_idx = 0
        for self.epoch in range(self.max_epochs):
            self.fit_epoch()

    def fit_epoch(self):
        raise NotImplementedError