In [None]:
import io

import numpy as np
import pandas as pd
import torch
from PIL import Image
from torchvision import transforms
from torchvision.models import resnet18, ResNet18_Weights
from pyspark.sql.functions import col, pandas_udf
from pyspark.sql.types import StringType, ArrayType, FloatType

In [None]:
%%configure -f
{
  "conf": {
    "spark.sql.execution.arrow.maxRecordsPerBatch": "100"
  }
}

In [None]:
_model_cache = {"model": None, "weights": None, "device": None}


def get_model():
    import os

    os.environ["TORCH_HOME"] = "/tmp/torch"
    os.environ["XDG_CACHE_HOME"] = "/tmp"

    if _model_cache["model"] is None:
        device = "cuda"
        weights = ResNet18_Weights.DEFAULT
        model = resnet18(weights=weights).eval().to(device)
        _model_cache["model"] = model
        _model_cache["weights"] = weights
        _model_cache["device"] = device
    return _model_cache["model"], _model_cache["weights"], _model_cache["device"]

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    ResNet18_Weights.DEFAULT.transforms()
])

@pandas_udf(ArrayType(FloatType()))
def decode_and_preprocess_image_udf(image_data_series: pd.Series) -> pd.Series:
    decoded_images = []
    
    for image_data in image_data_series:
        if image_data is None:
            decoded_images.append(None)
            continue
            
        try:
            image = np.array(Image.open(io.BytesIO(image_data)).convert("RGB"))
            
            if len(image.shape) != 3:
                raise ValueError(f"Invalid image shape: {image.shape}")
            
            decoded_images.append(transform(image).flatten().tolist())
            
        except Exception as e:
            print(f"Error decoding image: {e}")
            decoded_images.append(None)
    
    return pd.Series(decoded_images)

@pandas_udf(StringType())
def predict_batch_udf(norm_images: pd.Series):
    model, weights, device = get_model()
    try:
        np_batch = np.vstack(norm_images.tolist())
        np_batch_reshaped = np_batch.reshape(-1, 3, 224, 224).astype(np.float32)
    except ValueError as e:
        print(f"Error reshaping tensor: {e}")
        return pd.Series([None] * len(norm_images))

    torch_batch = torch.from_numpy(np_batch_reshaped).to(device)
    with torch.inference_mode():
        prediction = model(torch_batch)
        predicted_classes = prediction.argmax(dim=1).detach().cpu()
        predicted_labels = [
            weights.meta["categories"][i] for i in predicted_classes
        ]
    
    return pd.Series(predicted_labels)

In [None]:
paths = spark.read.parquet("s3://daft-public-datasets/imagenet/benchmark").collect()
paths = [row.image_url for row in paths]

df = spark.read.format("binaryFile").load(paths)
df = (
    df.withColumn("processed_image", decode_and_preprocess_image_udf(col("content")))
    .filter(col("processed_image").isNotNull())
    .withColumn("label", predict_batch_udf(col("processed_image")))
    .select("path", "label")
)

df.write.mode("append").parquet(
    "s3://eventual-dev-benchmarking-results/ai-benchmark-results/image-classification-results"
)