In [None]:
import tensorflow as tf
import tensorflow.keras as keras
import tensorflow.keras.backend as K
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import Input, Dense, GRU, Conv2D, ZeroPadding2D, Conv2DTranspose, SeparableConv2D
from tensorflow.keras.layers import BatchNormalization, LayerNormalization
from tensorflow.keras.activations import elu, relu, sigmoid, tanh
from tensorflow.keras.constraints import Constraint
# from tensorflow.nn import relu, sigmoid, tanh
from tensorflow.keras.layers import Layer
from keras.regularizers import l2
from keras.initializers import HeUniform

import numpy as np
import random
from math import gcd

In [None]:
import import_ipynb

In [None]:
reg = l2(1e-6)
class WeightClip(Constraint):
    '''Clips the weights incident to each hidden unit to be inside a range
    '''
    def __init__(self, c=2):
        self.c = c

    def __call__(self, p):
        return K.clip(p, -self.c, self.c)

    def get_config(self):
        return {'name': self.__class__.__name__,
            'c': self.c}
constraint = WeightClip(0.499)

In [None]:
def grouped_conv2d_transpose(inputs, groups, convkwargs):
    """Performs grouped transposed convolution.

    Args:
        inputs: A `Tensor` of shape `[batch_size, h, w, c]`.
        filters: The number of convolutional filters.
        kernel_size: The spatial size of the convolutional kernel.
        strides: The convolutional stride.
        groups: The number of groups to use in the grouped convolution step.
            The input channel count needs to be evenly divisible by `groups`.
    Returns:
        A `Tensor` of shape `[batch_size, new_h, new_w, filters]`.
    """
    splits = tf.split(inputs, groups, axis=-1)
    convolved_splits = [
        Conv2DTranspose(**convkwargs)(split) for split in splits
    ]
    return tf.concat(convolved_splits, -1)

def grouped_conv2d(inputs, groups, convkwargs):
    
    splits = tf.split(inputs, groups, axis=-1)
    convolved_splits = [
        Conv2D(**convkwargs)(split) for split in splits
    ]
    return tf.concat(convolved_splits, -1)

In [None]:
# def RELU(x, quant=False):
#     if quant:
#         return relu(x)
#     else:
#         return tf.where(x<0.0, 0.0, x)

In [None]:
# def convkxf_old(inputs,
#             out_ch: int,
#             k: int = 1,
#             f: int = 3,
#             fstride: int = 2,
#             lookahead: int = 0,
#             batch_norm: bool = False,
#             act = 'relu',
#             mode = "normal",
#             depthwise: bool = True,
#             complex_in: bool = False,
#             reshape: bool = False,
#             name: str = 'conv',
#             training=True
#            ):
#     in_ch = inputs.get_shape()[-1]
#     bias = batch_norm is False
#     stride = 1 if f == 1 else (1, fstride)
#     fpad = (f - 1) // 2
# #     convpad = (0, fpad)
        
#     if depthwise: groups = min(in_ch, out_ch)
#     else: groups = 1
        
#     if in_ch % groups != 0 or out_ch % groups != 0: groups = 1    
#     if complex_in and groups % 2 == 0: groups //= 2
        
#     convkwargs = {
#         "filters": out_ch//groups,
#         "kernel_size": (k, f),
#         "strides": stride,
#         "use_bias": bias,
#         "padding": 'same',
# #         "name": name,
#     }
    
#     if mode == "normal":
#         if not training:
#             convkwargs = {"filters": out_ch//groups,
#                           "kernel_size": (k, f),
#                           "strides": stride,
#                           "use_bias": bias,
#                           "padding": 'valid'}
#             if fpad>0:
#                 paddings = tf.constant([[0,0],[0,0],[fpad,fpad],[0,0]])
#                 inputs = tf.pad(inputs, paddings, "CONSTANT", constant_values=0)
        
#         if groups>1: conv_out = grouped_conv2d(inputs, groups, convkwargs)
#         else: conv_out = grouped_conv2d(inputs, groups, convkwargs)
        
#     elif mode == "transposed":
#         if groups>1: conv_out = grouped_conv2d_transpose(inputs, groups, convkwargs)
#         else: conv_out = Conv2DTranspose(**convkwargs)(inputs)        
#     else:
#         raise NotImplementedError()
        
#     print(convkwargs)
    
