# Objective
Set-up a first big data architecture using AWS products (Mobile application with a fruit pictures classifier engine)

# Data
Link to upload data: https://www.kaggle.com/moltean/fruits

# Table of contents <a class="anchor" id="chapter0"></a> 
* [Imports and declarations](#chapter1)
    * [Import packages](#sub1_1)
    * [Declare constants](#sub1_2)
* [Exploration of the full dataset](#chapter2)
    * [Get picture information from local full dataset](#sub2_1)
    * [Explore picture information](#sub2_2)
    * [Get class information](#sub2_3) 
    * [Target label encoding](#sub2_4)
* [Preparation of the local sampled picture set](#chapter3)
    * [Create a local sampled picture set](#sub3_1)
    * [Get picture information from local sampled picture set](#sub3_2)
* [Create and configure a Spark Session](#chapter4)
* [Load data](#chapter5)
    * [Load pictures](#sub5_1)
    * [Distinguish Target](#sub5_2)
* [Dimension reduction](#chapter6)
    * [Instantiate a ResNet50 model with pre-trained weights](#sub6_1)
    * [Functions](#sub6_2)
    * [Features extraction](#sub6_3)
* [Train a new model using pre-computed features](#chapter7)
    * [Prepare my new model](#sub7_1)
    * [Train my new model](#sub7_2)
* [Go to End](#chapter100)

# Imports and declarations <a class="anchor" id="chapter1"></a>

## Import packages <a class="anchor" id="sub1_1"></a>

In [1]:
import P8_02_module as MyMod

import numpy as np
import pandas as pd 

import os
from os import path
import glob
import shutil
import time

import matplotlib.pyplot as plt
#from matplotlib.image import imread

#from cv2 import cv2
import PIL
from PIL import Image
#, ImageDraw, ImageOps, ImageFilter

#from sklearn import cluster
#from sklearn import decomposition
#from sklearn import metrics
#from sklearn.neighbors import KNeighborsClassifier
#from sklearn.model_selection import learning_curve
#from sklearn.utils import shuffle

#import boto3

# Pyspark
import findspark
findspark.init("C:\spark\spark-3.2.1-bin-hadoop3.2") # Path to hadoop 
import pyspark
print("PySpark version:{}".format(pyspark.__version__)) # Verify PySpark version
from pyspark import SparkContext
from pyspark.sql import SparkSession, Row
from pyspark.sql.functions import element_at, split
from pyspark.sql.functions import col, pandas_udf, PandasUDFType
#import pyarrow
#print("PyArrow version:{}".format(pyarrow.__version__)) # Verify PySpark version

#from pyspark.sql.types import *
#import pyspark.sql.functions as F
#from pyspark.ml.image import ImageSchema # RDD (Resilient Distributed Dataset) from .jpg file
#from pyspark.ml.linalg import DenseVector, VectorUDT

import keras

# Tensorflow
#import tensorflow as tf

from tensorflow.keras.applications.resnet50 import ResNet50
from tensorflow.keras.applications.resnet50 import preprocess_input
from tensorflow.keras import optimizers
#from tensorflow.keras.preprocessing.image import ImageDataGenerator
#from tensorflow.keras.optimizers import Adam
#from tensorflow.python.keras.callbacks import EarlyStopping, ModelCheckpoint


#, decode_predictions
#from tensorflow.keras.preprocessing import image
#from tensorflow.keras.utils import image_dataset_from_directory
#from tensorflow.keras import layers,Dense,Flatten
#from tensorflow.keras.models import Sequential
#from tensorflow.keras.losses import BinaryCrossentropy
#from tensorflow.keras.metrics import BinaryAccuracy

PySpark version:3.2.1


## Declare constants <a class="anchor" id="sub1_2"></a>

In [2]:
# Sample limitations
GET_PICTURES_NB_PER_CLASS = 2

# Local repositories
LOCAL_SRC_PATH = '../fruits-360-original-size/'
LOCAL_DEST_PATH = 'C:/fruits-360-sample/'

# Image size
IMAGE_RESIZE = 224

# CNN model
BATCH_SIZE = 16
RESNET50_POOLING_AVERAGE = 'avg'
DENSE_LAYER_ACTIVATION = 'softmax'
OBJECTIVE_FUNCTION = 'categorical_crossentropy'
LOSS_METRICS = ['accuracy']
NUM_EPOCHS = 30
EARLY_STOP_PATIENCE = 3  # EARLY_STOP_PATIENCE must be < NUM_EPOCHS

# S3 Bucket
BUCKET_NAME = "moncompartimentamoi"

# KMeans hyper-param
#KMEANS_N_CLUSTERS = 90

# Exploration of the full dataset <a class="anchor" id="chapter2"></a>

## Get picture information from local full dataset <a class="anchor" id="sub2_1"></a>

In [None]:
def rep_2_picture_info(path):

    # Initiate Dataframe with Dataset names, Target class names and Picture names
    df = pd.DataFrame(columns = ['FullFileName', 'Dataset', 'Target', 'Picture', 'FileSize (in KB)']) 

    for file in glob.iglob(path+'**/*.jpg', recursive = True):

        lst = file.split('\\')
        
        # update DataFrame
        lst.append(os.path.getsize(file) / 1024)  # in KBytes    
        lst[0] = lst[0] + "/" + lst[1] + "/" + lst[2] + "/" + lst[3]
        df.loc[len(df)] = lst

    return df

In [None]:
df_main = rep_2_picture_info(LOCAL_SRC_PATH)
df_main

## Explore picture information <a class="anchor" id="sub2_2"></a>

### Assess volumes and modalities

In [None]:
df_main.describe()

 > 12 455 pictures, 24 target classes, 3 datasets  
 > 958 picture names mean that some pictures have the same name and are not classified in the same repository

### Count pictures by target class and dataset

In [None]:
pd.DataFrame(df_main.groupby(['Target', 'Dataset'])['Picture'].count())

### Count pictures by target class

In [None]:
pd.DataFrame(df_main.groupby(['Target'])['Picture'].count())

### Count dataset modality by picture name

In [None]:
df_dataset_mod = pd.DataFrame(df_main.groupby(['Picture'])['Dataset'].nunique())
df_dataset_mod.rename(columns={'Dataset':'Dataset_mod'}, inplace=True)
len(df_dataset_mod[df_dataset_mod['Dataset_mod'] > 1])

> No file with the same name in the different datasets

### Count target class modality by picture name

In [None]:
df_target_mod = pd.DataFrame(df_main.groupby(['Picture'])['Target'].nunique())
df_target_mod.rename(columns={'Target':'Target_mod'}, inplace=True)
df_target_mod[df_target_mod['Target_mod'] > 1]

In [None]:
df_target_mod = df_target_mod.reset_index(drop=False)

In [None]:
pd.DataFrame(df_target_mod.groupby(['Target_mod'])['Picture'].count())

 > many files with the same name in the different target classes   
 > for instance, 156 files have the same name and appear in 24 different target class

In [None]:
df_main[df_main['Picture'] == 'r0_0.jpg'][['Picture', 'Target', 'Dataset']]

In [None]:
pict = Image.open(df_main['FullFileName'].iloc[60])
plt.imshow(pict)
plt.show()

In [None]:
pict = Image.open(df_main['FullFileName'].iloc[40])
plt.imshow(pict)
plt.show()

 > File name format: r?_image_index.jpg (e.g. r0_31.jpg or r1_12.jpg)  
 > "r?" stands for rotation axis (first one is r0)

### Distinguish rotation axis and index

In [None]:
# Laurence: supprimler les FutureWarning
df_main["Rotation"], df_main["Index"] = df_main["Picture"].str.split("_", 1).str
df_main["Rotation"] = df_main["Rotation"].str.replace('r','')
df_main["Index"] = df_main["Index"].str.replace('.jpg','')
df_main

In [None]:
pd.DataFrame(df_main["Rotation"].unique())

In [None]:
pict = Image.open(df_main['FullFileName'].iloc[df_main[df_main['Rotation'] == '0'].head(1).index[0]])
plt.imshow(pict)
plt.show()

In [None]:
pict = Image.open(df_main['FullFileName'].iloc[df_main[df_main['Rotation'] == '1'].head(1).index[0]])
plt.imshow(pict)
plt.show()

In [None]:
pict = Image.open(df_main['FullFileName'].iloc[df_main[df_main['Rotation'] == '2'].head(1).index[0]])
plt.imshow(pict)
plt.show()

 > 0 - queue top or down > rotation around the z-axis  
 > 1 - queue behind or ahead > rotation around the x-axis  
 > 2 - queue left or right > rotation around the y-axis  

### Target class count distribution

In [None]:
def distribution(df_in, dataset):
    
    df = df_in.copy()
    
    if dataset in ['Training', 'Test', 'Validation']:
        df.drop(df[df['Dataset'] != dataset].index, inplace=True)
    elif dataset != '*':
        print("dataset argument should be 'Training', 'Test', 'Validation' or '*'")
        return -1
        
    df_distrib = pd.DataFrame(df.groupby(['Target'])['Picture'].count())
    df_distrib.reset_index(drop=False, inplace=True)
    df_distrib.rename(columns={'Picture':'Picture count', 'Target':'Class'}, inplace=True)
    df_distrib = df_distrib.sort_values(by='Class', ascending=False)    

    if len(df_distrib) == 0: return -1
    
    df_distrib.plot.barh(x='Class', y='Picture count', figsize=(12, 10))    
    
    return 1

In [None]:
ret = distribution(df_main, 'Training')
ret = distribution(df_main, 'Test')
ret = distribution(df_main, 'Validation')
ret = distribution(df_main, '*')

### Target class filesize average distribution  
Logitech C920 camera and dedicated algorithm which extract the fruit from the background

In [None]:
pd.DataFrame(df_main.groupby('Target')['FileSize (in KB)'].mean())

In [None]:
pd.DataFrame(df_main.groupby('Dataset')['FileSize (in KB)'].mean())

In [None]:
df_main['FileSize (in KB)'].mean(), df_main['FileSize (in KB)'].sum()/1024**2

## Get class information <a class="anchor" id="sub2_3"></a>

In [None]:
path = '../fruits-360-original-size/Meta/'

df_class_add = pd.DataFrame(columns = ['PathName', 'Target', 'TxtName'])
df_meta = pd.DataFrame(columns = ['Flag', 'Value'])

for file in glob.iglob(path+'**/info.txt', recursive = True):
    
    df_meta_add = pd.read_csv(file, sep="=", names=['Flag', 'Value'])
        
    df_class_add.loc[0] = file.split('\\')   
        
    df_meta = pd.concat([df_class_add.join(df_meta_add, how='cross'), df_meta])

del df_class_add, df_meta_add

df_meta.drop(columns=['PathName', 'TxtName'], inplace=True)
df_meta = df_meta.sort_values(['Target', 'Flag'], ascending=True)
df_meta.reset_index(drop=True, inplace=True)
df_meta

In [None]:
df_meta['Flag'].nunique()

In [None]:
pd.DataFrame(df_meta['Flag'].unique(), columns=['Flag']).sort_values(by='Flag')

## Target label encoding  <a class="anchor" id="sub2_4"></a>

In [None]:
df_main, df_target_mapping = MyMod.encode_LabelEncoder(df_main, 'Target')
df_main.head(5)

In [None]:
df_meta, df_target_mapping = MyMod.encode_LabelEncoder(df_meta, 'Target')
df_meta.head(5)

In [None]:
df_target_mapping

# Preparation of the local sampled picture set <a class="anchor" id="chapter3"></a>

## Create a local sampled picture set <a class="anchor" id="sub3_1"></a>

In [None]:
target_prec = ""
for file in glob.iglob(LOCAL_SRC_PATH+'**/*.jpg', recursive = True):
    
    lst = file.split('\\')
    
    # Limit number of pictures per class(GET_PICTURES_NB_PER_CLASS)
    if lst[2] == target_prec: 
        i += 1
    else:
        i = 0
    target_prec = lst[2]
    
    if (lst[1] == 'Training' and i < 2*GET_PICTURES_NB_PER_CLASS) or \
                    (lst[1] != 'Training' and i < GET_PICTURES_NB_PER_CLASS):     
                
        # create the destination repository if necessary
        if not os.path.exists(LOCAL_DEST_PATH+lst[1]+"/"+lst[2]):
             os.makedirs(LOCAL_DEST_PATH+lst[1]+"/"+lst[2])

        # copy the file to the destination repository
        shutil.copyfile(LOCAL_SRC_PATH+lst[1]+"/"+lst[2]+"/"+lst[3], LOCAL_DEST_PATH+lst[1]+"/"+lst[2]+"/"+lst[3])

## Get picture information from local sampled picture set <a class="anchor" id="sub3_2"></a>

In [None]:
df_main = rep_2_picture_info(LOCAL_DEST_PATH)
df_main

# Create and configure a Spark Session <a class="anchor" id="chapter4"></a> 
- builder()     generator to create the session
- master()      master name (yarn, mesos or local[number of cores to use]. Number of partitions for distributed objects.
- appName()     name the application
- getOrCreate() create a new session or return the existing one
- config()      configure session
    - spark.sql.repl.eagerEval.enabled: PySpark DataFrame quick assessment in Jupyter  
    - spark.sql.repl.eagerEval.maxNumRows: number of lines to show      
    - spark.sql.execution.arrow.pyspark.enabled: to use Arrow optimiser for Spark to Pandas DataFrame conversions (toPandas or createDataFrame)  

In [3]:
# Instantiate SparkSession
spark = SparkSession.builder\
        .master("local[*]")\
        .appName('P8_03')\
        .getOrCreate()

spark.conf.set('spark.sql.repl.eagerEval.enabled', True)
spark.conf.set('spark.sql.repl.eagerEval.maxNumRows', 5)
spark.conf.set('spark.sql.execution.arrow.pyspark.enabled', True)

# Pandas UDFs on large records (e.g., very large images) can run into Out Of Memory (OOM) errors.
# If you hit such errors in the cell below, try reducing the Arrow batch size via `maxRecordsPerBatch`.
spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "1024")

In [4]:
# Get the SparkContext
sc = spark.sparkContext

In [5]:
# User web interface
spark

# Load data <a class="anchor" id="chapter5"></a> 

## Load pictures <a class="anchor" id="sub5_1"></a> 

### Get pictures in a Spark DataFrame in "image" format

In [6]:
# Laurence: do not lauch: never ends !!
#df_pictures = spark.read.format('image').load(LOCAL_DEST_PATH + 'Training/*', inferschema=True)
df_pictures = spark.read.format('image').load(LOCAL_DEST_PATH + 'Training/test')
#Py4JJavaError: An error occurred while calling o33.load.
#: java.lang.UnsatisfiedLinkError: org.apache.hadoop.io.nativeio.NativeIO$Windows.access0(Ljava/lang/String;I)

Py4JJavaError: An error occurred while calling o33.load.
: java.lang.UnsatisfiedLinkError: org.apache.hadoop.io.nativeio.NativeIO$Windows.access0(Ljava/lang/String;I)Z
	at org.apache.hadoop.io.nativeio.NativeIO$Windows.access0(Native Method)
	at org.apache.hadoop.io.nativeio.NativeIO$Windows.access(NativeIO.java:793)
	at org.apache.hadoop.fs.FileUtil.canRead(FileUtil.java:1215)
	at org.apache.hadoop.fs.FileUtil.list(FileUtil.java:1420)
	at org.apache.hadoop.fs.RawLocalFileSystem.listStatus(RawLocalFileSystem.java:601)
	at org.apache.hadoop.fs.FileSystem.listStatus(FileSystem.java:1972)
	at org.apache.hadoop.fs.FileSystem.listStatus(FileSystem.java:2014)
	at org.apache.hadoop.fs.ChecksumFileSystem.listStatus(ChecksumFileSystem.java:761)
	at org.apache.spark.util.HadoopFSUtils$.listLeafFiles(HadoopFSUtils.scala:225)
	at org.apache.spark.util.HadoopFSUtils$.$anonfun$parallelListLeafFilesInternal$1(HadoopFSUtils.scala:95)
	at scala.collection.TraversableLike.$anonfun$map$1(TraversableLike.scala:286)
	at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
	at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
	at scala.collection.TraversableLike.map(TraversableLike.scala:286)
	at scala.collection.TraversableLike.map$(TraversableLike.scala:279)
	at scala.collection.AbstractTraversable.map(Traversable.scala:108)
	at org.apache.spark.util.HadoopFSUtils$.parallelListLeafFilesInternal(HadoopFSUtils.scala:85)
	at org.apache.spark.util.HadoopFSUtils$.parallelListLeafFiles(HadoopFSUtils.scala:69)
	at org.apache.spark.sql.execution.datasources.InMemoryFileIndex$.bulkListLeafFiles(InMemoryFileIndex.scala:158)
	at org.apache.spark.sql.execution.datasources.InMemoryFileIndex.listLeafFiles(InMemoryFileIndex.scala:131)
	at org.apache.spark.sql.execution.datasources.InMemoryFileIndex.refresh0(InMemoryFileIndex.scala:94)
	at org.apache.spark.sql.execution.datasources.InMemoryFileIndex.<init>(InMemoryFileIndex.scala:66)
	at org.apache.spark.sql.execution.datasources.DataSource.createInMemoryFileIndex(DataSource.scala:565)
	at org.apache.spark.sql.execution.datasources.DataSource.resolveRelation(DataSource.scala:409)
	at org.apache.spark.sql.DataFrameReader.loadV1Source(DataFrameReader.scala:274)
	at org.apache.spark.sql.DataFrameReader.$anonfun$load$3(DataFrameReader.scala:245)
	at scala.Option.getOrElse(Option.scala:189)
	at org.apache.spark.sql.DataFrameReader.load(DataFrameReader.scala:245)
	at org.apache.spark.sql.DataFrameReader.load(DataFrameReader.scala:188)
	at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at sun.reflect.NativeMethodAccessorImpl.invoke(Unknown Source)
	at sun.reflect.DelegatingMethodAccessorImpl.invoke(Unknown Source)
	at java.lang.reflect.Method.invoke(Unknown Source)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
	at py4j.Gateway.invoke(Gateway.java:282)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.ClientServerConnection.waitForCommands(ClientServerConnection.java:182)
	at py4j.ClientServerConnection.run(ClientServerConnection.java:106)
	at java.lang.Thread.run(Unknown Source)


In [7]:
# Fetch paths to each jpg file
lst_path = [item.replace('\\', '/') for item in list(glob.iglob(LOCAL_DEST_PATH + 'Training/**/*.jpg', recursive = True))]

In [8]:
# inferSchema: True to infer columns types from data
df_pictures = spark.read.format('image').load(lst_path, inferschema=True)

In [9]:
# Spark DataFrame visualisation
df_pictures.show(5)

+--------------------+
|               image|
+--------------------+
|{file:/C:/fruits-...|
|{file:/C:/fruits-...|
|{file:/C:/fruits-...|
|{file:/C:/fruits-...|
|{file:/C:/fruits-...|
+--------------------+
only showing top 5 rows



In [10]:
# Spark DataFrame scheme
df_pictures.printSchema()

root
 |-- image: struct (nullable = true)
 |    |-- origin: string (nullable = true)
 |    |-- height: integer (nullable = true)
 |    |-- width: integer (nullable = true)
 |    |-- nChannels: integer (nullable = true)
 |    |-- mode: integer (nullable = true)
 |    |-- data: binary (nullable = true)



In [11]:
df_pictures.select('image.origin', 'image.height', 'image.width', 'image.nChannels', 'image.mode').show(2, False, True)

-RECORD 0---------------------------------------------------------------
 origin    | file:/C:/fruits-360-sample/Training/apple_hit_1/r0_102.jpg 
 height    | 891                                                        
 width     | 840                                                        
 nChannels | 3                                                          
 mode      | 16                                                         
-RECORD 1---------------------------------------------------------------
 origin    | file:/C:/fruits-360-sample/Training/apple_hit_1/r0_100.jpg 
 height    | 885                                                        
 width     | 844                                                        
 nChannels | 3                                                          
 mode      | 16                                                         
only showing top 2 rows



### Get pictures in a Spark DataFrame in "binaryFile" format

In [12]:
# Valentin: df_pictures = spark.read.format("binaryFile") \
#                                   .option("pathGlobFilter", "*.jpg") \
#                                   .option("recursiveFileLookup", "true").load(LOCAL_DEST_PATH + 'Training/*')

In [13]:
df_pictures = spark.read.format("binaryFile") \
            .option("pathGlobFilter", "*.jpg") \
            .option("recursiveFileLookup", "true") \
            .load(lst_path)

In [14]:
# Count sample size
df_pictures.count()

97

In [15]:
# Spark DataFrame visualisation
df_pictures.show(5)

+--------------------+--------------------+------+--------------------+
|                path|    modificationTime|length|             content|
+--------------------+--------------------+------+--------------------+
|file:/C:/fruits-3...|2022-03-21 15:13:...|122525|[FF D8 FF E0 00 1...|
|file:/C:/fruits-3...|2022-03-21 15:13:...|119068|[FF D8 FF E0 00 1...|
|file:/C:/fruits-3...|2022-03-21 15:13:...| 96281|[FF D8 FF E0 00 1...|
|file:/C:/fruits-3...|2022-03-21 15:13:...| 95790|[FF D8 FF E0 00 1...|
|file:/C:/fruits-3...|2022-03-21 15:13:...| 95486|[FF D8 FF E0 00 1...|
+--------------------+--------------------+------+--------------------+
only showing top 5 rows



In [16]:
# Spark DataFrame scheme
df_pictures.printSchema()

root
 |-- path: string (nullable = true)
 |-- modificationTime: timestamp (nullable = true)
 |-- length: long (nullable = true)
 |-- content: binary (nullable = true)



## Distinguish Target  <a class="anchor" id="sub5_2"></a> 

In [17]:
# Extract Target from path
df_pictures = df_pictures.withColumn('target', element_at(split(df_pictures['path'], "/"), -2))

In [18]:
df_pictures.show(5)

+--------------------+--------------------+------+--------------------+---------------+
|                path|    modificationTime|length|             content|         target|
+--------------------+--------------------+------+--------------------+---------------+
|file:/C:/fruits-3...|2022-03-21 15:13:...|122525|[FF D8 FF E0 00 1...|    apple_hit_1|
|file:/C:/fruits-3...|2022-03-21 15:13:...|119068|[FF D8 FF E0 00 1...|    apple_hit_1|
|file:/C:/fruits-3...|2022-03-21 15:13:...| 96281|[FF D8 FF E0 00 1...|cabbage_white_1|
|file:/C:/fruits-3...|2022-03-21 15:13:...| 95790|[FF D8 FF E0 00 1...|cabbage_white_1|
|file:/C:/fruits-3...|2022-03-21 15:13:...| 95486|[FF D8 FF E0 00 1...|cabbage_white_1|
+--------------------+--------------------+------+--------------------+---------------+
only showing top 5 rows



In [19]:
# Spark DataFrame scheme
df_pictures.printSchema()

root
 |-- path: string (nullable = true)
 |-- modificationTime: timestamp (nullable = true)
 |-- length: long (nullable = true)
 |-- content: binary (nullable = true)
 |-- target: string (nullable = true)



# Dimension reduction <a class="anchor" id="chapter6"></a> 

## Instantiate a ResNet50 model with pre-trained weights  <a class="anchor" id="sub6_1"></a> 
"Residual Network" with 50 layers. Convolutional Neural Network

In [20]:
'''model = ResNet50(
    # Weights pre-trained on ImageNet: resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5 
    # in cache directory (~/.keras/models)
    weights='imagenet',                                 
    input_shape=(IMAGE_RESIZE, IMAGE_RESIZE, 3),
    # Do not include the ImageNet fully-connected layer at the top of the network
    include_top=False,
    classes=24)'''

import tensorflow as tf
from tensorflow.keras.applications.xception import Xception

print(tf.__version__)

model = Xception(
        include_top=False, # top layer supprimé
        weights=None,
        input_shape=(100,100,3),
        pooling='max')

2.8.0


In [21]:
model.summary()

Model: "xception"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 100, 100, 3  0           []                               
                                )]                                                                
                                                                                                  
 block1_conv1 (Conv2D)          (None, 49, 49, 32)   864         ['input_1[0][0]']                
                                                                                                  
 block1_conv1_bn (BatchNormaliz  (None, 49, 49, 32)  128         ['block1_conv1[0][0]']           
 ation)                                                                                           
                                                                                           

 n)                                                                                               
                                                                                                  
 block4_sepconv2 (SeparableConv  (None, 12, 12, 728)  536536     ['block4_sepconv2_act[0][0]']    
 2D)                                                                                              
                                                                                                  
 block4_sepconv2_bn (BatchNorma  (None, 12, 12, 728)  2912       ['block4_sepconv2[0][0]']        
 lization)                                                                                        
                                                                                                  
 conv2d_2 (Conv2D)              (None, 6, 6, 728)    186368      ['add_1[0][0]']                  
                                                                                                  
 block4_po

                                                                                                  
 block7_sepconv1_bn (BatchNorma  (None, 6, 6, 728)   2912        ['block7_sepconv1[0][0]']        
 lization)                                                                                        
                                                                                                  
 block7_sepconv2_act (Activatio  (None, 6, 6, 728)   0           ['block7_sepconv1_bn[0][0]']     
 n)                                                                                               
                                                                                                  
 block7_sepconv2 (SeparableConv  (None, 6, 6, 728)   536536      ['block7_sepconv2_act[0][0]']    
 2D)                                                                                              
                                                                                                  
 block7_se

                                                                  'add_6[0][0]']                  
                                                                                                  
 block10_sepconv1_act (Activati  (None, 6, 6, 728)   0           ['add_7[0][0]']                  
 on)                                                                                              
                                                                                                  
 block10_sepconv1 (SeparableCon  (None, 6, 6, 728)   536536      ['block10_sepconv1_act[0][0]']   
 v2D)                                                                                             
                                                                                                  
 block10_sepconv1_bn (BatchNorm  (None, 6, 6, 728)   2912        ['block10_sepconv1[0][0]']       
 alization)                                                                                       
          

 block12_sepconv3 (SeparableCon  (None, 6, 6, 728)   536536      ['block12_sepconv3_act[0][0]']   
 v2D)                                                                                             
                                                                                                  
 block12_sepconv3_bn (BatchNorm  (None, 6, 6, 728)   2912        ['block12_sepconv3[0][0]']       
 alization)                                                                                       
                                                                                                  
 add_10 (Add)                   (None, 6, 6, 728)    0           ['block12_sepconv3_bn[0][0]',    
                                                                  'add_9[0][0]']                  
                                                                                                  
 block13_sepconv1_act (Activati  (None, 6, 6, 728)   0           ['add_10[0][0]']                 
 on)      

In [22]:
# Broadcast the model weights in the SparkContext
bc_model_weights = sc.broadcast(model.get_weights())

## Functions  <a class="anchor" id="sub6_2"></a> 
https://docs.databricks.com/_static/notebooks/deep-learning/deep-learning-transfer-learning-keras.html

In [23]:
def model_fn():
    """
    Returns a ResNet50 model with top layer removed and broadcasted pretrained weights.
    """
    
    '''    model = ResNet50(
        weights=None,                                 
        input_shape=(IMAGE_RESIZE, IMAGE_RESIZE, 3),
        include_top=False,         # Do not include the ImageNet fully-connected layer at the top of the network
        classes=24)'''

    model = Xception(
                    include_top=False, # top layer supprimé
                    weights=None,
                    pooling='max')
    
    # set model weights to previously broadcasted weights
    model.set_weights(bc_model_weights.value)
    
    return model

### Preprocess one image

In [24]:
def preprocess(content):
    """
    Preprocesses raw image bytes for prediction.
    """
    # Redimension of pictures
    img = Image.open(io.BytesIO(content)).resize([299, 299])
    
    arr = img_to_array(img)
    
    return preprocess_input(arr)

### Featurize a pd.Series of images

In [25]:
def featurize_series(model, content_series):
    """
    Featurize a pd.Series of raw images using the input model.
    :return: a pd.Series of image features
    """
    input = np.stack(content_series.map(preprocess))
    
    preds = model.predict(input)

    # For some layers, output features will be multi-dimensional tensors.
    # We flatten the feature tensors to vectors for easier storage in Spark DataFrames.
    output = [p.flatten() for p in preds]
    
    return pd.Series(output)

###  Featurize all images
PandasUDFType.SCALAR_ITER used to amortize the cost of loading large models on workers  

In [26]:
#def featurize_udf(content_series_iter: PandasUDFType.SCALAR_ITER) -> 'array<float>':
#specify type hints for pandas UDF instead of specifying pandas UDF type

@pandas_udf('array<float>', PandasUDFType.SCALAR_ITER)
def featurize_udf(content_series_iter):
    '''
    This method is a Scalar Iterator pandas UDF wrapping our featurization function.
    The decorator specifies that this returns a Spark DataFrame column of type ArrayType(FloatType).

    :param content_series_iter: This argument is an iterator over batches of data, where each batch
                              is a pandas Series of image data.
    '''
    # With Scalar Iterator pandas UDFs, we can load the model once and then re-use it
    # for multiple data batches. This amortizes the overhead of loading big models.    
    model = model_fn()
    
    for content_series in content_series_iter:
        yield featurize_series(model, content_series)



## Features extraction  <a class="anchor" id="sub6_3"></a> 

In [27]:
df_pictures.show(5)

+--------------------+--------------------+------+--------------------+---------------+
|                path|    modificationTime|length|             content|         target|
+--------------------+--------------------+------+--------------------+---------------+
|file:/C:/fruits-3...|2022-03-21 15:13:...|122525|[FF D8 FF E0 00 1...|    apple_hit_1|
|file:/C:/fruits-3...|2022-03-21 15:13:...|119068|[FF D8 FF E0 00 1...|    apple_hit_1|
|file:/C:/fruits-3...|2022-03-21 15:13:...| 96281|[FF D8 FF E0 00 1...|cabbage_white_1|
|file:/C:/fruits-3...|2022-03-21 15:13:...| 95790|[FF D8 FF E0 00 1...|cabbage_white_1|
|file:/C:/fruits-3...|2022-03-21 15:13:...| 95486|[FF D8 FF E0 00 1...|cabbage_white_1|
+--------------------+--------------------+------+--------------------+---------------+
only showing top 5 rows



In [28]:
# Save starting time
time_start = time.time()

# Large model to the full dataset
features_df = df_pictures.repartition(16) \
    .select(col("path"), col('target'), featurize_udf("content").alias("X_features"))

# Compute time elapse
elapse_s = time.time()-time_start
elapse_m = int(elapse_s / 60)
print('Feature extraction done! Time elapsed: {} seconds ({} minutes)'.format(elapse_s, elapse_m))

Feature extraction done! Time elapsed: 0.055121421813964844 seconds (0 minutes)


In [29]:
features_df.count()

97

In [30]:
features_df.persist()

Py4JJavaError: An error occurred while calling o88.showString.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 17.0 failed 1 times, most recent failure: Lost task 0.0 in stage 17.0 (TID 241) (DESKTOP-4AFITJT executor driver): org.apache.spark.SparkException: Python worker exited unexpectedly (crashed)
	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator$$anonfun$1.applyOrElse(PythonRunner.scala:595)
	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator$$anonfun$1.applyOrElse(PythonRunner.scala:577)
	at scala.runtime.AbstractPartialFunction.apply(AbstractPartialFunction.scala:38)
	at org.apache.spark.sql.execution.python.PythonArrowOutput$$anon$1.read(PythonArrowOutput.scala:107)
	at org.apache.spark.sql.execution.python.PythonArrowOutput$$anon$1.read(PythonArrowOutput.scala:50)
	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:508)
	at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
	at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:491)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage2.processNext(Unknown Source)
	at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
	at org.apache.spark.sql.execution.WholeStageCodegenExec$$anon$1.hasNext(WholeStageCodegenExec.scala:759)
	at org.apache.spark.sql.execution.columnar.DefaultCachedBatchSerializer$$anon$1.hasNext(InMemoryRelation.scala:118)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at org.apache.spark.storage.memory.MemoryStore.putIterator(MemoryStore.scala:223)
	at org.apache.spark.storage.memory.MemoryStore.putIteratorAsValues(MemoryStore.scala:302)
	at org.apache.spark.storage.BlockManager.$anonfun$doPutIterator$1(BlockManager.scala:1481)
	at org.apache.spark.storage.BlockManager.org$apache$spark$storage$BlockManager$$doPut(BlockManager.scala:1408)
	at org.apache.spark.storage.BlockManager.doPutIterator(BlockManager.scala:1472)
	at org.apache.spark.storage.BlockManager.getOrElseUpdate(BlockManager.scala:1295)
	at org.apache.spark.rdd.RDD.getOrCompute(RDD.scala:384)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:335)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:373)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:337)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:373)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:337)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:373)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:337)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:373)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:337)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:373)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:337)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
	at org.apache.spark.scheduler.Task.run(Task.scala:131)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:506)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1462)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:509)
	at java.util.concurrent.ThreadPoolExecutor.runWorker(Unknown Source)
	at java.util.concurrent.ThreadPoolExecutor$Worker.run(Unknown Source)
	at java.lang.Thread.run(Unknown Source)
Caused by: java.io.EOFException
	at java.io.DataInputStream.readInt(Unknown Source)
	at org.apache.spark.sql.execution.python.PythonArrowOutput$$anon$1.read(PythonArrowOutput.scala:88)
	... 41 more

Driver stacktrace:
	at org.apache.spark.scheduler.DAGScheduler.failJobAndIndependentStages(DAGScheduler.scala:2454)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2(DAGScheduler.scala:2403)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2$adapted(DAGScheduler.scala:2402)
	at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
	at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:2402)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1(DAGScheduler.scala:1160)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1$adapted(DAGScheduler.scala:1160)
	at scala.Option.foreach(Option.scala:407)
	at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:1160)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:2642)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2584)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2573)
	at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:938)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2214)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2235)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2254)
	at org.apache.spark.sql.execution.SparkPlan.executeTake(SparkPlan.scala:476)
	at org.apache.spark.sql.execution.SparkPlan.executeTake(SparkPlan.scala:429)
	at org.apache.spark.sql.execution.CollectLimitExec.executeCollect(limit.scala:48)
	at org.apache.spark.sql.Dataset.collectFromPlan(Dataset.scala:3715)
	at org.apache.spark.sql.Dataset.$anonfun$head$1(Dataset.scala:2728)
	at org.apache.spark.sql.Dataset.$anonfun$withAction$1(Dataset.scala:3706)
	at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId$5(SQLExecution.scala:103)
	at org.apache.spark.sql.execution.SQLExecution$.withSQLConfPropagated(SQLExecution.scala:163)
	at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId$1(SQLExecution.scala:90)
	at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:775)
	at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:64)
	at org.apache.spark.sql.Dataset.withAction(Dataset.scala:3704)
	at org.apache.spark.sql.Dataset.head(Dataset.scala:2728)
	at org.apache.spark.sql.Dataset.take(Dataset.scala:2935)
	at org.apache.spark.sql.Dataset.getRows(Dataset.scala:287)
	at org.apache.spark.sql.Dataset.showString(Dataset.scala:326)
	at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at sun.reflect.NativeMethodAccessorImpl.invoke(Unknown Source)
	at sun.reflect.DelegatingMethodAccessorImpl.invoke(Unknown Source)
	at java.lang.reflect.Method.invoke(Unknown Source)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
	at py4j.Gateway.invoke(Gateway.java:282)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.ClientServerConnection.waitForCommands(ClientServerConnection.java:182)
	at py4j.ClientServerConnection.run(ClientServerConnection.java:106)
	at java.lang.Thread.run(Unknown Source)
Caused by: org.apache.spark.SparkException: Python worker exited unexpectedly (crashed)
	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator$$anonfun$1.applyOrElse(PythonRunner.scala:595)
	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator$$anonfun$1.applyOrElse(PythonRunner.scala:577)
	at scala.runtime.AbstractPartialFunction.apply(AbstractPartialFunction.scala:38)
	at org.apache.spark.sql.execution.python.PythonArrowOutput$$anon$1.read(PythonArrowOutput.scala:107)
	at org.apache.spark.sql.execution.python.PythonArrowOutput$$anon$1.read(PythonArrowOutput.scala:50)
	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:508)
	at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
	at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:491)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage2.processNext(Unknown Source)
	at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
	at org.apache.spark.sql.execution.WholeStageCodegenExec$$anon$1.hasNext(WholeStageCodegenExec.scala:759)
	at org.apache.spark.sql.execution.columnar.DefaultCachedBatchSerializer$$anon$1.hasNext(InMemoryRelation.scala:118)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at org.apache.spark.storage.memory.MemoryStore.putIterator(MemoryStore.scala:223)
	at org.apache.spark.storage.memory.MemoryStore.putIteratorAsValues(MemoryStore.scala:302)
	at org.apache.spark.storage.BlockManager.$anonfun$doPutIterator$1(BlockManager.scala:1481)
	at org.apache.spark.storage.BlockManager.org$apache$spark$storage$BlockManager$$doPut(BlockManager.scala:1408)
	at org.apache.spark.storage.BlockManager.doPutIterator(BlockManager.scala:1472)
	at org.apache.spark.storage.BlockManager.getOrElseUpdate(BlockManager.scala:1295)
	at org.apache.spark.rdd.RDD.getOrCompute(RDD.scala:384)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:335)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:373)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:337)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:373)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:337)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:373)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:337)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:373)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:337)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:373)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:337)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
	at org.apache.spark.scheduler.Task.run(Task.scala:131)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:506)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1462)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:509)
	at java.util.concurrent.ThreadPoolExecutor.runWorker(Unknown Source)
	at java.util.concurrent.ThreadPoolExecutor$Worker.run(Unknown Source)
	... 1 more
