### Init

In [None]:
import os

import torch
from torch.utils.data import DataLoader, Dataset
from tqdm.notebook import tqdm

from datasets import load_dataset
from scripts.sam_results import SAMResults
from utils import (get_coco_style_polygons, pad_to_fixed_size,
                   resize_preserve_aspect_ratio)

In [None]:
from grounded_sam import (
    run_grounded_sam_batch,
    transform_image_dino,
    transform_image_sam,
)

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

### Config

In [None]:
import yaml


def load_yaml(path):
    with open(path, "r") as file:
        data = yaml.load(file, Loader=yaml.FullLoader)
    return data


config_path = "configs/fashion_people.yml"
data = load_yaml(config_path)


def get_labels_dict(config_path):
    data = load_yaml(config_path)
    labels_dict = data.get("names")
    labels_dict = {v: k for k, v in labels_dict.items()}
    return labels_dict


config_path = "configs/fashion_people.yml"
data = load_yaml(config_path)
labels_dict = get_labels_dict(config_path)

In [None]:
def get_masks_md(results):
    results_list = []

    for result in results.formatted_results:
        mask = result.get("mask")
        coco_polygons = get_coco_style_polygons(mask)

        # format the polygons
        result.update({"polygons": coco_polygons})
        result.pop("mask")
        results_list.append(result)
    return results_list

In [None]:
labels = [k for k, v in labels_dict.items()]
text_prompt = " . ".join(labels)
text_prompt

### Dataset & Dataloader

In [None]:
def pil_image(image_pil):
    if image_pil.mode != "RGB":
        image_pil = image_pil.convert("RGB")
    image_pil = resize_preserve_aspect_ratio(image_pil, 1024)
    image_pil = pad_to_fixed_size(image_pil, (1024, 1024))
    return image_pil


class Segmentation(Dataset):
    def __init__(self, dataset_id=None, image_col="image", image_id_col=None):
        self.ds = load_dataset(
            dataset_id, split="train", trust_remote_code=True, num_proc=os.cpu_count()
        )
        self.image_col = image_col
        self.image_id_col = image_id_col
        self.imgsz = 1024

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, idx):
        item = self.ds[idx]

        # Get Image ID defaults to index
        image_id = item.get(self.image_id_col, idx)

        # Get PIL Image
        image_pil = item[self.image_col]

        image_size = image_pil.size
        if image_size[0] != self.imgsz or image_size[1] != self.imgsz:
            image_pil = pil_image(image_pil)

        # Process dino image
        dino_image = transform_image_dino(image_pil)

        # Process sam image
        sam_image = transform_image_sam(image_pil)

        return {"image_id": image_id, "dino_image": dino_image, "sam_image": sam_image, "image_pil": image_pil}

In [None]:
# Enter the dataset ID and load it as a torch dataset
dataset_id = "jordandavis/fashion_test"
ds = Segmentation(dataset_id=dataset_id, image_col="image", image_id_col=None)

In [None]:
# resize all images in dataset using the pil_image function
def resize_images(dataset, size=(1024, 1024)):
    for i in range(len(dataset)):
        image = dataset[i]["image"]
        image = pil_image(image)
        dataset[i]["image"] = image
        dataset[i]["width"] = size[0]
        dataset[i]["height"] = size[1]
    return dataset

ds.ds = resize_images(ds.ds)

In [None]:
# Dataloader


def collate_fn(ex):
    dino_images = torch.stack([e["dino_image"] for e in ex])
    sam_images = torch.stack([e["sam_image"] for e in ex])
    image_ids = [e["image_id"] for e in ex]
    pil_images = [e["image_pil"] for e in ex]
    return dict(image_ids=image_ids, dino_images=dino_images, sam_images=sam_images, pil_images=pil_images)


batch_size = 3
workers = os.cpu_count()
if workers > batch_size:
    num_workers = batch_size
else:
    num_workers = workers

dataloader = DataLoader(
    ds,
    collate_fn=collate_fn,
    batch_size=batch_size,
    num_workers=num_workers,
    pin_memory=False,
    shuffle=False,
)

### Run Inference

In [None]:
os.environ["TOKENIZERS_PARALLELISM"] = "false"
mask_metadata = {}

with tqdm(total=len(dataloader)) as pbar:
    for batch in dataloader:
        image_ids = batch.get("image_ids")
        images = batch.get("pil_images")

        with torch.no_grad():
            dino_images = batch.get("dino_images").to(device)
            sam_images = batch.get("sam_images").to(device)
            raw_results = run_grounded_sam_batch(dino_images, sam_images, text_prompt)

        for image_id, image, raw_result in zip(image_ids, images, raw_results):
            if raw_result.get('masks') is None or ('person' not in raw_result.get("phrases")):
                mask_md_row = dict(zip(str(image_id), [None]))

            else:

                result = SAMResults(image, labels_dict, **raw_result)
                mask_md = get_masks_md(result)

                mask_md_row = dict(zip(str(image_id), [mask_md]))

            mask_metadata.update(mask_md_row)
        pbar.update(1)

In [None]:
import gc
gc.collect() 
torch.cuda.empty_cache()

In [21]:
import json
from datasets import Value
from huggingface_hub import create_repo
updated_ds = ds.ds
updated_ds = updated_ds.add_column("mask_metadata", mask_metadata.values())
updated_ds = updated_ds.cast_column("width", Value("int16"))
updated_ds = updated_ds.cast_column("height", Value("int16"))

Casting the dataset:   0%|          | 0/5 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/5 [00:00<?, ? examples/s]

In [25]:
filtered_ds = updated_ds.filter(lambda x: x["mask_metadata"] is not None)

new_repo_id = "jordandavis/fashion_people_detections"
create_repo(
    repo_id=new_repo_id,
    repo_type="dataset",
    exist_ok=True,
)

Filter:   0%|          | 0/5 [00:00<?, ? examples/s]

In [None]:
filtered_ds.push_to_hub(new_repo_id, commit_message="md")