#     if groups>1: 
#         conv_out = Conv2D(out_ch, kernel_size=1, trainable = training,
#                           use_bias=False, name=name + '_1x1')(conv_out)
#     if batch_norm: 
#         conv_out = BatchNormalization()(conv_out)
        
#     if act == 'relu': 
#         conv_out = RELU(conv_out)
#     elif act == 'sigmoid': conv_out = sigmoid(conv_out)
#     else: pass
    
#     if reshape: 
#         conv_out = tf.squeeze(conv_out,-1)
        
#     return conv_out

In [None]:
def convkxf(inputs,
            out_ch: int,
            k: int = 1,
            f: int = 3,
            fstride: int = 2,
            lookahead: int = 0,
            batch_norm: bool = False,
            act = 'relu',
            mode = "normal",
            depthwise: bool = True,
            complex_in: bool = False,
            reshape: bool = False,
            name: str = 'conv',
            training=False,
            infer=False,
            bias: bool = False,
            BN_type = "normal",
            folding=False
           ):
    in_ch = inputs.get_shape()[-1]
#     bias = batch_norm is False
    stride = 1 if f == 1 else (1, fstride)
    fpad = (f - 1) // 2 # freq
        
    if depthwise: groups = gcd(in_ch, out_ch)
    else: groups = 1
        
    if in_ch % groups != 0 or out_ch % groups != 0: groups = 1    
#     if complex_in and groups % 2 == 0: groups //= 2
    
    if mode == "normal":
        convkwargs = {
            "filters": out_ch,
            "groups": groups,
            "kernel_size": (k, f),
            "strides": stride,
            "use_bias": True if (folding and (((groups==1) and not training) or complex_in)) else bias,
            "padding": 'valid',
            "kernel_initializer": 'he_normal',
            "kernel_regularizer": reg,
            "name": name
        }
        print(convkwargs)
        if training:
            paddings = tf.constant([[0,0],[k-1-lookahead,lookahead],[0,0],[0,0]])
            pad_flag = (0, 0, k - 1 - lookahead, lookahead)
            if any(p > 0 for p in pad_flag):
                inputs = tf.pad(inputs, paddings, "CONSTANT", constant_values=0)
            
        if fpad>0:
            paddings = tf.constant([[0,0],[0,0],[fpad,fpad],[0,0]])
            inputs = tf.pad(inputs, paddings, "CONSTANT", constant_values=0)
        conv_out = Conv2D(**convkwargs)(inputs)
        
    elif mode == "transposed":
#         groups=1
        convkwargs = {
            "filters": out_ch//groups,
            "kernel_size": (k, f),
            "strides": stride,
            "use_bias": True if (folding and  (not groups>1 and not training) or complex_in) else bias,
            "padding": 'same',
            "kernel_initializer": 'he_normal',
            "kernel_regularizer": reg,
#             "name": name + 'tradition'
        }
        if groups>1: conv_out = grouped_conv2d_transpose(inputs, groups, convkwargs)
        else: conv_out = Conv2DTranspose(**convkwargs)(inputs) 

    else:
        raise NotImplementedError()
            
    if groups>1: 
        if folding:
            conv_out = Conv2D(out_ch, kernel_size=1,
                          use_bias= True, name=name + '_1x1')(conv_out)
        else:
            conv_out = Conv2D(out_ch, kernel_size=1,
                          use_bias= False, name=name + '_1x1')(conv_out)
        
    name_list = ['conv0_encoder', 'conv1_encoder','conv2_encoder','conv3_encoder',
                 'df_conv0_encoder','df_conv1_encoder','mask_out']
    if batch_norm: 
        if BN_type == "normal":
#             if name in name_list: conv_out = BatchNormalization(name=name+'BN')(conv_out)
#             else: conv_out = BatchNormalization()(conv_out)

            conv_out = BatchNormalization(name = name + 'BatchNorm')(conv_out)
        if BN_type == "range":
            conv_out = RangeBN(filters=out_ch)(conv_out)
    
    if reshape: 
        shape = [tf.shape(conv_out)[l] for l in range(4)]
        conv_out = tf.reshape(conv_out,(shape[0],shape[1],shape[2]))
    
    if act == 'relu': conv_out = relu(conv_out)
    elif act == 'sigmoid': conv_out = sigmoid(conv_out)
    else: pass

    return conv_out

