In [1]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

In [19]:
@tf.custom_gradient
def GradientReversalOperator(x):
    def grad(dy):
        return -1 * dy
    return x, grad

class GradientReversalLayer(tf.keras.layers.Layer):
    def __init__(self):
        super(GradientReversalLayer, self).__init__()
        
    def call(self, inputs):
        return GradientReversalOperator(inputs)

In [2]:
class MNIST():
    def __init__(self, input_shape):
        super(MNIST, self).__init__()
        self.feature_extractor = tf.keras.models.Sequential([
            tf.keras.layers.Conv2D(filters=32, kernel_size=5,
                                   strides=1, input_shape=input_shape),
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.Activation('relu'),
            tf.keras.layers.MaxPooling2D(pool_size=2, strides=2),
            tf.keras.layers.Conv2D(filters=48, kernel_size=5, strides=1),
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.Activation('relu'),
            tf.keras.layers.MaxPooling2D(pool_size=2, strides=2),
            tf.keras.layers.Flatten()            
        ])
        
        self.label_predictor = tf.keras.models.Sequential([
            tf.keras.layers.Dense(100),
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.Activation('relu'),
            tf.keras.layers.Dense(100),
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.Activation('relu'),
            tf.keras.layers.Dense(10, activation='softmax')
        ])
        self.domain_predictor = tf.keras.models.Sequential([
            GradientReversalLayerientTape(),
            tf.keras.layers.Dense(100),
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.Activation('relu'),
            tf.keras.layers.Dense(1),
            tf.keras.layers.Activation('sigmoid')          
        ])
        self.path_1 = tf.keras.models.Sequential([
            self.feature_extractor,
            self.label_predictor
        ])
        self.path_2 = tf.keras.models.Sequential([
            self.feature_extractor,
            self.label_predictor
        ])
        
        
        self.loss = tf.keras.losses.SparseCategoricalCrossentropy()
        self.loss_2 = tf.keras.losses.SparseCategoricalCrossentropy()
        self.loss_3 = tf.keras.losses.SparseCategoricalCrossentropy()
        
        self.optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
        self.optimizer_2 = tf.keras.optimizers.Adam(learning_rate=0.001)
        
        self.train_loss = tf.keras.metrics.Mean()
        self.train_loss_2 = tf.keras.metrics.Mean()
        
        self.train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()
        self.train_accuracy_2 = tf.keras.metrics.SparseCategoricalAccuracy()
        
        
        self.test_loss = tf.keras.metrics.Mean()
        self.test_loss_2 = tf.keras.metrics.Mean()
        self.test_loss_3 = tf.keras.metrics.Mean()
        self.test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()
        self.test_accuracy_2 = tf.keras.metrics.SparseCategoricalAccuracy()
        self.test_accuracy_3 = tf.keras.metrics.SparseCategoricalAccuracy()

    @tf.function
    def train(self, x_train, y_train):
        with tf.GradientTape() as tape:
            y_pred = self.path_1(x_train)
            loss = self.loss(y_train, y_pred)
        gradients = tape.gradient(loss, self.path_1.trainable_variables)
        self.optimizer.apply_gradients(zip(gradients, self.path_1.trainable_variables))

        self.train_loss(loss)
        self.train_accuracy(y_train, y_pred)
        
        return
    
    @tf.function
    def train_2(self, x_train, y_train):
        with tf.GradientTape() as tape:
            y_pred = self.path_2(x_train)
            loss = self.loss_2(y_train, y_pred)
        gradients = tape.gradient(loss, self.path_2.trainable_variables)
        self.optimizer_2.apply_gradients(zip(gradients, self.path_2.trainable_variables))

        self.train_loss_2(loss)
        self.train_accuracy_2(y_train, y_pred)
        
        return
    
    @tf.function
    def train_both(self, x_class, y_class, x_domain):
        
        domain_labels = np.concatenate([np.zeros(len(x_class)), np.ones(len(x_domain))])
        
        x_both = tf.concat([x_class, x_domain], axis = 0)
        
        with tf.GradientTape() as tape:
            y_class_pred = self.path_1(x_class)
            loss_1 = self.loss(y_class, y_class_pred)   
        grad_1 = tape.gradient(loss_1, self.path_1.trainable_variables)
        
        with tf.GradientTape() as tape:
            y_domain_pred = self.path_2(x_both)
            loss_2 = self.loss_2(domain_labels, y_domain_pred) 
        grad_2 = tape.gradient(loss_2, self.path_2.trainable_variables)
        
        self.optimizer.apply_gradients(zip(grad_1, self.path_1.trainable_variables))
        self.optimizer_2.apply_gradients(zip(grad_2, self.path_2.trainable_variables))
        self.train_loss(loss_1)
        self.train_accuracy(y_class, y_class_pred)
        
        self.train_loss_2(loss_2)
        self.train_accuracy_2(domain_labels, y_domain_pred)
        
        return
    
    @tf.function
    def test_both(self, x_class, y_class, x_domain, y_domain):
        
        domain_labels = np.concatenate([np.zeros(len(x_class)), np.ones(len(x_domain))])
        
        x_both = tf.concat([x_class, x_domain], axis = 0)
        
        with tf.GradientTape() as tape:
            y_class_pred = self.path_1(x_class)
            y_domain_pred = self.path_2(x_both)
            y_target_class_pred = self.path_1(x_domain)
            
            loss_1 = self.loss(y_class, y_class_pred)
            loss_2 = self.loss_2(domain_labels, y_domain_pred)
            loss_3 = self.loss_3(y_domain, y_target_class_pred)
            
        self.test_loss(loss_1)
        self.test_accuracy(y_class, y_class_pred)
        
        self.test_loss_2(loss_2)
        self.test_accuracy_2(domain_labels, y_domain_pred)
        
        self.test_loss_3(loss_3)
        self.test_accuracy_3(y_domain, y_target_class_pred)
        
        return

    
    @tf.function
    def test(self, x_test, y_test):
        y_pred = self.path_1(x_test)
        loss = self.loss(y_test, y_pred)

        self.test_loss(loss)
        self.test_accuracy(y_test, y_pred)
        
    @tf.function
    def test_2(self, x_test, y_test):
        y_pred = self.path_2(x_test)
        loss = self.loss_2(y_test, y_pred)

        self.test_loss_2(loss)
        self.test_accuracy_2(y_test, y_pred)

