In [1]:
%matplotlib inline
import matplotlib.pyplot as plt

from __future__ import print_function
from collections import OrderedDict

import os
import sys
import timeit
import pickle

import scipy.io as sio
import numpy as np
import tensorflow as tf

sys.path.insert(0, "/home/jay/DSN/tensorflow/ADSA/Model/")
import cnnet as nn
import util
import read_data as read_data
import ADSA_DNN as ADSA
import DataPackage as dp
import AmazonReviewsFeaturePlot as fp

In [2]:
option='Amazon'
source_data, target_data, max_feature = read_data.read(option)
train_ftd_source, train_labeld_source = source_data[0]
valid_ftd_source, valid_labeld_source = source_data[1]
test_ftd_source, test_labeld_source = source_data[2]
    
train_ftd_target, train_labeld_target = target_data[0]
valid_ftd_target, valid_labeld_target = target_data[1]
test_ftd_target, test_labeld_target = target_data[2]

batch_size=128
x_dim=max_feature
y_dim=2
# Model construction
struct=ADSA.ADSA_struct()
struct.FE_struct.layer_struct=[max_feature,400,100]
struct.FE_struct.activation=[tf.nn.relu, tf.nn.relu]
struct.PSFE_struct.layer_struct=[max_feature,400,100]
struct.PSFE_struct.activation=[tf.nn.relu, tf.nn.relu]
struct.PTFE_struct.layer_struct=[max_feature,400,100]
struct.PTFE_struct.activation=[tf.nn.relu, tf.nn.relu]
struct.TC_struct.layer_struct=[100,50, 2]
struct.TC_struct.activation=[tf.nn.relu, 'linear']
struct.DC_struct.layer_struct=[100, 50, 2]
struct.DC_struct.activation=[tf.nn.relu, 'linear']
struct.Sep_struct.layer_struct=[100, 50, 3]
struct.Sep_struct.activation=[tf.nn.relu, 'linear']
struct.batch_size=batch_size

description='Amazon_ADSA'

In [3]:
# Build the model graph
graph=tf.get_default_graph()
with graph.as_default():
    model = ADSA.ADSAModel(struct, x_dim, y_dim)
    learning_rate= tf.placeholder(tf.float32, [])
    
    pred_loss = tf.reduce_mean(model.pred_loss)
    domain_loss = tf.reduce_mean(model.domain_adapt_loss)
    sep_loss=tf.reduce_mean(model.domain_sep_loss)
    total_loss = pred_loss - domain_loss+0.5*sep_loss
    # optimizer
    dann_train_op = tf.train.MomentumOptimizer(learning_rate, 0.9).minimize(total_loss)
    
    domain_adapt_classifier_vars=tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,'domain_predict')
    adapt_opt=tf.train.MomentumOptimizer(learning_rate, 0.009).minimize(domain_loss, var_list=domain_adapt_classifier_vars)
    
    target_FE_vars=tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,'shared_feature_extractor')
    gen_opt=tf.train.MomentumOptimizer(learning_rate,0.009).minimize(-domain_loss, var_list=target_FE_vars)
    
    domain_sep_classifier_vars=tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,'domain_sep_classifier')
    sep_opt=tf.train.MomentumOptimizer(learning_rate,0.009).minimize(sep_loss,var_list=domain_sep_classifier_vars)
    
    source_private_vars=tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'source_private_feature_extractor')
    target_private_vars=tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'target_private_feature_extractor')
    sep_encoder_vars=target_FE_vars+source_private_vars+target_private_vars
    sep_gen_opt=tf.train.MomentumOptimizer(learning_rate,0.009).minimize(sep_loss,var_list=sep_encoder_vars)
    
    # Accuracy evaluation
    correct_label_pred = tf.equal(tf.argmax(model.classify_labels, 1), tf.argmax(model.pred, 1))
    label_acc = tf.reduce_mean(tf.cast(correct_label_pred, tf.float32))
    correct_domain_pred = tf.equal(tf.argmax(model.domain_labels, 1), tf.argmax(model.domain_adapt_pred, 1))
    domain_acc = tf.reduce_mean(tf.cast(correct_domain_pred, tf.float32))
    correct_sep_pred=tf.equal(tf.argmax(model.domain_sep_labels, 1), tf.argmax(model.domain_sep_pred, 1))
    sep_acc=tf.reduce_mean(tf.cast(correct_sep_pred, tf.float32))

Shared_Feature_extractor is constructed with hidden layer number 1
Source_private_feature_extractor is constructed with hidden layer number 1
Target_private_feature_extractor is constructed with hidden layer number 1
Task_classifier is constructed with hidden layer number 1
Domain_adapt_classifier is constructed with hidden layer number 1
Domain_sep_classifier is constructed with hidden layer number 1