In [None]:
def GroupGRU(inputs, hidden, groups, return_sequences=True, name='GGRU', reshape=True, count=0, training=True, norm=False):
    
    grukwargs = {
        "units": hidden//groups,
        "return_sequences": True,
        "kernel_initializer": 'he_normal',
        "kernel_constraint": constraint, 
        "recurrent_constraint": constraint, 
        "bias_constraint": constraint,
        "kernel_regularizer": reg,
        "recurrent_regularizer": reg
    }
    if groups == 1:
        if not training: 
            gru_layer_output = GRU(**grukwargs, stateful=True, unroll=True, name=name)(inputs)
            if norm: gru_layer_output = BatchNormalization()(gru_layer_output)
        else: 
            gru_layer_output = GRU(**grukwargs, name=name)(inputs)
            if norm: gru_layer_output = BatchNormalization()(gru_layer_output)
    else:
        gru_list = []
        groups_inputs = tf.split(inputs, groups, axis=-1)
        for group_inputs in groups_inputs: # group number
            if not training: 
                gru_tmp = GRU(**grukwargs, stateful=True, unroll=True, name=name + '_' +str(count))(group_inputs)
                if norm: gru_tmp = BatchNormalization()(gru_tmp)
            else: 
                gru_tmp = GRU(**grukwargs, name=name + '_' +str(count))(group_inputs)
                if norm: gru_tmp = BatchNormalization()(gru_tmp)
            gru_list.append(gru_tmp)
            count+=1
        gru_output = tf.stack(gru_list, axis=-1, name=name + '_stack')
        if reshape:
            gru_output = tf.transpose(gru_output, [0, 1, 3, 2], name=name + '_transpose')

        shape = [tf.shape(gru_output)[l] for l in range(4)]
        gru_layer_output = tf.reshape(gru_output, shape = [shape[0], shape[1], shape[-1]*shape[-2]], name=name + '_reshape')
    
    if return_sequences:
        return gru_layer_output
    else:
        return gru_layer_output[:,-1]

In [None]:
def GroupGRULayer(inputs, hidden, groups, num_layer, name='GGRU', add_output=True, training=True, norm=False):
    for i in range(num_layer):
        if i < num_layer-1: reshape = True
        else: reshape = False
        
        inputs = GroupGRU(inputs, hidden, groups, return_sequences=True, name=name + str(i), \
                          reshape=reshape, training=training, norm=norm)
        if add_output:
            if i == 0 : output = inputs
            else: output += inputs
                
        else: output = inputs
        
    return output

In [None]:
def GroupFC(inputs, hidden, groups=8, activation=None, name='GFC', count=0, infer=False, norm=False):
    groups_inputs = tf.split(inputs, groups, axis=-1)
    FC_list = []
    fckwargs = {
    "units": hidden//groups,
    "kernel_initializer": 'he_normal',
    "kernel_constraint": constraint, 
    "bias_constraint": constraint,
    }
    for group_inputs in groups_inputs: # group number
        FC_tmp = Dense(**fckwargs, name=name + '_' +str(count))(group_inputs)
        FC_list.append(FC_tmp)
        count+=1
    FC_output = tf.stack(FC_list, axis=-1, name=name + '_stack')
    rearange = tf.transpose(FC_output, [0, 1, 3, 2], name=name + '_transpose')
    
    shape = [tf.shape(rearange)[l] for l in range(4)]
#     shape = rearange.get_shape().as_list()
    FC_layer_output = tf.reshape(rearange, shape = [shape[0], shape[1], shape[-1]*shape[-2]], name=name + '_reshape')
    
    if norm: FC_layer_output = BatchNormalization(name = name + '_' +str(count) + 'BatchNorm')(FC_layer_output)
        
    if activation == 'relu': FC_layer_output = relu(FC_layer_output)
    elif activation == 'sigmoid': FC_layer_output = sigmoid(FC_layer_output)
    elif activation == 'tanh': FC_layer_output = tanh(FC_layer_output)
    else: FC_layer_output = FC_layer_output
        
    return FC_layer_output