Caused by: java.io.EOFException
	at java.io.DataInputStream.readInt(Unknown Source)
	at org.apache.spark.sql.execution.python.PythonArrowOutput$$anon$1.read(PythonArrowOutput.scala:88)
	... 41 more


Py4JJavaError: An error occurred while calling o88.getRowsToPython.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 19.0 failed 1 times, most recent failure: Lost task 0.0 in stage 19.0 (TID 242) (DESKTOP-4AFITJT executor driver): org.apache.spark.SparkException: Python worker exited unexpectedly (crashed)
	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator$$anonfun$1.applyOrElse(PythonRunner.scala:595)
	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator$$anonfun$1.applyOrElse(PythonRunner.scala:577)
	at scala.runtime.AbstractPartialFunction.apply(AbstractPartialFunction.scala:38)
	at org.apache.spark.sql.execution.python.PythonArrowOutput$$anon$1.read(PythonArrowOutput.scala:107)
	at org.apache.spark.sql.execution.python.PythonArrowOutput$$anon$1.read(PythonArrowOutput.scala:50)
	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:508)
	at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
	at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:491)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage2.processNext(Unknown Source)
	at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
	at org.apache.spark.sql.execution.WholeStageCodegenExec$$anon$1.hasNext(WholeStageCodegenExec.scala:759)
	at org.apache.spark.sql.execution.columnar.DefaultCachedBatchSerializer$$anon$1.hasNext(InMemoryRelation.scala:118)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at org.apache.spark.storage.memory.MemoryStore.putIterator(MemoryStore.scala:223)
	at org.apache.spark.storage.memory.MemoryStore.putIteratorAsValues(MemoryStore.scala:302)
	at org.apache.spark.storage.BlockManager.$anonfun$doPutIterator$1(BlockManager.scala:1481)
	at org.apache.spark.storage.BlockManager.org$apache$spark$storage$BlockManager$$doPut(BlockManager.scala:1408)
	at org.apache.spark.storage.BlockManager.doPutIterator(BlockManager.scala:1472)
	at org.apache.spark.storage.BlockManager.getOrElseUpdate(BlockManager.scala:1295)
	at org.apache.spark.rdd.RDD.getOrCompute(RDD.scala:384)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:335)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:373)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:337)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:373)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:337)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:373)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:337)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:373)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:337)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:373)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:337)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
	at org.apache.spark.scheduler.Task.run(Task.scala:131)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:506)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1462)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:509)
	at java.util.concurrent.ThreadPoolExecutor.runWorker(Unknown Source)
	at java.util.concurrent.ThreadPoolExecutor$Worker.run(Unknown Source)
	at java.lang.Thread.run(Unknown Source)
