In [1]:
import os
import sys
import h5py
from tqdm.notebook import tqdm

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_datasets as tfds
from tensorflow.keras.losses import CategoricalCrossentropy as CCELoss
from tensorflow.keras.losses import BinaryCrossentropy as BCELoss
from tensorflow.keras.optimizers import Adam
from tensorflow.data import Dataset

import numpy as np
import matplotlib.pyplot as plt

# Backbone

There are 2 Feature Extractor versions. One is from the original paper, however I do not see the purpose of using a Conv1d operation on image data. Therefore I also implemented a very basical Conv2D network with multiple layers.

In [2]:
dropout = 0.5
conv_activation = 'sigmoid'
clf_activation = 'relu'
disc_activation = 'relu'
input_shape = (28, 28, 1)
hidden_len = 512
feature_len = 256
disc_len = 1024
hidden_depth = 10
kernel_size = 3
num_conv_layers = 3
num_classes = 10


class FeatureExtractor(tf.keras.layers.Layer):
    def __init__(self):
        super(FeatureExtractor, self).__init__()
        self.feature_extractor = keras.Sequential([
            layers.Conv2D(filters=16, kernel_size=(2,2), activation='relu'),
            layers.Dropout(dropout),
            layers.MaxPool2D(pool_size=(2,2)),
            layers.Conv2D(filters=32, kernel_size=(2,2), activation='relu'),
            layers.Dropout(dropout),
            layers.MaxPool2D(pool_size=(2,2)),
            layers.Conv2D(filters=64, kernel_size=(2,2), activation='relu'),
            layers.Dropout(dropout),
            layers.MaxPool2D(pool_size=(2,2)),
            layers.Flatten(),
            layers.Dense(256, activation='relu'),
            layers.Dropout(dropout)
        ])
        
    def call(self, x):
        return self.feature_extractor(x)


# class FeatureExtractor(tf.keras.layers.Layer):
#     def __init__(self):
#         super(FeatureExtractor, self).__init__()
#         feature_extractor = keras.Sequential([])
#         feature_extractor.add(layers.Reshape((-1, input_shape[0]**2, 3)))
#         feature_extractor.add(layers.Dense(hidden_len))

#         conv1d_layer = layers.Conv1D(filters=hidden_depth, kernel_size=kernel_size, activation=conv_activation)
#         dropout_layer = layers.Dropout(dropout)
#         for _ in range(num_conv_layers):
#             feature_extractor.add(conv1d_layer)
#             feature_extractor.add(dropout_layer)

#         feature_extractor.add(layers.Flatten())
#         feature_extractor.add(layers.Dense(feature_len, activation=conv_activation))
#         self.feature_extractor = feature_extractor
    
#     def call(self, x):
#         return self.feature_extractor(x)
    
    
class Classifier(tf.keras.layers.Layer):
    def __init__(self):
        super(Classifier, self).__init__()
        self.MLP = keras.Sequential([
            layers.Dense(feature_len, activation=clf_activation),
            layers.Dense(num_classes, activation='softmax')
        ])
        
    def call(self, x):
        return self.MLP(x)
    
    
class Discriminator(tf.keras.layers.Layer):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.MLP = keras.Sequential([
            layers.Dense(disc_len, activation=disc_activation),
            layers.Dense(disc_len, activation=disc_activation),
            layers.Dense(1, activation='sigmoid')
        ])
        
    def call(self, x):
        return self.MLP(x)
    
    
class Backbone(tf.keras.Model):
    def __init__(self):
        super(Backbone, self).__init__()
        self.f = FeatureExtractor()
        self.clf = Classifier()
    
    def call(self, x):
        return self.clf(self.f(x))

# Training Algorithm

The Training is implemented in 2 Stages, like in the paper. There is an Option to use reversed_gradients like in the paper, or do a regular GAN loss as well. 

In [3]:
@tf.custom_gradient
def grad_reverse(x):
    y = tf.identity(x)
    def custom_grad(dy):
        return -dy
    return y, custom_grad

