In [None]:
import keras
from keras.applications.inception_resnet_v2 import InceptionResNetV2
from keras.preprocessing import image
from keras.engine import Layer
from keras.applications.inception_resnet_v2 import preprocess_input
from keras.layers import Conv2D, UpSampling2D, InputLayer, Conv2DTranspose, Input, Reshape, merge, concatenate
from keras.layers import Activation, Dense, Dropout, Flatten
from keras.layers.normalization import BatchNormalization
from keras.callbacks import TensorBoard 
from keras.models import Sequential, Model
from keras.layers.core import RepeatVector, Permute
from keras.preprocessing.image import ImageDataGenerator, array_to_img, img_to_array, load_img
from skimage.color import rgb2lab, lab2rgb, rgb2gray, gray2rgb
from skimage.transform import resize
from skimage.io import imsave
import numpy as np
import onnx
import keras2onnx
import os
import random
import tensorflow as tf
from keras.callbacks import EarlyStopping,ModelCheckpoint
import pandas as pd



In [None]:
category='Plant'
subcategory='pineapple'

def load_images_from_directory(directory_path):
    image_data = []
    for filename in os.listdir(directory_path):
        img = load_img(os.path.join(directory_path, filename))
        img_array = img_to_array(img)
        image_data.append(img_array)
    image_data = 1.0/255 * np.array(image_data, dtype=float)
    return image_data

train_images_path = 'images//Train//' + category + '_all//' + subcategory
valid_images_path = 'images//Valid//' + category + '_all//' + subcategory
Xtrain = load_images_from_directory(train_images_path)
Xvalid = load_images_from_directory(valid_images_path)


#Load weights
inception = InceptionResNetV2(weights='imagenet', include_top=True)
inception.graph = tf.get_default_graph()

In [None]:
embed_input = Input(shape=(1000,))

#Encoder
encoder_input = Input(shape=(256, 256, 1))
encoder_output = Conv2D(64, (3,3), activation='relu', padding='same', strides=2)(encoder_input)
encoder_output = Conv2D(128, (3,3), activation='relu', padding='same')(encoder_output)
encoder_output = Conv2D(128, (3,3), activation='relu', padding='same', strides=2)(encoder_output)
encoder_output = Conv2D(256, (3,3), activation='relu', padding='same')(encoder_output)
encoder_output = Conv2D(256, (3,3), activation='relu', padding='same', strides=2)(encoder_output)
encoder_output = Conv2D(512, (3,3), activation='relu', padding='same')(encoder_output)
encoder_output = Conv2D(512, (3,3), activation='relu', padding='same')(encoder_output)
encoder_output = Conv2D(256, (3,3), activation='relu', padding='same')(encoder_output)

#Fusion
fusion_output = RepeatVector(32 * 32)(embed_input) 
fusion_output = Reshape(([32, 32, 1000]))(fusion_output)
fusion_output = concatenate([encoder_output, fusion_output], axis=3) 
fusion_output = Conv2D(256, (1, 1), activation='relu', padding='same')(fusion_output) 

#Decoder
decoder_output = Conv2D(128, (3,3), activation='relu', padding='same')(fusion_output)
decoder_output = UpSampling2D((2, 2))(decoder_output)
decoder_output = Conv2D(64, (3,3), activation='relu', padding='same')(decoder_output)
decoder_output = UpSampling2D((2, 2))(decoder_output)
decoder_output = Conv2D(32, (3,3), activation='relu', padding='same')(decoder_output)
decoder_output = Conv2D(16, (3,3), activation='relu', padding='same')(decoder_output)
decoder_output = Conv2D(2, (3, 3), activation='tanh', padding='same')(decoder_output)
decoder_output = UpSampling2D((2, 2))(decoder_output)

model = Model(inputs=[encoder_input, embed_input], outputs=decoder_output)
#model = keras.models.load_model(os.path.join(os.getcwd(), 'models/models_h5/with_fusion/best_model_' + subcategory + '.h5'))

In [None]:
def create_inception_embedding(grayscaled_rgb):
    grayscaled_rgb_resized = []
    for i in grayscaled_rgb:
        i = resize(i, (299, 299, 3), mode='constant')
        grayscaled_rgb_resized.append(i)
    grayscaled_rgb_resized = np.array(grayscaled_rgb_resized)
    grayscaled_rgb_resized = preprocess_input(grayscaled_rgb_resized)
    with inception.graph.as_default():
        embed = inception.predict(grayscaled_rgb_resized)  
    return embed

# Image transformer
datagen = ImageDataGenerator(
        shear_range=0.4,
        zoom_range=0.4,
        rotation_range=40,
        horizontal_flip=True
)

# Generate training data
batch_size = 50

# EarlyStopping callback
early_stopping = EarlyStopping(monitor='val_loss', patience=125, restore_best_weights=False)

checkpoint = ModelCheckpoint('models/models_h5/with_fusion/best_model_' + subcategory + '.h5', monitor='val_loss', 
                             save_best_only=True)


def image_a_b_gen(X,batch_size):
    for batch in datagen.flow(X, batch_size=batch_size):
        grayscaled_rgb = gray2rgb(rgb2gray(batch))
        embed = create_inception_embedding(grayscaled_rgb)
        lab_batch = rgb2lab(batch)
        X_batch = lab_batch[:,:,:,0]
        X_batch = X_batch.reshape(X_batch.shape+(1,))
        Y_batch = lab_batch[:,:,:,1:] / 128
        yield [X_batch, embed], Y_batch


# Train model with validation data
model.compile(optimizer='rmsprop', loss='mse')

hist = model.fit_generator(
        image_a_b_gen(Xtrain, batch_size),
        epochs=10000,
        steps_per_epoch=len(Xtrain) // batch_size,
        validation_data=image_a_b_gen(Xvalid, batch_size),
        validation_steps=len(Xvalid) // batch_size,
        callbacks=[early_stopping, checkpoint])  # Dodaj ModelCheckpoint do listy callbacks



In [None]:
model.save(os.path.join(os.getcwd(), 'models/models_h5/with_fusion/' + subcategory + '.h5'))

# Konwersja modelu na format ONNX
onnx_model = keras2onnx.convert_keras(model)

# Zapis modelu ONNX do pliku
onnx.save_model(onnx_model, os.path.join(os.getcwd(), 'models/models_onnx/with_fusion/' + subcategory + '.onnx'))


In [None]:
hist_df = pd.DataFrame(hist.history) 

hist_csv_file = 'history_' + subcategory + '.csv'
with open(hist_csv_file, mode='w') as f:
    hist_df.to_csv(f)