# <center> Learned primal-dual on Wavelet coefficients</center>

In [1]:
import os
import adler
adler.util.gpu.setup_one_gpu(1)

Picking GPU 1


In [2]:
from adler.odl.phantom import random_phantom
from adler.tensorflow import prelu, cosine_decay

In [3]:
import tensorflow as tf
import numpy as np
import odl
import odl.contrib.tensorflow

In [27]:
np.random.seed(0)
#name = os.path.splitext(os.path.basename(__file__))[0]
name = os.path.splitext(os.getcwd())[0]+'/checkpoints'

In [5]:
# Creating checkpoints directory
try:
    os.mkdir(name)
except OSError:
    print ("Creation of the directory %s failed" % name)
else:
    print ("Successfully created the directory %s " % name)

Creation of the directory /store/kepler/datastore/andrade/GitHub_repos/tfshearlab/Paper_experiments/Python/checkpoints failed


In [7]:
sess = tf.InteractiveSession()

In [8]:
adler.tensorflow.util.default_checkpoint_path(name)

'/store/kepler/datastore/andrade/GitHub_repos/tfshearlab/Paper_experiments/Python_wavelets.ckpt'

**Define space**

In [9]:
# Create ODL data structures
size = 128
space_image = odl.uniform_discr([-64, -64], [64, 64], [size, size],
                          dtype='float32')

In [10]:
import odl.trafos.wavelet as wavelet

**Wavelet analysis and synthesis operator**

In [11]:
# Wavelet analysis operator
wavean= wavelet.WaveletTransform(domain = space_image,wavelet = 'db2', nlevels = 4)

In [12]:
# Wavelet synthesis operator
wavesyn = wavean.inverse

In [13]:
# Wavelet coefficietns space
space = wavesyn.domain

**Ray operator**

In [14]:
# Using as backend scikit-image
geometry = odl.tomo.parallel_beam_geometry(wavesyn.range, num_angles=30)
ray_operator= odl.tomo.RayTransform(wavesyn.range, geometry)

**Operator = Ray(Shearsyn)**

In [15]:
operator = odl.operator.operator.OperatorComp(ray_operator,wavesyn)

In [16]:
# Ensure operator has fixed operator norm for scale invariance
opnorm = odl.power_method_opnorm(operator)
operator = (1 / opnorm) * operator

In [17]:
# Create tensorflow layer from odl operator
odl_op_layer = odl.contrib.tensorflow.as_tensorflow_layer(operator,
                                                          'RayWaveTransform')
odl_op_layer_adjoint = odl.contrib.tensorflow.as_tensorflow_layer(operator.adjoint,
                                                                  'RayWaveTransformAdjoint')

In [18]:
# User selected paramters
n_data = 5
n_iter = 10
n_primal = 5
n_dual = 5

In [19]:
def generate_data(validation=False):
    """Generate a set of random data."""
    n_generate = 1 if validation else n_data

    y_arr = np.empty((n_generate, operator.range.shape[0], operator.range.shape[1], 1), dtype='float32')
    x_true_arr = np.empty((n_generate,space.shape[0],1), dtype='float32')

    for i in range(n_generate):
        if validation:
            phantom = wavean(odl.phantom.shepp_logan(space_image, True))
        else:
            phantom = wavean(random_phantom(space_image))

        data = operator(phantom)
        noisy_data = data + odl.phantom.white_noise(operator.range) * np.mean(np.abs(data)) * 0.05
        
        x_true_arr[i,...,0] = phantom
        y_arr[i, ..., 0] = noisy_data
        
    return y_arr, x_true_arr

In [21]:
with tf.name_scope('placeholders'):
    x_true = tf.placeholder(tf.float32, shape=[None, wavean.range.shape[0],1], name="x_true")
    y_rt = tf.placeholder(tf.float32, shape=[None, operator.range.shape[0], operator.range.shape[1], 1], name="y_rt")
    is_training = tf.placeholder(tf.bool, shape=(), name='is_training')

