# Leveraging Spark to distribute computer vision

In computer vision, a common task is to bulk process videos and run detection algorithms on them 

In [None]:
%pip install -U timm transformers ffmpeg torchcodec
%restart_python

## Setup

First lets configure the location of video files and UC Catalog / Schema etc

In [None]:
import os
from transformers import DetrFeatureExtractor, DetrForObjectDetection, DetrImageProcessor
from PIL import Image
import torch
import numpy as np
from itertools import chain

db_catalog = 'brian_ml_dev'
db_schema = 'image_processing'
processed_videos = 'processed_video'
data_table = 'silver_detr_results'

video_path = f'/Volumes/{db_catalog}/{db_schema}/{processed_videos}'
print(video_path)

# quick review videos
video_files = os.listdir(video_path)
full_path = [os.path.join(video_path, x) for x in video_files ]
# video_files

To distribute the processing of the video files, we need to create a spark dataframe with all the filepaths for retrieval

In [None]:
from pyspark.sql.types import StructType, StructField, StringType, IntegerType

schema = StructType([
    StructField("src", StringType(), True)
])

sourcing_df = spark.createDataFrame([(item,) for item in full_path], schema=schema)
display(sourcing_df)

In [None]:
# Instantiate Models for testing functions
feature_extractor = DetrFeatureExtractor.from_pretrained("facebook/detr-resnet-50")
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")

model.eval()
model.to("cuda" if torch.cuda.is_available() else "cpu")  # Use GPU if available

### Video Processing - Torchcodec

PyTorch is the standard for deep learning currently. Torch has it's own Torchcodec library that can be more efficient than opencv is a torch model is being used for inference.

In [None]:
import torch
import torchvision
from torchcodec.decoders import VideoDecoder
from torchvision.transforms.functional import convert_image_dtype
from torchvision.transforms import Resize

# With this code, we do the preprocessing via torch modules so we don't need the detr feature extractor

# we need to add frame index etc
def process_file_w_torchcodec(video_path: str, model, processor, batch_size=8, device='cuda'):
    
    reader = VideoDecoder(video_path)
    
    first_frame = reader[0]
    image_size = (first_frame.shape[1], first_frame.shape[2])
    #reader.set_current_stream("video")

    # Preallocate storage
    frames = []
    frame_indices = []
    batch = []
    results = []
    
    for idx, frame in enumerate(reader):
        
        ### debug line 
        if idx >=30:
            break
        #frame_tensor = frame['data']  # Shape: (C, H, W), dtype: uint8
        frame_tensor = convert_image_dtype(frame, dtype=torch.float32)  # Normalize [0,1]

        image_size = (first_frame.shape[1], first_frame.shape[2])

        # Resize if needed (e.g., model requires specific input size)
        frame_tensor = Resize((480, 854))(frame_tensor)  # Example: resize to 854x480

        # we would need to adjust this
        frames.append(frame_tensor)
        frame_indices.append(idx)
        batch.append(frame_tensor)

        if len(batch) == batch_size:
            batch_tensor = torch.stack(batch).to(device)

            with torch.no_grad():
                outputs = model(batch_tensor)
                
                # we need to understand this function a bit more....
                processed_outputs = processor.post_process_object_detection(
                    outputs,
                    threshold=0.5,  # Score threshold
                    target_sizes=torch.tensor([image_size] * len(batch))
                )

            for i, frame_output in enumerate(processed_outputs):
                #print(frame_output.keys())

                annotations = []

                for score, label, box in zip(frame_output['scores'].cpu().numpy(),
                                             frame_output['labels'].cpu().numpy(),
                                             frame_output['boxes'].cpu().numpy()):
                
                    annotations.append({
                        "frame_index": frame_indices[i],
                        "score": score,
                        'label': label,
                        'box': box
                    })

                results.append({
                    'video_path': video_path,
                    'frame': frames[i].cpu().numpy(),
                    'frame_index': frame_indices[i],
                    'annotations': annotations

                })
                
            batch.clear()

    # Process remaining frames
    if batch:
        print('entering last batch')
        batch_tensor = torch.stack(batch).to(device)
        with torch.no_grad():
            outputs = model(batch_tensor)

            processed_final_outputs = processor.post_process_object_detection(
                outputs,
                threshold=0.5,  # Score threshold
                target_sizes=torch.tensor([image_size] * len(batch_tensor))
            )

            for i, frame_output in enumerate(processed_final_outputs):
                annotations = []

                for score, label, box in zip(frame_output['scores'].cpu().numpy(),
                                             frame_output['labels'].cpu().numpy(),
                                             frame_output['boxes'].cpu().numpy()):
                
                    annotations.append({
                        'frame_index': frame_indices[i],
                        'score': float(score.item()),
                        'label': label,
                        'box': [float(v) for v in box] 
                    })

                results.append({
                    'video_path': video_path,
                    'frame': frames[i].cpu().numpy(),
                    'frame_index': frame_indices[i],
                    'annotations': annotations

                })

    return results



In [None]:
file_to_check = os.path.join(video_path, os.listdir(video_path)[0])
print(file_to_check)

results = process_file_w_torchcodec(file_to_check, model, processor)


In [None]:
results

# Distributing on Spark Cluster

Now that we have tested our functions, we can distribute it across a full spark cluster.

In [None]:
# Load Model Function - to start the process

