Skip to content
This repository has been archived by the owner on Mar 3, 2024. It is now read-only.

Commit

Permalink
Add batch normalization layer
Browse files Browse the repository at this point in the history
  • Loading branch information
CyberZHG committed Mar 2, 2019
1 parent c02af7f commit bd54cd4
Show file tree
Hide file tree
Showing 12 changed files with 149 additions and 4 deletions.
1 change: 1 addition & 0 deletions auto_diff/layers/__init__.py
Expand Up @@ -4,3 +4,4 @@
from .conv import Conv2D
from .recurrent import LSTM, GRU
from .dropout import Dropout
from .batch_norm import BatchNorm
78 changes: 78 additions & 0 deletions auto_diff/layers/batch_norm.py
@@ -0,0 +1,78 @@
import auto_diff as ad
from .layer import Layer


class BatchNorm(Layer):

def __init__(self,
momentum=0.99,
epsilon=1e-3,
scale=True,
center=True,
beta_initializer=ad.inits.zeros,
gamma_initializer=ad.inits.ones,
**kwargs):
super(BatchNorm, self).__init__(**kwargs)
self.momentum = momentum
self.epsilon = epsilon
self.scale = scale
self.center = center
self.beta_initializer = beta_initializer
self.gamma_initializer = gamma_initializer
self.gamma, self.beta = None, None
self.moving_mean, self.moving_var = None, None

def build(self, input_shape):
if not self._built:
if self.scale:
self.gamma = self.add_weight(
name='gamma',
shape=(input_shape[-1],),
initializer=self.gamma_initializer,
trainable=True,
)
if self.center:
self.beta = self.add_weight(
name='beta',
shape=(input_shape[-1],),
initializer=self.beta_initializer,
trainable=True,
)
self.moving_mean = self.add_weight(
name='moving_mean',
shape=(input_shape[-1],),
initializer=self.gamma_initializer,
trainable=False,
)
self.moving_var = self.add_weight(
name='moving_var',
shape=(input_shape[-1],),
initializer=self.beta_initializer,
trainable=False,
)
super(BatchNorm, self).build(input_shape)

def compute_output_shape(self, input_shape):
return input_shape

def call_moving(self, inputs, moving_mean, moving_var):
normal = (inputs - moving_mean) / ad.sqrt(moving_var + self.epsilon)
if self.scale:
normal *= self.gamma
if self.center:
normal += self.beta
return normal

