In [None]:
# To use TF 2.0 (on EC2 instance running Deep Learning AMI):
# source activate tensorflow_p36
# pip uninstall tensorflow-gpu
# pip install tensorflow-gpu==2.0.0-alpha0

# But then later decided wanted even newer GPU stuff and "in that conda env, ran:
# pip uninstall tensorflow-gpu
# pip install --upgrade pip
# pip install wrapt --ignore-installed # ran this because had an error
# pip install  tf-nightly-gpu-2.0-preview

# Result: Successfully installed tf-nightly-gpu-2.0-preview-2.0.0.dev20190531

In [None]:
import os
from datetime import datetime
import numpy as np
import tensorflow as tf

In [None]:
from packaging import version

print("TensorFlow version: ", tf.__version__) # make sure >= 2.0.0-dev20190527
assert version.parse(tf.__version__).release[0] >= 2, "This notebook requires TensorFlow 2.0 or above."

In [None]:
import cs230_project_utilities as utils

In [None]:
# GPU usage logging (TF 2.0+)

tf.config.set_soft_device_placement(True)
tf.debugging.set_log_device_placement(False)

# Prepare CIFAR-100 dataset

In [None]:
# # Load cifar 100 (should already be shuffled)
# (x_train, labels_train), (x_test, labels_test) = tf.keras.datasets.cifar100.load_data(label_mode='fine')

# # Convert x_train to float32, grayscale
# x_train = x_train.astype('float32')
# x_train =(0.299 * x_train[..., 0] + 0.587 * x_train[..., 1] +  0.114 * x_train[..., 2]) / 255.0

# # Convert x_test to float32, grayscale 
# x_test = x_test.astype('float32')
# x_test =(0.299 * x_test[..., 0] + 0.587 * x_test[..., 1] +  0.114 * x_test[..., 2]) / 255.0

# # Split x_test to create x_dev
# x_dev, x_test = x_test[:len(x_test) // 2], x_test[len(x_test) // 2:]

# # Show stats of images
# print('Shape of x_train: ' + str(x_train.shape))
# print('Shape of x_test: ' + str(x_test.shape))
# print('Shape of x_dev: ' + str(x_dev.shape))

In [None]:
# def centered_2d_fft(tensor):
#     ''' 
#     Input
#         tensor: 2D tensor of shape [height, width]
#     Returns
#         fft: centered_2d_fft(tensor) / sqrt(product(tensor.shape)) as dtype tf.complex64
    
#     Inverse: ifft = tf.signal.ifft2d(fft) (if image, use abs(ifft) to view)
    
#     Notes:
    
#     The inverse ffts aren't perfect but pretty close (suspect this is due to casting of dtypes).
#     Difference b/w image and ifft is imperceptible (visually).
    
#     y = an image
#     np.allclose(tf.math.real(centered_2d_fft(y0)).numpy(), x0[..., 0], atol=5e-3) -> True
#     np.allclose(tf.math.imag(centered_2d_fft(y0)).numpy(), x0[..., 1], atol=5e-2) -> True
    
#     '''
    
#     tensor = tf.cast(tensor, tf.complex64)
#     fft_unshifted = tf.signal.fft2d(tensor)
#     fft = tf.signal.fftshift(fft_unshifted)
#     return fft

# def cifar_parser(sample):
#     # Returns: (fft, image reconstruction) pairs for automap model
    
#     # Image must be 3-dim
#     sample = tf.expand_dims(sample, -1)
#     resized = tf.image.resize(sample, [256, 256])
#     fft = centered_2d_fft(tf.squeeze(resized))
#     fft = tf.expand_dims(fft, -1) # tf.signal.fft2d expects 2D input, so we undo the squeeze() from before
    
#     # Separate real and imaginary components into separate channels (models operate on floats)
#     real = tf.math.real(fft)
#     imaginary = tf.math.imag(fft)
#     fft = tf.concat([real, imaginary], axis=-1)
    
#     return fft, resized