In [1]:
class Trainer(tf.keras.Model):
    def __init__(self, args):
        super(Trainer, self).__init__()
        self.logdir = args.logdir  
        self.writer = tf.summary.create_file_writer(self.logdir)
        self.eval_every = args.eval_every
        self.epochs_stage_1 = args.epochs_stage_1
        self.epochs_stage_2 = args.epochs_stage_2
        
        
        self.model_1 = Backbone()
        self.model_2 = Backbone()
        self.discriminator = Discriminator()
        
        
        self.lr = args.lr
        self.optimizer_pre = Adam(self.lr)
        self.optimizer_gen = Adam(self.lr)
        self.optimizer_disc = Adam(self.lr)
        self.optimizer_gen_and_clf = Adam(self.lr)
        
        self.reverse_gradients = args.reverse_gradients

    
    def train_stage_1(self, train_dataset, eval_dataset):
        self.model_1.trainable = True
        self.model_2.trainable = False
        self.discriminator.trainable = False
        
        
        self.model_1.compile(
            optimizer = self.optimizer_pre,
            loss = CCELoss()
        )
        self.model_1.fit(
            train_dataset,
            validation_data=eval_dataset,
            epochs=self.epochs_stage_1
        )
    
    def train_stage_2(self, train_dataset, eval_dataset):
        model.model_1.trainable = False
        model.model_2.trainable = True
        self.discriminator.trainable = True
                
        def train_one_batch(x_source, y_source, x_target, step, training):
            with tf.GradientTape() as tape_gen, tf.GradientTape() as tape_disc, tf.GradientTape() as tape_gen_and_clf:
                # feature extraction
                f_source_pretrained = self.model_1.f(x_source, training=False)
                f_source_aligned    = self.model_2.f(x_source, training=training)
                f_target            = self.model_2.f(x_target, training=training)
                

                # step b: classification loss
                y_pred = self.model_2.clf(f_source_aligned, training=training)
                L_clf = CCELoss()(y_source, y_pred)

                # step c: consistency loss
                L_c = tf.reduce_mean(tf.abs(f_source_pretrained - f_source_aligned), axis=-1)

                # step d: reverse gradient and adversarial alignment loss
                if self.reverse_gradients:
                    f_source_aligned = grad_reverse(f_source_aligned)
                    f_target         = grad_reverse(f_target)
                    loss_sign = -1
                else:
                    loss_sign = 1
                
                D_source_pretrained = self.discriminator(f_source_pretrained, training=training)
                D_target            = self.discriminator(f_target, training=training)
                
            
                L_d_disc = BCELoss()(tf.ones_like(D_source_pretrained), D_source_pretrained) \
                         + BCELoss()(tf.zeros_like(D_target), D_target) 
                L_d_gen = BCELoss()(tf.zeros_like(D_target), D_target) 
                
                # step e: overall loss and update
                loss_gen_and_clf = tf.reduce_mean(L_clf)
                loss_gen = tf.reduce_mean(-loss_sign*L_d_gen + L_c)
                loss_disc = tf.reduce_mean(loss_sign*L_d_disc)
                
                with self.writer.as_default(step=step):
                    tag = 'train' if training else 'eval'
                    tf.summary.scalar(f"{tag}/L_clf", tf.reduce_mean(L_clf))
                    tf.summary.scalar(f"{tag}/L_c", tf.reduce_mean(L_c))
                    tf.summary.scalar(f"{tag}/L_d", tf.reduce_mean(L_d_gen))
                    tf.summary.scalar(f"{tag}/p(Y=source | x_source)", tf.reduce_mean(D_source_pretrained))
                    tf.summary.scalar(f"{tag}/p(Y=source | x_target)", tf.reduce_mean(D_target))

            
            if training:
                weights_gen_and_clf = self.model_2.trainable_variables 
                gradients_gen_and_clf = tape_gen_and_clf.gradient(loss_gen_and_clf, weights_gen_and_clf)
                self.optimizer_gen_and_clf.apply_gradients(zip(gradients_gen_and_clf, weights_gen_and_clf))

                weights_gen = self.model_2.f.trainable_variables 
                gradients_gen = tape_gen.gradient(loss_gen, weights_gen)
                self.optimizer_gen.apply_gradients(zip(gradients_gen, weights_gen))

                weights_disc = self.discriminator.trainable_variables 
                gradients_disc = tape_disc.gradient(loss_disc, weights_disc)
                self.optimizer_disc.apply_gradients(zip(gradients_disc, weights_disc))
        
        train_batch_count = 0
        eval_batch_count = 0
        for epoch in tqdm(range(self.epochs_stage_2), desc='epochs', leave=True):
            for batch in tqdm(train_dataset, desc='train batches', leave=False):
                x_source, y_source, x_target = batch
                train_one_batch(x_source, y_source, x_target, step=train_batch_count, training=True)
                train_batch_count += 1
                
            if epoch % self.eval_every == 0:
                for step, batch in tqdm(enumerate(eval_dataset), desc='eval batches', leave=False):
                    x_source, y_source, x_target = batch
                    train_one_batch(x_source, y_source, x_target, step=eval_batch_count, training=False)
                    eval_batch_count += 1


NameError: name 'tf' is not defined

# Data Preparation

The datasat preprocessing is copyied from the DANN experiment. If you want to replicate the training, please download the dataset from this link: https://github.com/sghoshjr/tf-dann/releases/download/v1.0.0/mnistm.h5

And set the variable:  "MNIST_M_PATH"

In [5]:
MNIST_M_PATH = 'mnistm.h5'
BATCH_SIZE = 64
CHANNELS = 3
NUM_SAMPLES = 10000
VAL_SET = 0.2


