In [None]:
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

<table align="left">
</td>
<td style="text-align: center">
<a href="https://console.cloud.google.com/vertex-ai/workbench/instances/create?download_url=https://raw.githubusercontent.com/GoogleCloudDataproc/cloud-dataproc/ai-ml-samples/interactive/ImageClassificationInSpark.ipynb">
<img src="https://lh3.googleusercontent.com/UiNooY4LUgW_oTvpsNhPpQzsstV5W8F7rYgxgGBD85cWJoLmrOzhVs_ksK_vgx40SHs7jCqkTkCk=e14-rj-sc0xffffff-h130-w32" alt="Vertex AI logo"><br> Open in Vertex AI Workbench
</a>
</td>
<td style="text-align: center">
<a href="https://github.com/GoogleCloudDataproc/cloud-dataproc/ai-ml-samples/interactive/ImageClassificationInSpark.ipynb">
<img src="https://cloud.google.com/ml-engine/images/github-logo-32px.png" alt="Vertex AI logo"><br> Open in Vertex AI Workbench
</a>
</td>
</table>

## Overview

In this tutorial, you perform distributed ML inference via image classification using Apache Spark.

## Get started

1. Create a dataproc-enabled [Vertex workbench](https://cloud.google.com/vertex-ai/docs/workbench/instances/create-dataproc-enabled) instance or use an existing instance.
2. Enable [Private Google Access](https://cloud.google.com/dataproc-serverless/docs/concepts/network#private-google-access-requirement) on a subnet in your project.
2. Setup [Public NAT](https://cloud.google.com/nat/docs/set-up-manage-network-address-translation#create-nat-gateway to download Torch model weights. See ["External network access"](https://cloud.google.com/dataproc-serverless/docs/concepts/network#subnetwork_requirements).
3. Create a [serverless runtime template](https://cloud.google.com/dataproc-serverless/docs/quickstarts/jupyterlab-sessions#dataproc_create_serverless_runtime_template-JupyterLab) and connect to a [remote kernel](https://cloud.google.com/vertex-ai/docs/workbench/instances/create-dataproc-enabled#serverless-spark).

### Import libraries

All libraries needed in this notebook are installed in Dataproc Serverless versions 1.2 and 2.2+.

In [None]:
# Uncomment and run this cell if not using Dataproc Serverless version 1.2 or 2.2+.
# pip install torch torchvision google-cloud-storage

In [None]:
# Import the Google Cloud Storage client library and helper functions
from google.cloud import storage
from google.cloud.storage.blob import Blob

# Import Pandas
import pandas as pd

# Import PySpark helper functions
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, pandas_udf
from pyspark.sql.types import ArrayType, FloatType, StringType

# Import Pytorch and helper functions
import torch
from torch.utils.data import Dataset
from torchvision import  transforms
from torchvision.datasets.folder import default_loader
from torchvision.models import resnet50, ResNet50_Weights

### Get data paths

Get the list of URIs for the images to be classified.

In [None]:
# Set the bucket name and number of images to classify. Feel free to experiment with more or less images.
bucket_name = 'cloud-samples-data'
max_results = 50

# Load the bucket 
client = storage.Client()
bucket = client.get_bucket(bucket_name)

# Get the blob URI
blobs = bucket.list_blobs(prefix="generative-ai/image", max_results=max_results)
blob_uris = []
for blob in blobs:
    (blob_uris.append(f"gs://{bucket_name}/{blob.name}") 
    if blob.name.endswith("jpg") else None)

### Configure Spark

Create a Spark session and load the data into Spark.

In [None]:
# Create a SparkSession object.
spark = SparkSession.builder.appName("classificationDemo").getOrCreate()
sc = spark.sparkContext

# Enable Apache Arrow support.
spark.conf.set("spark.sql.execution.arrow.pyspark.enabled", "true")
spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "64")

# Create the Spark dataframe.
files_df = spark.createDataFrame(blob_uris, StringType()) \
                .withColumnRenamed("value", "inputFile") \
                .repartition(10)

### Run the training job

Create a [custom Dataset](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html#creating-a-custom-dataset-for-your-files) object to manage the input data.

In [None]:
# Custom Dataset class
class ImageDataset(Dataset):
    def __init__(self, paths, transform=None):
        self.paths = paths
        self.transform = transform
        
    def __len__(self):
        return len(self.paths)
    
    def __getitem__(self, index):
        client = storage.Client()
        path = self.paths[index]
        # Download file from GCS as image loader needs local file
        blob = Blob.from_string(path, client=client)
        local_file = "/tmp/" + path.split("/")[-1]
        blob.download_to_file(open(local_file, "wb"))
        image = default_loader(local_file)
        if self.transform is not None:
            image = self.transform(image)
        return image

Create a [Pandas UDF](https://spark.apache.org/docs/3.4.2/api/python/reference/pyspark.sql/api/pyspark.sql.functions.pandas_udf.html) that contains the Torch code that will run on each worker to perform model training. This training job uses the [ResNet50](https://spark.apache.org/docs/3.4.2/api/python/reference/pyspark.sql/api/pyspark.sql.functions.pandas_udf.html) model.

In [None]:
@pandas_udf(ArrayType(FloatType()))
def predict_batch_udf(paths: pd.Series) -> pd.Series:

  #Transformation needed on input by Resnet model
  transform = transforms.Compose([
      transforms.Resize(224),
      transforms.CenterCrop(224),
      transforms.ToTensor(),
      transforms.Normalize(mean=[0.485, 0.456, 0.406],
                           std=[0.229, 0.224, 0.225])
  ])
    
  # Create image dataset
  images = ImageDataset(paths, transform=transform)
    
  # Tune batch_size/num_workers based on your workload
  loader = torch.utils.data.DataLoader(images, batch_size=500, num_workers=8)
    
  # Set local directory to hold models
  torch.hub.set_dir("/tmp/models")
    
  # Configure if jobs will run on CPU or GPU
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  
  # Initialize model and load onto device
  model = resnet50(ResNet50_Weights.DEFAULT)
  model.to(device)
  
  # Run predictions
  all_predictions = []
  with torch.no_grad():
    for batch in loader:
      predictions = list(model(batch.to(device)).numpy())
      for prediction in predictions:
        all_predictions.append(prediction)
  return pd.Series(all_predictions)

Execute the training job and convert the output to a Pandas dataframe.

In [None]:
predictions_df = files_df.select(
    col("inputFile"),
    predict_batch_udf(col("inputFile")).alias("predictions"))
predictions = predictions_df.toPandas()

### View outputs

Get labels based on the model output.

In [None]:
weights = ResNet50_Weights.DEFAULT
predictions["label"] = predictions["predictions"].map(lambda x: weights.meta["categories"][x.argmax()])

pd.set_option('display.max_colwidth', None)
predictions.head(10)[["inputFile", "label"]]