# Train

In [None]:
import pandas as pd
import os
import tensorflow as tf


import scipy
from scipy.io import loadmat
import re
import copy as cp
import string
import imageio
import numpy as np
import matplotlib.pyplot as plt
from utils import *
import random
import time
import nltk

import warnings
import model


batch_size = 64
img_size = 64
z_dim = 512
t_dim = 256
df_dim = 64
gf_dim = 128


lr = 0.0002
lr_decay = 0.5      
decay_every = 100  
save_dir = './checkpoint'
restore_model = '/VBN700/modelVBN.ckpt'
n_caps = 5
warnings.filterwarnings('ignore')
ni = int(np.ceil(np.sqrt(batch_size)))
sample_size = batch_size
sample_seed = np.random.normal(loc=0.0, scale=1.0, size=(sample_size, z_dim)).astype(np.float32)
test_sentence = pd.read_csv('test.csv',dtype={'ID':'str'})
sample_sentence = []
for i in range(64):
    sample_sentence.append(test_sentence['Captions'].values[i])

for i, sent in enumerate(sample_sentence):
    sample_sentence[i] = sent2IdList(sent)

dictionary_path = './dictionary'
vocab = np.load(dictionary_path+'/vocab.npy')
print('there are {} vocabularies in total'.format(len(vocab)))

word2Id_dict = dict(np.load(dictionary_path+'/word2Id.npy'))
id2word_dict =  dict(np.load(dictionary_path+'/id2Word.npy'))



word2Id_dict = dict(np.load(dictionary_path + '/word2Id.npy'))
id2word_dict = dict(np.load(dictionary_path + '/id2Word.npy'))

train_img = np.load('train_images.npy', encoding='latin1')
train_cap = np.load('train_captions.npy', encoding='latin1')

cap_lst = []
for caps in train_cap:
    cap_lst.append(caps[:n_caps])
train_cap = np.concatenate(cap_lst, axis=0) 



model_options = {
    'z_dim' : z_dim,
    'batch_size' : batch_size,
    'img_size' :img_size,
    't_dim' :t_dim,
    'df_dim' :df_dim,
    'gf_dim' :gf_dim,
    'vocab_size':  len(vocab)
}

gan = model.GAN(model_options)


input_tensors,loss,net= gan.build_model()

d_loss = loss['d']
g_loss = loss['g']
rnn_loss = loss['r']
net_g = net['net_g']

net_rnn = net['net_rnn']



rnn_vars = [var for var in tf.trainable_variables() if 'rnn' in var.name]
g_vars = [var for var in tf.trainable_variables() if 'generator' in var.name]
d_vars = [var for var in tf.trainable_variables() if 'discrim' in var.name]
cnn_vars = [var for var in tf.trainable_variables() if 'cnn' in var.name]

update_ops_D = [var for var in tf.get_collection(tf.GraphKeys.UPDATE_OPS) if 'discrim' in var.name]
update_ops_G = [var for var in tf.get_collection(tf.GraphKeys.UPDATE_OPS) if 'generator' in var.name]
update_ops_CNN = [var for var in tf.get_collection(tf.GraphKeys.UPDATE_OPS) if 'cnn' in var.name]


with tf.variable_scope('learning_rate'):
    lr_v = tf.Variable(lr, trainable=False)

with tf.control_dependencies(update_ops_D):
    d_optim = tf.train.AdamOptimizer(lr_v, beta1=0.5).minimize(d_loss, var_list=d_vars)

with tf.control_dependencies(update_ops_G):
    g_optim = tf.train.AdamOptimizer(lr_v, beta1=0.5).minimize(g_loss, var_list=g_vars)

with tf.control_dependencies(update_ops_CNN):
    grads, _ = tf.clip_by_global_norm(tf.gradients(rnn_loss, rnn_vars + cnn_vars), 10)
    optimizer = tf.train.AdamOptimizer(lr_v, beta1=0.5)
    rnn_optim = optimizer.apply_gradients(zip(grads, rnn_vars + cnn_vars))


config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)
init = tf.global_variables_initializer()
sess.run(init)

saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=5)

try:
    ckpt_path  = save_dir+restore_model
    saver.restore(sess, ckpt_path)
    print("Restore model",ckpt_path)
