In [1]:
# setup

!sudo apt-get install openjdk-11-jdk
# !wget -q https://downloads.apache.org/spark/spark-3.1.1/spark-3.1.1-bin-hadoop3.2.tgz
# !tar xvzf spark-3.1.1-bin-hadoop3.2.tgz
!pip install pyspark
!pip install -q findspark
!pip install pyarrow
try:
  # %tensorflow_version only exists in Colab.
  !pip install  tf-estimator-nightly==2.8.0.dev2021122109
except Exception:
  pass

Reading package lists... Done
Building dependency tree       
Reading state information... Done
openjdk-11-jdk is already the newest version (11.0.14.1+1-0ubuntu1~18.04).
0 upgraded, 0 newly installed, 0 to remove and 39 not upgraded.


In [2]:
from PIL import Image
from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
from PIL import Image
from pyspark.sql.functions import col, pandas_udf, regexp_extract
import io
from tensorflow.keras.applications.imagenet_utils import decode_predictions
import pandas as pd
from pyspark.sql.functions import col, pandas_udf, PandasUDFType
import pathlib
from pyspark.sql.functions import col, pandas_udf, regexp_extract
import io
from tensorflow.keras.applications.imagenet_utils import decode_predictions
import pandas as pd
from pyspark.sql.functions import col, pandas_udf, PandasUDFType
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
from pyspark.sql import SparkSession

In [3]:
import os
os.environ["JAVA_HOME"] = "/usr/lib/jvm/java-11-openjdk-amd64"
# os.environ["SPARK_HOME"] = "/content/spark-3.1.1-bin-hadoop3.2"

In [4]:
import findspark
findspark.init()

spark = SparkSession.builder.master("local[*]").getOrCreate()

In [5]:
data_dir = tf.keras.utils.get_file(origin='http://pjreddie.com/media/files/cifar.tgz',fname='cifar', untar=True)

Downloading data from http://pjreddie.com/media/files/cifar.tgz


In [6]:
images = spark.read.format("binaryFile").option("recursiveFileLookup", "true").option("pathGlobFilter", "*.png").load(data_dir)

In [7]:
def extract_label(path_col):
  """Extract label from file path using built-in SQL functions."""
  return regexp_extract(path_col, "_([^/.]+)", 1)

def extract_size(content):
  """Extract image size from its raw content."""
  image = Image.open(io.BytesIO(content))
  return image.size

@pandas_udf("width: int, height: int")
def extract_size_udf(content_series):
  sizes = content_series.apply(extract_size)
  return pd.DataFrame(list(sizes))

df = images.select(
  col("path"),
  col("modificationTime"),
  extract_label(col("path")).alias("label"),
  extract_size_udf(col("content")).alias("size"),
  col("content"))

In [9]:
df.show(5)

