In [1]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

In [2]:
@tf.custom_gradient
def GradientReversalOperator(x):
    def grad(dy):
        return -1 * dy
    return x, grad

class GradientReversalLayer(tf.keras.layers.Layer):
    def __init__(self):
        super(GradientReversalLayer, self).__init__()
        
    def call(self, inputs):
        return GradientReversalOperator(inputs)

In [3]:
class MNIST():
    def __init__(self, input_shape):
        super(MNIST, self).__init__()
        self.feature_extractor = tf.keras.models.Sequential([
            tf.keras.layers.Conv2D(filters=32, kernel_size=5,
                                   strides=1, input_shape=input_shape),
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.Activation('relu'),
            tf.keras.layers.MaxPooling2D(pool_size=2, strides=2),
            tf.keras.layers.Conv2D(filters=48, kernel_size=5, strides=1),
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.Activation('relu'),
            tf.keras.layers.MaxPooling2D(pool_size=2, strides=2),
            tf.keras.layers.Flatten()            
        ])
        
        self.label_predictor = tf.keras.models.Sequential([
            tf.keras.layers.Dense(100),
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.Activation('relu'),
            tf.keras.layers.Dense(100),
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.Activation('relu'),
            tf.keras.layers.Dense(10, activation='softmax')
        ])
        
        self.domain_predictor = tf.keras.models.Sequential([
            GradientReversalLayer(),
            tf.keras.layers.Dense(100),
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.Activation('relu'),
            tf.keras.layers.Dense(2),
            tf.keras.layers.Activation('sigmoid')          
        ])
        self.path_1 = tf.keras.models.Sequential([
            self.feature_extractor,
            self.label_predictor
        ])
        self.path_2 = tf.keras.models.Sequential([
            self.feature_extractor,
            self.domain_predictor
        ])
        
        
        self.loss = tf.keras.losses.SparseCategoricalCrossentropy()
        self.loss_2 = tf.keras.losses.SparseCategoricalCrossentropy()
        self.loss_3 = tf.keras.losses.SparseCategoricalCrossentropy()
        
        self.optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
        self.optimizer_2 = tf.keras.optimizers.Adam(learning_rate=0.0001)
        
        self.train_loss = tf.keras.metrics.Mean()
        self.train_loss_2 = tf.keras.metrics.Mean()
        
        self.train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()
        self.train_accuracy_2 = tf.keras.metrics.SparseCategoricalAccuracy()
        
        
        self.test_loss = tf.keras.metrics.Mean()
        self.test_loss_2 = tf.keras.metrics.Mean()
        self.test_loss_3 = tf.keras.metrics.Mean()
        self.test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()
        self.test_accuracy_2 = tf.keras.metrics.SparseCategoricalAccuracy()
        self.test_accuracy_3 = tf.keras.metrics.SparseCategoricalAccuracy()
    
    @tf.function
    def train_both(self, x_class, y_class, x_domain):
        
        domain_labels = np.concatenate([np.zeros(len(x_class)), np.ones(len(x_domain))])
        
        x_both = tf.concat([x_class, x_domain], axis = 0)
        
        with tf.GradientTape() as tape:
            y_class_pred = self.path_1(x_class)
            loss_1 = self.loss(y_class, y_class_pred)   
        grad_1 = tape.gradient(loss_1, self.path_1.trainable_variables)
        
        with tf.GradientTape() as tape:
            y_domain_pred = self.path_2(x_both)
            loss_2 = self.loss_2(domain_labels, y_domain_pred) 
        grad_2 = tape.gradient(loss_2, self.path_2.trainable_variables)
        
        self.optimizer.apply_gradients(zip(grad_1, self.path_1.trainable_variables))
        self.optimizer_2.apply_gradients(zip(grad_2, self.path_2.trainable_variables))
        self.train_loss(loss_1)
        self.train_accuracy(y_class, y_class_pred)
        
        self.train_loss_2(loss_2)
        self.train_accuracy_2(domain_labels, y_domain_pred)
        
        return
    
    @tf.function
    def test_both(self, x_class, y_class, x_domain, y_domain):
        
        domain_labels = np.concatenate([np.zeros(len(x_class)), np.ones(len(x_domain))])
        
        x_both = tf.concat([x_class, x_domain], axis = 0)
        
        with tf.GradientTape() as tape:
            y_class_pred = self.path_1(x_class)
            y_domain_pred = self.path_2(x_both)
            y_target_class_pred = self.path_1(x_domain)
            
            loss_1 = self.loss(y_class, y_class_pred)
            loss_2 = self.loss_2(domain_labels, y_domain_pred)
            loss_3 = self.loss_3(y_domain, y_target_class_pred)
            
        self.test_loss(loss_1)
        self.test_accuracy(y_class, y_class_pred)
        
        self.test_loss_2(loss_2)
        self.test_accuracy_2(domain_labels, y_domain_pred)
        
        self.test_loss_3(loss_3)
        self.test_accuracy_3(y_domain, y_target_class_pred)
        
        return


