In [1]:
import pandas as pd

import numpy as np
import os
import requests
from PIL import Image

import torch
from torchvision import transforms
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from datasets import concatenate_datasets
from datasets import load_dataset
from datasets import Dataset as HF_Dataset
from torchvision.transforms import functional as F

from tqdm.notebook import tqdm
from ultralytics import YOLO
from utils import resize_preserve_aspect_ratio

In [2]:
def get_num_people(result):
    if bool(result):
        labels_dict = result.names
        det_labels = [int(i.item()) for i in result.boxes.cls]
        det_labels = [labels_dict[i] for i in det_labels]
        num_people = len(det_labels)
        return num_people
    else:
        return 0

In [3]:
def pad_to_fixed_size(img, size=(640, 640)):
    width, height = img.size
    # Calculate padding
    left = (size[0] - width) // 2
    top = (size[1] - height) // 2
    right = size[0] - width - left
    bottom = size[1] - height - top

    # Apply padding
    img_padded = F.pad(img, padding=(left, top, right, bottom))
    return img_padded

In [4]:
class PinterestDataset(Dataset):
    def __init__(self, dataset_id=None, image_col="image", image_id_col=None):
        self.ds = load_dataset(dataset_id, split="train", trust_remote_code=True)
        self.image_col = image_col
        self.image_id_col = image_id_col
        self.imgsz = 640
        self.half = False
        self.transform = transforms.Compose(
            [
                transforms.Lambda(lambda img: pad_to_fixed_size(img, (640, 640))),
                transforms.ToTensor(),
            ]
        )

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, idx):
        item = self.ds[idx]

        image_pil = item[self.image_col].convert("RGB")
        
        image = resize_preserve_aspect_ratio(image_pil, self.imgsz)

        image = self.transform(image).unsqueeze(0)

        if self.image_id_col:
            image_id = item[self.image_id_col]
        else:
            image_id = idx

        return {
            "image_id": image_id,
            "image": image,
            "image_pil": image_pil
        }

In [5]:
# Load the YOLO Model
path = "weights/yolov8n-seg.pt"
model = YOLO(path, verbose=False)

In [6]:
# Enter the dataset ID and load it as a torch dataset
dataset_id = "amyeroberts/filtered-coco-panoptic-val2017-people"
ds = PinterestDataset(dataset_id=dataset_id, image_col="image", image_id_col=None)

In [7]:
# Dataloader
def collate_fn(ex):
    images = torch.cat([e["image"] for e in ex], dim=0)
    image_ids = [e["image_id"] for e in ex]
    return dict(images=images, image_ids=image_ids)

workers = os.cpu_count()

dataloader = DataLoader(
    ds,
    collate_fn=collate_fn,
    batch_size=16,
    num_workers=workers,
    pin_memory=False,
    shuffle=False,
)

In [8]:
num_people_results = {}

max_people = 0
with tqdm(total=len(dataloader)) as pbar:
    for batch in dataloader:
        with torch.no_grad():
            images = batch.get("images").to("cuda")
            results = model(images, classes=0, verbose=False)

        num_people = [get_num_people(result) for result in results]
        max_people_batch = max(num_people)
        if max_people_batch > max_people:
            max_people = max_people_batch
            print(f"Max people detected: {max_people}", end="\r")
        image_ids = batch.get("image_ids")

        result = dict(zip(image_ids, num_people))
        num_people_results.update(result)
        pbar.update(1)

  0%|          | 0/103 [00:00<?, ?it/s]

Max people detected: 30

In [9]:
# Load the results into a pandas dataframe
df = pd.DataFrame(num_people_results.items(), columns=["image_id_processed", "num_people"])

# Convert that to a HF Dataset
ds_people = HF_Dataset.from_pandas(df)

# Concatenate the two datasets
new_ds = concatenate_datasets([ds.ds, ds_people], axis=1)

# View the results as a pandas dataframe
df = new_ds.to_pandas()
df

Unnamed: 0,label,segments_info,image_id,image,image_id_processed,num_people
0,{'bytes': b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHD...,"[{'id': 3226956, 'category_id': 1, 'iscrowd': ...",139,{'bytes': b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x...,0,1
1,{'bytes': b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHD...,"[{'id': 12234672, 'category_id': 1, 'iscrowd':...",872,{'bytes': b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x...,1,2
2,{'bytes': b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHD...,"[{'id': 2170142, 'category_id': 1, 'iscrowd': ...",885,{'bytes': b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x...,2,3
3,{'bytes': b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHD...,"[{'id': 3421582, 'category_id': 1, 'iscrowd': ...",1000,{'bytes': b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x...,3,13
4,{'bytes': b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHD...,"[{'id': 1447707, 'category_id': 1, 'iscrowd': ...",1268,{'bytes': b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x...,4,4
...,...,...,...,...,...,...
1643,{'bytes': b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHD...,"[{'id': 4936566, 'category_id': 1, 'iscrowd': ...",579070,{'bytes': b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x...,1643,7
1644,{'bytes': b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHD...,"[{'id': 6775163, 'category_id': 1, 'iscrowd': ...",579307,{'bytes': b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x...,1644,2
1645,{'bytes': b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHD...,"[{'id': 2364953, 'category_id': 1, 'iscrowd': ...",579818,{'bytes': b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x...,1645,2
1646,{'bytes': b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHD...,"[{'id': 3750203, 'category_id': 1, 'iscrowd': ...",580197,{'bytes': b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x...,1646,3
