In [1]:
import os
import numpy as np
import tensorflow as tf
from pyspark.sql.functions import udf
from pyspark.sql.types import DoubleType, ArrayType, FloatType
from tensorboard.notebook import display
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import shutil
import time
import pandas as pd
from PIL import Image
import uuid
from tensorflow.keras.applications.resnet50 import ResNet50
from pyspark.sql.functions import  pandas_udf, PandasUDFType

In [2]:
from pyspark.sql import SparkSession
os.environ["JAVA_HOME"] = "/usr/lib/jvm/java-8-openjdk-amd64" # Must corrispond to the current jdk used by colab
os.environ["SPARK_HOME"] = "/opt/spark/" # Must corrispond with the downloaded spark (1st line)
spark = SparkSession.builder.master("spark://192.168.1.38:7077").appName("testTrain").enableHiveSupport().getOrCreate()
sc = spark.sparkContext
sc.setLogLevel("Error")

In [3]:
def get_model():
    from tensorflow.keras.models import Sequential
    from tensorflow.keras.layers import Dense, Conv2D, MaxPool2D, Dropout, Flatten, BatchNormalization
    model = Sequential()

    model.add(Conv2D(filters=32, kernel_size=(3, 3), input_shape=(320, 320, 1), activation='relu'))
    model.add(BatchNormalization())
    model.add(Conv2D(filters=32, kernel_size=(3, 3), input_shape=(320, 320, 1), activation='relu'))
    model.add(BatchNormalization())
    model.add(MaxPool2D(pool_size=(2, 2)))

    model.add(Conv2D(filters=64, kernel_size=(3, 3), activation='relu'))
    model.add(BatchNormalization())
    model.add(Conv2D(filters=64, kernel_size=(3, 3), activation='relu'))
    model.add(BatchNormalization())
    model.add(MaxPool2D(pool_size=(2, 2)))

    model.add(Conv2D(filters=128, kernel_size=(3, 3), activation='relu'))
    model.add(BatchNormalization())
    model.add(Conv2D(filters=128, kernel_size=(3, 3), activation='relu'))
    model.add(BatchNormalization())
    model.add(MaxPool2D(pool_size=(2, 2)))

    model.add(Flatten())
    model.add(Dense(128, activation='relu'))
    model.add(Dropout(0.2))

    model.add(Dense(1, activation='sigmoid'))
    model.compile(loss='binary_crossentropy',
                  optimizer='adam',
                  metrics=['accuracy'])
    return model

In [4]:
model = get_model()

In [5]:
bc_model_weights = sc.broadcast(model.get_weights())

In [6]:
def fullpath(path, files):
    return  [(lambda x: path + x)(x) for x in files]
val_dir = "../chest_xray/chest_xray/val"
files = fullpath(val_dir+'/PNEUMONIA/',os.listdir(os.path.join(val_dir, 'PNEUMONIA')))
del files[0]
len(files)

8

In [8]:
file_name = "image_data2.parquet"
dbfs_file_path = "../chest_xray/chest_xray/dbfs/"
image_data = []
for file in files:
  img = Image.open(file)
  img = img.resize([320,320])
  data = np.asarray( img, dtype="float32" ).reshape([320*320*1])

  image_data.append({"data": data})

pandas_df = pd.DataFrame(image_data, columns = ['data'])
pandas_df.to_parquet(file_name)
#os.makedirs(dbfs_file_path)
shutil.copyfile(file_name, dbfs_file_path+file_name)

'../chest_xray/chest_xray/dbfs/image_data2.parquet'

In [11]:
df = spark.read.parquet("dbfs/"+file_name)
print(df.count())

8


In [12]:
spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "1024")
assert len(df.head()) > 0, "`df` should not be empty"

In [13]:
def parse_image(image_data):
  #image = tf.image.convert_image_dtype(image_data, dtype=tf.float32) * (2. / 255) - 1
  image = tf.reshape(image_data,[320,320,1])
  return image

In [14]:
@pandas_udf(ArrayType(FloatType()), PandasUDFType.SCALAR_ITER)
def predict_batch_udf(image_batch_iter):
  batch_size = 1
  model = get_model()
  model.set_weights(bc_model_weights.value)
  for image_batch in image_batch_iter:
    images = np.vstack(image_batch)
    dataset = tf.data.Dataset.from_tensor_slices(images)
    dataset = dataset.map(parse_image, num_parallel_calls=8).prefetch(5000).batch(batch_size)
    preds = model.predict(dataset)
    yield pd.Series(list(preds))



In [15]:
output_file_path = "../chest_xray/chest_xray/dbfs/predictions"
predictions_df = df.select(predict_batch_udf(col("data")).alias("prediction"))
predictions_df.write.mode("overwrite").parquet(output_file_path)

In [16]:
result_df = spark.read.load(output_file_path)
result_df.show()

+-------------+
|   prediction|
+-------------+
| [0.11440846]|
| [0.41464415]|
| [0.07537913]|
|  [0.0463596]|
| [0.03866735]|
| [0.15374437]|
|[0.077885926]|
| [0.42321396]|
+-------------+