In [None]:
def load_dataset(data_locations, batch_size, buffer_size, include_all_parsed_features):
    '''
    Returns iterator of automap data located in `data_locations`.
    
    data_locations:  A string, a list of strings, or a `tf.Tensor` of string type
    (scalar or vector), representing the filename glob (i.e. shell wildcard)
    pattern(s) that will be matched.
    '''
    filenames = tf.data.TFRecordDataset.list_files(data_locations)
    dataset = tf.data.TFRecordDataset(filenames)

    # Use `tf.parse_single_example()` to extract data from a `tf.Example`
    # protocol buffer, and perform any additional per-example processing.
    def parser(record):
        keys_to_features = {
            "path": tf.io.FixedLenFeature((), tf.string, ""),
            "sequence_index": tf.io.FixedLenFeature((), tf.int64, -1),
            "fft": tf.io.FixedLenFeature((), tf.string, ''),
            "image": tf.io.FixedLenFeature((), tf.string, ''),
            "dimension": tf.io.FixedLenFeature((), tf.int64, -1),
            "class": tf.io.FixedLenFeature((), tf.int64, -1)
        }
        parsed = tf.io.parse_single_example(record, keys_to_features)
        
        # Perform additional preprocessing on the parsed data.
        parsed['fft'] = tf.io.decode_raw(parsed['fft'], out_type=tf.float32)
        parsed['image'] = tf.io.decode_raw(parsed['image'], out_type=tf.float32)
        
        parsed['fft'] = tf.reshape(parsed['fft'], [parsed['dimension'], parsed['dimension'], 2])
        parsed['image'] = tf.reshape(parsed['image'], [parsed['dimension'], parsed['dimension'], 1])
        
        if include_all_parsed_features:
            return parsed
        
        # We only want input and expected output during training stage (X, Y)
        return parsed['fft'], parsed['image']
    
    # Use `Dataset.map()` to build a pair of a feature dictionary and a label
    # tensor for each example.
    dataset = dataset.map(parser)
    dataset = dataset.shuffle(buffer_size=buffer_size)
    dataset = dataset.batch(batch_size)

    # Each element of `dataset` is tuple containing a dictionary of features
    # (in which each value is a batch of values for that feature), and a batch of
    # labels.
    return dataset

In [None]:
# Use tf.data.Datasets to preprocess and iterate data efficiently

# include class_0?
test_data_locations = ['/home/ubuntu/cs230/data/tfrecords/test/class_0/*.tfrecord'] # only bad for test
dev_data_locations = ['/home/ubuntu/cs230/data/tfrecords/dev/class_0/*.tfrecord'] # only bad for dev
train_data_locations = ['/home/ubuntu/cs230/data/tfrecords/*/class_1/*.tfrecord'] # only good for train (note: first * is weird but prev had separated out 0, 1 for dev and test too but should train on all 1 classes and test on rest)

batch_size = 16

test_dataset = load_dataset(test_data_locations, batch_size=batch_size, buffer_size=1, include_all_parsed_features=False) # no need to shuffle test and dev
dev_dataset = load_dataset(dev_data_locations, batch_size=batch_size, buffer_size=1, include_all_parsed_features=False)
train_dataset = load_dataset(train_data_locations, batch_size=batch_size, buffer_size=512, include_all_parsed_features=False) # good to shuffle train

In [None]:
# First batch of each dataset to be used in plotting images periodically to tensorboard

first_test_batch = next(iter(test_dataset))
first_dev_batch = next(iter(dev_dataset))
first_train_batch = next(iter(train_dataset))

# Model

In [None]:
# A metric to use during training
def mean_PSNR(y_true, y_pred):
    max_value = 1.0
    MSE = tf.reduce_mean(tf.square(y_true - y_pred), axis=[1, 2, 3])
    PSNR = 10 * tf.math.log(tf.divide(max_value ** 2, MSE)) / tf.math.log(tf.constant(10, dtype=y_pred.dtype))
    mean = tf.reduce_mean(PSNR)
    return mean

