In [1]:
import numpy as np
from matplotlib import pyplot as plt
import os
import tensorflow as tf
from tqdm import tqdm
from scipy.ndimage import rotate, zoom
import math

In [2]:
tf.config.list_physical_devices()

[PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU'),
 PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

In [3]:
root_data_dir = "maps/"
runs_dir = "Runs/"
model_name = "tmp"
resume_model_name = ""
resume_epoch = 0
model_dir = runs_dir + model_name + "/"
if os.listdir(runs_dir).count(model_name) == 0:
    os.mkdir(runs_dir + model_name)
maps = os.listdir(root_data_dir)
dmaps = [m for m in maps if m[0] == "d"]
gmaps = [m for m in maps if m[0] == "g"]
resume_from_last = False
resume = False
train_val_split_rate = 0.9
n_crumples = len(dmaps)
idxs = np.arange(n_crumples)
np.random.shuffle(idxs)
e_crumple_idxs = idxs[:int(n_crumples * train_val_split_rate)]
f_crumple_idxs = idxs[int(n_crumples * train_val_split_rate):]
Batch_size = 20
Epochs = 150
Val_Batch_size = 40
total_chunks = 153
images_per_chunk = 1000
steps_per_epoch = int(images_per_chunk * train_val_split_rate) * total_chunks // Batch_size
steps_per_val_epoch = int(images_per_chunk * (1 - train_val_split_rate)) * total_chunks // Val_Batch_size
saved_models_list = os.listdir(runs_dir + model_name)
key_func = lambda x:int(x.split("_")[-1][:-3]) if x[-3:] == ".h5" else 0
saved_models_list.sort(key=key_func)
if resume_from_last:
    last_epoch = int(saved_models_list[-1].split("_")[-1][:-3])
else:
    last_epoch = -1

In [4]:
print(steps_per_epoch, steps_per_val_epoch)

6885 378


In [5]:
def quick_deformation(crumple_id, img):
    global dmaps, gmaps, root_data_dir
    g_map = np.load(root_data_dir + gmaps[crumple_id])
    d_map = np.load(root_data_dir + dmaps[crumple_id])
    ret = np.zeros_like(img)
    ret[...] = img[:, d_map[..., 0], d_map[..., 1], :] * np.expand_dims(g_map, 2)
    return ret

In [36]:
def mean_elastic_distance(y_true, y_pred):
    return tf.keras.losses.MSE(y_true, y_pred) + 80 * tf.keras.losses.MAE(y_true, y_pred)

In [7]:
def zoom_by_shape(img, z):
    res = np.zeros((1000,180,180))
    for i in range(1000):    
        x1 = int(np.random.uniform(0,z))
        x2 = z- x1
        y1 = int(np.random.uniform(0,z))
        y2 = z- y1
        cropped = img[i,x1:-x2,y1:-y2]
        res[i] = zoom(cropped, zoom = 180/(220-z))
        if res[i].shape != (180, 180):
            print(info)
            print(i, cropped.shape, res[i].shape, x1, x2, y1, y2, z)
    
    return res
    

In [8]:
def rotate_by_shape(img):
    res = np.zeros((img.shape[0], 220, 220, 1))
    for i in range(img.shape[0]):
        res[i] = rotate(img[i], angle=np.random.uniform(-5, 5), reshape=False)
    return res

In [4]:
def data_gen(Batch_size):
    while 1:    
        for i in range(140):
            first = int(np.random.uniform(140))
            first = np.load(f"Images/{first}.npy", mmap_mode="r")
            second = np.load(f"Images/{i}.npy", mmap_mode="r")
            # plt.imshow(second[0])
            second_fliped = np.flip(second, axis = 2)
            # plt.imshow(first_fliped[0])

            combined = first * .9+ second_fliped * .1
            for j in range(0,1000-Batch_size, Batch_size):
                yield combined[j:j+Batch_size],first[j:j+Batch_size]
                            

In [5]:
def val_data_gen(Batch_size):
    while 1:    
        for i in range(140,153):
            first = int(np.random.uniform(140,153))
            first = np.load(f"Images/{first}.npy", mmap_mode="r")
            second = np.load(f"Images/{i}.npy", mmap_mode="r")
            # plt.imshow(second[0])
            second_fliped = np.flip(second, axis = 2)
            # plt.imshow(first_fliped[0])

            combined = first * .9+ second_fliped * .1
            for j in range(0,1000-Batch_size, Batch_size):
                yield combined[j:j+Batch_size],first[j:j+Batch_size]
                            

In [9]:
# def data_gen(training=True, batch_size=1):
#     global e_crumple_idxs, f_crumple_idxs, total_chunks, train_val_split_rate, images_per_chunk
#     if training:
#         crumple_idxs = e_crumple_idxs
#         start_batch = 0
#         end_batch = int(images_per_chunk * train_val_split_rate)
#     else:
#         crumple_idxs = f_crumple_idxs
#         start_batch = int(images_per_chunk * train_val_split_rate)
#         end_batch = images_per_chunk
#     while 1:
#         crumple_i = 0
#         for chunk_num in range(total_chunks):
#             for j in range(start_batch, end_batch, batch_size):
#                 imgs = np.load("Images/" + str(chunk_num) + ".npy", mmap_mode="r")
#                 IMAGE_SHAPE = 180
#                 imgs = np.expand_dims(imgs[j:j + batch_size], -1)
#                 im = imgs[:,20:-20,20:-20]
#                 im_norm = np.subtract(im, 127.5)
#                 cr_id = crumple_idxs[crumple_i]
#                 cr_im = quick_deformation(cr_id, im)
#                 cr_im_norm = np.subtract(cr_im,  127.5)
#                 yield cr_im_norm, im_norm 
#                 crumple_i = (crumple_i + 1) % len(crumple_idxs)

In [14]:
if not resume_from_last and not resume:
    FACTOR = 2
    x_inp = tf.keras.layers.Input(shape=(180, 180, 1))
    padding_layer = tf.keras.layers.ZeroPadding2D(padding=(2, 2))(x_inp) # 184
    c1 = tf.keras.layers.Conv2D(int(24 * FACTOR), (3, 3), padding="same", activation=tf.nn.leaky_relu)(padding_layer) # 184
    c_dial_1 = tf.keras.layers.Conv2D(int(8 * FACTOR), (3, 3), dilation_rate=1, padding="same", activation=tf.nn.leaky_relu)(padding_layer)
    c1_out = tf.keras.layers.Concatenate(axis = 3)([c1, c_dial_1])
    c2 = tf.keras.layers.Conv2D(int(12 * FACTOR), (3, 3), padding="same", activation=tf.nn.leaky_relu)(c1_out) # 184
    c_dial_2 = tf.keras.layers.Conv2D(int(4 * FACTOR), (3, 3), dilation_rate=1, padding="same", activation=tf.nn.leaky_relu)(c1_out)
    c2_out = tf.keras.layers.Concatenate(axis = 3)([c2, c_dial_2])
    c3 = tf.keras.layers.Conv2D(int(12 * FACTOR), (3, 3), padding="same", activation=tf.nn.leaky_relu)(c2_out) # 184
    c_dial_3 = tf.keras.layers.Conv2D(int(4 * FACTOR), (3, 3), dilation_rate=1, padding="same", activation=tf.nn.leaky_relu)(c2_out)
    c3_out = tf.keras.layers.Concatenate(axis = 3)([c3, c_dial_3])
    c4 = tf.keras.layers.Conv2D(int(12 * FACTOR), (3, 3), padding="same", activation=tf.nn.leaky_relu)(c3_out) # 184
    c_dial_4 = tf.keras.layers.Conv2D(int(4 * FACTOR), (3, 3), dilation_rate=1, padding="same", activation=tf.nn.leaky_relu)(c3_out)
    c4_out = tf.keras.layers.Concatenate(axis = 3)([c4, c_dial_4])
    c5 = tf.keras.layers.Conv2D(int(12 * FACTOR), (3, 3), padding="same", activation=tf.nn.leaky_relu)(c4_out) # 184
    c_dial_5 = tf.keras.layers.Conv2D(int(4 * FACTOR), (3, 3), dilation_rate=1, padding="same", activation=tf.nn.leaky_relu)(c4_out)
    c5_out = tf.keras.layers.Concatenate(axis = 3)([c5, c_dial_5])
    p1 = tf.keras.layers.MaxPooling2D()(c5_out)
    c5 = tf.keras.layers.Conv2D(int(16 * FACTOR), (3, 3), padding="same", activation=tf.nn.leaky_relu)(p1) # 92
    c6 = tf.keras.layers.Conv2D(int(16 * FACTOR), (3, 3), padding="same", activation=tf.nn.leaky_relu)(c5) # 92
    p2 = tf.keras.layers.MaxPooling2D()(c6)
    c7 = tf.keras.layers.Conv2D(int(32 * FACTOR), (3, 3), padding="same", activation=tf.nn.leaky_relu)(p2) # 46
    c8 = tf.keras.layers.Conv2D(int(32 * FACTOR), (3, 3), padding="same", activation=tf.nn.leaky_relu)(c7) # 46
    p3 = tf.keras.layers.MaxPooling2D()(c8)
    c9 = tf.keras.layers.Conv2D(int(64 * FACTOR), (3, 3), padding="same", activation=tf.nn.leaky_relu)(p3) # 23
    c10 = tf.keras.layers.Conv2D(int(64 * FACTOR), (3, 3), padding="same", activation=tf.nn.leaky_relu)(c9) # 23
    d11 = tf.keras.layers.Conv2D(int(64 * FACTOR), (3, 3), padding="same", activation=tf.nn.leaky_relu)(c10) # 23
    d10 = tf.keras.layers.Conv2D(int(64 * FACTOR), (3, 3), padding="same", activation=tf.nn.leaky_relu)(d11) # 23
    u4 = tf.keras.layers.UpSampling2D()(d10)
    cc4 = tf.keras.layers.Concatenate()([u4, c8])
    d9 = tf.keras.layers.Conv2D(int(32 * FACTOR), (3, 3), padding="same", activation=tf.nn.leaky_relu)(c8) # 46
    d8 = tf.keras.layers.Conv2D(int(32 * FACTOR), (3, 3), padding="same", activation=tf.nn.leaky_relu)(d9) # 46
    u3 = tf.keras.layers.UpSampling2D()(d8)
    cc3 = tf.keras.layers.Concatenate()([u3, c6])
    d7 = tf.keras.layers.Conv2D(int(32 * FACTOR), (3, 3), padding="same", activation=tf.nn.leaky_relu)(cc3) # 92
    d6 = tf.keras.layers.Conv2D(int(16 * FACTOR), (3, 3), padding="same", activation=tf.nn.leaky_relu)(d7) # 92
    u2 = tf.keras.layers.UpSampling2D()(d6)
    cc2 = tf.keras.layers.Concatenate()([u2, c5_out])
    d5 = tf.keras.layers.Conv2D(int(24 * FACTOR), (3, 3), padding="same", activation=tf.nn.leaky_relu)(cc2) # 184
    d_dial_5 = tf.keras.layers.Conv2D(int(8 * FACTOR), (3, 3), dilation_rate=1, padding="same", activation=tf.nn.leaky_relu)(cc2)
    d5_out = tf.keras.layers.Concatenate(axis = 3)([d5, d_dial_5])
    d4 = tf.keras.layers.Conv2D(int(12 * FACTOR), (3, 3), padding="same", activation=tf.nn.leaky_relu)(d5_out) # 184
    d_dial_4 = tf.keras.layers.Conv2D(int(4 * FACTOR), (3, 3), dilation_rate=1, padding="same", activation=tf.nn.leaky_relu)(d5_out)
    d4_out = tf.keras.layers.Concatenate(axis = 3)([d4, d_dial_4])
    d3 = tf.keras.layers.Conv2D(int(12 * FACTOR), (3, 3), padding="same", activation=tf.nn.leaky_relu)(d4_out) # 184
    d_dial_3 = tf.keras.layers.Conv2D(int(4 * FACTOR), (3, 3), dilation_rate=1, padding="same", activation=tf.nn.leaky_relu)(d4_out)
    d3_out = tf.keras.layers.Concatenate(axis = 3)([d3, d_dial_3])
    d2 = tf.keras.layers.Conv2D(int(8 * FACTOR), (3, 3), padding="same", activation=tf.nn.leaky_relu)(d3_out) # 184
    d_dial_2 = tf.keras.layers.Conv2D(int(4 * FACTOR), (3, 3), dilation_rate=1, padding="same", activation=tf.nn.leaky_relu)(d3_out)
    d2_out = tf.keras.layers.Concatenate(axis = 3)([d2, d_dial_2])
    d1 = tf.keras.layers.Conv2D(1, (3, 3), padding="same", activation=None)(d2) # 184
    crop_layer = tf.keras.layers.Cropping2D(cropping=((2, 2), (2, 2)))(sig) # 180
    model = tf.keras.models.Model(inputs=[x_inp], outputs=[crop_layer])
    model.compile(optimizer=tf.keras.optimizers.Adam(0.0003), loss=mean_elastic_distance)
elif resume and not resume_from_last:
    target = glob.glob(runs_dir + "/" + resume_model_name + "/*" + resume_epoch + "*")[0]
    model = tf.keras.models.load_model(target, custom_objects={"mean_elastic_distance":mean_elastic_distance})
else:
    model = tf.keras.models.load_model(runs_dir + "/" + model_name + "/" + os.listdir(runs_dir + model_name)[-2], custom_objects={"mean_elastic_distance":mean_elastic_distance})

In [15]:
# for i in model.layers[-17:]:
#     i.trainable = Falseresume_from_last

In [16]:
model.summary(show_trainable=True)

Model: "model_1"
_____________________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     Trainable  
 input_2 (InputLayer)           [(None, 180, 180, 1  0           []                               Y          
                                )]                                                                           
                                                                                                             
 zero_padding2d_1 (ZeroPadding2  (None, 184, 184, 1)  0          ['input_2[0][0]']                Y          
 D)                                                                                                          
                                                                                                             
 conv2d_32 (Conv2D)             (None, 184, 184, 48  480         ['zero_padding2d_1[0][0]']       Y    

In [17]:
tr_gen = data_gen(Batch_size)
va_gen = val_data_gen(Batch_size*2)
for e in range(last_epoch + 1, Epochs):
    print("Epoch", e + 1, "/", Epochs, ":")
    model.fit(x=tr_gen, 
              validation_data=va_gen, 
              batch_size=Batch_size, 
              epochs=1, 
              steps_per_epoch=steps_per_epoch, 
              validation_steps=steps_per_val_epoch)
    model.save(model_dir + model_name + "_" + str(e) + ".h5")
    dg_tr =  data_gen(1)
    dg_te =  val_data_gen(1)
    # plt.figure(figsize=(8, 3))
    fig, axs = plt.subplots(3, 8, figsize=(30, 12))
    for i in range(4):
        a, b = dg_tr.send(None)
        axs[0, i].set_title("Train Input " + str(i + 1))
        axs[0, i].imshow(a[0], cmap='gray')
        axs[0, i].axis('off')
        axs[1, i].set_title("Train GT " + str(i + 1))
        axs[1, i].imshow(b[0], cmap='gray')
        axs[1, i].axis('off')
        axs[2, i].set_title("Train Output " + str(i + 1))
        axs[2, i].imshow(model.predict(a, verbose=0)[0], cmap='gray')
        axs[2, i].axis('off')
        c, d = dg_te.send(None)
        axs[0, 4 + i].set_title("Test Input " + str(i + 1))
        axs[0, 4 + i].imshow(c[0], cmap='gray')
        axs[0, 4 + i].axis('off')
        axs[1, 4 + i].set_title("Test GT " + str(i + 1))
        axs[1, 4 + i].imshow(d[0], cmap='gray')
        axs[1, 4 + i].axis('off')
        axs[2, 4 + i].set_title("Test Output " + str(i + 1))
        axs[2, 4 + i].imshow(model.predict(c, verbose=0)[0], cmap='gray')
        axs[2, 4 + i].axis('off')
    fig.savefig(model_dir + model_name + "_" + str(e) + ".png", bbox_inches='tight')
    plt.close()

Epoch 1 / 150 :
 331/6885 [>.............................] - ETA: 38:48 - loss: 24444.9082

KeyboardInterrupt: 

In [20]:
x, y = next(iter(tr_gen))

In [32]:
x[0].squeeze().shape

(180, 180)

In [35]:
print(tf.losses.mse(x.flatten(), y.flatten()))

tf.Tensor(10657.464691358025, shape=(), dtype=float64)
