In [6]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [7]:
import cv2, os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 

import numpy as np
import tensorflow as tf

from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.models import Sequential

from datetime import datetime
from IPython.display import clear_output
from typing import Tuple, List, Union, Optional

from rich import pretty
from rich.console import Console
from rich.traceback import install
from rich.progress import Progress
from rich import inspect

from models.MobileNetDecoder import MobileNetDecoder
from utils.image_handler import *

console = Console()
pretty.install()
install() # install rich traceback
print = console.print
tf.config.list_physical_devices('GPU')

In [8]:
BATCH_SIZE = 16
INPUT_SHAPE = (128, 128, 3)
EPOCHS = 100
SAVE_PERIOD = 10
LR = 1e-3,
GAMMA = .7 #scheduler decay rate
DATA_PATH = 'datasets/subflickr/'
MODEL_SAVE_PATH = 'saved_models/weights/'
LOG_PATH = 'runs/MOBILENET_DECODER_100EP_1'

data_gen_args = dict(
    brightness_range=[0.5, 1.2],
    horizontal_flip=True,
    rescale=1/255,
    fill_mode='reflect',
    data_format='channels_last'
)
data_flow_args = dict(
    target_size=INPUT_SHAPE[:-1],
    batch_size=BATCH_SIZE,
    class_mode='input') # Since we want to reconstruct the input

In [9]:
train_datagen = ImageDataGenerator(**data_gen_args)
val_datagen = ImageDataGenerator(**data_gen_args)
test_datagen = ImageDataGenerator(**data_gen_args)

train_batches = train_datagen.flow_from_directory(
    os.path.join(os.path.abspath(DATA_PATH), 'train'),
    **data_flow_args)

val_batches = val_datagen.flow_from_directory(
    os.path.join(os.path.abspath(DATA_PATH), 'val'),
    **data_flow_args)

test_batches = val_datagen.flow_from_directory(
    os.path.join(os.path.abspath(DATA_PATH), 'test'),
    **data_flow_args)

train_gen_batches = generator(train_batches, noise_sd=0)
val_gen_batches = generator(val_batches, noise_sd=0)
test_gen_batches = generator(test_batches, noise_sd=0)

Found 3520 images belonging to 1 classes.
Found 880 images belonging to 1 classes.
Found 1100 images belonging to 1 classes.


In [10]:
model = MobileNetDecoder(
    shape=INPUT_SHAPE,
    dropout=.2
)
model.build((None, *INPUT_SHAPE))
model.summary()
model.encode_decode_summary()
#scheduler = tf.keras.optimizers.schedules.ExponentialDecay(
#    LR,
#    decay_steps=len(train_batches),
#    decay_rate=GAMMA)
#model.compile(optimizer=Adam(scheduler), loss='mse', metrics=['accuracy'])
model.compile(optimizer=Adam(lr=1e-3), loss='mse', metrics=['accuracy'])

Model: "mobile_net_decoder_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
encoder (Sequential)         (None, 5120)              2257984   
_________________________________________________________________
decoder (Sequential)         (None, 128, 128, 3)       3048476   
Total params: 5,306,460
Trainable params: 5,272,100
Non-trainable params: 34,360
_________________________________________________________________
Model: "encoder"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
mobilenetv2_1.00_128 (Functi (None, 4, 4, 1280)        2257984   
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 2, 2, 1280)        0         
_________________________________________________________________
flatten_1 (Flatten)          (None, 5120)              0         
Total params: 

In [6]:
# Callbacks
saved_weight = os.path.join(MODEL_SAVE_PATH, 'weights.{epoch:02d}-{val_accuracy:.2f}.hdf5')
modelchk = tf.keras.callbacks.ModelCheckpoint(saved_weight, 
                                      monitor='val_accuracy', 
                                      verbose=1,
                                      save_best_only=True, 
                                      save_weights_only=False,
                                      mode='auto',
                                      save_freq='epoch') # save models every epoch
tensorboard = tf.keras.callbacks.TensorBoard(log_dir=LOG_PATH,
                                          histogram_freq=0,
                                          write_graph=True,
                                          write_images=True)

In [None]:
hist = model.fit(train_gen_batches,
            steps_per_epoch = train_batches.samples // BATCH_SIZE,
            epochs=EPOCHS,
            verbose=1, 
            validation_data=val_gen_batches,
            validation_steps = train_batches.samples // BATCH_SIZE,
            callbacks=[modelchk, tensorboard],
            use_multiprocessing=False).history

Epoch 1/100

Epoch 00001: val_accuracy improved from -inf to 0.37180, saving model to saved_models/weights/weights.01-0.37.hdf5
Epoch 2/100

Epoch 00002: val_accuracy improved from 0.37180 to 0.37468, saving model to saved_models/weights/weights.02-0.37.hdf5
Epoch 3/100

Epoch 00003: val_accuracy improved from 0.37468 to 0.37858, saving model to saved_models/weights/weights.03-0.38.hdf5
Epoch 4/100

Epoch 00004: val_accuracy improved from 0.37858 to 0.38718, saving model to saved_models/weights/weights.04-0.39.hdf5
Epoch 5/100

Epoch 00005: val_accuracy improved from 0.38718 to 0.38762, saving model to saved_models/weights/weights.05-0.39.hdf5
Epoch 6/100

Epoch 00006: val_accuracy did not improve from 0.38762
Epoch 7/100

Epoch 00007: val_accuracy did not improve from 0.38762
Epoch 8/100

Epoch 00008: val_accuracy did not improve from 0.38762
Epoch 9/100

Epoch 00009: val_accuracy did not improve from 0.38762
Epoch 10/100

Epoch 00010: val_accuracy did not improve from 0.38762
Epoch 1