In [22]:
def apply_conv_primal(x, filters=32):
    return tf.layers.conv1d(x, filters=filters, kernel_size=3, padding='SAME',
                            kernel_initializer=tf.contrib.layers.xavier_initializer())

In [23]:
def apply_conv_dual(x, filters=32):
    return tf.layers.conv2d(x, filters=filters, kernel_size=3, padding='SAME',
                            kernel_initializer=tf.contrib.layers.xavier_initializer())

In [24]:
with tf.name_scope('tomography'):
    with tf.name_scope('initial_values'):
        primal = tf.concat([tf.zeros_like(x_true)] * n_primal, axis=-1)
        dual = tf.concat([tf.zeros_like(y_rt)] * n_dual, axis=-1)

    for i in range(n_iter):
        with tf.variable_scope('dual_iterate_{}'.format(i)):
            evalop = odl_op_layer(primal[..., 1:2])
            update = tf.concat([dual, evalop, y_rt], axis=-1)

            update = prelu(apply_conv_dual(update), name='prelu_1')
            update = prelu(apply_conv_dual(update), name='prelu_2')
            update = apply_conv_dual(update, filters=n_dual)
            dual = dual + update

        with tf.variable_scope('primal_iterate_{}'.format(i)):
            evalop = odl_op_layer_adjoint(dual[..., 0:1])
            update = tf.concat([primal, evalop], axis=-1)

            update = prelu(apply_conv_primal(update), name='prelu_1')
            update = prelu(apply_conv_primal(update), name='prelu_2')
            update = apply_conv_primal(update, filters=n_primal)
            primal = primal + update

    x_result = primal[..., 0:1]

In [26]:
name

'/store/kepler/datastore/andrade/GitHub_repos/tfshearlab/Paper_experiments/Python_wavelets'

In [25]:
with tf.name_scope('loss'):
    residual = x_result - x_true
    squared_error = residual ** 2
    loss = tf.reduce_mean(squared_error)

In [36]:
with tf.name_scope('optimizer'):
    # Learning rate
    global_step = tf.Variable(0, trainable=False)
    maximum_steps = 100001
    starter_learning_rate = 1e-3
    learning_rate = cosine_decay(starter_learning_rate,
                                 global_step,
                                 maximum_steps,
                                 name='learning_rate')

    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        opt_func = tf.train.AdamOptimizer(learning_rate=learning_rate,
                                          beta2=0.99)

        tvars = tf.trainable_variables()
        grads, _ = tf.clip_by_global_norm(tf.gradients(loss, tvars), 1)
        optimizer = opt_func.apply_gradients(zip(grads, tvars),
                                             global_step=global_step)

In [37]:
ckp_name = name+'_lpd_wavelets/checkpoints'

In [38]:
adler.tensorflow.util.default_tensorboard_dir(ckp_name)

'/store/kepler/datastore/andrade/GitHub_repos/tfshearlab/Paper_experiments/Python/checkpoints_lpd_wavelets/checkpoints'

In [39]:
# Summaries
# tensorboard --logdir=...

with tf.name_scope('summaries'):
    tf.summary.scalar('loss', loss)
    tf.summary.scalar('psnr', -10 * tf.log(loss) / tf.log(10.0))

    #tf.summary.image('x_result', x_result)
    #tf.summary.image('x_true', x_true)
    #tf.summary.image('squared_error', squared_error)
    #tf.summary.image('residual', residual)

    merged_summary = tf.summary.merge_all()
    test_summary_writer = tf.summary.FileWriter(adler.tensorflow.util.default_tensorboard_dir(ckp_name) + '/test',
                                                sess.graph)
    train_summary_writer = tf.summary.FileWriter(adler.tensorflow.util.default_tensorboard_dir(ckp_name) + '/train')

In [40]:
# Initialize all TF variables
sess.run(tf.global_variables_initializer())

In [41]:
# Add op to save and restore
saver = tf.train.Saver()

In [42]:
# Generate validation data
y_arr_validate, x_true_arr_validate = generate_data(validation=False)

In [43]:
if 0:
    saver.restore(sess,
                  adler.tensorflow.util.default_checkpoint_path(ckp_name))

