In [None]:
from __future__ import annotations

import io
from typing import Any, Dict, List

import av
import torch
import torchvision
from PIL import Image
from pyspark.sql.functions import col, explode, pandas_udf
from pyspark.sql.types import (
    ArrayType,
    BinaryType,
    FloatType,
    IntegerType,
    StringType,
    StructField,
    StructType,
)
from ultralytics import YOLO

In [None]:
%%configure -f
{
  "executorCores": 1,
  "conf": {
    "spark.sql.execution.arrow.maxRecordsPerBatch": "64"
  }
}

In [None]:
_model_cache = {"model": None}


def get_model():
    if _model_cache["model"] is None:
        model = YOLO("yolo11n.pt")
        if torch.cuda.is_available():
            model.to("cuda")
        _model_cache["model"] = model
    return _model_cache["model"]


@pandas_udf(ArrayType(BinaryType()))
def decode_video_udf(video_bytes_iter):
    import pandas as pd

    results = []
    for video_bytes in video_bytes_iter:
        frames = []
        with av.open(io.BytesIO(video_bytes)) as container:
            for frame in container.decode(video=0):
                img = frame.to_ndarray(format="rgb24")
                pil_img = Image.fromarray(img).resize((640, 640))
                buf = io.BytesIO()
                pil_img.save(buf, format="PNG")
                frames.append(buf.getvalue())
        results.append(frames)
    return pd.Series(results)


feature_schema = ArrayType(
    StructType(
        [
            StructField("label", StringType(), False),
            StructField("confidence", FloatType(), False),
            StructField("bbox", ArrayType(IntegerType()), False),
        ]
    )
)


def to_features(res: Any) -> List[Dict[str, Any]]:
    return [
        {
            "label": res.names[int(cls_id)],
            "confidence": float(conf),
            "bbox": [int(v) for v in bbox.tolist()],
        }
        for cls_id, conf, bbox in zip(res.boxes.cls, res.boxes.conf, res.boxes.xyxy)
    ]


@pandas_udf(feature_schema)
def extract_image_features_udf(images):
    import pandas as pd

    if len(images) == 0:
        return []
    model = get_model()
    tensors = [torchvision.transforms.functional.to_tensor(Image.open(io.BytesIO(img))) for img in images]
    stack = torch.stack(tensors, dim=0)
    results = model(stack)
    return pd.Series([to_features(r) for r in results])


@pandas_udf(BinaryType())
def crop_udf(frame_bytes_iter, bbox_iter):
    import pandas as pd

    outputs = []
    for frame_bytes, bbox in zip(frame_bytes_iter, bbox_iter):
        try:
            img = Image.open(io.BytesIO(frame_bytes)).convert("RGB")
            x1, y1, x2, y2 = bbox
            cropped = img.crop((x1, y1, x2, y2))
            buf = io.BytesIO()
            cropped.save(buf, format="PNG")
            outputs.append(buf.getvalue())
        except Exception:
            outputs.append(None)
    return pd.Series(outputs)

In [None]:
df = spark.read.format("binaryFile").load("s3://daft-public-data/videos/Hollywood2-actions-videos/Hollywood2/AVIClips/")
df = df.withColumn("frame", decode_video_udf(col("content")))
df = df.withColumn("frame", explode(col("frame")))
df = df.checkpoint()
df = df.withColumn("features", extract_image_features_udf(col("frame")))
df = df.withColumn("feature", explode(col("features")))
df = df.checkpoint()
df = df.withColumn("object", crop_udf(col("frame"), col("feature.bbox")))
df = df.drop("content", "frame")
df.write.mode("append").parquet("s3://eventual-dev-benchmarking-results/ai-benchmark-results/video-object-detection-result")