In [None]:
class SVHN():
    def __init__(self, input_shape):
        super(SVHN, self).__init__()
        self.feature_extractor = tf.keras.models.Sequential([
            tf.keras.layers.Conv2D(filters=64, kernel_size=5,
                                   strides=1, input_shape=input_shape),
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.Activation('relu'),
            tf.keras.layers.MaxPooling2D(pool_size=3, strides=2),
            tf.keras.layers.Conv2D(filters=64, kernel_size=5, strides=1),
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.Activation('relu'),
            tf.keras.layers.MaxPooling2D(pool_size=3, strides=2),
            tf.keras.layers.Conv2D(filters=128, kernel_size=5, strides=1),
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.Activation('relu'),
            tf.keras.layers.Flatten()            
        ])
        
        self.label_predictor = tf.keras.models.Sequential([
            tf.keras.layers.Dense(3072),
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.Activation('relu'),
            tf.keras.layers.Dense(2048),
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.Activation('relu'),
            tf.keras.layers.Dense(10, activation='softmax')
        ])
        
        self.domain_predictor = tf.keras.models.Sequential([
            GradientReversalLayer(),
            tf.keras.layers.Dense(1024),
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.Activation('relu'),
            tf.keras.layers.Dense(1024),
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.Activation('relu'),
            tf.keras.layers.Dense(2),
            tf.keras.layers.Activation('sigmoid')          
        ])
        
        self.path_1 = tf.keras.models.Sequential([
            self.feature_extractor,
            self.label_predictor
        ])
        
        self.path_2 = tf.keras.models.Sequential([
            self.feature_extractor,
            self.domain_predictor
        ])
        
        self.loss = tf.keras.losses.SparseCategoricalCrossentropy()
        self.loss_2 = tf.keras.losses.SparseCategoricalCrossentropy()
        self.loss_3 = tf.keras.losses.SparseCategoricalCrossentropy()
        
        self.optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
        self.optimizer_2 = tf.keras.optimizers.Adam(learning_rate=0.0001)
        
        self.train_loss = tf.keras.metrics.Mean()
        self.train_loss_2 = tf.keras.metrics.Mean()
        
        self.train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()
        self.train_accuracy_2 = tf.keras.metrics.SparseCategoricalAccuracy()
        
        
        self.test_loss = tf.keras.metrics.Mean()
        self.test_loss_2 = tf.keras.metrics.Mean()
        self.test_loss_3 = tf.keras.metrics.Mean()
        self.test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()
        self.test_accuracy_2 = tf.keras.metrics.SparseCategoricalAccuracy()
        self.test_accuracy_3 = tf.keras.metrics.SparseCategoricalAccuracy()
    
    @tf.function
    def train_both(self, x_class, y_class, x_domain):
        
        domain_labels = np.concatenate([np.zeros(len(x_class)), np.ones(len(x_domain))])
        
        x_both = tf.concat([x_class, x_domain], axis = 0)
        
        with tf.GradientTape() as tape:
            y_class_pred = self.path_1(x_class)
            loss_1 = self.loss(y_class, y_class_pred)   
        grad_1 = tape.gradient(loss_1, self.path_1.trainable_variables)
        
        with tf.GradientTape() as tape:
            y_domain_pred = self.path_2(x_both)
            loss_2 = self.loss_2(domain_labels, y_domain_pred) 
        grad_2 = tape.gradient(loss_2, self.path_2.trainable_variables)
        
        self.optimizer.apply_gradients(zip(grad_1, self.path_1.trainable_variables))
        self.optimizer_2.apply_gradients(zip(grad_2, self.path_2.trainable_variables))
        self.train_loss(loss_1)
        self.train_accuracy(y_class, y_class_pred)
        
        self.train_loss_2(loss_2)
        self.train_accuracy_2(domain_labels, y_domain_pred)
        
        return
    
    @tf.function
    def test_both(self, x_class, y_class, x_domain, y_domain):
        
        domain_labels = np.concatenate([np.zeros(len(x_class)), np.ones(len(x_domain))])
        
        x_both = tf.concat([x_class, x_domain], axis = 0)
        
        with tf.GradientTape() as tape:
            y_class_pred = self.path_1(x_class)
            y_domain_pred = self.path_2(x_both)
            y_target_class_pred = self.path_1(x_domain)
            
            loss_1 = self.loss(y_class, y_class_pred)
            loss_2 = self.loss_2(domain_labels, y_domain_pred)
            loss_3 = self.loss_3(y_domain, y_target_class_pred)
            
        self.test_loss(loss_1)
        self.test_accuracy(y_class, y_class_pred)
        
        self.test_loss_2(loss_2)
        self.test_accuracy_2(domain_labels, y_domain_pred)
        
        self.test_loss_3(loss_3)
        self.test_accuracy_3(y_domain, y_target_class_pred)
        
        return


