In [1]:
from ExoRIM.model import RIM, CostFunction
from ExoRIM.simulated_data import CenteredImagesv1
from ExoRIM.utilities import load_dataset
from preprocessing.simulate_data import create_and_save_data
import json
import numpy as np
import os, glob
import collections
import pickle
import time
from datetime import datetime
import tensorflow as tf
from PIL import Image
import matplotlib.pyplot as plt

ImportError: cannot import name 'RIM' from 'ExoRIM.model' (/home/alexandre/Desktop/Projects/ExoRIM/ExoRIM/model.py)

In [None]:
root = os.path.split(os.getcwd())[0]
root

In [None]:
#id = datetime.now().strftime("%y-%m-%d_%H-%M-%S")
#id = "20-06-21_10-24-04"
#id = '20-06-22_13-38-14'
id = '20-06-22_18-19-03'
id

In [None]:
data_dir = os.path.join(root, "data", id)
if not os.path.isdir(data_dir): os.mkdir(data_dir)
test_dir = os.path.join(root, "data", id + "_test")
if not os.path.isdir(test_dir): os.mkdir(test_dir)
projector_dir = os.path.join(root, "data", "projector_arrays")
checkpoint_dir = os.path.join(root, "models", id)
if not os.path.isdir(checkpoint_dir): os.mkdir(checkpoint_dir)
output_dir = os.path.join(root, "results", id)
if not os.path.isdir(output_dir): os.mkdir(output_dir)

In [None]:
#with open(os.path.join(root, "hyperparameters.json"), "r") as f:
#    hparams = json.load(f)
with open(os.path.join(checkpoint_dir, "hyperparameters.json"), "r") as f:
     hparams = json.load(f)
hparams

In [None]:
AUTOTUNE = tf.data.experimental.AUTOTUNE
def create_dataset(meta_data, rim, dirname, batch_size=None):
    images = tf.convert_to_tensor(create_and_save_data(dirname, meta_data), dtype=tf.float32)
    k_images = rim.physical_model.simulate_noisy_image(images)
    X = tf.data.Dataset.from_tensor_slices(k_images)  # split along batch dimension
    Y = tf.data.Dataset.from_tensor_slices(images)
    dataset = tf.data.Dataset.zip((X, Y))
    if batch_size is not None: # for train set
        dataset = dataset.batch(batch_size, drop_remainder=True)
        dataset = dataset.enumerate(start=0)
        dataset = dataset.cache()  # accelerate the second and subsequent iterations over the dataset
        dataset = dataset.prefetch(AUTOTUNE)  # Batch is prefetched by CPU while training on the previous batch occurs
    else:
        # batch together all examples, for test set
        dataset = dataset.batch(images.shape[0], drop_remainder=True)
        dataset = dataset.cache()
    return dataset

In [None]:
holes = 20
# metrics only support grey scale images
metrics = {
    "ssim": lambda Y_pred, Y_true: tf.image.ssim(Y_pred, Y_true, max_val=1.0),
    # Bug is tf 2.0.0, make sure filter size is small enough such that H/2**4 and W/2**4 >= filter size
    # alternatively (since H/2**4 is = 1 in our case), it is possible to lower the power factors such that
    # H/(2**(len(power factor)-1)) > filter size
    # Hence, using 3 power factors with filter size=2 works, and so does 2 power factors with filter_size <= 8
    # paper power factors are [0.0448, 0.2856, 0.3001, 0.2363, 0.1333]
    # After test, it seems filter_size=11 also works with 2 power factors and 32 pixel image
    "ssim_multiscale_01": lambda Y_pred, Y_true: tf.image.ssim_multiscale(
        Y_pred, Y_true, max_val=1.0,
        filter_size=11,
        power_factors=[0.0448, 0.2856]),
    "ssim_multiscale_23": lambda Y_pred, Y_true: tf.image.ssim_multiscale(
        Y_pred, Y_true, max_val=1.0,
        filter_size=11,
        power_factors=[0.3001, 0.2363]),
    "ssim_multiscale_34": lambda Y_pred, Y_true: tf.image.ssim_multiscale(
        Y_pred, Y_true, max_val=1.0,
        filter_size=11,
        power_factors=[0.2363, 0.1333])
}
cost_function = CostFunction()
mask_coordinates = np.loadtxt(os.path.join(projector_dir, f"mask_{holes}_holes.txt"))
with open(os.path.join(projector_dir, f"projectors_{holes}_holes.pickle"), "rb") as fb:
    arrays = pickle.load(fb)

