# Image Quality Inference
---

In [1]:
import os
import numpy as np
import pandas as pd
from PIL import Image
from io import BytesIO

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.models import model_from_json

In [2]:
gpu_devices = tf.config.experimental.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(gpu_devices[0], True)
tf.config.threading.set_intra_op_parallelism_threads(4)
tf.config.threading.set_inter_op_parallelism_threads(4)

In [3]:
import findspark
findspark.init('/usr/lib/spark2')
from pyspark.sql import SparkSession

In [4]:
os.environ['PYSPARK_SUBMIT_ARGS'] = '--archives tf-env-2.4.zip#venv pyspark-shell'
os.environ['PYSPARK_PYTHON'] = 'venv/bin/python'

spark = (
    SparkSession.builder
    .appName('Image pipeline')
    .master('yarn')
    .config(
        'spark.driver.extraJavaOptions',
        ' '.join('-D{}={}'.format(k, v) for k, v in {
            'http.proxyHost': 'webproxy.eqiad.wmnet',
            'http.proxyPort': '8080',
            'https.proxyHost': 'webproxy.eqiad.wmnet',
            'https.proxyPort': '8080',
        }.items()))
    .config('spark.jars.packages', 'com.linkedin.sparktfrecord:spark-tfrecord_2.11:0.2.4')
    .config("spark.driver.memory", "4g") 
    .config('spark.dynamicAllocation.maxExecutors', 128) 
    .config("spark.executor.memory", "8g") 
    .config("spark.executor.cores", 4) 
    .config("spark.sql.shuffle.partitions", 512)
    .config("spark.sql.execution.arrow.maxRecordsPerBatch", 1024)
    .getOrCreate()
)

In [43]:
from pyspark.sql import functions as F
from pyspark.sql.types import ArrayType, StringType, FloatType, IntegerType
from pyspark import SparkContext
sc = SparkContext.getOrCreate()

#### Prepare trained model

In [6]:
model = keras.models.load_model('/home/aikochou/ImagePipeline/image_quality_model')
model_json = model.to_json()
bc_model_weights = sc.broadcast(model.get_weights())

#### Load the data into Spark DataFrames

In [58]:
df = (spark.read.schema('image_file_name string, image_bytes binary, label int').format("tfrecord")
        .option("recordType", "Example").load('image.tfrecords'))

In [59]:
df.show()

+--------------------+--------------------+-----+
|     image_file_name|         image_bytes|label|
+--------------------+--------------------+-----+
|Star_Magnolia_Mag...|[FF D8 FF E0 00 1...|    1|
|Heppenheim_BW_201...|[FF D8 FF E0 00 1...|    1|
|Lloyd's_Building_...|[FF D8 FF E0 00 1...|    1|
|B-Spandau_Okt12_R...|[FF D8 FF E0 00 1...|    1|
|Rösrath_Germany_W...|[FF D8 FF E0 00 1...|    1|
|Marmoutier_eglise...|[FF D8 FF E0 00 1...|    1|
|15-11-25-Maribor-...|[FF D8 FF E0 00 1...|    1|
|Mägiste_peatuskoh...|[FF D8 FF E0 00 1...|    1|
|USMC-111209-M-XR0...|[FF D8 FF E0 00 1...|    0|
|Eglise_Saint-Séve...|[FF D8 FF E0 00 1...|    0|
|Церковь_-_panoram...|[FF D8 FF E0 00 1...|    0|
|Exterieur_zuidgev...|[FF D8 FF E0 00 1...|    0|
|Поле_-_panoramio_...|[FF D8 FF E0 00 1...|    0|
|Père-Lachaise_-_D...|[FF D8 FF E0 00 1...|    1|
|Rosa_'Aachener_Do...|[FF D8 FF E0 00 1...|    0|
|Window_of_the_Sai...|[FF D8 FF E0 00 1...|    1|
|St_Petka_church_-...|[FF D8 FF E0 00 1...|    1|


#### Run model inference via pandas UDF

In [60]:
image_size = 180
batch_size = 64

@F.pandas_udf(ArrayType(FloatType()))
def process_image(image_bytes):
    ret = []
    for image in image_bytes:
        im = Image.open(BytesIO(image))
        im = im.resize([image_size,image_size])
        image_data = [float(i) for i in np.asarray(im, dtype='float32').flatten()]
        ret.append(image_data)
    return pd.Series(ret)

def parse_image(image_data):
    image = tf.image.convert_image_dtype(image_data, dtype=tf.float32) * (2. / 255) - 1 # normalization
    image = tf.reshape(image,[image_size,image_size,3])
    return image

@F.pandas_udf(ArrayType(FloatType()))
def predict_batch_udf(image_batch):
    model = model_from_json(model_json) # load the model graph 
    model.set_weights(bc_model_weights.value) # set the weights from the broadcasted variables
    images = np.vstack(image_batch)
    dataset = tf.data.Dataset.from_tensor_slices(images)
    dataset = dataset.map(parse_image, num_parallel_calls=8).prefetch(5000).batch(batch_size)
    preds = model.predict(dataset)
    return pd.Series(list(preds))

predict_label_udf = F.udf(lambda y: "Quality" if y[0] > 0.5 else "Random", StringType())

In [61]:
(df
 .withColumn('image_arr', process_image(F.col('image_bytes')))
 .withColumn('prediction', predict_batch_udf(F.col('image_arr')))
 .withColumn('pred_label', predict_label_udf(F.col('prediction')))
 .select('image_file_name', 'label', 'pred_label')
 .write
 .mode('overwrite')
 .parquet('output.parquet')
)

#### Load and check the prediction results

In [62]:
result_df = spark.read.load('output.parquet')

In [63]:
result_df.show()

+--------------------+-----+----------+
|     image_file_name|label|pred_label|
+--------------------+-----+----------+
|Star_Magnolia_Mag...|    1|   Quality|
|Heppenheim_BW_201...|    1|    Random|
|Lloyd's_Building_...|    1|    Random|
|B-Spandau_Okt12_R...|    1|   Quality|
|Rösrath_Germany_W...|    1|   Quality|
|Marmoutier_eglise...|    1|   Quality|
|15-11-25-Maribor-...|    1|   Quality|
|Mägiste_peatuskoh...|    1|    Random|
|USMC-111209-M-XR0...|    0|    Random|
|Eglise_Saint-Séve...|    0|   Quality|
|Церковь_-_panoram...|    0|    Random|
|Exterieur_zuidgev...|    0|    Random|
|Поле_-_panoramio_...|    0|    Random|
|Père-Lachaise_-_D...|    1|   Quality|
|Rosa_'Aachener_Do...|    0|   Quality|
|Window_of_the_Sai...|    1|   Quality|
|St_Petka_church_-...|    1|   Quality|
|Ribadavia_-_Galiz...|    1|    Random|
|1991_Volkswagen_T...|    0|    Random|
|Beach_(4706533028...|    0|    Random|
+--------------------+-----+----------+
only showing top 20 rows