def prepare_data():
    #Load MNIST Data (Source)
    (mnist_train_x, mnist_train_y), (mnist_test_x, mnist_test_y) = tf.keras.datasets.mnist.load_data()

    #Convert to 3 Channel and One_hot labels
    mnist_train_x, mnist_test_x = mnist_train_x.reshape((60000, 28, 28, 1)), mnist_test_x.reshape((10000, 28, 28, 1))
    mnist_train_x, mnist_test_x = mnist_train_x[:NUM_SAMPLES], mnist_test_x[:int(NUM_SAMPLES*VAL_SET)]
    mnist_train_y, mnist_test_y = mnist_train_y[:NUM_SAMPLES], mnist_test_y[:int(NUM_SAMPLES*VAL_SET)]
    mnist_train_x, mnist_test_x = mnist_train_x / 255.0, mnist_test_x / 255.0
    mnist_train_x, mnist_test_x = mnist_train_x.astype('float32'), mnist_test_x.astype('float32')

    mnist_train_x = np.repeat(mnist_train_x, CHANNELS, axis=3)
    mnist_test_x = np.repeat(mnist_test_x, CHANNELS, axis=3)
    mnist_train_y = tf.one_hot(mnist_train_y, depth=10)
    mnist_test_y = tf.one_hot(mnist_test_y, depth=10)



    #Load MNIST-M [Target]

    with h5py.File(MNIST_M_PATH, 'r') as mnist_m:
        mnist_m_train_x, mnist_m_test_x = mnist_m['train']['X'][()], mnist_m['test']['X'][()]

    mnist_m_train_x, mnist_m_test_x = mnist_m_train_x[:NUM_SAMPLES], mnist_m_test_x[:int(NUM_SAMPLES*VAL_SET)]
    mnist_m_train_x, mnist_m_test_x = mnist_m_train_x / 255.0, mnist_m_test_x / 255.0
    mnist_m_train_x, mnist_m_test_x = mnist_m_train_x.astype('float32'), mnist_m_test_x.astype('float32')
    mnist_m_train_y, mnist_m_test_y = mnist_train_y, mnist_test_y

    ds_stage_1_train = Dataset.from_tensor_slices((mnist_train_x, mnist_train_y)).batch(BATCH_SIZE)
    ds_stage_1_test = Dataset.from_tensor_slices((mnist_test_x, mnist_test_y)).batch(BATCH_SIZE)
    ds_stage_2_train = Dataset.from_tensor_slices((mnist_train_x, mnist_train_y, mnist_m_train_x)).batch(BATCH_SIZE)
    ds_stage_2_test = Dataset.from_tensor_slices((mnist_test_x, mnist_test_y, mnist_m_test_x)).batch(BATCH_SIZE)
    
    return ds_stage_1_train, ds_stage_1_test, ds_stage_2_train, ds_stage_2_test

# Training 

In [6]:
class Arguments():
    pass
args = Arguments()
args.lr = 0.01
args.logdir = 'logs'
args.eval_every = 1
args.epochs_stage_1 = 3
args.epochs_stage_2 = 3
args.reverse_gradients = False

In [7]:
if __name__ == "__main__":
    ds_stage_1_train, ds_stage_1_test, ds_stage_2_train, ds_stage_2_test = prepare_data()
    model = Trainer(args)
    model.train_stage_1(ds_stage_1_train, ds_stage_1_test)
    model.train_stage_2(ds_stage_2_train, ds_stage_2_test)

Epoch 1/3
Epoch 2/3
Epoch 3/3


epochs:   0%|          | 0/3 [00:00<?, ?it/s]

train batches:   0%|          | 0/157 [00:00<?, ?it/s]

eval batches: 0it [00:00, ?it/s]

train batches:   0%|          | 0/157 [00:00<?, ?it/s]

eval batches: 0it [00:00, ?it/s]

train batches:   0%|          | 0/157 [00:00<?, ?it/s]

eval batches: 0it [00:00, ?it/s]

# Metrics

In [None]:
!tensorboard --logdir logs

Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
TensorBoard 2.4.0 at http://localhost:6006/ (Press CTRL+C to quit)


# Visualization

In [None]:
num_samples = 6
fig, ax = plt.subplots(num_samples, 2, figsize=(20,20))
batch = next(iter(ds_stage_2_train))
x_source, y_source, x_target = batch
y_pred_pre = model.model_1(x_target)
y_pred = model.model_2(x_target)

L_clf = CCELoss()(y_source, y_pred)
L_clf_pre = CCELoss()(y_source, y_pred_pre)

for i in range(num_samples):
    ax[i, 0].imshow(x_source[i])
    ax[i, 1].imshow(x_target[i])
    ax[i, 1].set_title(f'Prediction: {y_pred.numpy().argmax(axis=-1)[i]}')
plt.tight_layout()
plt.show()