In [1]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
%matplotlib inline

In [2]:
np.random.seed(123)

In [3]:
mnist = keras.datasets.mnist
(X_train_full, y_train_full), (X_test, y_test) = mnist.load_data()

In [21]:
X_valid, X_train = X_train_full[:10000] / 255.0, X_train_full[10000:] / 255.0
y_valid, y_train = y_train_full[:10000], y_train_full[10000:]

In [49]:
class Denoising(keras.layers.Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.hidden_bn = keras.layers.BatchNormalization()
        
    def build(self, batch_input_shape):
        print("Denoising.build(): batch_input_shape = ", batch_input_shape)
        z_tilda, _u = batch_input_shape
        units = z_tilda[-1]
        print("Denoising.build(): batch_input_shape = ", batch_input_shape, "units = ", units)
        self.a_1 = self.add_weight(name = "a_1", shape = [units], initializer = "lecun_normal")
        self.a_2 = self.add_weight(name = "a_2", shape = [units], initializer = "lecun_normal")
        self.a_3 = self.add_weight(name = "a_3", shape = [units], initializer = "lecun_normal")
        self.a_4 = self.add_weight(name = "a_4", shape = [units], initializer = "lecun_normal")
        self.a_5 = self.add_weight(name = "a_5", shape = [units], initializer = "lecun_normal")
        self.a_6 = self.add_weight(name = "a_6", shape = [units], initializer = "lecun_normal")
        self.a_7 = self.add_weight(name = "a_7", shape = [units], initializer = "lecun_normal")
        self.a_8 = self.add_weight(name = "a_8", shape = [units], initializer = "lecun_normal")
        self.a_9 = self.add_weight(name = "a_9", shape = [units], initializer = "lecun_normal")
        self.a_10 = self.add_weight(name = "a_10", shape = [units], initializer = "lecun_normal")        
        super().build(batch_input_shape)
        
    def call(self, inputs):
        print("Denoising.call(): inputs = ", inputs, "self.a_1 = ", self.a_1)
        z_tilda, _u = inputs
        u = self.hidden_bn(_u)
        mu = tf.math.multiply(self.a_1, tf.math.multiply(self.a_2, u) + self.a_3) + tf.math.multiply(self.a_4, u) + self.a_5
        v = tf.math.multiply(self.a_6, tf.math.multiply(self.a_7, u) + self.a_8) + tf.math.multiply(self.a_9, u) + self.a_10
        z_hat = tf.math.multiply(z_tilda - mu, v) + mu
        
        return z_hat

In [144]:
class MyClassifier4(keras.Model):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.hidden = [
            keras.layers.Flatten(input_shape = [28, 28]),
            keras.layers.Dense(1000, kernel_initializer = "lecun_normal"),
            keras.layers.Dense(500, kernel_initializer = "lecun_normal"),
            keras.layers.Dense(250, kernel_initializer = "lecun_normal"),
            keras.layers.Dense(250, kernel_initializer = "lecun_normal"),
            keras.layers.Dense(250, kernel_initializer = "lecun_normal"),
        ]
        self.bn = [keras.layers.BatchNormalization() for _ in range(len(self.hidden))]
        self.out =  keras.layers.Dense(10, kernel_initializer = "lecun_normal")

        self.hidden_corrupted = [
            keras.layers.Flatten(input_shape = [28, 28]),
            keras.layers.Dense(1000, kernel_initializer = "lecun_normal"),
            keras.layers.Dense(500, kernel_initializer = "lecun_normal"),
            keras.layers.Dense(250, kernel_initializer = "lecun_normal"),
            keras.layers.Dense(250, kernel_initializer = "lecun_normal"),
            keras.layers.Dense(250, kernel_initializer = "lecun_normal"),
        ]
        self.bn_corrupted = [keras.layers.BatchNormalization() for _ in range(len(self.hidden))]
        self.out_corrupted =  keras.layers.Dense(10, kernel_initializer = "lecun_normal")
        
        self.noise = keras.layers.GaussianNoise(stddev = 0.3)
        
        self.hidden_decoder = [
            keras.layers.Dense(784, kernel_initializer = "lecun_normal"),
            keras.layers.Dense(1000, kernel_initializer = "lecun_normal"),
            keras.layers.Dense(500, kernel_initializer = "lecun_normal"),
            keras.layers.Dense(250, kernel_initializer = "lecun_normal"),
            keras.layers.Dense(250, kernel_initializer = "lecun_normal"),
            keras.layers.Dense(250, kernel_initializer = "lecun_normal")
        ]
        self.hidden_denoising = [Denoising() for _ in range(len(self.hidden_decoder))]
        self.in_decoder = Denoising()
        
    def build(self, batch_input_shape):
        super().build(batch_input_shape)
        
    def call(self, inputs):
        # Corrupted encoder and classifier
        z_tilda = []
        h_tilda = []    
        for i, layer in enumerate(self.hidden_corrupted):
            if i > 0:
                bn = self.bn_corrupted[i - 1]
                z_tilda.append(self.noise(bn(layer(h_tilda[-1]))))
                h_tilda.append(tf.nn.relu(z_tilda[-1]))
            else:
                z_tilda.append(self.noise(layer(inputs)))
                h_tilda.append(z_tilda[-1])
                
        layer = self.out_corrupted
        bn = self.bn_corrupted[-1]
        z_tilda.append(bn(layer(h_tilda[-1])))
        h_tilda.append(tf.nn.softmax(z_tilda[-1]))
        
        # Clean encoder (for denoising targets)
        z = []
        h = []
        mu = []
        sigma = []
        for i, layer in enumerate(self.hidden):
            if i > 0:
                bn = self.bn[i - 1]
                z.append(bn(layer(h[-1])))
                h.append(tf.nn.relu(z[-1]))
                mu.append(bn.weights[2])
                sigma.append(bn.weights[3])
            else:
                z.append(layer(inputs))
                h.append(z[-1])
                mu.append(np.zeros(h[-1].shape[-1]))
                sigma.append(np.ones(h[-1].shape[-1]))
        
        layer = self.out
        bn = self.bn[-1]
        z.append(bn(layer(h[-1])))
        h.append(tf.nn.softmax(z[-1]))
        mu.append(bn.weights[2])
        sigma.append(bn.weights[3])        

        # Decoder and denoising
        z_hat = [None for _ in range(len(h))]
        z_hat_BN = [None for _ in range(len(h))]
        
        for l in reversed(range(len(h))):
            if l == len(h) - 1:
                z_hat[l] = self.in_decoder((z_tilda[l], h[l]))           
            else:
                _u = self.hidden_decoder[l](z_hat[l + 1])
                z_hat[l] = self.hidden_denoising[l]((z_tilda[l], _u))
        #    
            z_hat_BN[l] = (z_hat[l] - mu[l]) / sigma[l]
        
       
        # Cost function C for training
        #print("z = ", z, "z_hat = ", z_hat)
        #print("z - z_hat = ", tf.math.subtract(z, z_hat_BN))
        #for i in range(len(z)):
            #print("i = ", i, "z[i] = ", z[i], "z_hat_BN = ", z_hat_BN[i])
            #print("z - z_hat = ", tf.reduce_sum(tf.math.square(z[i] - z_hat_BN[i]))) # tf.math.subtract(z[i], z_hat_BN[i]))
            #fn = lambda a, b: tf.reduce_sum(tf.math.square(a - b))
            #self.add_loss(fn(z[i], z_hat_BN[i]))
            
        #self.add_loss(lambda: 0.1)
        #self.add_loss(tf.reduce_mean(err)) #[1000, 10, 0.1, 0.1, 0.1, 0.1, 0.1]))) #lambda_lはベット定義
        
        
        #return h[-1]
        return [h[-1]] + z + z_hat_BN

In [145]:
def lossfunc_ladder_supervised(y_true, y_pred):
    print("lossfunc_ladder_supervised: y_true = ", y_true, ", y_pred = ", y_pred)
    h, z, z_hat_BN = y_pred
    loss = keras.losses.sparse_categorical_crossentropy(y_true, y_pred)
    return error

In [146]:
my4_model = MyClassifier4()
#my4_model.compile(loss = lossfunc_ladder_supervised, optimizer = "Adam", metrics = ["accuracy"])
my4_model.compile(loss = "sparse_categorical_crossentropy", optimizer = "Adam", metrics = ["accuracy"])
history = my4_model.fit(X_train, y_train, epochs = 1, validation_data = (X_valid, y_valid))

Denoising.build(): batch_input_shape =  (TensorShape([None, 10]), TensorShape([None, 10]))
Denoising.build(): batch_input_shape =  (TensorShape([None, 10]), TensorShape([None, 10])) units =  10
Denoising.call(): inputs =  (<tf.Tensor 'batch_normalization_1035/batchnorm/add_1:0' shape=(None, 10) dtype=float32>, <tf.Tensor 'Softmax_1:0' shape=(None, 10) dtype=float32>) self.a_1 =  <tf.Variable 'denoising_377/a_1:0' shape=(10,) dtype=float32>
Denoising.build(): batch_input_shape =  (TensorShape([None, 250]), TensorShape([None, 250]))
Denoising.build(): batch_input_shape =  (TensorShape([None, 250]), TensorShape([None, 250])) units =  250
Denoising.call(): inputs =  (<tf.Tensor 'gaussian_noise_55/add_5:0' shape=(None, 250) dtype=float32>, <tf.Tensor 'dense_1007/BiasAdd:0' shape=(None, 250) dtype=float32>) self.a_1 =  <tf.Variable 'denoising_376/a_1:0' shape=(250,) dtype=float32>
Denoising.build(): batch_input_shape =  (TensorShape([None, 250]), TensorShape([None, 250]))
Denoising.build(): 

Denoising.call(): inputs =  (<tf.Tensor 'my_classifier4_55/batch_normalization_1034/batchnorm/add_1:0' shape=(None, 250) dtype=float32>, <tf.Tensor 'my_classifier4_55/dense_1007/BiasAdd:0' shape=(None, 250) dtype=float32>) self.a_1 =  <tf.Variable 'denoising_376/a_1:0' shape=(250,) dtype=float32>
Denoising.call(): inputs =  (<tf.Tensor 'my_classifier4_55/batch_normalization_1033/batchnorm/add_1:0' shape=(None, 250) dtype=float32>, <tf.Tensor 'my_classifier4_55/dense_1006/BiasAdd:0' shape=(None, 250) dtype=float32>) self.a_1 =  <tf.Variable 'denoising_375/a_1:0' shape=(250,) dtype=float32>
Denoising.call(): inputs =  (<tf.Tensor 'my_classifier4_55/batch_normalization_1032/batchnorm/add_1:0' shape=(None, 250) dtype=float32>, <tf.Tensor 'my_classifier4_55/dense_1005/BiasAdd:0' shape=(None, 250) dtype=float32>) self.a_1 =  <tf.Variable 'denoising_374/a_1:0' shape=(250,) dtype=float32>
Denoising.call(): inputs =  (<tf.Tensor 'my_classifier4_55/batch_normalization_1031/batchnorm/add_1:0' sha