except:
    print('no model found.')

n_epoch = 700
n_batch_epoch = int(len(train_img) / batch_size)
train_img_lr = cp.copy(train_img)
train_img_ud = cp.copy(train_img)
train_img_udlr = cp.copy(train_img)

for i,img in enumerate(train_img):
    train_img_lr[i] = train_img[i][:, ::-1]
for i,img in enumerate(train_img):
    train_img_ud[i] = np.flipud(train_img[i])
for i,img in enumerate(train_img_ud):
    train_img_udlr[i] = np.flipud(train_img_lr[i])

for epoch in range(n_epoch):
    start_time = time.time()

    if epoch !=0 and (epoch % decay_every == 0):
        new_lr_decay = lr_decay ** (epoch // decay_every)
        sess.run(tf.assign(lr_v, lr * new_lr_decay))
        log = " ** new learning rate: %f" % (lr * new_lr_decay)
        print(log)
        
    elif epoch == 0:
        log = " ** init lr: %f  decay_every_epoch: %d, lr_decay: %f" % (lr, decay_every, lr_decay)
        print(log)

    for step in range(n_batch_epoch):
        step_time = time.time()

        ## get matched text & image
        # idxs = get_random_int(min=0, max=len(train_cap)-1, number=batch_size)
        idxs = np.random.randint(low = 0,high = len(train_cap), size=batch_size)
        real_caption = train_cap[idxs]

        r = np.random.randint(4, size=(1))
        if r == 0:
            real_images = train_img[np.floor(np.asarray(idxs).astype('float')/n_caps).astype('int')]
        elif r==1:
            real_images = train_img_lr[np.floor(np.asarray(idxs).astype('float')/n_caps).astype('int')]
        elif r ==2:
            real_images = train_img_ud[np.floor(np.asarray(idxs).astype('float')/n_caps).astype('int')]
        elif r == 3:
            real_images = train_img_udlr[np.floor(np.asarray(idxs).astype('float')/n_caps).astype('int')]
        ## get wrong caption & wrong image
        # idxs = get_random_int(min=0, max=len(train_cap)-1, number=batch_size)
        idxs = np.random.randint(low = 0,high = len(train_cap), size=batch_size)

        wrong_caption = train_cap[idxs]
        # idxs2 = get_random_int(min=0, max=len(train_img)-1, number=batch_size)
        idxs2 = np.random.randint(low = 0,high = len(train_img), size=batch_size)
        
        wrong_images = train_img[idxs2]

        ## get noise
        b_z = np.random.normal(loc=0.0, scale=1.0, size=(batch_size, z_dim)).astype(np.float32)

        real_images = threading_data(real_images, prepro_img, mode='train')   # [0, 255] --> [-1, 1] + augmentation
        wrong_images = threading_data(wrong_images, prepro_img, mode='train')

        # ## update RNN
        # if epoch < 80:
        R_loss, _ = sess.run([rnn_loss, rnn_optim], feed_dict={
                                        input_tensors['t_real_image'] : real_images,
                                        input_tensors['t_wrong_image'] : wrong_images,
                                        input_tensors['t_real_caption'] : real_caption,
                                        input_tensors['t_wrong_caption'] : wrong_caption})
        # else:
        #     R_loss = 0

        ## updates D
        D_loss, _ = sess.run([d_loss, d_optim], feed_dict={
                        input_tensors['t_real_image']  : real_images,
                        input_tensors['t_wrong_caption']  : wrong_caption,
                        input_tensors['t_real_caption']  : real_caption,
                        input_tensors['t_z']  : b_z})
        ## updates G
        G_loss, _ = sess.run([g_loss, g_optim], feed_dict={
                        input_tensors['t_real_caption'] : real_caption,
                        input_tensors['t_z'] : b_z})

        print("Epoch: [%2d/%2d] [%4d/%4d] time: %4.4fs, d_loss: %.8f, g_loss: %.8f, rnn_loss: %.8f" \
                    % (epoch, n_epoch, step, n_batch_epoch, time.time() - step_time, D_loss, G_loss, R_loss))

        # print('d/g/r',D_loss,G_loss,R_loss)
    
        if step == 0:
            print(" ** Epoch %d took %fs" % (epoch, time.time()-start_time))
            img_gen, rnn_out = sess.run([net_g.out, net_rnn.out], feed_dict={
                                        input_tensors['t_real_caption'] : sample_sentence,
                                        input_tensors['t_z']  : sample_seed})
            print("min:",np.min(img_gen),"max:",np.max(img_gen))
            img_gen = img_gen*0.5+0.5
            scipy.misc.imsave('train_samples/train_{:03d}.png'.format(epoch), merge(img_gen, [ni, ni]))
    if  (epoch % 10) == 0:

        saver.save(sess,save_dir+'/modelVBN.ckpt')
        print("model save to",save_dir)


# model

In [None]:
import tensorflow as tf
import numpy as np
from utils import *

from layer_utils import *
lrelu = lambda x: tf.nn.leaky_relu(x, 0.2)
relu = lambda x: tf.nn.relu(x)
class GAN:
    def __init__(self, options):
        self.options = options
    def build_model(self):
        batch_size = self.options['batch_size']
        img_size = self.options['img_size']
        z_dim = self.options['z_dim']

        t_real_image = tf.placeholder('float32', [batch_size, img_size, img_size, 3], name = 'real_image')
        t_wrong_image = tf.placeholder('float32', [batch_size ,img_size, img_size, 3], name = 'wrong_image')
        t_real_caption = tf.placeholder(dtype=tf.int64, shape=[batch_size, None], name='real_caption_input')
        t_wrong_caption = tf.placeholder(dtype=tf.int64, shape=[batch_size, None], name='wrong_caption_input')
        t_z = tf.placeholder('float32', [batch_size, z_dim])


        x   = cnn_encoder(t_real_image,option = self.options,is_training=True, reuse=False).out
        v   = TextEncoder(t_real_caption,option = self.options ,is_training=True, reuse=False).out
        x_w = cnn_encoder(t_wrong_image,option = self.options, is_training=True, reuse=True).out
        v_w = TextEncoder(t_wrong_caption,option = self.options ,is_training=True, reuse=True).out

        alpha = 0.2 # margin alpha
        rnn_loss = tf.reduce_mean(tf.maximum(0., alpha - cosine_similarity(x, v) + cosine_similarity(x, v_w))) + \
                    tf.reduce_mean(tf.maximum(0., alpha - cosine_similarity(x, v) + cosine_similarity(x_w, v)))

        ### Training Phase - GAN
        net_rnn = TextEncoder(t_real_caption, option = self.options,is_training=False, reuse=True)
        net_fake_image = Generator(t_z, net_rnn.out,option = self.options, is_training=True, reuse=False)
                
        disc_fake = Discriminator(net_fake_image.out, net_rnn.out, is_training=True, reuse=False)

        disc_real = Discriminator(t_real_image, net_rnn.out, is_training=True, reuse=True)


        disc_wrong = Discriminator(t_real_image, 
                                    TextEncoder(t_wrong_caption, option = self.options,is_training=False, reuse=True).out,
                                    is_training=True, reuse=True)

        d_loss1 = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=disc_real.logits,labels=tf.ones_like(disc_real.logits)*0.9,name='d1'))
        d_loss2 = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=disc_wrong.logits, labels=tf.zeros_like(disc_wrong.logits)+0.1, name='d2'))
        d_loss3 = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=disc_fake.logits,labels=tf.zeros_like(disc_fake.logits)+0.1,name='d3'))
        d_loss = d_loss1 + (d_loss2 + d_loss3) * 0.5
        g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=disc_fake.logits, labels=tf.ones_like(disc_fake.logits)*0.9, name='g'))
        

        #Use for visualization
        net_g = Generator(t_z,TextEncoder(t_real_caption,option = self.options, is_training=False, reuse=True).out,option = self.options,is_training=False, reuse=True)

        input_tensors = {
            't_real_image' : t_real_image,
            't_wrong_image' : t_wrong_image,
            't_real_caption' : t_real_caption,
            't_wrong_caption' : t_wrong_caption,
            't_z':t_z
        }
        loss = {
            'g' : g_loss,
            'd' : d_loss,
            'r' : rnn_loss
        }
        net = {
            'net_g':net_g,
            'net_rnn':net_rnn   
        }
        return input_tensors,loss,net

