In [1]:
import tensorflow as tf
from twinmnist import TwinMnist
from conf import X_DIM, C_DIM, S_DIM, MODEL_FILE
from conf import e, es, ec, ec_tail, d
from conf import sdc, cdc
from conf import rs, rc

  from ._conv import register_converters as _register_converters


In [2]:
x1 = tf.placeholder(tf.float32, [None, X_DIM], name='x1')
x2 = tf.placeholder(tf.float32, [None, X_DIM], name='x2')
xc = tf.placeholder(tf.float32, [None, X_DIM], name='xc')
s = tf.placeholder(tf.float32, [None, S_DIM], name='s')
c = tf.placeholder(tf.float32, [None, C_DIM], name='c')

In [3]:
x1_z, x2_z, xc_z = e(x1), e(x2), e(xc)
x1_ct, x2_ct, xc_ct = ec(x1_z), ec(x2_z), ec(xc_z)
x1_c, x2_c, xc_c = ec_tail(x1_ct), ec_tail(x2_ct), ec_tail(xc_ct)
x1_s = es(x1_z)

In [4]:
x1_r = d(tf.concat([x1_c, x1_s], axis=1))

In [5]:
sdc_real = sdc(s)
sdc_fake = sdc(x1_s)

In [6]:
cdc_real = cdc(c)
cdc_fake = cdc(x1_c)

In [7]:
ae_loss = tf.reduce_mean(tf.square(x1 - x1_r))

In [8]:
h = .2
loss_term = tf.reduce_sum(tf.square(x1_ct - x2_ct), 1) - tf.reduce_sum(tf.square(x1_ct - xc_ct), 1) + h
triplet_loss = tf.reduce_mean(tf.maximum(0., loss_term))

In [9]:
sdc_real_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(sdc_real), logits=sdc_real))
sdc_fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(sdc_fake), logits=sdc_fake))
sdc_loss = sdc_real_loss + sdc_fake_loss

sdc_e_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(sdc_fake), logits=sdc_fake))

In [10]:
cdc_real_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(cdc_real), logits=cdc_real))
cdc_fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(cdc_fake), logits=cdc_fake))
cdc_loss = cdc_real_loss + cdc_fake_loss

cdc_e_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(cdc_fake), logits=cdc_fake))

In [11]:
all_variables = tf.trainable_variables()
e_vars = [v for v in all_variables if v.name.startswith('es_')]
ec_vars = [v for v in all_variables if v.name.startswith('e_c_')]
ae_vars = [v for v in all_variables if v.name.startswith('e_') or v.name.startswith('d_') or v.name.startswith('es_')]
#ec_vars = [v for v in all_variables if v.name.startswith('ec_')]
sdc_vars = [v for v in all_variables if v.name.startswith('sdc_')]
cdc_vars = [v for v in all_variables if v.name.startswith('cdc_')]

In [12]:
# things to think:
# - use different learning rate (different optimizers)
# - greater learning rate for triplet (or pretrain on batches)
# - make style part of the encoder shorter (to prevent ignoring content)

learning_rate = 0.001 # AdamOptimizer default 0.001
ae_opt = tf.train.AdamOptimizer(learning_rate).minimize(ae_loss, var_list=ae_vars)

triplet_opt = tf.train.AdamOptimizer(learning_rate).minimize(triplet_loss) # , var_list=ec_vars

sdc_opt = tf.train.AdamOptimizer(learning_rate).minimize(sdc_loss, var_list=sdc_vars)
sdc_e_opt = tf.train.AdamOptimizer(learning_rate).minimize(sdc_e_loss, var_list=e_vars)

cdc_opt = tf.train.AdamOptimizer(learning_rate).minimize(cdc_loss, var_list=cdc_vars)
cdc_e_opt = tf.train.AdamOptimizer(learning_rate).minimize(cdc_e_loss, var_list=ec_vars)

In [13]:
tf.summary.scalar('ae_loss', ae_loss)
tf.summary.scalar('triplet_loss', triplet_loss)
tf.summary.scalar('sdc_loss', sdc_loss)
tf.summary.scalar('sdc_e_loss', sdc_e_loss)
tf.summary.scalar('cdc_loss', cdc_loss)
tf.summary.scalar('cdc_e_loss', cdc_e_loss)

tf.summary.histogram('x1_c', x1_c)
tf.summary.histogram('x1_s', x1_s)
#tf.summary.histogram('s', s)

#tf.summary.image('x1', tf.reshape(x1, [-1, 28, 28, 1]), 1)

#tf.summary.image('x1_r', tf.reshape(x1_r, [-1, 28, 28, 1]), 1)

#tf.summary.image('x_c1s2_r', tf.reshape(d(tf.concat([x1_c, x2_s], axis=1)), [-1, 28, 28, 1]), 1)
#tf.summary.image('x_c2s1_r', tf.reshape(d(tf.concat([x2_c, x1_s], axis=1)), [-1, 28, 28, 1]), 1)

all_summary = tf.summary.merge_all()

In [14]:
sess = tf.Session()
sess.run(tf.global_variables_initializer())
writer = tf.summary.FileWriter('summary', sess.graph)
saver = tf.train.Saver()

In [15]:
tm = TwinMnist(train_dir='MNIST_data', validation_size=0)
tm.load_triplets('mnist_triplets_k100')

Extracting MNIST_data\train-images-idx3-ubyte.gz
Extracting MNIST_data\train-labels-idx1-ubyte.gz
Extracting MNIST_data\t10k-images-idx3-ubyte.gz
Extracting MNIST_data\t10k-labels-idx1-ubyte.gz


In [16]:
batch_step = 0

In [17]:
epochs = 50
batch_size = 100
batches = tm.mnist.train.images.shape[0] // batch_size
batch_end = epochs*batches

In [18]:
for e in range(1, epochs + 1):
    for b in range(1, batches + 1):
        x1_batch, x2_batch, xc_batch = tm.next_batch_triplets(batch_size)
        s_batch = rs(batch_size)
        c_batch = rc(batch_size)
        
        #sess.run(triplet_opt, {x1: x1_batch, x2: x2_batch, xc: xc_batch})
        
        sess.run([ae_opt, triplet_opt], {x1: x1_batch, x2: x2_batch, xc: xc_batch})
        
        sess.run(sdc_opt, {x1: x1_batch, s: s_batch})
        sess.run(sdc_e_opt, {x1: x1_batch})
        
        sess.run(cdc_opt, {x1: x1_batch, c: c_batch})
        sess.run(cdc_e_opt, {x1: x1_batch})
        
        batch_step += 1
        print('\repoch {0} {1:3.0f} %'.format(e, b / batches * 100), end='', flush=True)
        summary_str = sess.run(all_summary, {
            x1: x1_batch,
            x2: x2_batch,
            xc: xc_batch,
            s: s_batch,
            c: c_batch
        })
        writer.add_summary(summary_str, batch_step)
print('\rDone', ' '*25, flush=True)

epoch 35  88 %

KeyboardInterrupt: 

In [19]:
saver.save(sess, MODEL_FILE)
writer.close()
sess.close()