In [17]:
from google.cloud import storage
client = storage.Client()
BUCKET_NAME = 'cloud-samples-data'
bucket = client.get_bucket(BUCKET_NAME)

# Limiting to only 50 images for sample
blobs = bucket.list_blobs(prefix="generative-ai/image", max_results = 50)
blobs = filter(lambda blob: get_blob_uri(blob).endswith("jpg"), list(blobs))

In [18]:
from google.cloud.storage.blob import Blob
def get_blob_uri(blob):
    return 'gs://' + blob.id[:-(len(str(blob.generation)) + 1)]


In [19]:
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName("classficationDemo").getOrCreate()
sc = spark.sparkContext

# Set to True for GPU enabled serverless sessions/dataproc clusters
cuda = False

# Enable Arrow support.
spark.conf.set("spark.sql.execution.arrow.enabled", "true")
spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "64")

import pandas as pd

import torch
from torch.utils.data import Dataset
from torchvision import datasets, models, transforms
from torchvision.datasets.folder import default_loader  # private API

from pyspark.sql.functions import col, pandas_udf
from pyspark.sql.types import ArrayType, FloatType, StringType


use_cuda = cuda and torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

files_df = spark.createDataFrame(map(lambda file : get_blob_uri(file), blobs), StringType()).repartition(10)

# Downloads and broadcasts the model weights to all the workers
model_state = models.resnet50(pretrained=True).state_dict()
bc_model_state = sc.broadcast(model_state)

def get_model_for_eval():
  """Gets the broadcasted model to each python worker"""
  torch.hub.set_dir("/tmp/models")
  model = models.resnet50(pretrained=True)
  model.load_state_dict(bc_model_state.value)
  model.eval()
  return model

class ImageDataset(Dataset):
  def __init__(self, paths, transform=None):
    self.paths = paths
    self.transform = transform
  def __len__(self):
    return len(self.paths)
  def __getitem__(self, index):
    client = storage.Client()
    path = self.paths[index]
    blob = Blob.from_string(path, client=client)
    local_file = "/tmp/" + path.split("/")[-1]
    blob.download_to_file(open(local_file, "wb"))
    image = default_loader(local_file)
    if self.transform is not None:
      image = self.transform(image)
    return image

# Using Pandas UDF for parallel run on each partition
@pandas_udf(ArrayType(FloatType()))
def predict_batch_udf(paths: pd.Series) -> pd.Series:
  transform = transforms.Compose([
    transforms.Resize(224),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
  ])
  images = ImageDataset(paths, transform=transform)
  loader = torch.utils.data.DataLoader(images, batch_size=500, num_workers=8)
  model = get_model_for_eval()
  model.to(device)
  all_predictions = []
  with torch.no_grad():
    for batch in loader:
      predictions = list(model(batch.to(device)).cpu().numpy())
      for prediction in predictions:
        all_predictions.append(prediction)
  return pd.Series(all_predictions)

output_file_path = "/tmp/results"
predictions_df = files_df.select(col("value"), predict_batch_udf(col("value"))).alias("predictions")
predictions_df.write.mode("overwrite").parquet(output_file_path)



24/11/11 15:56:37 WARN SQLConf: The SQL config 'spark.sql.execution.arrow.enabled' has been deprecated in Spark v3.0 and may be removed in the future. Use 'spark.sql.execution.arrow.pyspark.enabled' instead of it.
                                                                                

In [20]:
spark.read.parquet(output_file_path).limit(5).show()

+--------------------+------------------------+
|               value|predict_batch_udf(value)|
+--------------------+------------------------+
|gs://cloud-sample...|    [-1.4993658, -1.5...|
|gs://cloud-sample...|    [-3.9951859, -0.9...|
|gs://cloud-sample...|    [-2.278611, -1.12...|
|gs://cloud-sample...|    [-2.4935248, -0.2...|
|gs://cloud-sample...|    [-2.3286686, 0.73...|
+--------------------+------------------------+