class Generator:
    def __init__(self, input_z, input_txt,option, is_training, reuse):
        self.input_z = input_z
        self.input_txt = input_txt
        self.is_training = is_training
        self.reuse = reuse
        self.t_dim = option['t_dim']
        self.gf_dim = option['gf_dim']
        self.image_size = option['img_size']
        self.c_dim = 3
        self.VBN = tf.contrib.gan.features.VBN

        
        self.build_model()

    def build_model(self):
        s = self.image_size
        s2, s4, s8, s16 = int(s/2), int(s/4), int(s/8), int(s/16)

        gf_dim = self.gf_dim
        t_dim = self.t_dim
        c_dim = self.c_dim

        with tf.variable_scope("generator", reuse=self.reuse):
            net_txt = DenseLayer(inputs=self.input_txt, n_units=t_dim, act=tf.nn.leaky_relu, name='rnn_fc')

            net_in = ConcatLayer([self.input_z, net_txt], concat_dim=1, name='concat_z_txt')

            net_h0 = DenseLayer(inputs=net_in, n_units=gf_dim*8*s16*s16, name='g_h0/fc', b_init=False)
            # net_h0 = BatchNormLayer(net_h0, act=None, is_training=self.is_training, name='g_h0/batch_norm')
            net_h0 = self.VBN(net_h0,name='g_h0/batch_norm')(net_h0)

            net_h0 = tf.reshape(net_h0, [-1, s16, s16, gf_dim*8], name='g_h0/reshape')

            net_h0 = tf.layers.dropout(net_h0,0.5)

            net = Conv2d(net_h0,gf_dim*2, (1,1),  (1,1),padding='VALID', name='g_h1_res/conv2d')

            # net = BatchNormLayer(net, act=relu, is_training=self.is_training, name='g_h1_res/batch_norm')
            net = self.VBN(net, name='g_h1_res/batch_norm')(net)
            net = relu(net)
            net = Conv2d(net,gf_dim*2,  (3,3), (1,1), name='g_h1_res/conv2d2', padding='SAME')
            # net = BatchNormLayer(net, act=relu, is_training=self.is_training, name='g_h1_res/batch_norm2')
            net = self.VBN(net,  name='g_h1_res/batch_norm2')(net)
            net = relu(net)
            net = Conv2d(net, gf_dim*8,(3,3),  (1,1), name='g_h1_res/conv2d3', padding='SAME')
            
            # net = BatchNormLayer(net, act=None, is_training=self.is_training, name='g_h1_res/batch_norm3')
            net = self.VBN(net, name='g_h1_res/batch_norm3')(net)

            net_h1 = tf.add_n([net_h0, net], name='g_h1_res/add')
            net_h1_output = relu(net_h1)
            
            net_h2 = UpSample(net_h1_output, size=[s8, s8], method=1, align_corners=False, name='g_h2/upsample2d')
            net_h2 = Conv2d(net_h2, gf_dim*4,(3,3), (1,1), name='g_h2/conv2d', padding='SAME')
            # net_h2 = BatchNormLayer(net_h2, act=None, is_training=self.is_training, name='g_h2/batch_norm')
            net_h2 = self.VBN(net_h2,name='g_h2/batch_norm')(net_h2)
            net_h2 = relu(net_h2)

            net = Conv2d(net_h2,gf_dim, (1,1),  (1,1), name='g_h3_res/conv2d')
            # net = BatchNormLayer(net, act=relu, is_training=self.is_training, name='g_h3_res/batch_norm')
            net = self.VBN(net, name='g_h3_res/batch_norm')(net)
            net = relu(net)
            net = Conv2d(net,gf_dim, (3,3), (1,1), name='g_h3_res/conv2d2', padding='SAME')
            # net = BatchNormLayer(net, act=relu, is_training=self.is_training, name='g_h3_res/batch_norm2')
            net = self.VBN(net, name='g_h3_res/batch_norm2')(net)
            net = relu(net)
            net = Conv2d(net,gf_dim*4,(3,3), (1,1), name='g_h3_res/conv2d3', padding='SAME')
            # net = BatchNormLayer(net, act=None, is_training=self.is_training, name='g_h3_res/batch_norm3')
            net = self.VBN(net, name='g_h3_res/batch_norm3')(net)

            net_h3 = tf.add_n([net_h2, net], name='g_h3/add')
            net_h3_outputs = relu(net_h3)

            net_h4 = UpSample(net_h3_outputs, size=[s4, s4], method=1, align_corners=False, name='g_h4/upsample2d')
            net_h4 = Conv2d(net_h4,gf_dim*2, (3,3), (1,1) , name='g_h4/conv2d', padding='SAME')
            # net_h4 = BatchNormLayer(net_h4, act=relu, is_training=self.is_training, name='g_h4/batch_norm')
            net_h4 = self.VBN(net_h4, name='g_h4/batch_norm')(net_h4)
            net_h4 = relu(net_h4)

            net_h5 = UpSample(net_h4, size=[s2, s2], method=1, align_corners=False, name='g_h5/upsample2d')
            net_h5 = Conv2d(net_h5, gf_dim, (3,3), (1,1), name='g_h5/conv2d', padding='SAME')
            # net_h5 = BatchNormLayer(net_h5, act=relu, is_training=self.is_training, name='g_h5/batch_norm')
            net_h5 = self.VBN(net_h5, name='g_h5/batch_norm')(net_h5)
            net_h5 = relu(net_h5)

            net_ho = UpSample(net_h5, size=[s, s], method=1, align_corners=False, name='g_ho/upsample2d')
            net_ho = Conv2d(net_ho,c_dim,(3,3),  (1,1), name='g_ho/conv2d', padding='SAME', b_init=True) ## b_init = True

            self.logits = net_ho
            # self.out = tf.nn.tanh(self.logits)*0.5+0.5
            self.out = tf.nn.tanh(self.logits)
            