Caused by: java.io.EOFException
	at java.io.DataInputStream.readInt(Unknown Source)
	at org.apache.spark.sql.execution.python.PythonArrowOutput$$anon$1.read(PythonArrowOutput.scala:88)
	... 41 more

Driver stacktrace:
	at org.apache.spark.scheduler.DAGScheduler.failJobAndIndependentStages(DAGScheduler.scala:2454)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2(DAGScheduler.scala:2403)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2$adapted(DAGScheduler.scala:2402)
	at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
	at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:2402)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1(DAGScheduler.scala:1160)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1$adapted(DAGScheduler.scala:1160)
	at scala.Option.foreach(Option.scala:407)
	at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:1160)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:2642)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2584)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2573)
	at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:938)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2214)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2235)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2254)
	at org.apache.spark.sql.execution.SparkPlan.executeTake(SparkPlan.scala:476)
	at org.apache.spark.sql.execution.SparkPlan.executeTake(SparkPlan.scala:429)
	at org.apache.spark.sql.execution.CollectLimitExec.executeCollect(limit.scala:48)
	at org.apache.spark.sql.Dataset.collectFromPlan(Dataset.scala:3715)
	at org.apache.spark.sql.Dataset.$anonfun$head$1(Dataset.scala:2728)
	at org.apache.spark.sql.Dataset.$anonfun$withAction$1(Dataset.scala:3706)
	at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId$5(SQLExecution.scala:103)
	at org.apache.spark.sql.execution.SQLExecution$.withSQLConfPropagated(SQLExecution.scala:163)
	at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId$1(SQLExecution.scala:90)
	at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:775)
	at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:64)
	at org.apache.spark.sql.Dataset.withAction(Dataset.scala:3704)
	at org.apache.spark.sql.Dataset.head(Dataset.scala:2728)
	at org.apache.spark.sql.Dataset.take(Dataset.scala:2935)
	at org.apache.spark.sql.Dataset.getRows(Dataset.scala:287)
	at org.apache.spark.sql.Dataset.getRowsToPython(Dataset.scala:3558)
	at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at sun.reflect.NativeMethodAccessorImpl.invoke(Unknown Source)
	at sun.reflect.DelegatingMethodAccessorImpl.invoke(Unknown Source)
	at java.lang.reflect.Method.invoke(Unknown Source)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
	at py4j.Gateway.invoke(Gateway.java:282)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.ClientServerConnection.waitForCommands(ClientServerConnection.java:182)
	at py4j.ClientServerConnection.run(ClientServerConnection.java:106)
	at java.lang.Thread.run(Unknown Source)
