# Optimizer

In [1]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

In [2]:
#export
import sys
sys.path.insert(0, '/'.join(sys.path[0].split('/')[:-1] + ['scripts']))

from sub_model import *

In [3]:
#export
class Optimizer():
    def __init__(self, parameters, learning_rate):
        '''Vanilla optimizer with basic methods.
            parameters: model parameters
            learning_rate: step size of each training iteration
        '''
        self.parameters, self.learning_rate = parameters, learning_rate
        
    def step(self):
        for parameter in self.parameters:
            parameter.step(self.learning_rate)
    
    def zero_grad(self):
        for parameter in self.parameters:
            parameter.zero_grad()
    
    def __repr__(self):
        return f'(Optimizer) learning_rate: {self.learning_rate}'

In [4]:
#export
class DynamicOpt():
    def __init__(self, parameters, **hyper_params):
        '''Dynamic optimizer to allow multiple hyper parameters and param scheduling.
            parameters: model parameters
            hyper_params: dictionary of hyper parameters optimizer keeps track of
        '''
        self.parameters = parameters
        self.hyper_params = dict(hyper_params)
        
    def step(self):
        for parameter in self.parameters:
            parameter.step(self.hyper_params['learning_rate'])
    
    def zero_grad(self):
        for parameter in self.parameters:
            parameter.zero_grad()
    
    def __repr__(self):
        return f'(DynamicOpt) hyper_params: {list(self.hyper_params)}'

# Tests

In [5]:
data_bunch = get_data_bunch(*get_mnist_data(), batch_size=64)
model = get_conv_model(data_bunch)
optimizer = Optimizer(list(model.parameters()), learning_rate=0.1)

print(model)
print(optimizer)

(Model)
    Reshape(1, 28, 28)
    Conv(1, 8, 5, 4)
    ReLU()
    Conv(8, 16, 3, 2)
    Flatten()
    Linear(256, 10)
(Optimizer) learning_rate: 0.1


In [6]:
data_bunch = get_data_bunch(*get_mnist_data(), batch_size=64)
model = get_conv_model(data_bunch)
optimizer = DynamicOpt(list(model.parameters()), learning_rate=0.1)

print(model)
print(optimizer)

(Model)
    Reshape(1, 28, 28)
    Conv(1, 8, 5, 4)
    ReLU()
    Conv(8, 16, 3, 2)
    Flatten()
    Linear(256, 10)
(DynamicOpt) hyper_params: ['learning_rate']
