Skip to content

Commit

Permalink
Fixes for Checkpoint & Reload
Browse files Browse the repository at this point in the history
  • Loading branch information
obilaniu committed Sep 1, 2017
1 parent b821d7b commit dc2b7f1
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 54 deletions.
67 changes: 37 additions & 30 deletions complexnn/bn.py
Expand Up @@ -20,6 +20,19 @@ def sqrt_init(shape, dtype=None):
return value


def sanitizedInitGet(init):
if init in ["sqrt_init"]:
return sqrt_init
else:
return initializers.get(init)
def sanitizedInitSer(init):
if init in [sqrt_init]:
return "sqrt_init"
else:
return initializers.serialize(init)



def complex_standardization(input_centred, Vrr, Vii, Vri,
layernorm=False, axis=-1):

Expand Down Expand Up @@ -246,24 +259,18 @@ def __init__(self,
self.epsilon = epsilon
self.center = center
self.scale = scale
self.beta_initializer = initializers.get(beta_initializer)
if gamma_diag_initializer != 'sqrt_init':
self.gamma_diag_initializer = initializers.get(gamma_diag_initializer)
else:
self.gamma_diag_initializer = sqrt_init
self.gamma_off_initializer = initializers.get(gamma_off_initializer)
self.moving_mean_initializer = initializers.get(moving_mean_initializer)
if moving_variance_initializer != 'sqrt_init':
self.moving_variance_initializer = initializers.get(moving_variance_initializer)
else:
self.moving_variance_initializer = sqrt_init
self.moving_covariance_initializer = initializers.get(moving_covariance_initializer)
self.beta_regularizer = regularizers.get(beta_regularizer)
self.gamma_diag_regularizer = regularizers.get(gamma_diag_regularizer)
self.gamma_off_regularizer = regularizers.get(gamma_off_regularizer)
self.beta_constraint = constraints.get(beta_constraint)
self.gamma_diag_constraint = constraints.get(gamma_diag_constraint)
self.gamma_off_constraint = constraints.get(gamma_off_constraint)
self.beta_initializer = sanitizedInitGet(beta_initializer)
self.gamma_diag_initializer = sanitizedInitGet(gamma_diag_initializer)
self.gamma_off_initializer = sanitizedInitGet(gamma_off_initializer)
self.moving_mean_initializer = sanitizedInitGet(moving_mean_initializer)
self.moving_variance_initializer = sanitizedInitGet(moving_variance_initializer)
self.moving_covariance_initializer = sanitizedInitGet(moving_covariance_initializer)
self.beta_regularizer = regularizers.get(beta_regularizer)
self.gamma_diag_regularizer = regularizers.get(gamma_diag_regularizer)
self.gamma_off_regularizer = regularizers.get(gamma_off_regularizer)
self.beta_constraint = constraints .get(beta_constraint)
self.gamma_diag_constraint = constraints .get(gamma_diag_constraint)
self.gamma_off_constraint = constraints .get(gamma_off_constraint)

def build(self, input_shape):

Expand Down Expand Up @@ -434,18 +441,18 @@ def get_config(self):
'epsilon': self.epsilon,
'center': self.center,
'scale': self.scale,
'beta_initializer': initializers.serialize(self.beta_initializer),
'gamma_diag_initializer': initializers.serialize(self.gamma_diag_initializer) if self.gamma_diag_initializer != sqrt_init else 'sqrt_init',
'gamma_off_initializer': initializers.serialize(self.gamma_off_initializer),
'moving_mean_initializer': initializers.serialize(self.moving_mean_initializer),
'moving_variance_initializer': initializers.serialize(self.moving_variance_initializer) if self.moving_variance_initializer != sqrt_init else 'sqrt_init',
'moving_covariance_initializer': initializers.serialize(self.moving_covariance_initializer),
'beta_regularizer': regularizers.serialize(self.beta_regularizer),
'gamma_diag_regularizer': regularizers.serialize(self.gamma_diag_regularizer),
'gamma_off_regularizer': regularizers.serialize(self.gamma_off_regularizer),
'beta_constraint': constraints.serialize(self.beta_constraint),
'gamma_diag_constraint': constraints.serialize(self.gamma_diag_constraint),
'gamma_off_constraint': constraints.serialize(self.gamma_off_constraint),
'beta_initializer': sanitizedInitSer(self.beta_initializer),
'gamma_diag_initializer': sanitizedInitSer(self.gamma_diag_initializer),
'gamma_off_initializer': sanitizedInitSer(self.gamma_off_initializer),
'moving_mean_initializer': sanitizedInitSer(self.moving_mean_initializer),
'moving_variance_initializer': sanitizedInitSer(self.moving_variance_initializer),
'moving_covariance_initializer': sanitizedInitSer(self.moving_covariance_initializer),
'beta_regularizer': regularizers.serialize(self.beta_regularizer),
'gamma_diag_regularizer': regularizers.serialize(self.gamma_diag_regularizer),
'gamma_off_regularizer': regularizers.serialize(self.gamma_off_regularizer),
'beta_constraint': constraints .serialize(self.beta_constraint),
'gamma_diag_constraint': constraints .serialize(self.gamma_diag_constraint),
'gamma_off_constraint': constraints .serialize(self.gamma_off_constraint),
}
base_config = super(ComplexBatchNormalization, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
Expand Down
48 changes: 31 additions & 17 deletions complexnn/conv.py
Expand Up @@ -21,6 +21,27 @@
from .norm import LayerNormalization, ComplexLayerNorm



def sanitizedInitGet(init):
if init in ["sqrt_init"]:
return sqrt_init
elif init in ["complex", "complex_independent",
"glorot_complex", "he_complex"]:
return init
else:
return initializers.get(init)
def sanitizedInitSer(init):
if init in [sqrt_init]:
return "sqrt_init"
elif init == "complex" or isinstance(init, ComplexInit):
return "complex"
elif init == "complex_independent" or isinstance(init, ComplexIndependentFilters):
return "complex_independent"
else:
return initializers.serialize(init)



class ComplexConv(Layer):
"""Abstract nD complex convolution layer.
This layer creates a complex convolution kernel that is convolved
Expand Down Expand Up @@ -132,13 +153,10 @@ def __init__(self, rank,
self.init_criterion = init_criterion
self.spectral_parametrization = spectral_parametrization
self.epsilon = epsilon
if kernel_initializer in ['complex', 'complex_independent']:
self.kernel_initializer = kernel_initializer
else:
self.kernel_initializer = initializers.get(kernel_initializer)
self.bias_initializer = initializers.get(bias_initializer)
self.gamma_diag_initializer = initializers.get(gamma_diag_initializer)
self.gamma_off_initializer = initializers.get(gamma_off_initializer)
self.kernel_initializer = sanitizedInitGet(kernel_initializer)
self.bias_initializer = sanitizedInitGet(bias_initializer)
self.gamma_diag_initializer = sanitizedInitGet(gamma_diag_initializer)
self.gamma_off_initializer = sanitizedInitGet(gamma_off_initializer)
self.kernel_regularizer = regularizers.get(kernel_regularizer)
self.bias_regularizer = regularizers.get(bias_regularizer)
self.gamma_diag_regularizer = regularizers.get(gamma_diag_regularizer)
Expand Down Expand Up @@ -380,10 +398,6 @@ def compute_output_shape(self, input_shape):
return (input_shape[0],) + (2 * self.filters,) + tuple(new_space)

def get_config(self):
if self.kernel_initializer in {'complex', 'complex_independent'}:
ki = self.kernel_initializer
else:
ki = initializers.serialize(self.kernel_initializer)
config = {
'rank': self.rank,
'filters': self.filters,
Expand All @@ -395,10 +409,10 @@ def get_config(self):
'activation': activations.serialize(self.activation),
'use_bias': self.use_bias,
'normalize_weight': self.normalize_weight,
'kernel_initializer': ki,
'bias_initializer': initializers.serialize(self.bias_initializer),
'gamma_diag_initializer': initializers.serialize(self.gamma_diag_initializer),
'gamma_off_initializer': initializers.serialize(self.gamma_off_initializer),
'kernel_initializer': sanitizedInitSer(self.kernel_initializer),
'bias_initializer': sanitizedInitSer(self.bias_initializer),
'gamma_diag_initializer': sanitizedInitSer(self.gamma_diag_initializer),
'gamma_off_initializer': sanitizedInitSer(self.gamma_off_initializer),
'kernel_regularizer': regularizers.serialize(self.kernel_regularizer),
'bias_regularizer': regularizers.serialize(self.bias_regularizer),
'gamma_diag_regularizer': regularizers.serialize(self.gamma_diag_regularizer),
Expand Down Expand Up @@ -821,7 +835,7 @@ def __init__(self,
super(WeightNorm_Conv, self).__init__(**kwargs)
if self.rank == 1:
self.data_format = 'channels_last'
self.gamma_initializer = initializers.get(gamma_initializer)
self.gamma_initializer = sanitizedInitGet(gamma_initializer)
self.gamma_regularizer = regularizers.get(gamma_regularizer)
self.gamma_constraint = constraints.get(gamma_constraint)
self.epsilon = epsilon
Expand Down Expand Up @@ -887,7 +901,7 @@ def call(self, inputs):

def get_config(self):
config = {
'gamma_initializer': initializers.serialize(self.gamma_initializer),
'gamma_initializer': sanitizedInitSer(self.gamma_initializer),
'gamma_regularizer': regularizers.serialize(self.gamma_regularizer),
'gamma_constraint': constraints.serialize(self.gamma_constraint),
'epsilon': self.epsilon
Expand Down
22 changes: 18 additions & 4 deletions complexnn/utils.py
Expand Up @@ -5,7 +5,7 @@
# Authors: Dmitriy Serdyuk, Olexa Bilaniuk, Chiheb Trabelsi

import keras.backend as K
from keras.layers import Lambda
from keras.layers import Layer, Lambda

#
# GetReal/GetImag Lambda layer Implementation
Expand Down Expand Up @@ -69,6 +69,20 @@ def getpart_output_shape(input_shape):

return tuple(returned_shape)

GetReal = Lambda(get_realpart, output_shape=getpart_output_shape)
GetImag = Lambda(get_imagpart, output_shape=getpart_output_shape)
GetAbs = Lambda(get_abs, output_shape=getpart_output_shape)

class GetReal(Layer):
def call(self, inputs):
return get_realpart(inputs)
def compute_output_shape(self, input_shape):
return getpart_output_shape(input_shape)
class GetImag(Layer):
def call(self, inputs):
return get_imagpart(inputs)
def compute_output_shape(self, input_shape):
return getpart_output_shape(input_shape)
class GetAbs(Layer):
def call(self, inputs):
return get_abs(inputs)
def compute_output_shape(self, input_shape):
return getpart_output_shape(input_shape)

8 changes: 5 additions & 3 deletions scripts/training.py
Expand Up @@ -123,8 +123,8 @@ def getResidualBlock(I, filter_size, featmaps, stage, block, shortcut, convArgs,
(1, 1),
**convArgs)(I)

O_real = Concatenate(channel_axis)([GetReal(X), GetReal(O)])
O_imag = Concatenate(channel_axis)([GetImag(X), GetImag(O)])
O_real = Concatenate(channel_axis)([GetReal()(X), GetReal()(O)])
O_imag = Concatenate(channel_axis)([GetImag()(X), GetImag()(O)])
O = Concatenate( 1 )([O_real, O_imag])

return O
Expand Down Expand Up @@ -590,7 +590,9 @@ def train(d):
np.random.seed(d.seed % 2**32)
model = KM.load_model(chkptFilename, custom_objects={
"ComplexConv2D": ComplexConv2D,
"ComplexBatchNormalization": ComplexBN
"ComplexBatchNormalization": ComplexBN,
"GetReal": GetReal,
"GetImag": GetImag
})
L.getLogger("entry").info("... reloading complete.")

Expand Down

0 comments on commit dc2b7f1

Please sign in to comment.