class Discriminator:
    def __init__(self, input_image, input_txt, is_training, reuse):
        self.input_image = input_image
        self.input_txt = input_txt
        self.is_training = is_training
        self.reuse = reuse
        self.df_dim = 64
        self.t_dim = 128
        self.image_size = 64
        self.build_model()

    def build_model(self):
        s = self.image_size
        s2, s4, s8, s16 = int(s/2), int(s/4), int(s/8), int(s/16)

        df_dim = self.df_dim
        t_dim = self.t_dim

        with tf.variable_scope("discriminator", reuse=self.reuse):
            net_h0 = Conv2d(self.input_image, df_dim,(4,4),(2, 2), name='d_h0/conv2d', act=tf.nn.leaky_relu, padding='SAME', b_init=True)

            net_h1 = Conv2d(net_h0,df_dim*2, (4,4), (2, 2), name='d_h1/conv2d', padding='SAME')
            net_h1 = BatchNormLayer(net_h1, act=tf.nn.leaky_relu, is_training=self.is_training, name='d_h1/batchnorm')

            net_h2 = Conv2d(net_h1,df_dim*4, (4,4), (2, 2), name='d_h2/conv2d', padding='SAME')
            net_h2 = BatchNormLayer(net_h2, act=tf.nn.leaky_relu, is_training=self.is_training, name='d_h2/batchnorm')

            net_h3 = Conv2d(net_h2,df_dim*8, (4,4), (2, 2), name='d_h3/conv2d', padding='SAME')
            net_h3 = BatchNormLayer(net_h3, act=None, is_training=self.is_training, name='d_h3/batchnorm')

            net = Conv2d(net_h3,df_dim*2,(1, 1),  (1, 1), name='d_h4_res/conv2d')
            net = BatchNormLayer(net, act=tf.nn.leaky_relu, is_training=self.is_training, name='d_h4_res/batchnorm')
            net = Conv2d(net, df_dim*2,(3, 3),  (1, 1), name='d_h4_res/conv2d2', padding='SAME')
            net = BatchNormLayer(net, act=tf.nn.leaky_relu, is_training=self.is_training, name='d_h4_res/batchnorm2')
            net = Conv2d(net,df_dim*8, (3, 3),  (1, 1), name='d_h4_res/conv2d3', padding='SAME')
            net = BatchNormLayer(net, act=None, is_training=self.is_training, name='d_h4_res/batchnorm3')

            net_h4 = tf.add_n([net_h3, net], name='d_h4/add')
            net_h4_outputs = tf.nn.leaky_relu(net_h4)

            net_txt = DenseLayer(self.input_txt, n_units=t_dim, act=tf.nn.leaky_relu, name='d_reduce_txt/dense')
            net_txt = tf.expand_dims(net_txt, axis=1, name='d_txt/expanddim1')
            net_txt = tf.expand_dims(net_txt, axis=1, name='d_txt/expanddim2')
            net_txt = tf.tile(net_txt, [1, 4, 4, 1], name='d_txt/tile')
            
            net_h4_concat = ConcatLayer([net_h4_outputs, net_txt], concat_dim=3, name='d_h3_concat')

            net_h4 = Conv2d(net_h4_concat, df_dim*8,(1, 1),  (1, 1), name='d_h3/conv2d_2')
            net_h4 = BatchNormLayer(net_h4, act=tf.nn.leaky_relu, is_training=self.is_training, name='d_h3/batch_norm_2')

            net_ho = Conv2d(net_h4, 1, (s16, s16),  (s16, s16), name='d_ho/conv2d', b_init=True) # b_init = True
            self.logits = net_ho
            self.out = tf.nn.sigmoid(net_ho)



