# Standard Object Detection

In [None]:
import requests
from PIL import Image

url = "https://s3-us-west-2.amazonaws.com/air-example-data/AnimalDetection/JPEGImages/2007_000063.jpg"
img = Image.open(requests.get(url, stream=True).raw)
display(img)

In [None]:
from torchvision import transforms
from torchvision.models.detection import fasterrcnn_resnet50_fpn_v2, FasterRCNN_ResNet50_FPN_V2_Weights

weights = FasterRCNN_ResNet50_FPN_V2_Weights.DEFAULT
model = fasterrcnn_resnet50_fpn_v2(weights=weights, box_score_thresh=0.9)
model.eval()

In [3]:
img = transforms.Compose([transforms.PILToTensor()])(img)
preprocess = weights.transforms()
batch = [preprocess(img)]

In [4]:
prediction = model(batch)[0]

In [None]:
from torchvision.utils import draw_bounding_boxes
from torchvision.transforms.functional import to_pil_image

labels = [weights.meta["categories"][i] for i in prediction["labels"]]
box = draw_bounding_boxes(img,
                          boxes=prediction["boxes"],
                          labels=labels,
                          colors="red",
                          width=4)
im = to_pil_image(box.detach())
display(im)

# Scaling with Ray

## Preprocessing

In [None]:
import ray
ray.init(_temp_dir='/taiga/mohanar2/raytmp', ignore_reinit_error=True)
ds = ray.data.read_images("./AnimalDetection")
display(ds.schema())

In [14]:
# TODO: Use stateful transforms to load model weights only once

import numpy as np
import torch
from torchvision import transforms
from torchvision.models.detection import (FasterRCNN_ResNet50_FPN_V2_Weights,
                                          fasterrcnn_resnet50_fpn_v2)
from typing import Dict


def preprocess_image(data: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
    weights = FasterRCNN_ResNet50_FPN_V2_Weights.DEFAULT
    preprocessor = transforms.Compose(
        [transforms.ToTensor(), weights.transforms()]
    )
    return {
        "image": data["image"],
        "transformed": preprocessor(data["image"]),
    }

In [15]:
ds = ds.map(preprocess_image)

In [None]:
single_batch = ds.take_batch(batch_size=3)
display(single_batch)

## Model Inference

In [17]:
class ObjectDetectionModel:
    def __init__(self):
        # Define the model loading and initialization code in `__init__`.
        self.weights = FasterRCNN_ResNet50_FPN_V2_Weights.DEFAULT
        self.model = fasterrcnn_resnet50_fpn_v2(
            weights=self.weights,
            box_score_thresh=0.9,
        )
        if torch.cuda.is_available():
            # Move the model to GPU if it's available.
            self.model = self.model.cuda()
        self.model.eval()

    def __call__(self, input_batch: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
        # Define the per-batch inference code in `__call__`.
        batch = [torch.from_numpy(image) for image in input_batch["transformed"]]
        if torch.cuda.is_available():
            # Move the data to GPU if it's available.
            batch = [image.cuda() for image in batch]
        predictions = self.model(batch)
        return {
            "image": input_batch["image"],
            "labels": [pred["labels"].detach().cpu().numpy() for pred in predictions],
            "boxes": [pred["boxes"].detach().cpu().numpy() for pred in predictions],
        }

In [18]:
ds = ds.map_batches(
    ObjectDetectionModel,
    # Use 1 GPUs. Change this number based on the number of GPUs in your cluster.
    concurrency=1,
    batch_size=4,  # Use the largest batch size that can fit in GPU memory.
    # Specify 1 GPU per model replica. Remove this if you are doing CPU inference.
    num_gpus=1,
)

In [None]:
from torchvision.transforms.functional import convert_image_dtype, to_tensor
count = 0
batch = ds.take_batch(batch_size=30)
for image, labels, boxes in zip(batch["image"], batch["labels"], batch["boxes"]):
    image = convert_image_dtype(to_tensor(image), torch.uint8)
    labels = [weights.meta["categories"][i] for i in labels]
    boxes = torch.from_numpy(boxes)
    img = to_pil_image(draw_bounding_boxes(
        image,
        boxes,
        labels=labels,
        colors="red",
        width=4,
    ))
    display(img)
    count += 1
print(count)