# Celeb A - phase 2
---

Without batch normalization

In [1]:
import os
import pandas as pd
import datetime
from pconv_keras.util import MaskGenerator
from pconv_keras.generator import AugmentingDataGenerator
from pconv_keras.pconv_model import PConvUnet

from keras.callbacks import TensorBoard, ModelCheckpoint, LambdaCallback

import matplotlib.pyplot as plt
from matplotlib.ticker import NullFormatter
from IPython.display import clear_output

BATCH_SIZE=16

Using TensorFlow backend.
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


## Create Test & Train Generator

In [2]:
# Create training generator
train_datagen = AugmentingDataGenerator(  
    rotation_range=10,
    width_shift_range=0.1,
    height_shift_range=0.1,
    rescale=1./255,
    horizontal_flip=True
)
train_generator = train_datagen.flow_from_dataframe(
    pd.read_csv("/mnt/data/data/celeba/train.csv"),
    MaskGenerator(256, 256, 3),
    folder='/mnt/data/data/celeba/',
    target_size=(256, 256), 
    batch_size=BATCH_SIZE
)

Found 141819 validated image filenames.


In [3]:
# Create validation generator
val_datagen = AugmentingDataGenerator(rescale=1./255)
val_generator = val_datagen.flow_from_dataframe(
    pd.read_csv("/mnt/data/data/celeba/val.csv"), 
    MaskGenerator(256, 256, 3), 
    folder='/mnt/data/data/celeba/',
    target_size=(256,256), 
    batch_size=BATCH_SIZE, 
    seed=42
)

Found 20260 validated image filenames.


In [4]:
# Create testing generator
test_datagen = AugmentingDataGenerator(rescale=1./255)
test_generator = test_datagen.flow_from_dataframe(
    pd.read_csv("/mnt/data/data/celeba/test.csv"), 
    MaskGenerator(256, 256, 3), 
    folder='/mnt/data/data/celeba/',
    target_size=(256, 256), 
    batch_size=BATCH_SIZE, 
    seed=42
)

Found 40520 validated image filenames.


## Plotting

In [5]:
# Pick out an example
test_data = next(test_generator)
(masked, mask), ori = test_data

In [28]:
def plot_callback(model, folder):
    """Called at the end of each epoch, displaying our previous test images,
    as well as their masked predictions and saving them to disk"""
    
    # Get samples & Display them        
    pred_img = model.predict([masked, mask])
    pred_time = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')

    # Clear current output and display test images
    for i in range(len(ori)):
        _, axes = plt.subplots(1, 3, figsize=(20, 5))
        axes[0].imshow(masked[i,:,:,:])
        axes[1].imshow(pred_img[i,:,:,:] * 1.)
        axes[2].imshow(ori[i,:,:,:])
        axes[0].set_title('Masked Image')
        axes[1].set_title('Predicted Image')
        axes[2].set_title('Original Image')
                
        plt.savefig(os.path.join(folder, f"00img_{i}_{pred_time}.png"))
        plt.close()


### Phase 2 - without batch normalization

In [8]:
# Load weights from previous run
model = PConvUnet(img_rows=256, img_cols=256,
                   vgg_weights=r"/mnt/data/train_camp/pconv_keras_imagenet/pytorch_to_keras_vgg16.h5")
model.load(
    r"/mnt/data/train_camp/pconv_keras_celeba/imagenet_phase2/weights.01-1.30.h5",
#     r"/mnt/data/train_camp/pconv_keras_celeba/imagenet_phase1_paperMasks/weights.09-1.23.h5",
    train_bn=False,
    lr=0.00005
)

In [None]:
FOLDER = r'/mnt/data/train_camp/pconv_keras_celeba/imagenet_phase2/'
TEST_SAMPLE_FOLDER = os.path.join(FOLDER, 'test_samples')
if not os.path.isdir(TEST_SAMPLE_FOLDER):
    os.makedirs(TEST_SAMPLE_FOLDER)
    
plot_callback(model, TEST_SAMPLE_FOLDER)

In [None]:
# Run training for certain amount of epochs
model.fit_generator(
    train_generator, 
    steps_per_epoch=len(train_datagen.generator),
    validation_data=val_generator,
    validation_steps=len(val_datagen.generator),
    epochs=10,  
#     verbose=0,
    callbacks=[
        TensorBoard(
            log_dir=FOLDER,
            write_graph=False
        ),
        ModelCheckpoint(
            os.path.join(FOLDER, 'weights.{epoch:02d}-{loss:.2f}.h5'),
            monitor='val_loss', 
            save_best_only=True, 
            save_weights_only=True
        ),
        LambdaCallback(
            on_epoch_end=lambda epoch, logs: plot_callback(model, TEST_SAMPLE_FOLDER)
        )
    ]
)

In [None]:
print("FINISHED!!!!!")