In [34]:
import tensorflow as tf
import numpy as np

In [51]:
class MNIST():
    def __init__(self, input_shape):
        super(MNIST, self).__init__()
        self.feature_extractor = tf.keras.models.Sequential([
            tf.keras.layers.Conv2D(filters=32, kernel_size=5,
                                   strides=1, input_shape=input_shape),
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.Activation('relu'),
            tf.keras.layers.MaxPooling2D(pool_size=2, strides=2),
            tf.keras.layers.Conv2D(filters=48, kernel_size=5, strides=1),
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.Activation('relu'),
            tf.keras.layers.MaxPooling2D(pool_size=2, strides=2),
            tf.keras.layers.Flatten()            
        ])
        
        self.label_predictor = tf.keras.models.Sequential([
            tf.keras.layers.Dense(100),
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.Activation('relu'),
            tf.keras.layers.Dense(100),
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.Activation('relu'),
            tf.keras.layers.Dense(10, activation='softmax')
        ])
        self.domain_predictor = tf.keras.models.Sequential([
            tf.keras.layers.Dense(100),
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.Activation('relu'),
            tf.keras.layers.Dense(1),
            tf.keras.layers.Activation('sigmoid')          
        ])
        self.path_1 = tf.keras.models.Sequential([
            self.feature_extractor,
            self.label_predictor
        ])
        self.path_2 = tf.keras.models.Sequential([
            self.feature_extractor,
            self.label_predictor
        ])
        
        
        self.loss = tf.keras.losses.SparseCategoricalCrossentropy()
        self.loss_2 = tf.keras.losses.SparseCategoricalCrossentropy()
        
        self.optimizer = tf.keras.optimizers.Adam()
        
        self.train_loss = tf.keras.metrics.Mean()
        self.train_loss_2 = tf.keras.metrics.Mean()
        
        self.train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()
        self.train_accuracy_2 = tf.keras.metrics.SparseCategoricalAccuracy()
        
        
        self.test_loss = tf.keras.metrics.Mean()
        self.test_loss_2 = tf.keras.metrics.Mean()
        self.test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()
        self.test_accuracy_2 = tf.keras.metrics.SparseCategoricalAccuracy()

    @tf.function
    def train(self, x_train, y_train):
        with tf.GradientTape() as tape:
            y_pred = self.path_1(x_train)
            loss = self.loss(y_train, y_pred)
        gradients = tape.gradient(loss, self.path_1.trainable_variables)
        self.optimizer.apply_gradients(zip(gradients, self.path_1.trainable_variables))

        self.train_loss(loss)
        self.train_accuracy(y_train, y_pred)
        
        return
    
    @tf.function
    def train_2(self, x_train, y_train):
        with tf.GradientTape() as tape:
            y_pred = self.path_2(x_train)
            loss = self.loss_2(y_train, y_pred)
        gradients = tape.gradient(loss, self.path_2.trainable_variables)
        self.optimizer.apply_gradients(zip(gradients, self.path_2.trainable_variables))

        self.train_loss_2(loss)
        self.train_accuracy_2(y_train, y_pred)
        
        return
    
    @tf.function
    def test(self, x_test, y_test):
        y_pred = self.path_1(x_test)
        loss = self.loss(y_test, y_pred)

        self.test_loss(loss)
        self.test_accuracy(y_test, y_pred)
        
    @tf.function
    def test_2(self, x_test, y_test):
        y_pred = self.path_2(x_test)
        loss = self.loss_2(y_test, y_pred)

        self.test_loss_2(loss)
        self.test_accuracy_2(y_test, y_pred)

In [3]:
mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

x_train = x_train[..., tf.newaxis]
x_test = x_test[..., tf.newaxis]

In [4]:
x_train = x_train - 0.5
x_test = x_test - 0.5

In [5]:
x_train = tf.cast(x_train, tf.float32)
x_test = tf.cast(x_test, tf.float32)
y_train = tf.cast(y_train, tf.float32)
y_test = tf.cast(y_test, tf.float32)

In [32]:
x_train_2 = tf.concat([x_train, x_train + tf.random.normal(x_train.shape, mean=0.0, stddev=.1)], axis=0)
x_test_2 = tf.concat([x_test, x_test + tf.random.normal(x_test.shape, mean=0.0, stddev=.1)], axis=0)

In [40]:
y_train_2 = tf.concat([np.zeros(len(x_train)), np.ones(len(x_train))], axis=0)
y_test_2 = tf.concat([np.zeros(len(x_test)), np.ones(len(x_test))], axis=0)

In [9]:
train_ds = tf.data.Dataset.from_tensor_slices(
    (x_train, y_train)).shuffle(len(x_train)).batch(1000)
test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(1000)

In [43]:
train_ds_2 = tf.data.Dataset.from_tensor_slices(
    (x_train_2, y_train_2)).shuffle(len(x_train_2)).batch(2000)
