[View in Colaboratory](https://colab.research.google.com/github/averma12/Beginning_Data_Science/blob/master/Python_OOP_basics.ipynb)

In [0]:
import matplotlib.pyplot as plt
import os

##Basic OOP in Python
We will be visiting the basic OOP operations used in Python when using and building a deep learning library namely Classes,Abstract Classes, Inheritance etc.

In [0]:
#Abstract class. A class is abstract if it maintains abstract methods. i.e methods that are not defined in the class but declared for the purpose
# of inheritance

class Callback:
    '''
    An abstract class that all callback(e.g., LossRecorder) classes extends from. 
    Must be extended before usage.
    '''
    def on_train_begin(self): pass
    def on_batch_begin(self): pass
    def on_phase_begin(self): pass
    def on_epoch_end(self, metrics): pass
    def on_phase_end(self): pass
    def on_batch_end(self, metrics): pass
    def on_train_end(self): pass

In [0]:
class LossRecorder(Callback):
    '''
    Saves and displays loss functions and other metrics. 
    Default sched when none is specified in a learner. 
    '''
    def __init__(self, layer_opt, save_path='', record_mom=False, metrics=[]):
        super().__init__()
        self.layer_opt=layer_opt
        self.init_lrs=np.array(layer_opt.lrs)
        self.save_path, self.record_mom, self.metrics = save_path, record_mom, metrics

    def on_train_begin(self):
        self.losses,self.lrs,self.iterations,self.epochs,self.times = [],[],[],[],[]
        self.start_at = timer()
        self.val_losses, self.rec_metrics = [], []
        if self.record_mom:
            self.momentums = []
        self.iteration = 0
        self.epoch = 0

    def on_epoch_end(self, metrics):
        self.epoch += 1
        self.epochs.append(self.iteration)
        self.times.append(timer() - self.start_at)
        self.save_metrics(metrics)

    def on_batch_end(self, loss):
        self.iteration += 1
        self.lrs.append(self.layer_opt.lr)
        self.iterations.append(self.iteration)
        if isinstance(loss, list):
            self.losses.append(loss[0])
            self.save_metrics(loss[1:])
        else: self.losses.append(loss)
        if self.record_mom: self.momentums.append(self.layer_opt.mom)

    def save_metrics(self,vals):
        self.val_losses.append(delistify(vals[0]))
        if len(vals) > 2: self.rec_metrics.append(vals[1:])
        elif len(vals) == 2: self.rec_metrics.append(vals[1])

    def plot_loss(self, n_skip=10, n_skip_end=5):
        '''
        plots loss function as function of iterations. 
        When used in Jupyternotebook, plot will be displayed in notebook. Else, plot will be displayed in console and both plot and loss are saved in save_path. 
        '''
        if not in_ipynb(): plt.switch_backend('agg')
        plt.plot(self.iterations[n_skip:-n_skip_end], self.losses[n_skip:-n_skip_end])
        if not in_ipynb():
            plt.savefig(os.path.join(self.save_path, 'loss_plot.png'))
            np.save(os.path.join(self.save_path, 'losses.npy'), self.losses[10:])

    def plot_lr(self):
        '''Plots learning rate in jupyter notebook or console, depending on the enviroment of the learner.'''
        if not in_ipynb():
            plt.switch_backend('agg')
        if self.record_mom:
            fig, axs = plt.subplots(1,2,figsize=(12,4))
            for i in range(0,2): axs[i].set_xlabel('iterations')
            axs[0].set_ylabel('learning rate')
            axs[1].set_ylabel('momentum')
            axs[0].plot(self.iterations,self.lrs)
            axs[1].plot(self.iterations,self.momentums)   
        else:
            plt.xlabel("iterations")
            plt.ylabel("learning rate")
            plt.plot(self.iterations, self.lrs)
        if not in_ipynb():
            plt.savefig(os.path.join(self.save_path, 'lr_plot.png'))