In [None]:
def load_uncompiled_automap_model():

    N = 256
    X = tf.keras.layers.Input(shape=(N, N, 2))

    # Paper says 1% multiplicative gaussian noise (this multiplies by 1-centered gaussian
    # having stddev = sqrt(rate / (1 - rate)) (here, 0.00032...)
    noisy_X = tf.keras.layers.GaussianDropout(rate=1e-7)(X) # spatial dimension: 256
    # Note: (we could corrupt when training with cifar, but maybe not other dataset?)

    ds1 = tf.keras.layers.Conv2D(256, (3, 3), strides=(1, 1), activation='relu', padding='same')(noisy_X)
    ds1 = tf.keras.layers.MaxPool2D(pool_size=2)(ds1) # (new) downsample to spatial dimension: 128

    ds2 = tf.keras.layers.Conv2D(2, (3, 3), strides=(1, 1), activation='relu', padding='same')(ds1)
    ds2 = tf.keras.layers.MaxPool2D(pool_size=2)(ds2) # (new) downsample to spatial dimension: 64

    ds_flat = tf.keras.layers.Flatten()(ds2)

    # fc1 = tf.keras.layers.Dense(8192)(ds_flat)
    fc1 = tf.keras.layers.Dense(8192, activation='tanh')(ds_flat)
    fc2 = tf.keras.layers.Dense(4096, activation='tanh')(fc1)
    fc3 = tf.keras.layers.Dense(4096, activation='tanh')(fc2)

    fc_output = tf.keras.layers.Reshape([64, 64, 1])(fc3)

    conv1 = tf.keras.layers.Conv2D(256, (5, 5), strides=(1, 1), activation='relu', padding='same')(fc_output)
    conv1 = tf.keras.layers.Conv2DTranspose(256, (5, 5), strides=2, activation='relu', padding='same')(conv1) # (new) upsample to spatial dimension 128

    # L1 regularization to encourage sparsity
    conv2 = tf.keras.layers.Conv2D(256, (5, 5), strides=(1, 1), activation='relu', padding='same',
                                       kernel_regularizer=tf.keras.regularizers.l1(1e-4))(conv1)
    conv2 = tf.keras.layers.Conv2DTranspose(256, (5, 5), strides=2, activation='relu', padding='same',
                                        kernel_regularizer=tf.keras.regularizers.l1(1e-4))(conv2) # (new) upsample to spatial dimension 256

    Y_pred = tf.keras.layers.Conv2DTranspose(1, (7, 7), strides=1, activation='relu', padding='same')(conv2) # upsample to spatial dimension 256

    model = tf.keras.Model(inputs=X, outputs=Y_pred)

    return model

In [None]:
def load_compiled_automap_model():
    multi_gpu = False
    # Distribute training across GPUs (each GPU receives identical updates to weights but different batches w/
    # mirrored strategy). Restricts callbacks we can use
    if multi_gpu:
        mirrored_strategy = tf.distribute.MirroredStrategy()
        with mirrored_strategy.scope():
            model = load_uncompiled_automap_model()
            model.compile(loss='mse', optimizer=tf.keras.optimizers.Adam(), metrics=[mean_PSNR])
    else:
        model = load_uncompiled_automap_model()
        model.compile(loss='mse', optimizer=tf.keras.optimizers.Adam(), metrics=[mean_PSNR])
    return model

In [None]:
# Custom learning rate schedule


def lr_schedule(epoch):
    """
    Returns a custom learning rate that decreases as epochs progress.
    """
    learning_rate = 2e-4
    if epoch > 100:
        learning_rate = 1e-4
    if epoch > 150:
        learning_rate = 1e-5
    elif epoch > 200:
        learning_rate = 5e-5
    elif epoch > 250:
        learning_rate = 1e-5
    elif epoch > 300:
        learning_rate = 1e-6
    elif epoch > 350:
        learning_rate = 5e-5
    elif epoch > 400:
        learning_rate = 1e-6
    elif epoch > 600:
        learning_rate = 1e-7
    elif epoch > 800:
        learning_rate = 5e-6
    elif epoch > 900:
        learning_rate = 1e-7

    with file_writer.as_default():
        tf.summary.scalar('learning rate', data=learning_rate, step=epoch)
        
    return learning_rate

