In [2]:
#import findspark
#findspark.init()
import os
from pyspark.sql import SparkSession
from pyspark import SparkContext

from dotenv import load_dotenv
load_dotenv()
key_id = os.getenv('key_id')
secret = os.getenv('secret')

#os.environ['PYSPARK_SUBMIT_ARGS'] = "--packages=org.apache.hadoop:hadoop-aws:2.7.3 pyspark-shell"
os.environ['PYSPARK_SUBMIT_ARGS'] = "--packages=com.amazonaws:aws-java-sdk-bundle:1.11.271,org.apache.hadoop:hadoop-aws:2.7.3 pyspark-shell"
#os.environ['PYSPARK_SUBMIT_ARGS'] = "--packages com.amazonaws:aws-java-sdk:1.7.4,org.apache.hadoop:hadoop-aws:2.7.3 pyspark-shell"



spark = SparkSession.builder.getOrCreate()
sc = spark.sparkContext.getOrCreate()
hadoop_conf=sc._jsc.hadoopConfiguration()
hadoop_conf.set("fs.s3a.impl", "org.apache.hadoop.fs.s3native.NativeS3FileSystem")
hadoop_conf.set("fs.s3a.awsAccessKeyId", key_id)
hadoop_conf.set("fs.s3a.awsSecretAccessKey", secret)

In [3]:
from typing import Iterator
import pandas as pd
from PIL import Image
import numpy as np
import io

import tensorflow as tf
from tensorflow.keras.applications.mobilenet import MobileNet, preprocess_input
#from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input
from tensorflow.keras.preprocessing.image import img_to_array

from pyspark.sql.functions import col, pandas_udf, PandasUDFType

In [4]:
sample_img_dir = 's3a://oc-av-p8/images'
text = 's3a://oc-av-p8test/text.txt'
text_df = spark.read.text(text)
print('************done********')
print(text_df)
text_df.show

#sample_img_dir = 'data/images/'
image_df = spark.read.format("binaryFile") \
            .option("pathGlobFilter", "*.jpg") \
            .option("recursiveFileLookup", "true") \
            .load(sample_img_dir)

#pkeys = sc.parallelize(keys)
    # Call the map step to handle reading in the file contents
#activation = pkeys.flatMap(map_func)

model = MobileNet(include_top=False)
bc_model_weights = sc.broadcast(model.get_weights())


"""
fonctions adaptées de la documentation databricks:
https://docs.databricks.com/applications/machine-learning/preprocess-data/transfer-learning-tensorflow.html
"""

def model_fn():
  """
  Returns a MobileNet model with top layer removed and broadcasted pretrained weights.
  """
  model = MobileNet(weights=None,
   #input_shape=(100,100,3), 
   pooling='avg',
   include_top=False)
  model.set_weights(bc_model_weights.value)
  return model

def preprocess(content):
  """
  Preprocesses raw image bytes for prediction.
  """
  img = Image.open(io.BytesIO(content)).resize([224, 224])
  arr = img_to_array(img)
  arr = tf.convert_to_tensor(arr)
  return preprocess_input(arr)

def featurize_series(model, content_series):
  """
  Featurize a pd.Series of raw images using the input model.
  :return: a pd.Series of image features
  """
  print('*************content_series*********', content_series)
  input = tf.stack(content_series.map(preprocess))
  print('input shape', input.shape)
  preds = model.predict(input)
  # For some layers, output features will be multi-dimensional tensors.
  # We flatten the feature tensors to vectors for easier storage in Spark DataFrames.
  output = [p.flatten() for p in preds]
  print('output shape:', np.shape(output))
  output = pd.Series(output)
  return output

@pandas_udf('array<float>')
def featurize_udf(content_series_iter: Iterator[pd.Series]) -> Iterator[pd.Series]:
  '''
  This method is a Scalar Iterator pandas UDF wrapping our featurization function.
  The decorator specifies that this returns a Spark DataFrame column of type ArrayType(FloatType).
  
  :param content_series_iter: This argument is an iterator over batches of data, where each batch
                              is a pandas Series of image data.
  '''
  # With Scalar Iterator pandas UDFs, we can load the model once and then re-use it
  # for multiple data batches.  This amortizes the overhead of loading big models.
  model = model_fn()
  for content_series in content_series_iter:
    yield featurize_series(model, content_series)
    
def get_filename(lines):
    separates = lines.flatMap(lambda line: line.split('/')[-2:])
    filename = separates.map(lambda sep: '/'.join(sep))
    return filename

features_df = image_df.select('path',
                              get_filename('path').alias('filename'), 
                              featurize_udf('content').alias('features'))
#features_df.write.mode("overwrite").csv("sample_file.csv")
#features_df.write.mode("overwrite").parquet('s3a://oc-av-p8/parquets')

print(image_df)
#image_df.select("image.origin", "image.width", "image.height").show(truncate=False)

print(features_df)
features_df.printSchema()
features_df.select('filename').show()

************done********
DataFrame[value: string]


AttributeError: 'str' object has no attribute 'flatMap'