In [None]:
#weight_file = os.path.join(checkpoint_dir, "rim_005_115.67728.h5")
#weight_file=os.path.join(checkpoint_dir, "rim_035_11.35379.h5")
weight_file=os.path.join(checkpoint_dir, "rim_098_15.75180.h5")
#weight_file=None
rim = RIM(mask_coordinates=mask_coordinates, hyperparameters=hparams, arrays=arrays, weight_file=weight_file)

In [None]:
meta_data = CenteredImagesv1(total_items=1000, pixels=32)
test_meta = CenteredImagesv1(total_items=200, pixels=32)
train_dataset = create_dataset(meta_data, rim, data_dir, batch_size=50)
test_dataset = create_dataset(test_meta, rim, test_dir)
# train_dataset = load_dataset(data_dir, rim, batch_size=50)
# test_dataset = load_dataset(test_dir, rim)

In [None]:
if os.path.isfile(os.path.join(output_dir, "history.pickle")):
    with open(os.path.join(output_dir, "history.pickle"), "rb") as f:
        history = pickle.load(f)
else:
    history = {key + "_train": [] for key in metrics.keys()}
    history.update({key + "_test": [] for key in metrics.keys()})
    history.update({"train_loss": [], "test_loss": []})
start = time.time()
_history = rim.fit(
    train_dataset=train_dataset,
    test_dataset=test_dataset,
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-5),
    max_time=0.8,
    cost_function=cost_function,
    min_delta=0,
    patience=10,
    checkpoints=5,
    output_dir=output_dir,
    checkpoint_dir=checkpoint_dir,
    max_epochs=70,
    output_save_mod={"index_mod": 300,
                     "epoch_mod": 1,
                     "step_mod": 11}, # save first and last step imagees
    metrics=metrics,
    name="rim"
)
end = time.time() - start
print(f"Training took {end/60:.02f} minute")

In [None]:
with open(os.path.join(checkpoint_dir, "hyperparameters.json"), "w") as f:
    json.dump(rim.hyperparameters, f)

In [None]:
for key, item in _history.items():
    history[key].extend(item)
with open(os.path.join(output_dir, "history.pickle"), "wb") as f:
    pickle.dump(history, f)

In [None]:
history["train_loss"]

In [None]:
# update of nested dictionaries
def update(d, u):
    for k, v in u.items():
        if isinstance(v, collections.abc.Mapping):
            d[k] = update(d.get(k, {}), v)
        else:
            d[k] = v
    return d

images = {}
for file in glob.glob(os.path.join(output_dir, "output*")):
    name = os.path.split(file)[-1]
    epoch = int(name[7:10])
    index = int(name[11:15])
    step = int(name[16:18])
    with Image.open(file) as image:
        im = np.array(image.getdata()).reshape([image.size[0], image.size[1]])
        update(images, {index: {epoch : {step: im}}})

In [None]:
index=900
images.keys()

In [None]:
epoch = max(images[0].keys())
print(epoch)
images[0].keys()

In [None]:
step=max(images[index][epoch])
images[index][epoch].keys()

In [None]:
plt.imshow(images[index][epoch][step], cmap="gray")

In [None]:
gt_file = os.path.join(data_dir, "image" + str(index) + ".png")
with Image.open(gt_file) as image:
    im = np.array(image.getdata()).reshape([32,32])
plt.imshow(im)