# Show reconstructions during training

def plot_fft_reconstructions(batch, logs):
    plot_frequency = 100
    
    if batch % plot_frequency != 0:
        return
    
    batches = [(first_test_batch, 'Test'), (first_dev_batch, 'Dev'), (first_train_batch, 'Train')]
    
    for dataset_batch, name in batches:
        x, y = dataset_batch
        y = y.numpy()
        y_pred = model.predict(x)

        with file_writer.as_default():
            for i in range(min(len(y), 8)):
                prediction, ground_truth = y_pred[i:i + 1, ...], y[i:i + 1, ...]
                tf.summary.image("{} Image {} (Prediction)".format(name, i), prediction, max_outputs=1, step=batch)
                tf.summary.image("{} Image {} (Ground Truth)".format(name, i), ground_truth, max_outputs=1, step=batch)


In [None]:
model = load_compiled_automap_model()

In [None]:
model.summary()

# Training

In [None]:
# Define where logs will be saved

logdir = "logs/scalars/" + datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
file_writer = tf.summary.create_file_writer(logdir + "/metrics")

In [None]:
# Create callbacks to use in various stages of training

# Callback for printing the LR at the end of each epoch.
class PrintAndLogLR(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        with file_writer.as_default():
            tf.summary.scalar('learning rate (end of epoch)', data=model.optimizer.lr.numpy(), step=epoch)
        print('\nLearning rate for epoch {} is {}'.format(epoch + 1,
                                                          model.optimizer.lr.numpy()))

plot_images_callback = tf.keras.callbacks.LambdaCallback(on_batch_end=plot_fft_reconstructions)
lr_callback = tf.keras.callbacks.LearningRateScheduler(lr_schedule)
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=logdir, histogram_freq=2, update_freq=200,
                                                      profile_batch=0) # workaround for: https://github.com/tensorflow/tensorboard/issues/2084

reduce_lr_callback = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.25,
                              patience=2, min_lr=1e-8)

callbacks = [tensorboard_callback, lr_callback, plot_images_callback, reduce_lr_callback]

In [None]:
training_history = model.fit(
    train_dataset,
    validation_data=dev_dataset,
    verbose=1, # set to 0 to suppress chatty output and use Tensorboard instead
    epochs=1000,
    callbacks=callbacks,
    use_multiprocessing=True) # see if speeds things up

In [None]:
# # # Uncomment to save model
# saved_model_path = 'automap_our_dataset_original_paper_model_with_up_down_sampling' (had out 256 channels... bad)

saved_model_path = 'automap_our_dataset_original_paper_model_with_up_down_sampling_single_GPU_v5'

model.save(saved_model_path)

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt

In [None]:
plt.figure()
plt.plot(training_history.history["loss"], label="Train")
plt.plot(training_history.history["val_loss"], label="Test")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend(loc="center right")

In [None]:
plt.figure()
plt.plot(training_history.history["mean_PSNR"], label="Train")
plt.plot(training_history.history["val_mean_PSNR"], label="Test")
plt.xlabel("Epoch")
plt.ylabel("µ ( PSNR ) ")
plt.legend(loc="center right")

In [None]:
model.evaluate(test_dataset)

In [None]:
model.evaluate(dev_dataset)

In [None]:
model.evaluate(train_dataset)

In [None]:
# Predict on a test batch

