In [None]:
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

## Overview

With this notebook, we learn how to do distributed ML inference (image classification) using Dataproc Spark Serverless interactively.

Following steps are performed:
1. Create a [Vertex workbench](https://cloud.google.com/vertex-ai/docs/workbench/instances/create-dataproc-enabled) instance
2. Connect to a remote-notebook using [serverless sessions](https://cloud.google.com/dataproc-serverless/docs/quickstarts/jupyterlab-sessions)
3. Write code in the above notebook which runs on multiple Spark executors
4. We then create a Spark DataFrame of the urls of the images we want to classify. We download a pre-trained Resnet50 model from  on driver, broadcast it to all the executors. Inference is written as a pandas UDF, that runs on each partition of the URLs.

Note: You should first create a notebook mentioned in step 2 above, then import this entire notebook there.

## Get started

## Notebook tutorial

### Import libraries

In [None]:
# All libraries needed in this notebook like torch, torchvision, google-cloud-storage are already installed in serverless. If you need something extra, feel free to do `pip install <library>`

### List images to classify in a GCS bucket

In [None]:
from google.cloud import storage

client = storage.Client()
BUCKET_NAME = 'cloud-samples-data'
bucket = client.get_bucket(BUCKET_NAME)

# Limiting to only 50 images for sample
blobs = bucket.list_blobs(prefix="generative-ai/image", max_results=50)
blobs = filter(lambda blob: get_blob_uri(blob).endswith("jpg"), list(blobs))
from google.cloud.storage.blob import Blob


def get_blob_uri(blob):
    return 'gs://' + blob.id[:-(len(str(blob.generation)) + 1)]

### Set Spark configurations, create Dataframe and broadcast model

In [None]:
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName("classficationDemo").getOrCreate()
sc = spark.sparkContext

# Set to True for GPU enabled serverless sessions/dataproc clusters
cuda = False

# Enable Arrow support.
spark.conf.set("spark.sql.execution.arrow.enabled", "true")
spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "64")

import pandas as pd

import torch
from torch.utils.data import Dataset
from torchvision import datasets, models, transforms
from torchvision.datasets.folder import default_loader  # private API

from pyspark.sql.functions import col, pandas_udf
from pyspark.sql.types import ArrayType, FloatType, StringType


use_cuda = cuda and torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

files_df = spark.createDataFrame(map(lambda file : get_blob_uri(file), blobs), StringType()).repartition(10)

# Downloads and broadcasts the model weights to all the workers
model_state = models.resnet50(pretrained=True).state_dict()
bc_model_state = sc.broadcast(model_state)

### Wrapper class and define Pandas UDF

In [None]:
class ImageDataset(Dataset):
    def __init__(self, paths, transform=None):
        self.paths = paths
        self.transform = transform
    def __len__(self):
        return len(self.paths)
    def __getitem__(self, index):
        client = storage.Client()
        path = self.paths[index]
        blob = Blob.from_string(path, client=client)
        local_file = "/tmp/" + path.split("/")[-1]
        blob.download_to_file(open(local_file, "wb"))
        image = default_loader(local_file)
        if self.transform is not None:
            image = self.transform(image)
        return image

def get_model_for_eval():
    """Gets the broadcasted model to each python worker"""
    torch.hub.set_dir("/tmp/models")
    model = models.resnet50(pretrained=True)
    model.load_state_dict(bc_model_state.value)
    model.eval()
    return model

# Using Pandas UDF for parallel run on each partition
@pandas_udf(ArrayType(FloatType()))
def predict_batch_udf(paths: pd.Series) -> pd.Series:
    transform = transforms.Compose([
        transforms.Resize(224),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
    images = ImageDataset(paths, transform=transform)
    loader = torch.utils.data.DataLoader(images, batch_size=500, num_workers=8)
    model = get_model_for_eval()
    model.to(device)
    all_predictions = []
    with torch.no_grad():
        for batch in loader:
            predictions = list(model(batch.to(device)).cpu().numpy())
            for prediction in predictions:
                all_predictions.append(prediction)
    return pd.Series(all_predictions)

### Call UDF on the DataFrame and write output to file, then read

In [None]:
output_file_path = "/tmp/results"
predictions_df = files_df.select(col("value"),
                                 predict_batch_udf(col("value"))).alias(
    "predictions")
predictions_df.write.mode("overwrite").parquet(output_file_path)

spark.read.parquet(output_file_path).limit(5).show()