def call(self, inputs, **kwargs):
sum_axis = tuple(range(len(inputs.shape) - 1))
mean = ad.mean(inputs, axis=sum_axis, keepdims=True)
var = ad.mean(ad.square(inputs - mean), axis=sum_axis, keepdims=True)
moving_mean = self.momentum * self.moving_mean + (1.0 - self.momentum) * mean
moving_var = self.momentum * self.moving_var + (1.0 - self.momentum) * var
self.add_update(self.moving_mean, ad.squeeze(moving_mean, axis=sum_axis))
self.add_update(self.moving_var, ad.squeeze(moving_var, axis=sum_axis))
return ad.where(
ad.in_train_phase(),
self.call_moving(inputs, moving_mean, moving_var),
self.call_moving(inputs, self.moving_mean, self.moving_var),
)
3 changes: 1 addition & 2 deletions auto_diff/layers/dropout.py
Expand Up @@ -11,7 +11,6 @@ def __init__(self,
super(Dropout, self).__init__(**kwargs)
self.rate = rate
self.noise_shape = noise_shape
self.in_train_phase = ad.in_train_phase()

def compute_output_shape(self, input_shape):
return input_shape
Expand All @@ -22,7 +21,7 @@ def call(self, inputs, **kwargs):
noise_shape = self.noise_shape
else:
noise_shape = ad.shape(inputs)
return ad.where(self.in_train_phase,
return ad.where(ad.in_train_phase(),
inputs * (ad.random(noise_shape) > self.rate),
inputs)

Expand Down
8 changes: 8 additions & 0 deletions auto_diff/layers/layer.py
Expand Up @@ -12,6 +12,7 @@ def __init__(self, **kwargs):
self._outputs = None
self._input_shapes = None
self._output_shapes = None
self._updates = []

def build(self, input_shape):
self._built = True
Expand Down Expand Up @@ -65,3 +66,10 @@ def input_shapes(self):
@property
def output_shapes(self):
return self._output_shapes

@property
def updates(self):
return self._updates

def add_update(self, var: ad.OpVariable, update: ad.Operation):
self.updates.append((var, update))
4 changes: 4 additions & 0 deletions auto_diff/models/model.py
Expand Up @@ -16,6 +16,7 @@ def __init__(self,
self._losses = None
self._loss = None
self._layers = None
self._updates = []
self._output_placeholders = None
self._session = ad.sess.Session()

Expand Down Expand Up @@ -51,6 +52,7 @@ def _collect_all_layers(layer):
for layer in self._layers.values():
self._trainable_weights += layer.trainable_weights
self._non_trainable_weights += layer.non_trainable_weights
self._updates += layer.updates

self._loss = 0.0
if isinstance(self.outputs, list):
Expand Down Expand Up @@ -99,6 +101,8 @@ def fit_on_batch(self,
self._session.prepare()
self._session.run(self._loss, feed_dict=feed_dict)
self._loss.backward()
for var, update in self.updates:
var.update(update.forward(feed_dict=feed_dict))
self._optimizer.update(self.trainable_weights, self._session)

def predict_on_batch(self, x: Union[np.ndarray, List[np.ndarray]]) -> Union[np.ndarray, List[np.ndarray]]:
Expand Down
39 changes: 39 additions & 0 deletions tests/layers/test_batch_norm.py
@@ -0,0 +1,39 @@
from unittest import TestCase
import numpy as np
import auto_diff as ad


class TestBatchNorm(TestCase):

def test_no_moving(self):
input_layer = ad.layers.Input(shape=(None, 5))
normal_layer = ad.layers.BatchNorm()(input_layer)
dense_layer = ad.layers.Dense(output_dim=2, activation=ad.acts.softmax)(normal_layer)
model = ad.models.Model(inputs=input_layer, outputs=dense_layer)
model.build(
optimizer=ad.optims.Adam(),
losses=ad.losses.cross_entropy,
)

input_vals = np.random.random((2, 5))
first = model.predict_on_batch(input_vals)
second = model.predict_on_batch(input_vals)
self.assertTrue(np.allclose(first, second))

def test_fit(self):
np.random.seed(0xcafe)
input_layer = ad.layers.Input(shape=(None, 5))
normal_layer = ad.layers.BatchNorm()(input_layer)
dense_layer = ad.layers.Dense(output_dim=2, activation=ad.acts.softmax)(normal_layer)
model = ad.models.Model(inputs=input_layer, outputs=dense_layer)
model.build(
optimizer=ad.optims.Adam(),
losses=ad.losses.cross_entropy,
)

input_vals = np.random.random((2, 5))
output_vals = np.array([[0.0, 1.0], [1.0, 0.0]])
for _ in range(5000):
model.fit_on_batch(input_vals, output_vals)
actual = np.argmax(model.predict_on_batch(input_vals), axis=-1).tolist()
self.assertEqual([1.0, 0.0], actual)
1 change: 1 addition & 0 deletions tests/layers/test_conv.py
Expand Up @@ -69,6 +69,7 @@ def test_dilation_valid_shape(self):
self.assertEqual((2, 3, 3, 4), output.shape)

def test_fit(self):
np.random.seed(0xcafe)
input_layer = ad.layers.Input(shape=(None, None, None, 2))
conv_layer = ad.layers.Conv2D(kernel_size=3, filters=2, padding='same', activation=ad.acts.relu)(input_layer)
model = ad.models.Model(inputs=input_layer, outputs=conv_layer)
Expand Down
1 change: 1 addition & 0 deletions tests/layers/test_dense.py
Expand Up @@ -14,6 +14,7 @@ def test_output(self):
self.assertEqual((3, 3), output.shape)

def test_fit(self):
np.random.seed(0xcafe)
input_layer = ad.layers.Input(shape=(None, 5))
dense_layer = ad.layers.Dense(output_dim=2, activation=ad.acts.softmax)(input_layer)
model = ad.models.Model(inputs=input_layer, outputs=dense_layer)
Expand Down
7 changes: 5 additions & 2 deletions tests/layers/test_dropout.py
Expand Up @@ -13,9 +13,10 @@ def test_predict_phase(self):
y = model.predict_on_batch(x)
self.assertTrue(np.allclose(x, y))

def test_fit_half(self):
def test_fit(self):
np.random.seed(0xcafe)
input_layer = ad.layers.Input(shape=(None, 5))
drop_layer = ad.layers.Dropout(rate=0.5)(input_layer)
drop_layer = ad.layers.Dropout(rate=0.1)(input_layer)
dense_layer = ad.layers.Dense(output_dim=2, activation=ad.acts.softmax)(drop_layer)
model = ad.models.Model(inputs=input_layer, outputs=dense_layer)
model.build(
Expand All @@ -31,6 +32,7 @@ def test_fit_half(self):
self.assertEqual([1.0, 0.0], actual)

def test_fit_zero(self):
np.random.seed(0xcafe)
input_layer = ad.layers.Input(shape=(None, 5))
drop_layer = ad.layers.Dropout(rate=0.0)(input_layer)
dense_layer = ad.layers.Dense(output_dim=2, activation=ad.acts.softmax)(drop_layer)
Expand All @@ -48,6 +50,7 @@ def test_fit_zero(self):
self.assertEqual([1.0, 0.0], actual)

def test_fit_noise_shape(self):
np.random.seed(0xcafe)
input_layer = ad.layers.Input(shape=(None, 5))
drop_layer = ad.layers.Dropout(rate=0.5, noise_shape=(1, 5))(input_layer)
dense_layer = ad.layers.Dense(output_dim=2, activation=ad.acts.softmax)(drop_layer)
Expand Down
2 changes: 2 additions & 0 deletions tests/layers/test_recurrent.py
Expand Up @@ -22,6 +22,7 @@ def test_output_seq(self):
self.assertEqual((2, 7, 5), output.shape)

def test_fit(self):
np.random.seed(0xcafe)
input_layer = ad.layers.Input(shape=(None, None, 3))
lstm_layer = ad.layers.LSTM(units=7, return_sequences=True)(input_layer)
lstm_layer = ad.layers.LSTM(units=2)(lstm_layer)
Expand Down Expand Up @@ -63,6 +64,7 @@ def test_output_seq(self):
self.assertEqual((2, 7, 5), output.shape)

def test_fit(self):
np.random.seed(0xcafe)
input_layer = ad.layers.Input(shape=(None, None, 3))
lstm_layer = ad.layers.GRU(units=7, return_sequences=True)(input_layer)
lstm_layer = ad.layers.GRU(units=2)(lstm_layer)
Expand Down
4 changes: 4 additions & 0 deletions tests/optims/test_adam.py
Expand Up @@ -21,6 +21,7 @@ def _test_fitting(self, model):
self.assertEqual([1.0, 0.0], actual)

def test_default(self):
np.random.seed(0xcafe)
model = self._create_model()
model.build(
optimizer=ad.optims.Adam(lr=1e-3),
Expand All @@ -29,6 +30,7 @@ def test_default(self):
self._test_fitting(model)

def test_decay(self):
np.random.seed(0xcafe)
model = self._create_model()
model.build(
optimizer=ad.optims.Adam(lr=1e-3, decay=1e-3),
Expand All @@ -37,6 +39,7 @@ def test_decay(self):
self._test_fitting(model)

def test_amsgrad(self):
np.random.seed(0xcafe)
model = self._create_model()
model.build(
optimizer=ad.optims.Adam(lr=1e-3, amsgrad=True),
Expand All @@ -45,6 +48,7 @@ def test_amsgrad(self):
self._test_fitting(model)

def test_all(self):
np.random.seed(0xcafe)
model = self._create_model()
model.build(
optimizer=ad.optims.Adam(lr=1e-3, decay=1e-3, amsgrad=True),
Expand Down
5 changes: 5 additions & 0 deletions tests/optims/test_sgd.py
Expand Up @@ -21,6 +21,7 @@ def _test_fitting(self, model):
self.assertEqual([1.0, 0.0], actual)

def test_default(self):
np.random.seed(0xcafe)
model = self._create_model()
model.build(
optimizer=ad.optims.SGD(lr=1e-3),
Expand All @@ -29,6 +30,7 @@ def test_default(self):
self._test_fitting(model)

def test_momentum(self):
np.random.seed(0xcafe)
model = self._create_model()
model.build(
optimizer=ad.optims.SGD(momentum=0.9, lr=1e-3),
Expand All @@ -37,6 +39,7 @@ def test_momentum(self):
self._test_fitting(model)

def test_decay(self):
np.random.seed(0xcafe)
model = self._create_model()
model.build(
optimizer=ad.optims.SGD(decay=1e-3, lr=1e-3),
Expand All @@ -45,6 +48,7 @@ def test_decay(self):
self._test_fitting(model)

def test_nesterov(self):
np.random.seed(0xcafe)
model = self._create_model()
model.build(
optimizer=ad.optims.SGD(lr=1e-3, nesterov=True),
Expand All @@ -53,6 +57,7 @@ def test_nesterov(self):
self._test_fitting(model)

def test_all(self):
np.random.seed(0xcafe)
model = self._create_model()
model.build(
optimizer=ad.optims.SGD(momentum=0.9, decay=1e-3, lr=1e-3, nesterov=True),
Expand Down

0 comments on commit bd54cd4

Please sign in to comment.