test_ds_2 = tf.data.Dataset.from_tensor_slices((x_test_2, y_test_2)).batch(2000)

In [52]:
model = MNIST(input_shape=(28, 28, 1))

In [10]:
EPOCHS = 5

for epoch in range(EPOCHS):
    for images, labels in train_ds:
        model.train(images, labels)

    for test_images, test_labels in test_ds:
        model.test(test_images, test_labels)

    template = '에포크: {}, 손실: {}, 정확도: {}, 테스트 손실: {}, 테스트 정확도: {}'
    print(template.format(epoch+1,
                         model.train_loss.result(),
                         model.train_accuracy.result()*100,
                         model.test_loss.result(),
                         model.test_accuracy.result()*100))

에포크: 1, 손실: 0.04913996532559395, 정확도: 98.6513900756836, 테스트 손실: 0.03934966400265694, 테스트 정확도: 98.88333129882812
에포크: 2, 손실: 0.048862822353839874, 정확도: 98.82357025146484, 테스트 손실: 0.03925415873527527, 테스트 정확도: 98.9385757446289
에포크: 3, 손실: 0.04858255386352539, 정확도: 98.95770263671875, 테스트 손실: 0.039157550781965256, 테스트 정확도: 98.97875213623047
에포크: 4, 손실: 0.04830179363489151, 정확도: 99.06388854980469, 테스트 손실: 0.03906082361936569, 테스트 정확도: 99.0111083984375
에포크: 5, 손실: 0.048021625727415085, 정확도: 99.15050506591797, 테스트 손실: 0.038964904844760895, 테스트 정확도: 99.03800201416016


In [75]:
EPOCHS = 100

for epoch in range(EPOCHS):
    for (images, labels), (images_2, labels_2) in zip(train_ds, train_ds_2):
        model.train(images, labels)
        model.train_2(images_2, labels_2)

    for (test_images, test_labels), (test_images_2, test_labels_2) in zip(test_ds, test_ds_2):
        model.test(test_images, test_labels)
        model.test_2(test_images_2, test_labels_2)

    template = 'Epoch: {}\n' + \
    'L1: {:.4f}, Acc1: {:.2f}, L1 Test: {:.4f}, Acc1 Test: {:.2f}\n'+ \
    'L2: {:.4f}, Acc2: {:.2f}, L2 Test: {:.4f}, Acc2 Test: {:.2f}\n'
    
    print(template.format(epoch+1,
                         model.train_loss.result(),
                         model.train_accuracy.result()*100,
                         model.test_loss.result(),
                         model.test_accuracy.result()*100,
                         model.train_loss_2.result(),
                         model.train_accuracy_2.result()*100,
                         model.test_loss_2.result(),
                         model.test_accuracy_2.result()*100))

Epoch: 1
L1: 0.5665, Acc1: 70.70, L1 Test: 0.4853, Acc1 Test: 91.05
L2: 0.5107, Acc2: 56.74, L2 Test: 0.5309, Acc2 Test: 58.03
Epoch: 2
L1: 0.5619, Acc1: 71.26, L1 Test: 0.4815, Acc1 Test: 91.31
L2: 0.5104, Acc2: 56.69, L2 Test: 0.5292, Acc2 Test: 57.96
Epoch: 3
L1: 0.5574, Acc1: 71.80, L1 Test: 0.4780, Acc1 Test: 91.55
L2: 0.5100, Acc2: 56.65, L2 Test: 0.5278, Acc2 Test: 57.89
Epoch: 4
L1: 0.5670, Acc1: 72.21, L1 Test: 0.5172, Acc1 Test: 89.44
L2: 0.5278, Acc2: 56.32, L2 Test: 0.5650, Acc2 Test: 56.68
Epoch: 5
L1: 0.5745, Acc1: 72.36, L1 Test: 0.5260, Acc1 Test: 89.40
L2: 0.6057, Acc2: 52.62, L2 Test: 0.5958, Acc2 Test: 54.93
Epoch: 6
L1: 0.5785, Acc1: 72.70, L1 Test: 0.5329, Acc1 Test: 89.35
L2: 0.6638, Acc2: 49.42, L2 Test: 0.6159, Acc2 Test: 53.88
Epoch: 7
L1: 0.5783, Acc1: 72.97, L1 Test: 0.5309, Acc1 Test: 89.44
L2: 0.6686, Acc2: 49.70, L2 Test: 0.6156, Acc2 Test: 54.01
Epoch: 8
L1: 0.5753, Acc1: 73.34, L1 Test: 0.5270, Acc1 Test: 89.62
L2: 0.6586, Acc2: 50.33, L2 Test: 0.6120, A

KeyboardInterrupt: 

In [None]:
class SVHN():

In [None]:
class GTSRB():