Caused by: org.apache.spark.SparkException: Python worker exited unexpectedly (crashed)
	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator$$anonfun$1.applyOrElse(PythonRunner.scala:595)
	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator$$anonfun$1.applyOrElse(PythonRunner.scala:577)
	at scala.runtime.AbstractPartialFunction.apply(AbstractPartialFunction.scala:38)
	at org.apache.spark.sql.execution.python.PythonArrowOutput$$anon$1.read(PythonArrowOutput.scala:107)
	at org.apache.spark.sql.execution.python.PythonArrowOutput$$anon$1.read(PythonArrowOutput.scala:50)
	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:508)
	at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
	at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:491)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage2.processNext(Unknown Source)
	at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
	at org.apache.spark.sql.execution.WholeStageCodegenExec$$anon$1.hasNext(WholeStageCodegenExec.scala:759)
	at org.apache.spark.sql.execution.columnar.DefaultCachedBatchSerializer$$anon$1.hasNext(InMemoryRelation.scala:118)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at org.apache.spark.storage.memory.MemoryStore.putIterator(MemoryStore.scala:223)
	at org.apache.spark.storage.memory.MemoryStore.putIteratorAsValues(MemoryStore.scala:302)
	at org.apache.spark.storage.BlockManager.$anonfun$doPutIterator$1(BlockManager.scala:1481)
	at org.apache.spark.storage.BlockManager.org$apache$spark$storage$BlockManager$$doPut(BlockManager.scala:1408)
	at org.apache.spark.storage.BlockManager.doPutIterator(BlockManager.scala:1472)
	at org.apache.spark.storage.BlockManager.getOrElseUpdate(BlockManager.scala:1295)
	at org.apache.spark.rdd.RDD.getOrCompute(RDD.scala:384)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:335)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:373)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:337)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:373)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:337)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:373)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:337)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:373)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:337)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:373)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:337)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
	at org.apache.spark.scheduler.Task.run(Task.scala:131)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:506)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1462)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:509)
	at java.util.concurrent.ThreadPoolExecutor.runWorker(Unknown Source)
	at java.util.concurrent.ThreadPoolExecutor$Worker.run(Unknown Source)
	... 1 more
