In [1]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

In [2]:
#export
import sys
from os.path import join
import math

sys.path.insert(0, '/'.join(sys.path[0].split('/')[:-1] + ['scripts']))
from batch_norm import *

In [3]:
#export
class Optimizer():
    # vanilla
    def __init__(self, parameters, learning_rate):
        self.parameters = parameters
        self.learning_rate = learning_rate
    
    def __repr__(self):
        return f'(Optimizer) learning_rate: {self.learning_rate}'
        
    def step(self):
        for parameter in self.parameters:
            parameter.step(self.learning_rate)
    
    def zero_grad(self):
        for parameter in self.parameters:
            parameter.zero_grad()

In [4]:
data_bunch = get_data_bunch(*get_mnist_data(), batch_size=64)
model = get_conv_model(data_bunch)
optimizer = Optimizer(list(model.parameters()), learning_rate=0.1)

print(model)
print(optimizer)

(Sequential)
	(Layer1) Reshape(1, 28, 28)
	(Layer2) Conv2D(in: 1, out: 8, kernel: 5, stride: 4, pad: 2)
	(Layer3) ReLU()
	(Layer4) Conv2D(in: 8, out: 16, kernel: 3, stride: 2, pad: 1)
	(Layer5) Flatten()
	(Layer6) Linear(256, 10)
(Optimizer) learning_rate: 0.1


In [5]:
#export
class DynamicOpt():
    # for things like param scheduling or having multiple hyper params
    def __init__(self, parameters, **hyper_params):
        self.parameters = parameters
        self.hyper_params = dict(hyper_params)
    
    def __repr__(self):
        return f'(DynamicOpt) hyper_params: {list(self.hyper_params)}'
        
    def step(self):
        for parameter in self.parameters:
            parameter.step(self.hyper_params['learning_rate'])
    
    def zero_grad(self):
        for parameter in self.parameters:
            parameter.zero_grad()

In [6]:
data_bunch = get_data_bunch(*get_mnist_data(), batch_size=64)
model = get_conv_model(data_bunch)
optimizer = DynamicOpt(list(model.parameters()), learning_rate=0.1)

print(model)
print(optimizer)

(Sequential)
	(Layer1) Reshape(1, 28, 28)
	(Layer2) Conv2D(in: 1, out: 8, kernel: 5, stride: 4, pad: 2)
	(Layer3) ReLU()
	(Layer4) Conv2D(in: 8, out: 16, kernel: 3, stride: 2, pad: 1)
	(Layer5) Flatten()
	(Layer6) Linear(256, 10)
(DynamicOpt) hyper_params: ['learning_rate']