In [3]:
x_train_mnist = np.load('../data/mnist/x_train.npy')
y_train_mnist = np.load('../data/mnist/y_train.npy')

x_test_mnist = np.load('../data/mnist/x_test.npy')
y_test_mnist = np.load('../data/mnist/y_test.npy')

In [4]:
x_train_svhn = np.load('../data/svhn/x_train.npy')
y_train_svhn = np.load('../data/svhn/y_train.npy')

x_test_svhn = np.load('../data/svhn/x_test.npy')
y_test_svhn = np.load('../data/svhn/y_test.npy')

In [5]:
x_train_mnist, x_test_mnist = x_train_mnist / 255.0, x_test_mnist / 255.0
x_train_svhn, x_test_svhn = x_train_svhn / 255.0, x_test_svhn / 255.0

In [6]:
x_train_mnist = tf.cast(x_train_mnist, tf.float32)
x_test_mnist = tf.cast(x_test_mnist, tf.float32)
x_train_svhn = tf.cast(x_train_svhn, tf.float32)
x_test_svhn = tf.cast(x_test_svhn, tf.float32)

In [7]:
def pad_image(x, y):
    
    paddings = tf.constant([[2, 2,], [2, 2]])
    
    new_x = tf.pad(x, paddings, "CONSTANT")
    
    return (new_x, y)

def duplicate_channel(x, y):

    new_x = tf.stack([x, x, x], axis = -1)
    
    return (new_x, y)