#Modified from course slide
class TextEncoder:
    def __init__(self, input_seqs,option, is_training, reuse):
        self.input_seqs = input_seqs

        self.t_dim = option['t_dim']
        self.rnn_hidden_size = 128
        self.vocab_size = option['vocab_size']
        self.word_embedding_size = 256
        self.batch_size = option['batch_size']
        self.is_training = is_training        
        self.reuse = reuse
        self.build_model()

    def build_model(self):
        w_init = tf.random_normal_initializer(stddev=0.02)
    
        with tf.variable_scope("rnnftxt", reuse=self.reuse):
            word_embed_matrix = tf.get_variable('rnn/wordembed', 
                shape=(self.vocab_size, self.word_embedding_size),
                initializer=tf.random_normal_initializer(stddev=0.02),
                dtype=tf.float32)
            embedded_word_ids = tf.nn.embedding_lookup(word_embed_matrix, self.input_seqs)

            # RNN encoder
            LSTMCell = tf.contrib.rnn.BasicLSTMCell(self.t_dim, reuse=self.reuse)
            initial_state = LSTMCell.zero_state(self.batch_size, dtype=tf.float32)
            network  = tf.nn.dynamic_rnn(cell=LSTMCell,
                                    inputs=embedded_word_ids,
                                    initial_state=initial_state,
                                    dtype=np.float32,
                                    time_major=False,
                                    scope='rnn/dynamic')
        self.out = network[0][:, -1, :]