In [4]:
x_train_mnist = np.load('../data/mnist/x_train.npy')
y_train_mnist = np.load('../data/mnist/y_train.npy')

x_test_mnist = np.load('../data/mnist/x_test.npy')
y_test_mnist = np.load('../data/mnist/y_test.npy')

In [5]:
x_train_svhn = np.load('../data/svhn/x_train.npy')
y_train_svhn = np.load('../data/svhn/y_train.npy')

x_test_svhn = np.load('../data/svhn/x_test.npy')
y_test_svhn = np.load('../data/svhn/y_test.npy')

In [6]:
x_train_mnist, x_test_mnist = x_train_mnist / 255.0, x_test_mnist / 255.0
x_train_svhn, x_test_svhn = x_train_svhn / 255.0, x_test_svhn / 255.0

In [7]:
x_train_mnist = tf.cast(x_train_mnist, tf.float32)
x_test_mnist = tf.cast(x_test_mnist, tf.float32)
x_train_svhn = tf.cast(x_train_svhn, tf.float32)
x_test_svhn = tf.cast(x_test_svhn, tf.float32)

In [8]:
def pad_image(x, y):
    
    paddings = tf.constant([[2, 2,], [2, 2]])
    
    new_x = tf.pad(x, paddings, "CONSTANT")
    
    return (new_x, y)

def duplicate_channel(x, y):

    new_x = tf.stack([x, x, x], axis = -1)
    
    return (new_x, y)

In [9]:
mnist_train_dataset = tf.data.Dataset.from_tensor_slices((x_train_mnist, y_train_mnist))
mnist_train_dataset = mnist_train_dataset.map(pad_image)
mnist_train_dataset = mnist_train_dataset.map(duplicate_channel)
target_train_dataset = mnist_train_dataset.shuffle(len(y_train_mnist))

svhn_train_dataset = tf.data.Dataset.from_tensor_slices((x_train_svhn, y_train_svhn))
source_train_dataset = svhn_train_dataset.shuffle(len(y_train_svhn))

source_train_dataset = source_train_dataset.batch(2000, drop_remainder=True)
source_train_dataset = source_train_dataset.prefetch(50)

target_train_dataset = target_train_dataset.batch(2000, drop_remainder=True)
target_train_dataset = target_train_dataset.prefetch(50)

In [10]:
model = MNIST(input_shape=(32, 32, 3))

In [11]:
EPOCHS = 30

for epoch in range(EPOCHS):
    i = 0
    for (source_images, class_labels), (target_images, _) in zip(source_train_dataset, target_train_dataset):
        model.train_both(source_images, class_labels, target_images)
        print(i)
        i =

    for (test_images, test_labels), (target_images, target_labels) in zip(source_train_dataset, target_train_dataset):
        model.test_both(test_images, test_labels, target_images, target_labels)

    template = 'Epoch: {}\n' + \
    'L1: {:.4f}, Acc1: {:.2f}, L1 Test: {:.4f}, Acc1 Test: {:.2f}\n'+ \
    'L2: {:.4f}, Acc2: {:.2f}, L2 Test: {:.4f}, Acc2 Test: {:.2f}\n'+ \
    'L3 Test: {:.4f}, Acc3 Test: {:.2f}\n'
    
    
    print(template.format(epoch+1,
                         model.train_loss.result(),
                         model.train_accuracy.result()*100,
                         model.test_loss.result(),
                         model.test_accuracy.result()*100,
                         model.train_loss_2.result(),
                         model.train_accuracy_2.result()*100,
                         model.test_loss_2.result(),
                         model.test_accuracy_2.result()*100,
                         model.test_loss_3.result(),
                         model.test_accuracy_3.result()*100))

W0627 18:02:12.008243 21964 deprecation.py:323] From c:\users\jw\appdata\local\programs\python\python36\lib\site-packages\tensorflow\python\ops\math_grad.py:1220: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where


Epoch: 1
L1: 2.2375, Acc1: 18.72, L1 Test: 2.2102, Acc1 Test: 19.45
L2: 0.6681, Acc2: 65.06, L2 Test: 0.6302, Acc2 Test: 67.12
L3 Test: 2.3877, Acc3 Test: 12.26