In [None]:
def GroupGRU_lite(inputs, hidden, groups, state=None, return_sequences=True, name='GGRU', reshape=False, 
                  count=0, training=True, norm=False):
    grukwargs = {
        "units": hidden//groups,
        "return_sequences": True,
        "kernel_initializer": 'he_normal',
    }
    if groups == 1:
        gru_output, state_output = GRU(**grukwargs, return_state=True, unroll=True,
                                     name=name)([inputs, state])

    else:
        groups_inputs = tf.split(inputs, groups, axis=-1)
        gru_list, state_list = [], []
        for group_inputs in groups_inputs: # group number
            gru_tmp, gru_state = GRU(**grukwargs, return_state=True, unroll=True,
                                         name=name + '_' +str(count))([group_inputs, state[count]])
                
        gru_list.append(gru_tmp)
        state_list.append(gru_state)
        
        count+=1
        
        gru_output = tf.stack(gru_list, axis=-1, name=name + '_stack')
        state_output = tf.concat(state_list, axis=0)
    
        if reshape:
            gru_output = tf.transpose(gru_output, [0, 1, 3, 2], name=name + '_transpose')

        shape = [tf.shape(gru_output)[l] for l in range(4)]
        gru_output = tf.reshape(gru_output, shape = [shape[0], shape[1], shape[-1]*shape[-2]], name=name + '_reshape')
        
    if norm: gru_output = BatchNormalization()(gru_output)
        
    if return_sequences:
        return gru_output, state_output
    else:
        return gru_output[:,-1], state_output
    
def GroupGRULayer_lite(inputs, hidden, groups, num_layer, state=None, name='GGRU', add_output=True, training=True, norm=False):
    state_list = []
    for i in range(num_layer):
        if i < num_layer-1: reshape = True
        else: reshape = False
        
        inputs, state1 = GroupGRU_lite(inputs, hidden, groups, state=state[i], 
                              return_sequences=True, name=name + str(i), 
                              reshape=reshape, training=training, norm=norm)
        
        state_list.append(state1) 
        
        if add_output:
            if i == 0 : output = inputs
            else: output += inputs
                
        else: output = inputs
        
    return output, state_list

In [None]:
def dense_gru(inputs, previous_state, units, kernel, recurrent, bias):
    kernel_z = kernel[:,:units]
    recurrent_z = recurrent[:,:units]
    bias_zx = bias[:1,:units]
    bias_zh = bias[1:,:units]
    
    kernel_r = kernel[:,units:2*units]
    recurrent_r = recurrent[:,units:2*units]
    bias_rx = bias[:1,units:2*units]
    bias_rh = bias[1:,units:2*units]
    
    kernel_h = kernel[:,2*units:]
    recurrent_h = recurrent[:,2*units:]
    bias_hx = bias[:1,2*units:]
    bias_hh = bias[1:,2*units:]
    
    zt = tf.nn.sigmoid(tf.matmul(inputs, kernel_z) + bias_zx + 
                    tf.matmul(previous_state, recurrent_z) + bias_zh)
    rt = tf.nn.sigmoid(tf.matmul(inputs, kernel_r) + bias_rx +
                    tf.matmul(previous_state, recurrent_r) + bias_rh)
    hhat = tf.nn.tanh(tf.matmul(inputs, kernel_h) + bias_hx +
                      rt * (tf.matmul(previous_state, recurrent_h) + bias_hh))
    current_state = (1-zt) * hhat + zt * previous_state
    return current_state, current_state

In [None]:
def Dense_gru_lite(inputs, previous_state, units, num_layer, kernel, recurrent, bias, add_output=True):
    outputs_list = []
    output_list = []
    state_list = []
    
    
    if num_layer==1:
        state = tf.identity(previous_state)
        for j in range(inputs.shape[1]):
            outputs, state = dense_gru(inputs[:,j:j+1], state, units, kernel, recurrent, bias)
#             previous_state = state

            outputs_list.append(outputs) 
            
        output_stack = tf.concat(outputs_list,axis=1)
        state_list.append(previous_state)
        
        output = output_stack
    else:
        state = [tf.identity(previous_state[k]) for k in range(num_layer)]
        for i in range(num_layer):
            for j in range(inputs.shape[1]):
                outputs, state[i] = dense_gru(inputs[:,j:j+1], state[i], units[i], kernel[i], recurrent[i], bias[i])
#                 previous_state[i] = state

                outputs_list.append(outputs) 
            output_stack = tf.concat(outputs_list,axis=1)
            
            
            if add_output and i>=1: output += output_stack
            else: output = output_stack
                
            outputs_list = []
            inputs = output_stack
        state_list.append(state)
    return output, state_list[0]