Train the network

In [None]:
# Train the network
for i in range(0, maximum_steps):
    if i%10 == 0:
        y_arr, x_true_arr = generate_data()

    _, merged_summary_result_train, global_step_result = sess.run([optimizer, merged_summary, global_step],
                              feed_dict={x_true: x_true_arr,
                                         y_rt: y_arr,
                                         is_training: True})

    if i>0 and i%10 == 0:
        loss_result, merged_summary_result, global_step_result = sess.run([loss, merged_summary, global_step],
                              feed_dict={x_true: x_true_arr_validate,
                                         y_rt: y_arr_validate,
                                         is_training: False})

        train_summary_writer.add_summary(merged_summary_result_train, global_step_result)
        test_summary_writer.add_summary(merged_summary_result, global_step_result)

        print('iter={}, loss={}'.format(global_step_result, loss_result))

    if i>0 and i%1000 == 0:
        saver.save(sess,
                   adler.tensorflow.util.default_checkpoint_path(ckp_name))


iter=11, loss=0.009012219496071339
iter=21, loss=0.007784486748278141
iter=31, loss=0.0048680491745471954
iter=41, loss=0.005165236536413431
iter=51, loss=0.003361302660778165
iter=61, loss=0.0058276657946407795
iter=71, loss=0.004375253338366747
iter=81, loss=0.005451131146401167
iter=91, loss=0.002762345364317298
iter=101, loss=0.0021725408732891083
iter=111, loss=0.002000361680984497
iter=121, loss=0.0019093636656180024
iter=131, loss=0.0033185307402163744
iter=141, loss=0.00477041769772768
iter=151, loss=0.002937125042080879
iter=161, loss=0.0023987104650586843
iter=171, loss=0.002112820278853178
iter=181, loss=0.0020153429359197617
iter=191, loss=0.0021713459864258766
iter=201, loss=0.0017834482714533806
iter=211, loss=0.0016926759853959084
iter=221, loss=0.0016266816528514028
iter=231, loss=0.0016025702934712172
iter=241, loss=0.0015950504457578063
iter=251, loss=0.001563190366141498
iter=261, loss=0.0015857559628784657
iter=271, loss=0.001562768709845841
iter=281, loss=0.0015591

iter=2221, loss=0.001180921564809978
iter=2231, loss=0.0011620448203757405
iter=2241, loss=0.0011366549879312515
iter=2251, loss=0.0011058036470785737
iter=2261, loss=0.0011348769767209888
iter=2271, loss=0.001129043404944241
iter=2281, loss=0.001151464763097465
iter=2291, loss=0.0011433000909164548
iter=2301, loss=0.0011312979040667415
iter=2311, loss=0.0011390235740691423
iter=2321, loss=0.0011149703059345484
iter=2331, loss=0.0011134323431178927
iter=2341, loss=0.001100356923416257
iter=2351, loss=0.0011077163508161902
iter=2361, loss=0.0010979571379721165
iter=2371, loss=0.00110327557194978
iter=2381, loss=0.0011244240449741483
iter=2391, loss=0.001097469124943018
iter=2401, loss=0.0011131883366033435
iter=2411, loss=0.0011080725817009807
iter=2421, loss=0.0010907790856435895
iter=2431, loss=0.0010823693592101336
iter=2441, loss=0.0010708858026191592
iter=2451, loss=0.00108527357224375
iter=2461, loss=0.001209918293170631
iter=2471, loss=0.0011520334519445896
iter=2481, loss=0.0010