Caused by: java.io.EOFException
	at java.io.DataInputStream.readInt(Unknown Source)
	at org.apache.spark.sql.execution.python.PythonArrowOutput$$anon$1.read(PythonArrowOutput.scala:88)
	... 41 more


In [None]:
#features_df.show()

# Train a new model using pre-computed features <a class="anchor" id="chapter7"></a> 
because my dataset has too little data to train a full-scale model from scratch  

Layers trainable attributes:
- weights : list of all weights variables of the layer  
- trainable_weights : list of those that are meant to be updated (via gradient descent) to minimize the loss during training  
- non_trainable_weights : list of those that aren't meant to be trained (updated by the model during the forward pass)  
- trainable : false moves all the layer's weights from trainable to non-trainable ("freezing" the layer)  
The only built-in layer that has non-trainable weights is the BatchNormalization layer  

Workflow 1 :  
1 - Instantiate a base model and load pre-trained weights into it  
2 - Freeze all layers in the base model by setting trainable = False  
3 - Create a new model on top of the output of one/several layers from the base model  
4 - Train your new model on your new dataset  (top layers to learn to turn the old features into predictions on a new dataset)   

Workflow 2 : feature extraction  
1 - Instantiate a base model and load pre-trained weights into it  
2 - Run your new dataset through it and record the output of one/several layers from the base model   
3 - Use that output as input data for a new, smaller model   
advantage: you only run the base model once on your data, rather than once per epoch of training > a lot faster & cheaper  
issue: doesn't allow you to dynamically modify the input data of your new model during training, which is required when doing data augmentation  

