# ResNet50 feature extraction
---

In [None]:
import os
import time
import numpy as np
import pandas as pd
from PIL import Image
from io import BytesIO
import base64

In [None]:
import tensorflow as tf
gpu_devices = tf.config.experimental.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(gpu_devices[0], True)
tf.config.threading.set_intra_op_parallelism_threads(4)
tf.config.threading.set_inter_op_parallelism_threads(4)

from tensorflow.keras.models import model_from_json
from tensorflow.keras.applications.resnet50 import ResNet50
from tensorflow.keras.applications.imagenet_utils import preprocess_input

In [None]:
import findspark
findspark.init('/usr/lib/spark2')
from pyspark.sql import SparkSession

In [None]:
os.environ['PYSPARK_SUBMIT_ARGS'] = ' --archives tf-env-2.4.zip#venv pyspark-shell'
os.environ['PYSPARK_PYTHON'] = 'venv/bin/python'

spark = (
    SparkSession.builder
    .appName('ResNet50')
    .master('yarn')
    .config(
        'spark.driver.extraJavaOptions',
        ' '.join('-D{}={}'.format(k, v) for k, v in {
            'http.proxyHost': 'webproxy.eqiad.wmnet',
            'http.proxyPort': '8080',
            'https.proxyHost': 'webproxy.eqiad.wmnet',
            'https.proxyPort': '8080',
        }.items()))
    .config('spark.jars.packages', 'org.apache.spark:spark-avro_2.11:2.4.4')
    .config("spark.driver.memory", "4g") 
    .config('spark.dynamicAllocation.maxExecutors', 128) 
    .config("spark.executor.memory", "8g") 
    .config("spark.executor.cores", 4) 
    .config("spark.sql.shuffle.partitions", 512)
    .config("spark.sql.execution.arrow.maxRecordsPerBatch", 1024)
    .getOrCreate()
)

In [None]:
from pyspark.sql import functions as F
from pyspark.sql.types import ArrayType, FloatType
from pyspark import SparkContext
sc = SparkContext.getOrCreate()

#### Load the data into Spark DataFrames

In [None]:
pixels = spark.read.format('avro').load('image.avro')

In [None]:
pixels.printSchema()

root
 |-- i: integer (nullable = true)
 |-- image_url: string (nullable = true)
 |-- project: string (nullable = true)
 |-- image_file_name: string (nullable = true)
 |-- thumbnail_size: string (nullable = true)
 |-- image: struct (nullable = true)
 |    |-- image_bytes_b64: string (nullable = true)
 |    |-- format: string (nullable = true)
 |    |-- width: integer (nullable = true)
 |    |-- height: integer (nullable = true)
 |    |-- image_bytes_sha1: string (nullable = true)
 |    |-- error: string (nullable = true)



#### Prepare trained model

In [None]:
model = ResNet50(weights='imagenet', pooling='max', include_top=False)
model_json = model.to_json()
bc_model_weights = sc.broadcast(model.get_weights())

#### Run model inference via pandas UDF

In [None]:
image_size = 224

@F.pandas_udf(returnType=ArrayType(FloatType()))
def extract_features(image_bytes):
    images = []
    for row in image_bytes:
        try:
            img = Image.open(BytesIO(base64.b64decode(row))).convert('RGB').resize([image_size, image_size])
        except:
            images.append(np.zeros((image_size, image_size, 3), dtype=np.float32))
        else:
            images.append(np.asarray(img, dtype=np.float32))
    model = model_from_json(model_json)
    model.set_weights(bc_model_weights.value)
    images = np.vstack(images)
    images = images.reshape((-1,image_size,image_size,3))
    x = np.copy(images)
    x = preprocess_input(x)
    features = model.predict(x)
    return pd.Series([row.tolist() for row in features])

In [None]:
(pixels
    .withColumn("features", extract_features(F.col("image.image_bytes_b64")))
    .select("i","image_file_name","features","image_url")
    .write
    .mode("overwrite")
    .parquet('output.parquet')
)

#### Load and check the prediction results

In [None]:
result_df = spark.read.load('output.parquet')

In [None]:
result_df.show()

+-------+--------------------+--------------------+--------------------+
|      i|     image_file_name|            features|           image_url|
+-------+--------------------+--------------------+--------------------+
| 662714|Seribu_Rumah_Gada...|[3.1161337, 2.381...|https://upload.wi...|
|2816800|DM_Rad_2017_Männe...|[0.97691995, 4.64...|https://upload.wi...|
|1752258|Jackie_Chan_TIFF_...|[7.088956, 4.9775...|https://upload.wi...|
|2373722|      Amblypigid.jpg|[0.44165516, 4.50...|http://upload.wik...|
|2156691|Keski-Uudenmaan_p...|[10.556867, 3.813...|https://upload.wi...|
|3267094|GuentherZ_2012-02...|[1.130087, 2.2200...|https://upload.wi...|
|1334754|RO_B_Batiste_chur...|[1.3544466, 11.27...|https://upload.wi...|
| 305959|  Biokilereaktor.png|[1.8614346, 6.087...|https://upload.wi...|
|3238121|Saint_Davids_Naas...|[4.4893565, 8.195...|https://upload.wi...|
|1928771|Waltraud_Starck_2...|[1.0994802, 5.066...|https://upload.wi...|
| 276264|    LA2_vemardet.jpg|[4.3736835, 0.595...|