for x, y in test_dataset:
    x, y = x.numpy(), y.numpy()
    y_pred = model.predict(x)
    
    # Inspect output

    for i in range(len(x)):

        fft_mag = x[i, ..., 0]
        fft_ang = x[i, ..., 1]
        c = None #int(cls[i])
        reconstruction = y_pred[i, ..., 0]
        reconstruction[reconstruction < 0] = 0
        reconstruction[reconstruction > 1] = 1
        image = y[i, ..., 0]

        print('Class: {}'.format(c))

        MSE = utils.signal_processing.mean_square_error(reconstruction, image)
        PSNR = utils.signal_processing.PSNR(reconstruction, image, max_value=1.0)

        plt.subplot(2, 2, 1)
        plt.title('Reconstruction (MSE: {:0.5f}, PSNR: {:0.5f})'.format(MSE, PSNR))
        utils.plot.imshowgray(reconstruction)

        plt.subplot(2, 2, 2)
        plt.title('FFT (Magnitude)')
        utils.plot.imshowfft(fft_mag)

        plt.subplot(2, 2, 3)
        plt.title('Expected reconstruction')
        utils.plot.imshowgray(image)

        plt.subplot(2, 2, 4)
        plt.title('FFT (Phase)')
        utils.plot.imshowgray(fft_ang)

        plt.show()
    
    break

In [None]:
# # In theory, in TF 2.0 we should be able to see Tensorboard in this notebook with magics:
# %load_ext tensorboard
# %tensorboard --logdir logs

# Clear logs if needed
# !rm -rf logs/

In [None]:
# !wget https://www.dropbox.com/s/1l4z7u062nvlhrz/MRI_Kspace.dat
# See https://github.com/kmjohnson3/ML4MI_BootCamp/blob/fe9d96cd9f68db073a44f9dc9a015533a008d0a7/ImageReconstruction/CoLab_AutoMap_Recon.ipynb

In [None]:
# ll -h MRI_Kspace.dat

# Ignore these old models (here for ref)

In [None]:
# def load_uncompiled_automap_model():
#     N = 256
#     small_N = N // 4 # after downsampling by 2 twice
    
#     X = tf.keras.layers.Input(shape=(N, N, 2))
#     noisy_X = tf.keras.layers.GaussianNoise(stddev=1e-6)(X)
#     conv_downsample1 = tf.keras.layers.Conv2D(16, (4, 4), strides=(2, 2), activation='tanh', padding='same')(noisy_X)
# #     conv_downsample2 = tf.keras.layers.Conv2D(4, (4, 4), strides=(1, 1), activation='tanh', padding='same')(conv_downsample1)
#     conv_downsample3 = tf.keras.layers.Conv2D(2, (4, 4), strides=(2, 2), activation='tanh', padding='same')(conv_downsample1)
#     X1 = tf.keras.layers.Flatten()(conv_downsample3)
    
#     # Workaround for: ValueError: The last dimension of the inputs to `Dense` should be defined. Found `None`.
#     X1 = tf.keras.layers.Reshape(target_shape=((small_N ** 2) * 2,))(X1)
    
#     fc1 = tf.keras.layers.Dense((small_N ** 2) * 1, activation = 'tanh')(X1)
#     fc1_DO = tf.keras.layers.Dropout(0.1)(fc1)
    
#     fc2 = tf.keras.layers.Dense(small_N ** 2, activation = 'tanh')(fc1_DO)
#     fc2_DO = tf.keras.layers.Dropout(0.1)(fc2)

#     fc3 = tf.keras.layers.Dense(small_N ** 2, activation = 'tanh')(fc2_DO)
#     X2 = tf.keras.layers.Reshape((small_N, small_N, 1))(fc3)
#     conv1_1 = tf.keras.layers.Conv2D(small_N, 5, activation='relu', padding='same', kernel_regularizer=tf.keras.regularizers.l1(1e-4))(X2)
#     conv1_2 = tf.keras.layers.Conv2D(small_N, 5, activation='relu', padding='same', kernel_regularizer=tf.keras.regularizers.l1(1e-4))(conv1_1)
#     conv1_3a = tf.keras.layers.Conv2DTranspose(small_N, 9, activation='relu', padding='same')(conv1_2)
#     conv1_3b = tf.keras.layers.Conv2DTranspose(small_N, 9, strides=2, activation='relu', padding='same')(conv1_3a)
#     conv1_3c = tf.keras.layers.Conv2DTranspose(small_N, 9, strides=2, activation='relu', padding='same')(conv1_3b)
    