In [None]:
class MyGRU(tf.keras.layers.Layer):
    def __init__(self, units, num_layer, kernel, recurrent, bias, BN_weights=None, add_output=True, layernorm=False, trainable=False):
        super(MyGRU, self).__init__()
        self.units = units
        self.num_layer = num_layer
        if num_layer>1:
            self.kernel = [tf.Variable(initial_value=kernel[i], trainable=trainable, dtype=tf.float32) for i in range(num_layer) ]
            self.recurrent = [tf.Variable(initial_value=recurrent[i], trainable=trainable, dtype=tf.float32) for i in range(num_layer) ]
            self.bias = [tf.Variable(initial_value=bias[i], trainable=trainable, dtype=tf.float32) for i in range(num_layer) ]
        else:
            self.kernel = tf.Variable(initial_value=kernel, trainable=trainable, dtype=tf.float32)
            self.recurrent = tf.Variable(initial_value=recurrent, trainable=trainable, dtype=tf.float32)
            self.bias = tf.Variable(initial_value=bias, trainable=trainable, dtype=tf.float32)

        self.add_output = add_output
        self.layernorm = layernorm
        if self.layernorm:
            self.norm_layer = [BatchNormalization(beta_initializer=tf.keras.initializers.Constant(BN_weights[i][1]), 
                                              gamma_initializer=tf.keras.initializers.Constant(BN_weights[i][0])) for i in range(num_layer)]

    def call(self, inputs, previous_state):
        outputs_list, state_list = [], []
        
        if self.num_layer>1:
            state = [tf.identity(previous_state[k]) for k in range(self.num_layer)]
            for i in range(self.num_layer):
                for j in range(inputs.shape[1]):
                    outputs, state[i] = self.dense_gru(inputs[:,j:j+1], state[i], idx=i)
                    if self.layernorm: outputs = self.norm_layer[i](outputs)
                    outputs_list.append(outputs) 
                output_stack = tf.concat(outputs_list,axis=1)
                output_stack = tf.reshape(output_stack, (-1,inputs.shape[1],self.units))
                if self.add_output and i>=1: output += output_stack
                else: output = output_stack
                outputs_list = []
                inputs = output_stack
                state_list.append(state)
            
        else:
            state = tf.identity(previous_state)
            for j in range(inputs.shape[1]):
                outputs, state = self.dense_gru(inputs[:,j:j+1], state)
                if self.layernorm: outputs = self.norm_layer[0](outputs)
                outputs_list.append(outputs) 
            output_stack = tf.concat(outputs_list,axis=1)
            output_stack = tf.reshape(output_stack, (-1,inputs.shape[1],self.units))
            state_list.append(state)
            output = output_stack
            
        return output, state_list[0]

    def dense_gru(self, inputs, previous_state, idx=0):
        if self.num_layer>1:
            kernel_z = self.kernel[idx][:,:self.units]
            recurrent_z = self.recurrent[idx][:,:self.units]
            bias_zx = self.bias[idx][:1,:self.units]
            bias_zh = self.bias[idx][1:,:self.units]

            kernel_r = self.kernel[idx][:,self.units:2*self.units]
            recurrent_r = self.recurrent[idx][:,self.units:2*self.units]
            bias_rx = self.bias[idx][:1,self.units:2*self.units]
            bias_rh = self.bias[idx][1:,self.units:2*self.units]

            kernel_h = self.kernel[idx][:,2*self.units:]
            recurrent_h = self.recurrent[idx][:,2*self.units:]
            bias_hx = self.bias[idx][:1,2*self.units:]
            bias_hh = self.bias[idx][1:,2*self.units:]
        else:
            kernel_z = self.kernel[:,:self.units]
            recurrent_z = self.recurrent[:,:self.units]
            bias_zx = self.bias[:1,:self.units]
            bias_zh = self.bias[1:,:self.units]

            kernel_r = self.kernel[:,self.units:2*self.units]
            recurrent_r = self.recurrent[:,self.units:2*self.units]
            bias_rx = self.bias[:1,self.units:2*self.units]
            bias_rh = self.bias[1:,self.units:2*self.units]

            kernel_h = self.kernel[:,2*self.units:]
            recurrent_h = self.recurrent[:,2*self.units:]
            bias_hx = self.bias[:1,2*self.units:]
            bias_hh = self.bias[1:,2*self.units:]

        zt = sigmoid(tf.matmul(inputs, kernel_z)+ bias_zx + 
                     tf.matmul(previous_state, recurrent_z)+ bias_zh)
        rt = sigmoid(tf.matmul(inputs, kernel_r)+ bias_rx+
                     tf.matmul(previous_state, recurrent_r)+ bias_rh)
        hhat = tanh(tf.matmul(inputs, kernel_h)+ bias_hx+ 
                    rt* (tf.matmul(previous_state, recurrent_h)+ bias_hh))
                    
        current_state = (1-zt) * hhat + zt * previous_state
        return current_state, current_state
    
    def get_config(self):
        config = super(MyGRU, self).get_config()
        config['units'] = self.units
        config['num_layer'] = self.num_layer
        config['add_output'] = self.add_output
        
        config['kernel'] = self.kernel
        config['recurrent'] = self.recurrent
        config['bias'] = self.bias