Fine-tuning (optionnal): 
unfreeze the entire/partial model you obtained and re-training it on the new data with a very low learning rate  

## Prepare my new model <a class="anchor" id="sub7_1"></a> 

### Freeze all layers in the base model by setting trainable = False

In [None]:
model.trainable = False

### Create a new model on top of the output of one layer from the model  

In [None]:
# Instantiates a Keras tensor
inputs = keras.Input(shape=(IMAGE_RESIZE, IMAGE_RESIZE, 3))

# Makes pretrained_model run in inference mode by passing training to False (necessary for fine-tuning)
x = model(inputs, training=False)

# Convert features of shape 'base_model.output_shape[1:]' to vectors
#x = keras.layers.GlobalAveragePooling2D()(x)

# A Dense classifier with 24 units et activation 'softmax'
# Node: 'categorical_crossentropy/softmax_cross_entropy_with_logits' > logits and labels must be broadcastable
outputs = keras.layers.Dense(24, activation=DENSE_LAYER_ACTIVATION)(x)

new_model = keras.Model(inputs, outputs)

In [None]:
# Compile
sgd = optimizers.SGD(learning_rate = 0.01, decay = 1e-6, momentum = 0.9, nesterov = True)
new_model.compile(optimizer=sgd, loss=OBJECTIVE_FUNCTION, metrics=LOSS_METRICS)

