In [None]:
import sys
sys.path.insert(0, './src')
import os
os.environ['CUDA_VISIBLE_DEVICES'] = "1"

import numpy as np
np.random.seed(42)
import tensorflow as tf
tf.set_random_seed(42)
import odl
from odl.contrib.tensorflow import as_tensorflow_layer
from odl.contrib import fom
import matplotlib.pyplot as plt

import datetime
import time

from generate_data import generate_data
from operators import operators_smooth, operators_nonsmooth
from optimization import optimize

# Start a tensorflow session
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.99)
session = tf.InteractiveSession(config=tf.ConfigProto(gpu_options=gpu_options))

In [None]:
problem = "smooth" # "smooth" or "nonsmooth"
# smooth algorithms: "learned_smooth", "nesterov", "steep_desc"
# nonsmooth algorithms: "learned_nonsmooth", "ista", "fista"
algorithm = "learned_smooth" 
param_filename = "" 
batch_size = 1
n_iter = 10
val_ratio = 0.01

In [None]:
# Create operators
if problem == "smooth":
    T_odl, W_odl = operators_smooth()
elif problem == "nonsmooth":
    T_odl, W_odl = operators_nonsmooth()

In [None]:
x_test, y_test = generate_data(T_odl, 'test', batch_size).__next__()
plt.imshow(x_test[0].squeeze(), cmap='gray')
plt.show()
plt.imshow(y_test[0].squeeze(), cmap='gray')

In [None]:
x_shape = (batch_size, T_odl.domain.shape[0], T_odl.domain.shape[1], 1)
y_shape = (batch_size, T_odl.range.shape[0], T_odl.range.shape[1], 1)
x = tf.placeholder(tf.float32, x_shape)
y = tf.placeholder(tf.float32, y_shape)

### Testing

In [None]:
# Load the model
reconstruction, loss, loss_parts = optimize(algorithm, x, y, T_odl, W_odl, n_iter)
tf.global_variables_initializer().run(session = session)
if param_filename != "":
    optimizer = tf.train.AdamOptimizer(learning_rate=1e-3)
    train = optimizer.minimize(loss)
    saver = tf.train.Saver()
    saver.restore(session, param_filename)

In [None]:
start = time.time()
reconstructed, final_loss, it_loss = \
        session.run([reconstruction, loss, loss_parts], feed_dict={x: np.zeros(x_shape), y: y_test}) 

elapsed_time = time.time() - start

plt.figure(figsize=(10,10))
plt.imshow(x_test[0].squeeze(), clim=[0.8, 1.2], cmap='bone')
plt.figure(figsize=(10,10))
plt.imshow(reconstructed[0].squeeze(), clim=[0.8, 1.2], cmap='bone')
print("Final loss ", final_loss)
print("Elapsed time ", elapsed_time, ", average time ", elapsed_time / n_iter)
print("PSNR ", fom.psnr(reconstructed, x_test))

In [None]:
# f + g
plt.loglog(it_loss[0] + it_loss[1])

In [None]:
if len(it_loss) > 2: 
    plt.loglog(it_loss[2])

In [None]:
# run tests
generator_test = generate_data(T_odl, 'test', batch_size)
res_filename = './results/mayo/'+str(algorithm)+"_"+ str(n_iter)+"_opt_iter.dat"
if os.path.exists(res_filename):
    os.remove(res_filename)

i = 0
for batch in generator_test.__iter__():
    x_test, y_test = batch
    reconstructed, final_loss, it_loss = \
        session.run([reconstruction, loss, loss_parts], feed_dict={x: np.zeros(x_shape), y: y_test}) 
    print("Iteration ",i,"------------------------")
    print("Final loss ", final_loss)
    print("Loss parts", it_loss[0][0], it_loss[0][n_iter-1], it_loss[1][0], it_loss[1][n_iter-1])

    # save loss info to file
    if res_filename != "":
        f = open(res_filename,'ab')
        np.savetxt(f,(it_loss[0], it_loss[1]))
        f.close()
    i += 1

### Training

In [None]:
# train
res_filename = './models/mayo/'+str(algorithm)+"_"
# sample number of iterations
n_iter_rand = tf.constant(n_iter, dtype=tf.int32) + tf.random_uniform([],minval=0, maxval=n_iter,dtype=tf.int32)
reconstruction, loss, loss_parts = optimize(algorithm, x, y, T_odl, W_odl, n_iter_rand)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
    optimizer = tf.train.AdamOptimizer(learning_rate=1e-3)
    train = optimizer.minimize(loss)
tf.global_variables_initializer().run(session = session)
saver = tf.train.Saver()
if param_filename != "":
    saver.restore(session, param_filename)
    
# Summaries
time_str = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
summary_path = './log/' + time_str
if not os.path.exists(summary_path):
    os.makedirs(summary_path)
tf.summary.scalar('loss', loss)
summary = tf.summary.merge_all()
test_log = tf.summary.FileWriter(summary_path + '/test', session.graph)
train_log = tf.summary.FileWriter(summary_path + '/train')

generator_train = generate_data(T_odl, 'train', batch_size, val_ratio=0.01)
for i in range(100001):
    x_train, y_train = generator_train.__next__()
    _, summ = session.run([train, summary], feed_dict={x: np.zeros(x_shape), y: y_train})
    train_log.add_summary(summ, i)
    if i % 1000 == 0:
        print('iter={}/100000'.format(i))
        val_loss = 0
        j = 0
        generator_val = generate_data(T_odl, 'validate', batch_size, val_ratio=0.01)
        for val_batch in generator_val.__iter__():
            x_validate, y_validate = val_batch
            reconstructed, loss_batch, summ = \
                session.run([reconstruction, loss, summary], feed_dict={x: np.zeros(x_shape), y: y_validate})
            val_loss += loss_batch
            test_log.add_summary(summ, i)
            j = j + 1
        val_loss /= float(j) 
        print("Validation loss ", val_loss)
    if i % 10000 == 0:
        if res_filename != "":
            save_path = saver.save(session, res_filename + str(i)+"_train_iter.ckpt")