iter=4391, loss=0.0009383125579915941
iter=4401, loss=0.0009336162474937737
iter=4411, loss=0.0009275730117224157
iter=4421, loss=0.0009302960825152695
iter=4431, loss=0.0009326639701612294
iter=4441, loss=0.0009484213660471141
iter=4451, loss=0.0009853760711848736
iter=4461, loss=0.0009473816026002169
iter=4471, loss=0.0009392376523464918
iter=4481, loss=0.0009321593679487705
iter=4491, loss=0.0009214460151270032
iter=4501, loss=0.0009308572043664753
iter=4511, loss=0.0009252945310436189
iter=4521, loss=0.0009614664595574141
iter=4531, loss=0.0009421825525350869
iter=4541, loss=0.0009394190274178982
iter=4551, loss=0.000934218056499958
iter=4561, loss=0.0009170068078674376
iter=4571, loss=0.000924690393730998
iter=4581, loss=0.0009660552605055273
iter=4591, loss=0.0009662149241194129
iter=4601, loss=0.0009518215665593743
iter=4611, loss=0.0009408151963725686
iter=4621, loss=0.0009327973239123821
iter=4631, loss=0.0009418928530067205
iter=4641, loss=0.0009253250900655985
iter=4651, los

iter=6561, loss=0.000867366383317858
iter=6571, loss=0.000914059579372406
iter=6581, loss=0.0009084202465601265
iter=6591, loss=0.0009297551587224007
iter=6601, loss=0.0009250775910913944
iter=6611, loss=0.0008847826393321157
iter=6621, loss=0.0008972108480520546
iter=6631, loss=0.0008931909687817097
iter=6641, loss=0.0008957553072832525
iter=6651, loss=0.0008738723117858171
iter=6661, loss=0.0008743811631575227
iter=6671, loss=0.0008557260734960437
iter=6681, loss=0.0008647188660688698
iter=6691, loss=0.0008832389721646905
iter=6701, loss=0.0008873352780938148
iter=6711, loss=0.0008630522643215954
iter=6721, loss=0.0008655813289806247
iter=6731, loss=0.0008672562544234097
iter=6741, loss=0.0008524828008376062
iter=6751, loss=0.0008821265655569732
iter=6761, loss=0.0009728210861794651
iter=6771, loss=0.0009123618365265429
iter=6781, loss=0.0008907938026823103
iter=6791, loss=0.0008977290708571672
iter=6801, loss=0.0009036862757056952
iter=6811, loss=0.0008816597401164472
iter=6821, los

iter=8731, loss=0.0008486213046126068
iter=8741, loss=0.0008791601867415011
iter=8751, loss=0.00083679094677791
iter=8761, loss=0.0008485678117722273
iter=8771, loss=0.0008430502493865788
iter=8781, loss=0.0008535178494639695
iter=8791, loss=0.0008711675181984901
iter=8801, loss=0.0008635572157800198
iter=8811, loss=0.0008510082261636853
iter=8821, loss=0.000877299637068063
iter=8831, loss=0.0008700986509211361
iter=8841, loss=0.0008422602550126612
iter=8851, loss=0.0008410954615101218
iter=8861, loss=0.0008289800025522709
iter=8871, loss=0.0008201810996979475
iter=8881, loss=0.0008214744157157838
iter=8891, loss=0.0008443697006441653
iter=8901, loss=0.0008612367091700435
iter=8911, loss=0.0008402333478443325
iter=8921, loss=0.000835552578791976
iter=8931, loss=0.0008327826508320868
iter=8941, loss=0.0008705832879059017
iter=8951, loss=0.0008405352709814906
iter=8961, loss=0.0008508640457876027
iter=8971, loss=0.0008658508886583149
iter=8981, loss=0.0008465969003736973
iter=8991, loss=

iter=10881, loss=0.0008131852373480797
iter=10891, loss=0.000806022493634373
iter=10901, loss=0.0008122195722535253
iter=10911, loss=0.0008113718940876424
iter=10921, loss=0.0008155520190484822
iter=10931, loss=0.0008298109169118106
iter=10941, loss=0.0008319426560774446
iter=10951, loss=0.0008448105654679239
iter=10961, loss=0.0008273685816675425
iter=10971, loss=0.0008372055599465966
iter=10981, loss=0.0008169793291017413
iter=10991, loss=0.000836427672766149
iter=11001, loss=0.000826552277430892
iter=11011, loss=0.000811595527920872
iter=11021, loss=0.0008240563329309225
iter=11031, loss=0.0008148625493049622
iter=11041, loss=0.0008128414046950638
iter=11051, loss=0.0008331610588356853
iter=11061, loss=0.0008787341066636145
iter=11071, loss=0.0008629215881228447
iter=11081, loss=0.0008413646719418466
iter=11091, loss=0.0009020353900268674
iter=11101, loss=0.0009976248256862164
iter=11111, loss=0.0008891608449630439
iter=11121, loss=0.0008593311067670584
iter=11131, loss=0.0008431523