#     Y_pred = tf.keras.layers.Conv2D(1, 1, activation = 'linear', padding='same')(conv1_3c)
    
#     model = tf.keras.Model(inputs=X, outputs=Y_pred)
    
#     return model

In [None]:
# def load_uncompiled_automap_model():
    
#     # this one's solid, but I believe we'll need a few hours to train it.
    
#     N = 256
#     F = N
#     X = tf.keras.layers.Input(shape=(N, N, 2))

#     # Half-assed data augmentation
#     noisy_X = tf.keras.layers.GaussianNoise(stddev=1e-7)(X) # shape: (256, 256, 256)

#     # These layers all halve the spatial dimension (but also each output 256 channels)
#     conv1 = tf.keras.layers.Conv2D(F, (3, 3), strides=(1, 1), activation='relu', padding='same')(noisy_X)
#     pool1 = tf.keras.layers.AveragePooling2D(pool_size=2)(conv1) # shape: (128, 128, F)

#     conv2 = tf.keras.layers.Conv2D(F, (3, 3), strides=(1, 1), activation='relu', padding='same')(pool1)
#     pool2 = tf.keras.layers.AveragePooling2D(pool_size=2)(conv2) # shape: (64, 64, F)

#     conv3 = tf.keras.layers.Conv2D(F, (3, 3), strides=(1, 1), activation='relu', padding='same')(pool2)
#     pool3 = tf.keras.layers.AveragePooling2D(pool_size=2)(conv3) # shape: (32, 32, F)

#     conv4 = tf.keras.layers.Conv2D(F, (3, 3), strides=(1, 1), activation='relu', padding='same')(pool3)
#     pool4 = tf.keras.layers.AveragePooling2D(pool_size=2)(conv4) # shape: (16, 16, F)

#     conv5 = tf.keras.layers.Conv2D(F, (3, 3), strides=(1, 1), activation='relu', padding='same')(pool4)
#     pool5 = tf.keras.layers.AveragePooling2D(pool_size=2)(conv5) # shape: (8, 8, F)

#     conv6 = tf.keras.layers.Conv2D(F, (3, 3), strides=(1, 1), activation='relu', padding='same')(pool5)
#     pool6 = tf.keras.layers.AveragePooling2D(pool_size=2)(conv6) # shape: (4, 4, F)

#     # A "FC-like" layer for fun before we do upsampling
#     conv7 = tf.keras.layers.Conv2D(F, (1, 1), strides=(1, 1), activation='relu', padding='same')(pool6) # spatial dim: 4

#     # These transposed convolutions upsample spatial dimension by 2
#     t_conv1 = tf.keras.layers.Conv2DTranspose(F, 4, strides=2, activation='relu', padding='same')(conv7) # spatial dim: 8
#     t_conv2 = tf.keras.layers.Conv2DTranspose(F, 4, strides=2, activation='relu', padding='same')(t_conv1) # spatial dim: 16
#     t_conv3 = tf.keras.layers.Conv2DTranspose(F, 4, strides=2, activation='relu', padding='same')(t_conv2) # spatial dim: 32
#     t_conv4 = tf.keras.layers.Conv2DTranspose(F, 4, strides=2, activation='relu', padding='same')(t_conv3) # spatial dim: 64
#     t_conv5 = tf.keras.layers.Conv2DTranspose(F, 4, strides=2, activation='relu', padding='same')(t_conv4) # spatial dim: 128
    
#     Y_pred = tf.keras.layers.Conv2DTranspose(1, 4, strides=2, activation='linear', padding='same')(t_conv5) # spatial dim: 256

#     model = tf.keras.Model(inputs=X, outputs=Y_pred)

#     return model



In [None]:
# no transposed conv, just upsample (first try, guessing too much data lost in middle)

# def load_uncompiled_automap_model():

#     N = 256
#     F = N
#     X = tf.keras.layers.Input(shape=(N, N, 2))