class cnn_encoder:
    def __init__(self, inputs,option,is_training=True, reuse=False):
        self.inputs = inputs
        self.df_dim = option['df_dim']
        self.t_dim = option['t_dim']
        self.is_training = is_training
        self.reuse = reuse
        self.build_model()
    def build_model(self):
        df_dim = self.df_dim

        with tf.variable_scope('cnnftxt', reuse=self.reuse):
            net_h0 = Conv2d(self.inputs,df_dim,(4,4),(2,2), name='cnnf/h0/conv2d', act=tf.nn.leaky_relu, padding='SAME', b_init=True)
            
            net_h1 = Conv2d(net_h0,df_dim*2,(4,4),(2,2), name='cnnf/h1/conv2d', padding='SAME')
            net_h1 = BatchNormLayer(net_h1, act=tf.nn.leaky_relu, is_training=self.is_training, name='cnnf/h1/batch_norm')

            net_h2 = Conv2d(net_h1,df_dim*4,(4,4),(2,2), name='cnnf/h2/conv2d', padding='SAME')
            net_h2 = BatchNormLayer(net_h2, act=tf.nn.leaky_relu, is_training=self.is_training, name='cnnf/h2/batch_norm')

            net_h3 = Conv2d(net_h2, df_dim*8,(4,4),(2,2), name='cnnf/h3/conv2d', padding='SAME')
            net_h3 = BatchNormLayer(net_h3, act=tf.nn.leaky_relu, is_training=self.is_training, name='cnnf/h3/batch_norm')

            net_h4 = flatten(net_h3, name='cnnf/h4/flatten')
            net_h4 = DenseLayer(net_h4, n_units=self.t_dim, name='cnnf/h4/embed', b_init=False)
        
        self.out = net_h4