#         if self.num_layer==1:
#             config.update({"kernel": self.kernel.numpy(),
#                            "recurrent": self.recurrent.numpy(), 
#                            "bias": self.bias.numpy()})
#         else:
#             config.update({"kernel": [self.kernel[i].numpy() for i in range(self.num_layer)],
#                            "recurrent": [self.recurrent[i].numpy() for i in range(self.num_layer)], 
#                            "bias": [self.bias[i].numpy() for i in range(self.num_layer)]})
        return config

In [None]:
import math

In [None]:
class RangeBN(tf.keras.layers.Layer):
    def __init__(self,
                 filters,
                 momentum=0.1,
                 num_chunks=16,
                 eps=1e-6,
                 name=None,
                 **kwargs):
        super(RangeBN, self).__init__(**kwargs)
        """
            Follow the reference https://arxiv.org/abs/1805.11046, Scalable Methods for 8-bit Training of Neural Networks
            With the range of min(x) - max(x) of  input distribution, making more tolerant to quantization.    
            Formula: x = (x - mu)/(C(n) - range(x - mu))
        """
        self.filters = filters
        self.momentum = momentum
        self.num_chunks = num_chunks
        self.eps = eps

    def build(self, input_shape):
        initializer = tf.keras.initializers.RandomUniform(minval=0., maxval=1)
        input_shape = tf.TensorShape(input_shape)
        self.weight = tf.Variable(initial_value=initializer(
            shape=(input_shape[-1], ), dtype=tf.dtypes.float32),
                                  name='weights',
                                  trainable=True)
        self.bias = self.add_weight(
            shape=(input_shape[-1], ),
            initializer=tf.keras.initializers.Constant(0.),
            name='bias',
            trainable=True)
        self.moving_mean = tf.Variable(initial_value=tf.constant(
            0., shape=(self.filters, )),
                                       shape=self.filters,
                                       dtype=tf.dtypes.float32,
                                       name='moving_mean',
                                       trainable=False)

        self.moving_variance = tf.Variable(initial_value=tf.constant(
            0., shape=(self.filters, )),
                                           shape=(self.filters, ),
                                           dtype=tf.dtypes.float32,
                                           name='moving_variance',
                                           trainable=False)
        super(RangeBN, self).build(input_shape)

    def call(self, inputs, training=False):
        if training:
            B, H, W, C = [tf.shape(inputs)[i] for i in range(4)]
            y = tf.transpose(inputs, (3, 0, 1, 2))
            y = tf.reshape(y,
                           [C, self.num_chunks, B * H * W // self.num_chunks])
            mean_max = tf.math.reduce_max(y, [-1])
            mean_max = tf.math.reduce_mean(mean_max, axis=-1)  # C
            mean_min = tf.math.reduce_min(y, [-1])
            mean_min = tf.math.reduce_mean(mean_min, axis=-1)  # C

            mean = tf.math.reduce_mean(tf.reshape(y, (C, -1)), axis=-1)  # C
            B, H, W, C = inputs.get_shape().as_list()
            if B is None:
                B = 1
            upper = (0.5 * 0.35) * (1. + (math.pi * math.log(4))**0.5)
            lower = ((2. * math.log(B * H * W // self.num_chunks))**0.5)
            scale_fix = upper / lower
            scale = 1 / ((mean_max - mean_min) * scale_fix + self.eps)

            self.moving_mean.assign(self.moving_mean * self.momentum +
                                    (mean * (1 - self.momentum)))
            self.moving_variance.assign(self.moving_variance * self.momentum +
                                        (scale * (1 - self.momentum)))
        else:
            mean = self.moving_mean
            scale = self.moving_variance
        out = (inputs - tf.reshape(mean, (1, 1, 1, self.filters))) * \
            tf.reshape(scale, (1, 1, 1,self.filters))
        if self.weight is not None:
            out = out * tf.reshape(self.weight, (1, 1, 1, self.filters))
        if self.bias is not None:
            out = out + tf.reshape(self.bias, (1, 1, 1, self.filters))
        return out