+--------------------+-------------------+-----+--------+--------------------+
|                path|   modificationTime|label|    size|             content|
+--------------------+-------------------+-----+--------+--------------------+
|file:/root/.keras...|2016-11-18 20:24:13| frog|{32, 32}|[89 50 4E 47 0D 0...|
|file:/root/.keras...|2016-11-18 20:24:13| bird|{32, 32}|[89 50 4E 47 0D 0...|
|file:/root/.keras...|2016-11-18 20:24:12| frog|{32, 32}|[89 50 4E 47 0D 0...|
|file:/root/.keras...|2016-11-18 20:24:13| frog|{32, 32}|[89 50 4E 47 0D 0...|
|file:/root/.keras...|2016-11-18 20:24:12| deer|{32, 32}|[89 50 4E 47 0D 0...|
+--------------------+-------------------+-----+--------+--------------------+
only showing top 5 rows



In [10]:
df.select(col('label')).distinct().show()

+----------+
|     label|
+----------+
|      deer|
|      bird|
|      frog|
|     truck|
|       dog|
|       cat|
|  airplane|
|automobile|
|     horse|
|      ship|
+----------+



In [11]:
class ImageNetDataset(Dataset):
  """
  Converts image contents into a PyTorch Dataset with standard ImageNet preprocessing.
  """
  def __init__(self, contents):
    self.contents = contents

  def __len__(self):
    return len(self.contents)

  def __getitem__(self, index):
    return self._preprocess(self.contents[index])

  def _preprocess(self, content):
    """
    Preprocesses the input image content using standard ImageNet normalization.
    
    See https://pytorch.org/docs/stable/torchvision/models.html.
    """
    image = Image.open(io.BytesIO(content))
    transform = transforms.Compose([
      transforms.Resize(256),
      transforms.CenterCrop(224),
      transforms.ToTensor(),
      transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    return transform(image)

In [12]:
def imagenet_model_udf(model_fn):
  """
  Wraps an ImageNet model into a Pandas UDF that makes predictions.
  
  You might consider the following customizations for your own use case:
    - Tune DataLoader's batch_size and num_workers for better performance.
    - Use GPU for acceleration.
    - Change prediction types.
  """
  def predict(content_series_iter):
    model = model_fn()
    model.eval()
    for content_series in content_series_iter:
      dataset = ImageNetDataset(list(content_series))
      loader = DataLoader(dataset, batch_size=64)
      with torch.no_grad():
        for image_batch in loader:
          predictions = model(image_batch).numpy()
          predicted_labels = [x[0] for x in decode_predictions(predictions, top=1)]
          # print(predicted_labels)
          # break
          yield pd.DataFrame(predicted_labels)
  return_type = "class: string, desc: string, score:float"
  return pandas_udf(return_type, PandasUDFType.SCALAR_ITER)(predict)

In [13]:
mobilenet_v2_udf = imagenet_model_udf(lambda: models.mobilenet_v2(pretrained=True))



In [14]:
#images = table("ml_tmp.flowers")
#predictions = images.withColumn("prediction", mobilenet_v2_udf(col("content")))

predictions = df.withColumn("prediction", mobilenet_v2_udf(col("content")))
display(predictions.select(col("path"), col("prediction")).limit(5))

DataFrame[path: string, prediction: struct<class:string,desc:string,score:float>]

In [16]:
# predictions.select(col("prediction")).show(10, truncate = False)
predictions.select(col("label"),col("prediction")).show(10, truncate = False)

+-----+----------------------------------------+
|label|prediction                              |
+-----+----------------------------------------+
|frog |{n02128925, jaguar, 13.927444}          |
|bird |{n07745940, strawberry, 9.118509}       |
|frog |{n01630670, common_newt, 11.725792}     |
|frog |{n02457408, three-toed_sloth, 9.893128} |
|deer |{n02114712, red_wolf, 7.620323}         |
|frog |{n02130308, cheetah, 12.235772}         |
|bird |{n02457408, three-toed_sloth, 11.331468}|
|frog |{n01744401, rock_python, 9.915939}      |
|frog |{n01644900, tailed_frog, 10.826723}     |
|deer |{n02356798, fox_squirrel, 10.129302}    |
+-----+----------------------------------------+
only showing top 10 rows



## Trying other models

In [17]:
# Resnet 

resnet18_udf = imagenet_model_udf(lambda: models.resnet18(pretrained=True))
predictions = df.withColumn("prediction", resnet18_udf(col("content")))
predictions.select(col("label"),col("prediction")).show(10, truncate = False)



+-----+----------------------------------------+
|label|prediction                              |
+-----+----------------------------------------+
|frog |{n02130308, cheetah, 11.560522}         |
|bird |{n01443537, goldfish, 10.951969}        |
|frog |{n01744401, rock_python, 7.9980874}     |
|frog |{n01871265, tusker, 8.802372}           |
|deer |{n02114712, red_wolf, 7.421795}         |
|frog |{n02356798, fox_squirrel, 7.0292416}    |
|bird |{n02457408, three-toed_sloth, 11.415656}|
|frog |{n01644900, tailed_frog, 12.168472}     |
|frog |{n02115913, dhole, 10.988573}           |
|deer |{n03017168, chime, 8.844654}            |
+-----+----------------------------------------+
only showing top 10 rows



In [18]:
# Alexnet

alexnet_udf = imagenet_model_udf(lambda: models.alexnet(pretrained=True))
predictions = df.withColumn("prediction", alexnet_udf(col("content")))
predictions.select(col("label"),col("prediction")).show(10, truncate = False)



+-----+---------------------------------------+
|label|prediction                             |
+-----+---------------------------------------+
|frog |{n02128925, jaguar, 12.137697}         |
|bird |{n01443537, goldfish, 7.694492}        |
|frog |{n01644900, tailed_frog, 7.793912}     |
|frog |{n02487347, macaque, 7.6259594}        |
|deer |{n02356798, fox_squirrel, 8.7287}      |
|frog |{n02606052, rock_beauty, 6.159985}     |
|bird |{n02457408, three-toed_sloth, 8.934298}|
|frog |{n01688243, frilled_lizard, 10.808053} |
|frog |{n02356798, fox_squirrel, 7.9078245}   |
|deer |{n01843065, jacamar, 6.974002}         |
+-----+---------------------------------------+
only showing top 10 rows



In [19]:
# Squeezenet

squeezenet_udf = imagenet_model_udf(lambda: models.squeezenet1_0(pretrained=True))
predictions = df.withColumn("prediction", squeezenet_udf(col("content")))
predictions.select(col("label"),col("prediction")).show(10, truncate = False)



+-----+----------------------------------------+
|label|prediction                              |
+-----+----------------------------------------+
|frog |{n02356798, fox_squirrel, 16.09594}     |
|bird |{n01443537, goldfish, 13.846475}        |
|frog |{n07760859, custard_apple, 18.216074}   |
|frog |{n13044778, earthstar, 16.680613}       |
|deer |{n02089973, English_foxhound, 20.55805} |
|frog |{n01756291, sidewinder, 13.225813}      |
|bird |{n02457408, three-toed_sloth, 11.665099}|
|frog |{n01776313, tick, 19.134926}            |
|frog |{n02356798, fox_squirrel, 20.144127}    |
|deer |{n02119022, red_fox, 16.427446}         |
+-----+----------------------------------------+
only showing top 10 rows



In [20]:
# Vgg16

vgg16_udf = imagenet_model_udf(lambda: models.vgg16(pretrained=True))
predictions = df.withColumn("prediction", vgg16_udf(col("content")))
predictions.select(col("label"),col("prediction")).show(10, truncate = False)



+-----+----------------------------------------+
|label|prediction                              |
+-----+----------------------------------------+
|frog |{n13037406, gyromitra, 11.249056}       |
|bird |{n02002724, black_stork, 7.4124026}     |
|frog |{n02128925, jaguar, 10.290902}          |
|frog |{n02457408, three-toed_sloth, 8.371567} |
|deer |{n02099601, golden_retriever, 7.6250567}|
|frog |{n02128925, jaguar, 8.039926}           |
|bird |{n02457408, three-toed_sloth, 7.6756473}|
|frog |{n01688243, frilled_lizard, 8.463073}   |
|frog |{n13037406, gyromitra, 8.195274}        |
|deer |{n02013706, limpkin, 9.639962}          |
+-----+----------------------------------------+
only showing top 10 rows



In [21]:
# Densenet

densenet161_udf = imagenet_model_udf(lambda: models.densenet161(pretrained=True))
predictions = df.withColumn("prediction", densenet161_udf(col("content")))
predictions.select(col("label"),col("prediction")).show(10, truncate = False)



+-----+------------------------------------+
|label|prediction                          |
+-----+------------------------------------+
|frog |{n13037406, gyromitra, 11.50078}    |
|bird |{n01443537, goldfish, 10.488806}    |
|frog |{n01630670, common_newt, 14.277996} |
|frog |{n13037406, gyromitra, 11.299743}   |
|deer |{n02114855, coyote, 10.284748}      |
|frog |{n01744401, rock_python, 7.383606}  |
|bird |{n02500267, indri, 7.679103}        |
|frog |{n02356798, fox_squirrel, 11.697643}|
|frog |{n13037406, gyromitra, 11.137905}   |
|deer |{n02356798, fox_squirrel, 11.487185}|
+-----+------------------------------------+
only showing top 10 rows



In [22]:
# Inception

inception_v3_udf = imagenet_model_udf(lambda: models.inception_v3(pretrained=True))
predictions = df.withColumn("prediction", inception_v3_udf(col("content")))
predictions.select(col("label"),col("prediction")).show(10, truncate = False)



+-----+--------------------------------------+
|label|prediction                            |
+-----+--------------------------------------+
|frog |{n02127052, lynx, 15.658775}          |
|bird |{n01631663, eft, 14.187774}           |
|frog |{n01644900, tailed_frog, 12.118631}   |
|frog |{n01744401, rock_python, 9.564246}    |
|deer |{n04525305, vending_machine, 8.474264}|
|frog |{n02128385, leopard, 15.748716}       |
|bird |{n01770081, harvestman, 9.068718}     |
|frog |{n01688243, frilled_lizard, 13.252224}|
|frog |{n09256479, coral_reef, 8.501441}     |
|deer |{n02356798, fox_squirrel, 10.647777}  |
+-----+--------------------------------------+
only showing top 10 rows



In [23]:
# Googlenet

googlenet_udf = imagenet_model_udf(lambda: models.googlenet(pretrained=True))
predictions = df.withColumn("prediction", googlenet_udf(col("content")))
predictions.select(col("label"),col("prediction")).show(10, truncate = False)



+-----+------------------------------------+
|label|prediction                          |
+-----+------------------------------------+
|frog |{n02128925, jaguar, 8.538764}       |
|bird |{n02606052, rock_beauty, 5.672103}  |
|frog |{n02128385, leopard, 6.2645426}     |
|frog |{n07730033, cardoon, 3.9538639}     |
|deer |{n03016953, chiffonier, 5.7657547}  |
|frog |{n03998194, prayer_rug, 6.065628}   |
|bird |{n02356798, fox_squirrel, 6.559846} |
|frog |{n01644900, tailed_frog, 6.832458}  |
|frog |{n02356798, fox_squirrel, 7.4379053}|
|deer |{n07745940, strawberry, 6.2945814}  |
+-----+------------------------------------+
only showing top 10 rows



In [24]:
# Shufflenet

shufflenet_udf = imagenet_model_udf(lambda: models.shufflenet_v2_x1_0(pretrained=True))
predictions = df.withColumn("prediction", shufflenet_udf(col("content")))
predictions.select(col("label"),col("prediction")).show(10, truncate = False)



+-----+----------------------------------------+
|label|prediction                              |
+-----+----------------------------------------+
|frog |{n02129165, lion, 13.891481}            |
|bird |{n07720875, bell_pepper, 11.219691}     |
|frog |{n02356798, fox_squirrel, 9.561252}     |
|frog |{n02457408, three-toed_sloth, 10.987317}|
|deer |{n02119789, kit_fox, 7.5422497}         |
|frog |{n01644900, tailed_frog, 8.782939}      |
|bird |{n02457408, three-toed_sloth, 12.465017}|
|frog |{n01688243, frilled_lizard, 12.1711035} |
|frog |{n02129165, lion, 12.077578}            |
|deer |{n02356798, fox_squirrel, 13.9859705}   |
+-----+----------------------------------------+
only showing top 10 rows



In [25]:
# Mobilenet

mobilenet_v2_udf = imagenet_model_udf(lambda: models.mobilenet_v2(pretrained=True))
predictions = df.withColumn("prediction", mobilenet_v2_udf(col("content")))
predictions.select(col("label"),col("prediction")).show(10, truncate = False)



+-----+----------------------------------------+
|label|prediction                              |
+-----+----------------------------------------+
|frog |{n02128925, jaguar, 13.927444}          |
|bird |{n07745940, strawberry, 9.118509}       |
|frog |{n01630670, common_newt, 11.725792}     |
|frog |{n02457408, three-toed_sloth, 9.893128} |
|deer |{n02114712, red_wolf, 7.620323}         |
|frog |{n02130308, cheetah, 12.235772}         |
|bird |{n02457408, three-toed_sloth, 11.331468}|
|frog |{n01744401, rock_python, 9.915939}      |
|frog |{n01644900, tailed_frog, 10.826723}     |
|deer |{n02356798, fox_squirrel, 10.129302}    |
+-----+----------------------------------------+
only showing top 10 rows



In [26]:
mobilenet_v3_large_udf = imagenet_model_udf(lambda: models.mobilenet_v3_large(pretrained=True))
predictions = df.withColumn("prediction", mobilenet_v3_large_udf(col("content")))
predictions.select(col("label"),col("prediction")).show(10, truncate = False)



+-----+----------------------------------+
|label|prediction                        |
+-----+----------------------------------+
|frog |{n01644900, tailed_frog, 8.541575}|
|bird |{n01818515, macaw, 7.593059}      |
|frog |{n01644900, tailed_frog, 8.550882}|
|frog |{n13037406, gyromitra, 8.226665}  |
|deer |{n02423022, gazelle, 7.046612}    |
|frog |{n02128925, jaguar, 7.5521393}    |
|bird |{n02500267, indri, 7.035239}      |
|frog |{n01644900, tailed_frog, 8.152968}|
|frog |{n13037406, gyromitra, 7.8628855} |
|deer |{n02422106, hartebeest, 9.383434} |
+-----+----------------------------------+
only showing top 10 rows



In [27]:
mobilenet_v3_small_udf = imagenet_model_udf(lambda: models.mobilenet_v3_small(pretrained=True))
predictions = df.withColumn("prediction", mobilenet_v3_small_udf(col("content")))
predictions.select(col("label"),col("prediction")).show(10, truncate = False)



+-----+-----------------------------------+
|label|prediction                         |
+-----+-----------------------------------+
|frog |{n02130308, cheetah, 8.274257}     |
|bird |{n02002724, black_stork, 6.107225} |
|frog |{n01630670, common_newt, 8.539049} |
|frog |{n13037406, gyromitra, 7.852108}   |
|deer |{n02114712, red_wolf, 10.001156}   |
|frog |{n01756291, sidewinder, 7.2776165} |
|bird |{n04604644, worm_fence, 5.3525352} |
|frog |{n01644900, tailed_frog, 9.493502} |
|frog |{n02356798, fox_squirrel, 6.365085}|
|deer |{n02389026, sorrel, 4.5488844}     |
+-----+-----------------------------------+
only showing top 10 rows



In [28]:
# Resnext

resnext50_32x4d_udf = imagenet_model_udf(lambda: models.resnext50_32x4d(pretrained=True))
predictions = df.withColumn("prediction",resnext50_32x4d_udf(col("content")))
predictions.select(col("label"),col("prediction")).show(10, truncate = False)



+-----+--------------------------------------+
|label|prediction                            |
+-----+--------------------------------------+
|frog |{n02128925, jaguar, 15.668928}        |
|bird |{n02002724, black_stork, 9.183935}    |
|frog |{n01644900, tailed_frog, 7.7040114}   |
|frog |{n01644900, tailed_frog, 12.080752}   |
|deer |{n02090379, redbone, 8.180999}        |
|frog |{n01744401, rock_python, 9.788459}    |
|bird |{n02492660, howler_monkey, 11.933609} |
|frog |{n01688243, frilled_lizard, 18.348505}|
|frog |{n13037406, gyromitra, 10.3714695}    |
|deer |{n02115913, dhole, 11.735558}         |
+-----+--------------------------------------+
only showing top 10 rows



In [29]:
# WideResnet

wide_resnet50_2_udf = imagenet_model_udf(lambda: models.wide_resnet50_2(pretrained=True))
predictions = df.withColumn("prediction", wide_resnet50_2_udf(col("content")))
predictions.select(col("label"),col("prediction")).show(10, truncate = False)



+-----+---------------------------------------+
|label|prediction                             |
+-----+---------------------------------------+
|frog |{n02128925, jaguar, 13.496078}         |
|bird |{n07714990, broccoli, 7.3184166}       |
|frog |{n01630670, common_newt, 8.448229}     |
|frog |{n01688243, frilled_lizard, 9.081802}  |
|deer |{n02129604, tiger, 9.687185}           |
|frog |{n02128385, leopard, 8.887554}         |
|bird |{n02492660, howler_monkey, 10.001827}  |
|frog |{n01688243, frilled_lizard, 12.0437355}|
|frog |{n13037406, gyromitra, 8.195}          |
|deer |{n02356798, fox_squirrel, 10.605209}   |
+-----+---------------------------------------+
only showing top 10 rows



In [30]:
# Mnasnet

mnasnet1_0_udf = imagenet_model_udf(lambda: models.mnasnet1_0(pretrained=True))
predictions = df.withColumn("prediction", mnasnet1_0_udf(col("content")))
predictions.select(col("label"),col("prediction")).show(10, truncate = False)



+-----+----------------------------------------+
|label|prediction                              |
+-----+----------------------------------------+
|frog |{n02128925, jaguar, 15.212853}          |
|bird |{n01443537, goldfish, 10.728087}        |
|frog |{n07760859, custard_apple, 12.752817}   |
|frog |{n02457408, three-toed_sloth, 12.623458}|
|deer |{n02090379, redbone, 8.531106}          |
|frog |{n02128925, jaguar, 11.006413}          |
|bird |{n02493793, spider_monkey, 12.954729}   |
|frog |{n01688243, frilled_lizard, 11.320785}  |
|frog |{n01990800, isopod, 10.630776}          |
|deer |{n02087046, toy_terrier, 9.972003}      |
+-----+----------------------------------------+
only showing top 10 rows