#     # Half-assed data augmentation
#     noisy_X = tf.keras.layers.GaussianNoise(stddev=1e-7)(X) # shape: (256, 256, 256)

#     # Downsampling

#     conv1 = tf.keras.layers.Conv2D(F, (3, 3), strides=(1, 1), activation='relu', padding='same')(noisy_X)
#     pool1 = tf.keras.layers.AveragePooling2D(pool_size=2)(conv1) # shape: (128, 128, F)

#     conv2 = tf.keras.layers.Conv2D(F, (3, 3), strides=(1, 1), activation='relu', padding='same')(pool1)
#     pool2 = tf.keras.layers.AveragePooling2D(pool_size=2)(conv2) # shape: (64, 64, F)

#     conv3 = tf.keras.layers.Conv2D(F, (3, 3), strides=(1, 1), activation='relu', padding='same')(pool2)
#     pool3 = tf.keras.layers.AveragePooling2D(pool_size=2)(conv3) # shape: (32, 32, F)

#     conv4 = tf.keras.layers.Conv2D(F, (3, 3), strides=(1, 1), activation='relu', padding='same')(pool3)
#     pool4 = tf.keras.layers.AveragePooling2D(pool_size=2)(conv4) # shape: (16, 16, F)

#     conv5 = tf.keras.layers.Conv2D(F, (3, 3), strides=(1, 1), activation='relu', padding='same')(pool4)
#     pool5 = tf.keras.layers.AveragePooling2D(pool_size=2)(conv5) # shape: (8, 8, F)

#     conv6 = tf.keras.layers.Conv2D(F, (3, 3), strides=(1, 1), activation='relu', padding='same')(pool5)
#     pool6 = tf.keras.layers.AveragePooling2D(pool_size=2)(conv6) # shape: (4, 4, F)

#     # Some pointwise convoluations before finally upsampling

#     conv7 = tf.keras.layers.Conv2D(F, (1, 1), strides=(1, 1), activation='relu', padding='same')(pool6) # shape: (4, 4, F)
#     conv8 = tf.keras.layers.Conv2D(F, (1, 1), strides=(1, 1), activation='relu', padding='same')(conv7) # shape: (4, 4, F)
#     conv9 = tf.keras.layers.Conv2D(F, (1, 1), strides=(1, 1), activation='relu', padding='same')(conv8) # shape: (4, 4, F)

#     # Upsampling

#     conv10 = tf.keras.layers.Conv2D(F, (3, 3), strides=(1, 1), activation='relu', padding='same')(conv9)
#     pool10 = tf.keras.layers.UpSampling2D(2)(conv10) # shape: (8, 8, F)

#     conv11 = tf.keras.layers.Conv2D(F, (3, 3), strides=(1, 1), activation='relu', padding='same')(pool10)
#     pool11 = tf.keras.layers.UpSampling2D(2)(conv11) # shape: (16, 16, F)

#     conv12 = tf.keras.layers.Conv2D(F, (3, 3), strides=(1, 1), activation='relu', padding='same')(pool11)
#     pool12 = tf.keras.layers.UpSampling2D(2)(conv12) # shape: (32, 32, F)

#     conv13 = tf.keras.layers.Conv2D(F, (3, 3), strides=(1, 1), activation='relu', padding='same')(pool12)
#     pool13 = tf.keras.layers.UpSampling2D(2)(conv13) # shape: (64, 64, F)

#     conv14 = tf.keras.layers.Conv2D(F, (3, 3), strides=(1, 1), activation='relu', padding='same')(pool13)
#     pool14 = tf.keras.layers.UpSampling2D(2)(conv14) # shape: (128, 128, F)

#     conv15 = tf.keras.layers.Conv2D(F, (3, 3), strides=(1, 1), activation='relu', padding='same')(pool14)
#     pool15 = tf.keras.layers.UpSampling2D(2)(conv15) # shape: (256, 256, F)

#     # One more to smooth things out

#     Y_pred = tf.keras.layers.Conv2D(1, (3, 3), strides=(1, 1), activation='linear', padding='same')(pool15)