model = None
feature_extractor = None

def load_model():
    """Load model per worker process (lazy initialization)."""
    global model, feature_extractor, processor
    if model is None or feature_extractor is None or processor is None:
        feature_extractor = DetrFeatureExtractor.from_pretrained("facebook/detr-resnet-50")
        model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
        processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
        
        model.eval()
        model.to("cuda" if torch.cuda.is_available() else "cpu")  # Use GPU if available

In order to run a python function across a pyspark cluster we need to wrap it into a pandas udf for best performance and distribution

In [None]:
# udf transformers
import pandas as pd
from pyspark.sql.functions import col, pandas_udf
from pyspark.sql.types import ArrayType, StructType, StructField, FloatType, IntegerType, StringType, BinaryType
import io

def python_batch_process(file_iter: pd.Series) -> pd.DataFrame:

    # run the model load
    load_model()

    for file_batch in file_iter:
        batch_results = []

        for video_path in file_batch["src"]:
            video_results = process_file_w_torchcodec(video_path, model, processor)

            frames = []
            annotations = []
            frame_indexes = []

            for dict_obj in video_results:
                frame = dict_obj["frame"]
                frame_index = dict_obj["frame_index"]
                frame_annotations = dict_obj["annotations"]

                try:
                # Step 1: Transpose dimensions to (height, width, channels)
                    frame_transposed = np.transpose(frame, (1, 2, 0))

                    # Step 2: Convert to uint8 (if values are in 0–1, scale to 0–255 first)
                    if frame_transposed.dtype == np.float32:
                        frame_transposed = (frame_transposed * 255).astype(np.uint8)
                    else:
                        frame_transposed = frame_transposed.astype(np.uint8)

                    try:
                        img = Image.fromarray(frame_transposed)
                        byte_stream = io.BytesIO()
                        img.save(byte_stream, format="PNG")
                        frame_bytes = byte_stream.getvalue()

                        dict_obj["frame"] = frame_bytes
                        dict_obj["frame_index"] = float(dict_obj["frame_index"])

                    except TypeError:
                        print('failed')
                        failed_block = dict_obj
                        #print(dict_obj)
                        break
        
                except ValueError:
                    print('failed reshaping')
                    failed_block = dict_obj
                    break

        yield pd.DataFrame(video_results)

In [None]:
result = sourcing_df.mapInPandas(python_batch_process, schema=StructType([
        StructField("video_path", StringType(), True),
        StructField("frame", BinaryType(), True),  # Encoded images
        StructField("frame_index", IntegerType(), True),
        StructField("annotations", ArrayType(
            StructType([
                StructField("frame_index", IntegerType(), True),
                StructField("score", FloatType(), True),
                StructField("label", IntegerType(), True),
                StructField("box", ArrayType(FloatType()), True)
            ])
        ), True)
    ]))    

In [None]:
display(result)

TODO: Issue with the frame indexing

In [None]:
result.write.mode('overwrite').saveAsTable(f"`{db_catalog}`.`{db_schema}`.silver_detr_results_w_frame")

# Convert to COCO Output for Finetuning

We can convert the output to coco format for finetuning as well

In [None]:
result = spark.sql(f"SELECT * FROM {db_catalog}.{db_schema}.silver_detr_results_w_frame")

In [None]:
from transformers import AutoConfig

# Load config from pretrained RT-DETR model
config = AutoConfig.from_pretrained("facebook/detr-resnet-50")

# Get category mappings
id2label = config.id2label
label2id = config.label2id

# Convert to COCO categories format
base_coco_categories = [
    {"id": int(k), "name": v, "supercategory": "none"}
    for k, v in id2label.items()
]

In [None]:
import os
import io
import json
import uuid
from PIL import Image

storage_location = 'coco_dataset'

output_root = f'/Volumes/{db_catalog}/{db_schema}/{storage_location}'
images_dir = os.path.join(output_root, "images")
os.makedirs(images_dir, exist_ok=True)

coco = {
    "images": [],
    "annotations": [],
    "categories": []
}

used_category_ids = set()
annotation_id = 1

df_collected = result.select("video_path", "frame", "annotations").collect()

for idx, row in enumerate(df_collected):
    frame_bytes = row["frame"]
    annotations = row["annotations"]

    image_id = idx
    filename = f"image_{idx}.jpg"
    file_path = os.path.join(images_dir, filename)

    # Save image
    img = Image.open(io.BytesIO(frame_bytes))
    img.save(file_path)
    width, height = img.size

    # Image entry
    coco["images"].append({
        "id": idx,
        "file_name": f"images/{filename}",
        "width": width,
        "height": height
    })

    # Annotations
    if annotations:
        for ann in annotations:
            category_id = int(ann["label"])
            used_category_ids.add(category_id)

            x, y, x2, y2 = ann["box"]
            w, h = x2 - x, y2 - y

            coco["annotations"].append({
                "id": annotation_id,
                "image_id": idx,
                "category_id": category_id,
                "bbox": [x, y, w, h],
                "area": w * h,
                "iscrowd": 0
            })
            annotation_id += 1

# Add categories
coco["categories"] = [
    cat for cat in base_coco_categories if cat["id"] in used_category_ids
]

# Save to JSON
with open(os.path.join(output_root, "annotations.json"), "w") as f:
    json.dump(coco, f)