In [8]:
mnist_train_dataset = tf.data.Dataset.from_tensor_slices((x_train_mnist, y_train_mnist))
mnist_train_dataset = mnist_train_dataset.map(pad_image)
source_train_dataset = mnist_train_dataset.map(duplicate_channel)

source_train_dataset = source_train_dataset.shuffle(len(y_train_mnist))
source_train_dataset = source_train_dataset.batch(1000, drop_remainder=True)
source_train_dataset = source_train_dataset.prefetch(5)


svhn_train_dataset = tf.data.Dataset.from_tensor_slices((x_train_svhn, y_train_svhn))

target_train_dataset = svhn_train_dataset.shuffle(len(y_train_svhn))
target_train_dataset = target_train_dataset.batch(1000, drop_remainder=True)
target_train_dataset = target_train_dataset.prefetch(5)

In [9]:
model = MNIST(input_shape=(32, 32, 3))

In [10]:
EPOCHS = 10

for epoch in range(EPOCHS):
    for (source_images, class_labels), (target_images, _) in zip(source_train_dataset, target_train_dataset):
        model.train_both(source_images, class_labels, target_images)

    for (test_images, test_labels), (target_images, target_labels) in zip(source_train_dataset, target_train_dataset):
        model.test_both(test_images, test_labels, target_images, target_labels)

    template = 'Epoch: {}\n' + \
    'L1: {:.4f}, Acc1: {:.2f}, L1 Test: {:.4f}, Acc1 Test: {:.2f}\n'+ \
    'L2: {:.4f}, Acc2: {:.2f}, L2 Test: {:.4f}, Acc2 Test: {:.2f}\n'+ \
    'L3 Test: {:.4f}, Acc3 Test: {:.2f}\n'
    
    
    print(template.format(epoch+1,
                         model.train_loss.result(),
                         model.train_accuracy.result()*100,
                         model.test_loss.result(),
                         model.test_accuracy.result()*100,
                         model.train_loss_2.result(),
                         model.train_accuracy_2.result()*100,
                         model.test_loss_2.result(),
                         model.test_accuracy_2.result()*100,
                         model.test_loss_3.result(),
                         model.test_accuracy_3.result()*100))

Epoch: 1
L1: 59.8112, Acc1: 9.81, L1 Test: 87.1789, Acc1 Test: 9.87
L2: 0.1386, Acc2: 96.35, L2 Test: 0.0001, Acc2 Test: 100.00
L3 Test: 95.4398, Acc3 Test: 18.91

Epoch: 2
L1: 73.4325, Acc1: 9.83, L1 Test: 88.1163, Acc1 Test: 9.87
L2: 0.0696, Acc2: 98.16, L2 Test: 0.0003, Acc2 Test: 99.99
L3 Test: 91.5169, Acc3 Test: 18.90

Epoch: 3
L1: 76.9243, Acc1: 9.85, L1 Test: 86.3621, Acc1 Test: 9.87
L2: 0.0464, Acc2: 98.77, L2 Test: 0.0002, Acc2 Test: 99.99
L3 Test: 90.6967, Acc3 Test: 18.90

Epoch: 4
L1: 78.4236, Acc1: 9.85, L1 Test: 85.0088, Acc1 Test: 9.87
L2: 0.0349, Acc2: 99.08, L2 Test: 0.0002, Acc2 Test: 99.99
L3 Test: 90.3088, Acc3 Test: 18.90

Epoch: 5
L1: 79.1191, Acc1: 9.86, L1 Test: 84.1450, Acc1 Test: 9.87
L2: 0.0279, Acc2: 99.26, L2 Test: 0.0001, Acc2 Test: 100.00
L3 Test: 89.9081, Acc3 Test: 18.91

Epoch: 6
L1: 79.4448, Acc1: 9.86, L1 Test: 83.3351, Acc1 Test: 9.87
L2: 0.0233, Acc2: 99.39, L2 Test: 0.0001, Acc2 Test: 100.00
L3 Test: 89.6105, Acc3 Test: 18.91

Epoch: 7
L1: 79.560