# ImageNet Partial Convolution
----
Copied from https://github.com/MathiasGruber/PConv-Keras/blob/master/notebooks/Step4%20-%20Imagenet%20Training.ipynb

In [None]:
import os
import gc
import datetime
import numpy as np
import pandas as pd
import cv2

from copy import deepcopy
from tqdm import tqdm

from keras.preprocessing.image import ImageDataGenerator
from keras.callbacks import TensorBoard, ModelCheckpoint, LambdaCallback
from keras import backend as K
from keras.utils import Sequence
from keras_tqdm import TQDMNotebookCallback

import matplotlib.pyplot as plt
from matplotlib.ticker import NullFormatter
from IPython.display import clear_output

In [None]:
# # Change to root path
# if os.path.basename(os.getcwd()) != 'PConv-Keras':
#     os.chdir(r"/mnt/data/PConv-Keras")    
# import sys
# sys.path.append(r"/mnt/data/PConv-Keras")

from libs.pconv_model import PConvUnet
from libs.util import MaskGenerator

In [None]:
# SETTINGS
TRAIN_DIR = r"/mnt/data/data/imagenet/ILSVRC/Data/CLS-LOC/train"
VAL_DIR = r"/mnt/data/data/imagenet/ILSVRC/Data/CLS-LOC/"
TEST_DIR = r"/mnt/data/data/imagenet/ILSVRC/Data/CLS-LOC/"

BATCH_SIZE = 4

## Creating train & test data generator

In [None]:
class AugmentingDataGenerator(ImageDataGenerator):
    def flow_from_directory(self, directory, mask_generator, *args, **kwargs):
        generator = super().flow_from_directory(directory, class_mode=None, *args, **kwargs)        
        seed = None if 'seed' not in kwargs else kwargs['seed']
        while True:
            
            # Get augmentend image samples
            ori = next(generator)

            # Get masks for each image sample            
            mask = np.stack([
                mask_generator.sample(seed)
                for _ in range(ori.shape[0])], axis=0
            )

            # Apply masks to all image sample
            masked = deepcopy(ori)
            masked[mask==0] = 1

            # Yield ([ori, masl],  ori) training batches
#             print(masked.shape, ori.shape)
            gc.collect()
            yield [masked, mask], ori

In [None]:
# Create training generator
train_datagen = AugmentingDataGenerator(  
    rotation_range=10,
    width_shift_range=0.1,
    height_shift_range=0.1,
    rescale=1./255,
    horizontal_flip=True
)
train_generator = train_datagen.flow_from_directory(
    TRAIN_DIR, 
    MaskGenerator(512, 512, 3),
    target_size=(512, 512), 
    batch_size=BATCH_SIZE
)

In [None]:
# Create validation generator
val_datagen = AugmentingDataGenerator(rescale=1./255)
val_generator = val_datagen.flow_from_directory(
    VAL_DIR, 
    MaskGenerator(512, 512, 3), 
    target_size=(512, 512), 
    batch_size=BATCH_SIZE, 
    classes=['val'], 
    seed=42
)

In [None]:
# Create testing generator
test_datagen = AugmentingDataGenerator(rescale=1./255)
test_generator = test_datagen.flow_from_directory(
    TEST_DIR, 
    MaskGenerator(512, 512, 3), 
    target_size=(512, 512), 
    batch_size=BATCH_SIZE, 
    classes=['test'],
    seed=42
)

In [None]:
# Pick out an example
test_data = next(test_generator)
(masked, mask), ori = test_data

# Show side by side
for i in range(len(ori)):
    _, axes = plt.subplots(1, 3, figsize=(20, 5))
    axes[0].imshow(masked[i,:,:,:])
    axes[1].imshow(mask[i,:,:,:] * 1.)
    axes[2].imshow(ori[i,:,:,:])
    plt.show()

# Training on ImageNet

In [None]:
def plot_callback(model, folder):
    """Called at the end of each epoch, displaying our previous test images,
    as well as their masked predictions and saving them to disk"""
    
    # Get samples & Display them        
    pred_img = model.predict([masked, mask])
    pred_time = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')

    # Clear current output and display test images
    for i in range(len(ori)):
        _, axes = plt.subplots(1, 3, figsize=(20, 5))
        axes[0].imshow(masked[i,:,:,:])
        axes[1].imshow(pred_img[i,:,:,:] * 1.)
        axes[2].imshow(ori[i,:,:,:])
        axes[0].set_title('Masked Image')
        axes[1].set_title('Predicted Image')
        axes[2].set_title('Original Image')
                
        plt.savefig(os.path.join(folder, f"00img_{i}_{pred_time}.png"))
        plt.close()


In [None]:
batch_print_callback = LambdaCallback(
    on_epoch_begin=lambda epoch, logs: print(f"Epoch {epoch}:", flush=True, end=''),
    on_batch_begin=lambda batch, logs: print(".", flush=True, end=''),
    on_epoch_end=lambda epoch, logs: print("")
)

### Phase 1 - with batch normalization

In [None]:
# Instantiate the model
model = PConvUnet(vgg_weights=r"/mnt/data/train_camp/pconv_keras_imagenet/pytorch_to_keras_vgg16.h5")
FOLDER = r'/mnt/data/train_camp/pconv_keras_imagenet/imagenet_phase1_paperMasks'
TEST_SAMPLE_FOLDER = os.path.join(FOLDER, 'test_samples')
if not os.path.isdir(TEST_SAMPLE_FOLDER):
    os.makedirs(TEST_SAMPLE_FOLDER)

In [None]:
# Run training for certain amount of epochs
model.fit_generator(
    train_generator, 
    steps_per_epoch=10000,
    validation_data=val_generator,
    validation_steps=1000,
    epochs=50,  
    verbose=0,
    callbacks=[
        TensorBoard(
            log_dir=FOLDER,
            write_graph=False
        ),
        ModelCheckpoint(
            FOLDER+'weights.{epoch:02d}-{loss:.2f}.h5',
            monitor='val_loss', 
            save_best_only=True, 
            save_weights_only=True
        ),
        batch_print_callback,
        LambdaCallback(
            on_epoch_end=lambda epoch, logs: plot_callback(model, TEST_SAMPLE_FOLDER)
        )#,TQDMNotebookCallback()
    ]
)

In [None]:
print("FINISHED!!!!!")