In [4]:
num_epoch=2000
def train_and_evaluate(training_mode, graph, model, verbose=False):
    
    with tf.Session(graph=graph) as sess:
        tf.global_variables_initializer().run()
        
        # Batch generators
        t_source_batch=util.batch_generator([train_ftd_source, train_labeld_source], batch_size/2)
        t_target_batch=util.batch_generator([train_ftd_target, train_labeld_target], batch_size/2)
        
        # adaptation discriminator labels and separation discriminator labels
        domain_labels=np.vstack([np.tile([1, 0], [batch_size/2, 1]), np.tile([0, 1],[batch_size/2, 1])])
        domain_sep_labels=np.vstack([np.tile([1, 0, 0], [batch_size, 1]), np.tile([0, 1, 0],[batch_size/2, 1]), np.tile([0, 0, 1],[batch_size/2, 1])])
        
        for itr in range(num_epoch):
            #learning rate
            p = float(itr) / num_epoch
            flip_scale= 2. / (1. + np.exp(-10. * p)) 
            lr=0.01 / (1. + 10 * p)**0.75
            
            if training_mode=='ADSA':
                
                d_source_batch=util.batch_generator([train_ftd_source, train_labeld_source], batch_size/2)
                d_target_batch=util.batch_generator([train_ftd_target, train_labeld_target], batch_size/2)
                gen_source_batch= util.batch_generator([train_ftd_source, train_labeld_source], batch_size / 2)
                gen_target_batch= util.batch_generator([train_ftd_target, train_labeld_target], batch_size / 2)
                #update the adaptation discriminator and separation discriminator
                for d_iter in range(20):
                    dX0, dy0= d_source_batch.next()
                    dX1, dy1= d_target_batch.next()
                    dX=np.vstack([dX0,dX1])
                    dy=np.vstack([dy0, dy1])
                
                    _, d_loss=sess.run([adapt_opt, domain_loss], feed_dict={model.X: dX, model.y: dy,
                                                                            model.domain_labels: domain_labels,
                                                                            model.train_flag: True,
                                                                            model.flip_scale: flip_scale,
                                                                            learning_rate: lr})
                    if d_iter %10 ==0:
                        _, s_loss,s_acc=sess.run([sep_opt, sep_loss, sep_acc], feed_dict={model.X: dX, model.y: dy,
                                                                            model.domain_labels: domain_labels,
                                                                            model.domain_sep_labels: domain_sep_labels,
                                                                            model.train_flag: True,
                                                                            model.flip_scale: flip_scale,
                                                                            learning_rate: lr})
                # update shared encoder source individual encoder and target individual encoder
                for g_iter in range(10):
                    gX0, gy0 =gen_source_batch.next()
                    gX1, gy1 = gen_target_batch.next()
                    gX=np.vstack([gX0, gX1])
                    gy=np.vstack([gy0, gy1])
                    _, g_loss=sess.run([gen_opt, domain_loss],feed_dict={model.X: gX, model.y: gy,
                                                                         model.domain_labels: domain_labels,
                                                                            model.train_flag: True,
                                                                            model.flip_scale: flip_scale,
                                                                            learning_rate: lr})
                    
                    _, s_loss=sess.run([sep_gen_opt, sep_loss],feed_dict={model.X: gX, model.y: gy,
                                                                          model.domain_sep_labels: domain_sep_labels,
                                                                          model.train_flag: True,
                                                                          model.flip_scale: flip_scale,
                                                                          learning_rate: lr})
                # update all 
                X0, y0 =t_source_batch.next()
                X1, y1 = t_target_batch.next()
                X=np.vstack([X0, X1])
                y=np.vstack([y0, y1])
                _, batch_loss, p_acc, d_acc=\
                        sess.run([dann_train_op, total_loss,label_acc, domain_acc], feed_dict={model.X: X,model.y:y,
                                                                                               model.domain_labels:domain_labels,
                                                                                               model.domain_sep_labels:domain_sep_labels,
                                                                                               model.train_flag: True,
                                                                                               model.flip_scale: flip_scale,
                                                                                               learning_rate:lr})
                
                if verbose and itr % 100==0:
                    print ('loss: %f adapt_acc: %f task_acc: %f sep acc: %f\n'
                           'flip_scale: %f  learning_rate: %f' %
                           (batch_loss, d_acc, p_acc,s_acc, flip_scale, lr))
        # test
        source_acc = sess.run(label_acc,
                            feed_dict={model.X: test_ftd_source, model.y: test_labeld_source,
                                       model.train_flag: False})

        target_acc = sess.run(label_acc,
                            feed_dict={model.X: test_ftd_target, model.y: test_labeld_target,
                                       model.train_flag: False})
        
        
        
    return source_acc, target_acc

In [5]:
print ('\nDomain adaptation training')
source_acc, target_acc= train_and_evaluate('ADSA', graph, model,True)
print ('Source (MNIST) accuracy: %f'% source_acc)
print ('Target (MNIST-M) accuracy: %f'% target_acc)


Domain adaptation training
loss: 0.547829 adapt_acc: 0.445312 task_acc: 0.453125 sep acc: 0.203125
flip_scale: 1.000000  learning_rate: 0.010000
loss: 0.025800 adapt_acc: 0.578125 task_acc: 0.406250 sep acc: 1.000000
flip_scale: 1.244919  learning_rate: 0.007378
loss: -0.030833 adapt_acc: 0.515625 task_acc: 0.703125 sep acc: 1.000000
flip_scale: 1.462117  learning_rate: 0.005946
loss: -0.060310 adapt_acc: 0.492188 task_acc: 0.718750 sep acc: 1.000000
flip_scale: 1.635149  learning_rate: 0.005030
loss: -0.112889 adapt_acc: 0.507812 task_acc: 0.734375 sep acc: 1.000000
flip_scale: 1.761594  learning_rate: 0.004387
loss: -0.249090 adapt_acc: 0.578125 task_acc: 0.828125 sep acc: 1.000000
flip_scale: 1.848284  learning_rate: 0.003908
loss: -0.339880 adapt_acc: 0.515625 task_acc: 0.875000 sep acc: 1.000000
flip_scale: 1.905148  learning_rate: 0.003536
loss: -0.360324 adapt_acc: 0.507812 task_acc: 0.875000 sep acc: 1.000000
flip_scale: 1.941376  learning_rate: 0.003237
loss: -0.430987 adapt_