In [0]:
import time
import torch
torch.__version__
torch.cuda.is_available()

Out[1]: True

In [0]:
#print("Profiling enabled: ", spark.conf.get("spark.python.profile"))
print("Executor memory: ", spark.conf.get("spark.executor.memory"))

Executor memory:  148728m


In [0]:
import pandas as pd
from torchvision import transforms
from torchvision.models import resnet50, ResNet50_Weights



In [0]:
# Enable Arrow support.
spark.conf.set("spark.sql.execution.arrow.enabled", "true")

In [0]:
# Create and broadcast model state. Equivalent to AIR Checkpoint
model_state = resnet50(weights=ResNet50_Weights.DEFAULT).state_dict()
# sc is already initialized by Databricks. Broadcast the model state to all executors.
bc_model_state = sc.broadcast(model_state)

Downloading: "https://download.pytorch.org/models/resnet50-11ad3fa6.pth" to /root/.cache/torch/hub/checkpoints/resnet50-11ad3fa6.pth


  0%|          | 0.00/97.8M [00:00<?, ?B/s]

In [0]:
dbutils.fs.ls("/preprocessed_data/")

Out[8]: [FileInfo(path='dbfs:/preprocessed_data/_SUCCESS', name='_SUCCESS', size=0, modificationTime=1674513911000),
 FileInfo(path='dbfs:/preprocessed_data/_committed_7519231348214228538', name='_committed_7519231348214228538', size=1599, modificationTime=1674513911000),
 FileInfo(path='dbfs:/preprocessed_data/_started_7519231348214228538', name='_started_7519231348214228538', size=0, modificationTime=1674513904000),
 FileInfo(path='dbfs:/preprocessed_data/part-00000-tid-7519231348214228538-7e2ad37e-a18d-40be-b39c-d6894b3d3059-1-1.c000.snappy.parquet', name='part-00000-tid-7519231348214228538-7e2ad37e-a18d-40be-b39c-d6894b3d3059-1-1.c000.snappy.parquet', size=191342118, modificationTime=1674513908000),
 FileInfo(path='dbfs:/preprocessed_data/part-00001-tid-7519231348214228538-7e2ad37e-a18d-40be-b39c-d6894b3d3059-2-1.c000.snappy.parquet', name='part-00001-tid-7519231348214228538-7e2ad37e-a18d-40be-b39c-d6894b3d3059-2-1.c000.snappy.parquet', size=191342118, modificationTime=167451390700

In [0]:
input_data = spark.read.format("parquet").load("/preprocessed_data/")

In [0]:
spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "1000")

In [0]:
from pyspark.sql.functions import col, pandas_udf
from pyspark.sql.types import ArrayType, FloatType

import numpy as np

import torch

@pandas_udf(ArrayType(FloatType()))
def predict(preprocessed_images: pd.Series) -> pd.Series:
    with torch.inference_mode():
        model = resnet50()
        model.load_state_dict(bc_model_state.value)
        model = model.to(torch.device("cuda")) # Move model to GPU
        model.eval()
        
        batch = preprocessed_images
        batch_dim = len(batch)
        numpy_batch = np.stack(batch.values)
        # Spark has no tensor support, so it flattens the image tensor to a single array during read.
        # Each image is represented as a flattened numpy array.
        # We have to reshape back to the original number of dimensions.
        reshaped_images = numpy_batch.reshape(batch_dim, 3, 224, 224)
        gpu_batch = torch.Tensor(reshaped_images).to(torch.device("cuda"))
        predictions = list(model(gpu_batch).cpu().numpy())
        assert len(predictions) == batch_dim
        
        return pd.Series(predictions)

In [0]:
predictions = input_data.select(predict(col("preprocess(image)")))

In [0]:
start_time = time.time()
predictions.write.mode("overwrite").format("noop").save()
end_time = time.time()
print(f"Prediction took: {end_time-start_time} seconds")

assert predictions.count() == 16232

Prediction took: 99.5347626209259 seconds