Spark workers need to access the model and its weights.  
For moderately sized models (< 1GB in size), a good practice is to download the model to the Spark driver and 
then broadcast the weights to the workers.  
For large models (> 1GB), it is best to load the model weights from distributed storage to workers directly.  

## Train my new model <a class="anchor" id="sub7_2"></a> 

In [None]:
def train_series(model, content_series):
    """
    Train a model a raw images
    :return: a trained model
    """
    X_train = np.stack(content_series.map(preprocess))
    y_train = content_series.col('Target')
    
    print(type(X_train))
    print(type(y_train))
    
    #cb_early_stopper = EarlyStopping(monitor = 'loss', patience = EARLY_STOP_PATIENCE)
    #cb_checkpointer = ModelCheckpoint(filepath = 'best.hdf5', monitor = 'loss', save_best_only = True, mode = 'auto')

    #preds = model.fit(X_train, epochs=NUM_EPOCHS, callbacks=[cb_checkpointer, cb_early_stopper], y_train)

    return model

In [None]:
@pandas_udf('array<float>', PandasUDFType.SCALAR_ITER)
def train_udf(content_series_iter):

    # With Scalar Iterator pandas UDFs, we can load the model once and then re-use it
    # for multiple data batches.  This amortizes the overhead of loading big models.
    
    for content_series in content_series_iter:
        yield train_series(my_model, content_series)

In [None]:
result_df = df_pictures.repartition(16) \
    .select(col("path"), col('target'), train_udf("content").alias("target_pred"))

In [None]:
result_df.persist()

In [None]:
result_df.show()

* [Go to Table des matières](#chapter0)

# End <a class="anchor" id="chapter100"></a> 