Epoch: 2
L1: 2.1624, Acc1: 22.43, L1 Test: 2.0289, Acc1 Test: 31.21
L2: 0.5962, Acc2: 71.33, L2 Test: 0.4993, Acc2 Test: 78.63
L3 Test: 2.3771, Acc3 Test: 18.23

Epoch: 3
L1: 1.9542, Acc1: 32.01, L1 Test: 1.7546, Acc1 Test: 42.22
L2: 0.4778, Acc2: 79.50, L2 Test: 0.3797, Acc2 Test: 85.15
L3 Test: 2.7375, Acc3 Test: 22.68

Epoch: 4
L1: 1.7208, Acc1: 41.30, L1 Test: 1.5332, Acc1 Test: 50.45
L2: 0.3842, Acc2: 84.31, L2 Test: 0.3077, Acc2 Test: 88.54
L3 Test: 3.1489, Acc3 Test: 26.85

Epoch: 5
L1: 1.5345, Acc1: 48.49, L1 Test: 1.3685, Acc1 Test: 56.33
L2: 0.3217, Acc2: 87.26, L2 Test: 0.2588, Acc2 Test: 90.67
L3 Test: 3.4101, Acc3 Test: 30.64

Epoch: 6
L1: 1.3904, Acc1: 53.91, L1 Test: 1.2443, Acc1 Test: 60.72
L2: 0.2768, Acc2: 89.27, L2 Test: 0.2239, Acc2 Test: 92.12
L3 Test: 3.5665, Acc3 Test: 33.66

Epoch: 7
L1: 1.2774, Acc1: 5

In [12]:
for epoch in range(EPOCHS):
    for (source_images, class_labels), (target_images, _) in zip(source_train_dataset, target_train_dataset):
        model.train_both(source_images, class_labels, target_images)

    for (test_images, test_labels), (target_images, target_labels) in zip(source_train_dataset, target_train_dataset):
        model.test_both(test_images, test_labels, target_images, target_labels)

    template = 'Epoch: {}\n' + \
    'L1: {:.4f}, Acc1: {:.2f}, L1 Test: {:.4f}, Acc1 Test: {:.2f}\n'+ \
    'L2: {:.4f}, Acc2: {:.2f}, L2 Test: {:.4f}, Acc2 Test: {:.2f}\n'+ \
    'L3 Test: {:.4f}, Acc3 Test: {:.2f}\n'
    
    
    print(template.format(epoch+1,
                         model.train_loss.result(),
                         model.train_accuracy.result()*100,
                         model.test_loss.result(),
                         model.test_accuracy.result()*100,
                         model.train_loss_2.result(),
                         model.train_accuracy_2.result()*100,
                         model.test_loss_2.result(),
                         model.test_accuracy_2.result()*100,
                         model.test_loss_3.result(),
                         model.test_accuracy_3.result()*100))

Epoch: 1
L1: 0.5726, Acc1: 82.48, L1 Test: 0.5353, Acc1 Test: 84.01
L2: 0.0991, Acc2: 96.70, L2 Test: 0.0896, Acc2 Test: 97.21
L3 Test: 3.2361, Acc3 Test: 49.47

Epoch: 2
L1: 0.5629, Acc1: 82.79, L1 Test: 0.5272, Acc1 Test: 84.25
L2: 0.0968, Acc2: 96.78, L2 Test: 0.0874, Acc2 Test: 97.28
L3 Test: 3.2425, Acc3 Test: 49.49

Epoch: 3
L1: 0.5537, Acc1: 83.08, L1 Test: 0.5192, Acc1 Test: 84.49
L2: 0.0945, Acc2: 96.87, L2 Test: 0.0853, Acc2 Test: 97.35
L3 Test: 3.2532, Acc3 Test: 49.52

Epoch: 4
L1: 0.5449, Acc1: 83.36, L1 Test: 0.5113, Acc1 Test: 84.73
L2: 0.0922, Acc2: 96.95, L2 Test: 0.0834, Acc2 Test: 97.41
L3 Test: 3.2750, Acc3 Test: 49.51

Epoch: 5
L1: 0.5366, Acc1: 83.62, L1 Test: 0.5036, Acc1 Test: 84.97
L2: 0.0901, Acc2: 97.02, L2 Test: 0.0816, Acc2 Test: 97.47
L3 Test: 3.2976, Acc3 Test: 49.52

Epoch: 6
L1: 0.5286, Acc1: 83.87, L1 Test: 0.4961, Acc1 Test: 85.20
L2: 0.0880, Acc2: 97.09, L2 Test: 0.0798, Acc2 Test: 97.53
L3 Test: 3.3233, Acc3 Test: 49.49

Epoch: 7
L1: 0.5207, Acc1: 8

KeyboardInterrupt: 