iter=12991, loss=0.0008065822185017169
iter=13001, loss=0.0007898975745774806
iter=13011, loss=0.0007959124050103128
iter=13021, loss=0.0007956057088449597
iter=13031, loss=0.0008222225005738437
iter=13041, loss=0.0007925115642137825
iter=13051, loss=0.0007952918531373143
iter=13061, loss=0.0007850166293792427
iter=13071, loss=0.0007822208572179079
iter=13081, loss=0.0007901015924289823
iter=13091, loss=0.0008157793781720102
iter=13101, loss=0.0008315154118463397
iter=13111, loss=0.0008252106490544975
iter=13121, loss=0.0008349831332452595
iter=13131, loss=0.0008052530465647578
iter=13141, loss=0.0007891962304711342
iter=13151, loss=0.0007732189842499793
iter=13161, loss=0.0007764468900859356
iter=13171, loss=0.0007762283203192055
iter=13181, loss=0.000786800286732614
iter=13191, loss=0.0007989909499883652
iter=13201, loss=0.0007843615021556616
iter=13211, loss=0.0007864997605793178
iter=13221, loss=0.0007900320342741907
iter=13231, loss=0.0008342006476595998
iter=13241, loss=0.0008325

iter=15101, loss=0.0008055586367845535
iter=15111, loss=0.0008003385737538338
iter=15121, loss=0.0007863836362957954
iter=15131, loss=0.0007824157364666462
iter=15141, loss=0.0008218264556489885
iter=15151, loss=0.0008291727863252163
iter=15161, loss=0.00079112418461591
iter=15171, loss=0.0007844812935218215
iter=15181, loss=0.0007927019032649696
iter=15191, loss=0.0007674607331864536
iter=15201, loss=0.0007633325294591486
iter=15211, loss=0.0007749506621621549
iter=15221, loss=0.0007888941327109933
iter=15231, loss=0.0007845595246180892
iter=15241, loss=0.0007847375236451626
iter=15251, loss=0.0008203416946344078
iter=15261, loss=0.0008096028468571603
iter=15271, loss=0.0007893478032201529
iter=15281, loss=0.0007884580991230905
iter=15291, loss=0.0007742120651528239
iter=15301, loss=0.0007787045906297863
iter=15311, loss=0.0007726590847596526
iter=15321, loss=0.0008008487056940794
iter=15331, loss=0.0007947767153382301
iter=15341, loss=0.0008185760816559196
iter=15351, loss=0.00081959

iter=17211, loss=0.0008650504169054329
iter=17221, loss=0.0008735796436667442
iter=17231, loss=0.0008761590579524636
iter=17241, loss=0.0008602819871157408
iter=17251, loss=0.0008451910689473152
iter=17261, loss=0.0008513160282745957
iter=17271, loss=0.0008269919781014323
iter=17281, loss=0.0008128384361043572
iter=17291, loss=0.0008184361504390836
iter=17301, loss=0.0008133540977723897
iter=17311, loss=0.0008059279643930495
iter=17321, loss=0.0007961473311297596
iter=17331, loss=0.0008027300937101245
iter=17341, loss=0.0007924579549580812
iter=17351, loss=0.0007895621238276362
iter=17361, loss=0.0008018008666113019
iter=17371, loss=0.0008143652812577784
iter=17381, loss=0.0008200393640436232
iter=17391, loss=0.0007910377462394536
iter=17401, loss=0.0007871302077546716
iter=17411, loss=0.0008240158203989267
iter=17421, loss=0.0008270171238109469
iter=17431, loss=0.0008159621502272785
iter=17441, loss=0.0007938950438983738
iter=17451, loss=0.0007986229611560702
iter=17461, loss=0.000799