In [1]:
import tensorflow as tf
import tensorflow_io as tfio
import keras
from tensorflow.keras import layers

from keras import Model
from keras.utils import plot_model
import PIL
from PIL import Image, ImageFilter
import numpy as np
import os
import pathlib
import glob

from models import *
from training_utils import AEMonitor

from tfrecord_utils import _parse_function, parse_tfrecord_fn, parse_tfrecord_fn_yuv, load_dataset, get_dataset

In [2]:
tf.config.list_physical_devices('GPU')

[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

In [3]:
# Color space info, for RGB se yuv to False
yuv = False

if yuv:
    mode = "YCbCr"
else:
    mode = "RGB"

In [4]:
# Get Dataset
ds = get_dataset('/PATH/TO/TFRECORD/training_data.tfrecord', shuffle=256, batch=4, yuv=False)

In [5]:
# Checkpoint to produce frames to monitor progress
# Optional
sampled_frames = glob.glob('/PATH/TO/FRAMES/*')
sampled_frames.sort()
sampled_frames = sampled_frames[2500:]
prediction_array = np.zeros([512,512,15])
for i in range(5):
    img = Image.open(sampled_frames[10+i]).convert(mode).resize([512,512])
    prediction_array[:,:,i*3:(i+1)*(3)] = ((np.array(img)/255))#*2)-1

monitor_dir = 'monitoring/monitor_dir'
if not os.path.exists(monitor_dir):
    os.mkdir(monitor_dir)
    
monitor = AEMonitor(np.expand_dims(prediction_array,0), monitor_dir, 60, True)

In [6]:
# Weight checkpoint
weight_checkpoint = tf.keras.callbacks.ModelCheckpoint(
                                        'model_weights/model_name{epoch:04d}',
                                        monitor="val_loss",
                                        verbose=1,
                                        save_best_only=False,
                                        save_weights_only=True,
                                        mode="auto",
                                        save_freq=256*10,
                                        options=None
                                    )

In [7]:
# Custom Loss Functions
def custom_loss(y_true, y_pred):
    return keras.losses.mean_squared_error(y_true, y_pred) + keras.losses.categorical_crossentropy(y_true, y_pred)# \
#        + tf.keras.losses.KLDivergence()(y_true, y_pred)

def flood_loss(y_true, y_pred):
    flood = 0.01
    return keras.backend.abs(keras.losses.mean_squared_error(y_true, y_pred) - flood) + flood

def sin_loss(y_true, y_pred):
    mse = keras.losses.mean_squared_error(y_true, y_pred)
    return  mse + keras.backend.sin(mse)

lr = tf.keras.optimizers.schedules.ExponentialDecay(
    0.000002, 500, 0.96, staircase=False, name=None
)


In [8]:
make_deep_model_4().summary()

Model: "functional_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
image (InputLayer)              [(None, 512, 512, 15 0                                            
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 512, 512, 8)  3008        image[0][0]                      
__________________________________________________________________________________________________
p_re_lu (PReLU)                 (None, 512, 512, 8)  2097152     conv2d[0][0]                     
__________________________________________________________________________________________________
batch_normalization (BatchNorma (None, 512, 512, 8)  32          p_re_lu[0][0]                    
_______________________________________________________________________________________

In [9]:
import random
def scheduler(epoch, lr):
    #return random.random()/100
    return random.choice([0.00002,0.000002,0.0000002])
    
callback = tf.keras.callbacks.LearningRateScheduler(scheduler)

In [11]:
opt = keras.optimizers.Adam(learning_rate=0.000001)#, beta_1=0.95, beta_2=0.85)
loss_fn = custom_loss

ae = make_deep_model_4()
ae.compile(opt, loss_fn)

print(ae.summary())

try:
    plot_model(ae, to_file="model4.png")
except:
    print("Could not plot")

ae.load_weights('model_weights/model_name')

print("Fitting Model")
ae.fit(ds, epochs = 1000, callbacks=[monitor,weight_checkpoint], steps_per_epoch=256)

Model: "functional_5"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
image (InputLayer)              [(None, 512, 512, 15 0                                            
__________________________________________________________________________________________________
conv2d_38 (Conv2D)              (None, 512, 512, 8)  3008        image[0][0]                      
__________________________________________________________________________________________________
p_re_lu_50 (PReLU)              (None, 512, 512, 8)  2097152     conv2d_38[0][0]                  
__________________________________________________________________________________________________
batch_normalization_66 (BatchNo (None, 512, 512, 8)  32          p_re_lu_50[0][0]                 
_______________________________________________________________________________________

Fitting Model
Epoch 1/1000
Epoch 2/1000
Epoch 3/1000
Epoch 4/1000
Epoch 5/1000
Epoch 00005: saving model to model_weights/deep_model_4_road_0005
Epoch 6/1000
Epoch 7/1000
Epoch 8/1000
Epoch 9/1000
Epoch 10/1000
Epoch 11/1000
Epoch 12/1000
Epoch 13/1000
Epoch 14/1000
Epoch 15/1000
Epoch 00015: saving model to model_weights/deep_model_4_road_0015
Epoch 16/1000
Epoch 17/1000
Epoch 18/1000
Epoch 19/1000
Epoch 20/1000
Epoch 21/1000
Epoch 22/1000
Epoch 23/1000
Epoch 24/1000
Epoch 25/1000
Epoch 00025: saving model to model_weights/deep_model_4_road_0025
Epoch 26/1000
Epoch 27/1000
Epoch 28/1000
Epoch 29/1000
Epoch 30/1000
Epoch 31/1000
Epoch 32/1000
Epoch 33/1000
Epoch 34/1000
Epoch 35/1000
Epoch 00035: saving model to model_weights/deep_model_4_road_0035
Epoch 36/1000
Epoch 37/1000
Epoch 38/1000
Epoch 39/1000
Epoch 40/1000
Epoch 41/1000
Epoch 42/1000
Epoch 43/1000
Epoch 44/1000
Epoch 45/1000
Epoch 00045: saving model to model_weights/deep_model_4_road_0045
Epoch 46/1000
Epoch 47/1000
Epoch 4

ResourceExhaustedError: model_weights/deep_model_4_road_0115_temp_056b9384bfb44660a394d4941fe03cff/part-00000-of-00001.data-00000-of-00001.tempstate499389358356234605; No space left on device [Op:SaveV2]