#     model = tf.keras.Model(inputs=X, outputs=Y_pred)

#     return model

In [None]:
# def load_uncompiled_automap_model():

#     N = 256
#     F = N
#     X = tf.keras.layers.Input(shape=(N, N, 2))

#     # Half-assed data augmentation
#     noisy_X = tf.keras.layers.GaussianNoise(stddev=1e-7)(X) # shape: (256, 256, 256)

#     # Downsampling

#     conv1 = tf.keras.layers.Conv2D(F, (9, 9), strides=(1, 1), activation='relu', padding='same')(noisy_X)
#     pool1 = tf.keras.layers.AveragePooling2D(pool_size=2)(conv1) # shape: (128, 128, F)

#     conv2 = tf.keras.layers.Conv2D(F, (7, 7), strides=(1, 1), activation='relu', padding='same')(pool1)
#     pool2 = tf.keras.layers.AveragePooling2D(pool_size=2)(conv2) # shape: (64, 64, F)

#     conv3 = tf.keras.layers.Conv2D(F, (5, 5), strides=(1, 1), activation='relu', padding='same')(pool2)
#     pool3 = tf.keras.layers.AveragePooling2D(pool_size=2)(conv3) # shape: (32, 32, F)

#     conv4 = tf.keras.layers.Conv2D(F, (3, 3), strides=(1, 1), activation='relu', padding='same')(pool3)
#     pool4 = tf.keras.layers.AveragePooling2D(pool_size=2)(conv4) # shape: (16, 16, F)

#     # Some pointwise convolutions before finally upsampling

#     conv5 = tf.keras.layers.Conv2D(F, (3, 3), strides=(1, 1), activation='relu', padding='same')(pool4) # shape: (16, 16, F)
#     conv6 = tf.keras.layers.Conv2D(F, (1, 1), strides=(1, 1), activation='relu', padding='same')(conv5) # shape: (16, 16, F)
#     conv7 = tf.keras.layers.Conv2D(F, (1, 1), strides=(1, 1), activation='relu', padding='same')(conv6) # shape: (16, 16, F)

#     # Upsampling

#     t_conv8 = tf.keras.layers.Conv2DTranspose(F, 4, strides=2, activation='relu', padding='same')(conv7)
#     t_conv9 = tf.keras.layers.Conv2DTranspose(F, 4, strides=2, activation='relu', padding='same')(t_conv8)
#     t_conv10 = tf.keras.layers.Conv2DTranspose(F, 4, strides=2, activation='relu', padding='same')(t_conv9)
#     t_conv11 = tf.keras.layers.Conv2DTranspose(F, 4, strides=2, activation='relu', padding='same')(t_conv10)
    
# #     pool8 = tf.keras.layers.UpSampling2D(2)(conv7) # shape: (32, 32, F)
# #     conv8 = tf.keras.layers.Conv2D(F, (3, 3), strides=(1, 1), activation='relu', padding='same')(pool8)

# #     pool9 = tf.keras.layers.UpSampling2D(2)(conv8) # shape: (64, 64, F)
# #     conv9 = tf.keras.layers.Conv2D(F, (3, 3), strides=(1, 1), activation='relu', padding='same')(pool9)

# #     pool10 = tf.keras.layers.UpSampling2D(2)(conv9) # shape: (128, 128, F)
# #     conv10 = tf.keras.layers.Conv2D(F, (3, 3), strides=(1, 1), activation='relu', padding='same')(pool10)

# #     pool11 = tf.keras.layers.UpSampling2D(2)(conv10) # shape: (256, 256, F)
# #     conv11 = tf.keras.layers.Conv2D(F, (5, 5), strides=(1, 1), activation='relu', padding='same')(pool11)

#     # One more to smooth things out

#     Y_pred = tf.keras.layers.Conv2D(1, (2, 2), strides=(1, 1), activation='linear', padding='same')(t_conv11)

#     model = tf.keras.Model(inputs=X, outputs=Y_pred)

#     return model