In [1]:
from ExoRIM.model import RIM, CostFunction
from ExoRIM.simulated_data import CenteredImagesv1
from preprocessing.simulate_data import create_and_save_data
import json
import numpy as np
import os
import pickle
import time
from datetime import datetime
import tensorflow as tf

In [2]:
root = os.path.split(os.getcwd())[0]
root

'/home/aadam/Desktop/Projects/ExoRIM'

In [21]:
#id = datetime.now().strftime("%y-%m-%d_%H-%M-%S")
id = "20-06-21_10-24-04"
id

'20-06-21_10-24-04'

In [16]:
data_dir = os.path.join(root, "data", id)
if not os.path.isdir(data_dir): os.mkdir(data_dir)
test_dir = os.path.join(root, "data", id + "_test")
if not os.path.isdir(test_dir): os.mkdir(test_dir)
projector_dir = os.path.join(root, "data", "projector_arrays")
checkpoint_dir = os.path.join(root, "models", id)
if not os.path.isdir(checkpoint_dir): os.mkdir(checkpoint_dir)
output_dir = os.path.join(root, "results", id)
if not os.path.isdir(output_dir): os.mkdir(output_dir)

In [17]:
with open(os.path.join(root, "hyperparameters.json"), "r") as f:
    hparams = json.load(f)
hparams

{'steps': 12,
 'pixels': 32,
 'channels': 1,
 'state_size': 16,
 'state_depth': 32,
 'Regularizer Amplitude': {'kernel': 0.01, 'bias': 0.01},
 'Physical Model': {'Visibility Noise': 0.0001, 'Closure Phase Noise': 1e-05},
 'Downsampling Block': [{'Conv_Downsample': {'kernel_size': [5, 5],
    'filters': 1,
    'strides': [2, 2]}}],
 'Convolution Block': [{'Conv_1': {'kernel_size': [3, 3],
    'filters': 8,
    'strides': [1, 1]}},
  {'Conv_2': {'kernel_size': [3, 3], 'filters': 16, 'strides': [1, 1]}}],
 'Recurrent Block': {'GRU_1': {'kernel_size': [3, 3], 'filters': 16},
  'Hidden_Conv_1': {'kernel_size': [3, 3], 'filters': 16},
  'GRU_2': {'kernel_size': [3, 3], 'filters': 16}},
 'Upsampling Block': [{'Conv_Fraction_Stride': {'kernel_size': [3, 3],
    'filters': 16,
    'strides': [2, 2]}}],
 'Transposed Convolution Block': [{'TConv_1': {'kernel_size': [3, 3],
    'filters': 8,
    'strides': [1, 1]}},
  {'TConv_2': {'kernel_size': [3, 3], 'filters': 1, 'strides': [1, 1]}}]}

In [18]:
AUTOTUNE = tf.data.experimental.AUTOTUNE
def create_dataset(meta_data, rim, dirname, batch_size=None):
    images = tf.convert_to_tensor(create_and_save_data(dirname, meta_data), dtype=tf.float32)
    k_images = rim.physical_model.simulate_noisy_image(images)
    X = tf.data.Dataset.from_tensor_slices(k_images)  # split along batch dimension
    Y = tf.data.Dataset.from_tensor_slices(images)
    dataset = tf.data.Dataset.zip((X, Y))
    if batch_size is not None: # for train set
        dataset = dataset.batch(batch_size, drop_remainder=True)
        dataset = dataset.enumerate(start=0)
        dataset = dataset.cache()  # accelerate the second and subsequent iterations over the dataset
        dataset = dataset.prefetch(AUTOTUNE)  # Batch is prefetched by CPU while training on the previous batch occurs
    else:
        # batch together all examples, for test set
        dataset = dataset.batch(images.shape[0], drop_remainder=True)
        dataset = dataset.cache()
    return dataset

In [19]:
holes = 20
# metrics only support grey scale images
metrics = {
    "ssim": lambda Y_pred, Y_true: tf.image.ssim(Y_pred, Y_true, max_val=1.0),
    # Bug is tf 2.0.0, make sure filter size is small enough such that H/2**4 and W/2**4 >= filter size
    # alternatively (since H/2**4 is = 1 in our case), it is possible to lower the power factors such that
    # H/(2**(len(power factor)-1)) > filter size
    # Hence, using 3 power factors with filter size=2 works, and so does 2 power factors with filter_size <= 8
    # paper power factors are [0.0448, 0.2856, 0.3001, 0.2363, 0.1333]
    # After test, it seems filter_size=11 also works with 2 power factors and 32 pixel image
    "ssim_multiscale_01": lambda Y_pred, Y_true: tf.image.ssim_multiscale(
        Y_pred, Y_true, max_val=1.0,
        filter_size=11,
        power_factors=[0.0448, 0.2856]),
    "ssim_multiscale_23": lambda Y_pred, Y_true: tf.image.ssim_multiscale(
        Y_pred, Y_true, max_val=1.0,
        filter_size=11,
        power_factors=[0.3001, 0.2363]),
    "ssim_multiscale_34": lambda Y_pred, Y_true: tf.image.ssim_multiscale(
        Y_pred, Y_true, max_val=1.0,
        filter_size=11,
        power_factors=[0.2363, 0.1333])
}
meta_data = CenteredImagesv1(total_items=1000, pixels=32)
test_meta = CenteredImagesv1(total_items=200, pixels=32)
cost_function = CostFunction()
mask_coordinates = np.loadtxt(os.path.join(projector_dir, f"mask_{holes}_holes.txt"))
with open(os.path.join(projector_dir, f"projectors_{holes}_holes.pickle"), "rb") as fb:
    arrays = pickle.load(fb)

In [20]:
weight_file = os.path.join(checkpoint_dir, "rim_005_115.67728.h5")
rim = RIM(mask_coordinates=mask_coordinates, hyperparameters=hparams, arrays=arrays, weight_file=weight_file)

In [63]:
start = time.time()
history = rim.fit(
    train_dataset=train_dataset,
    test_dataset=test_dataset,
    max_time=0.2,
    cost_function=cost_function,
    min_delta=0,
    patience=10,
    checkpoints=5,
    output_dir=output_dir,
    checkpoint_dir=checkpoint_dir,
    max_epochs=20,
    output_save_mod={"index_mod": 300,
                     "epoch_mod": 1,
                     "step_mod": 1}, # save first and last step imagees
    metrics=metrics,
    name="rim"
)
end = time.time() - start
with open(os.path.join(checkpoint_dir, "hyperparameters.json"), "w") as f:
    json.dump(rim.hyperparameters, f)
with open(os.path.join(output_dir, "history.json"), "w") as f:
    json.dump(history, f)

0