### Init

In [1]:
import os

import torch
from torch.utils.data import DataLoader, Dataset
from tqdm.notebook import tqdm

from datasets import load_dataset
from scripts.sam_results import SAMResults
from utils import (get_coco_style_polygons, pad_to_fixed_size,
                   resize_preserve_aspect_ratio)

In [2]:
from grounded_sam import (
    run_grounded_sam_batch,
    transform_image_dino,
    transform_image_sam,
)

final text_encoder_type: bert-base-uncased
_IncompatibleKeys(missing_keys=[], unexpected_keys=['label_enc.weight', 'bert.embeddings.position_ids'])


In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

### Config

In [4]:
import yaml


def load_yaml(path):
    with open(path, "r") as file:
        data = yaml.load(file, Loader=yaml.FullLoader)
    return data

def get_labels_dict(config_path):
    data = load_yaml(config_path)
    labels_dict = data.get("names")
    labels_dict = {v: k for k, v in labels_dict.items()}
    return labels_dict

config_path = "configs/fashion_people.yml"
data = load_yaml(config_path)
labels_dict = get_labels_dict(config_path)

In [16]:
def get_masks_md(results):
    results_list = []

    for result in results.formatted_results:
        mask = result.get("mask")
        coco_polygons = get_coco_style_polygons(mask)

        # format the polygons
        result.update({"polygons": coco_polygons})
        result.pop("mask")
        results_list.append(result)
    return results_list

In [6]:
labels = [k for k, v in labels_dict.items()]
text_prompt = " . ".join(labels)
text_prompt

'hair . face . neck . arm . hand . back . leg . foot . outfit . person . phone'

### Dataset & Dataloader

In [7]:
global cache_dir
cache_dir = "../hf_datasets"

In [8]:
def resize_image_pil(image_pil):
    if image_pil.mode != "RGB":
        image_pil = image_pil.convert("RGB")
    image_pil = resize_preserve_aspect_ratio(image_pil, 1024)
    image_pil = pad_to_fixed_size(image_pil, (1024, 1024))
    return image_pil


class Segmentation(Dataset):
    def __init__(self, dataset_id=None, image_col="image", image_id_col=None):
        self.ds = load_dataset(
            dataset_id, split="train", trust_remote_code=True, cache_dir=cache_dir, num_proc=os.cpu_count()
        )
        self.image_col = image_col
        self.image_id_col = image_id_col
        self.imgsz = 1024

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, idx):
        item = self.ds[idx]

        # Get Image ID defaults to index
        image_id = item.get(self.image_id_col, idx)

        # Get PIL Image
        image_pil = item[self.image_col]
        if image_pil.size[0] != self.imgsz or image_pil.size[1] != self.imgsz:
            image_pil = resize_image_pil(image_pil)


        # Process dino image
        dino_image = transform_image_dino(image_pil)

        # Process sam image
        sam_image = transform_image_sam(image_pil)

        return {"image_id": image_id, "dino_image": dino_image, "sam_image": sam_image, "image_pil": image_pil}

In [9]:
# Enter the dataset ID and load it as a torch dataset
dataset_id = "jordandavis/fashion_num_people"
ds = Segmentation(dataset_id=dataset_id, image_col="image", image_id_col=None)

Resolving data files:   0%|          | 0/383 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/383 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/374 [00:00<?, ?it/s]

In [10]:
# ds.ds = ds.ds.filter(lambda x: x['num_people'] > 0)
a = ds.ds['num_people']
b = [i for i, x in enumerate(a) if x > 0]
c = ds.ds.select(b)
ds.ds = c

In [11]:
ds.ds = ds.ds.shuffle(seed=42)

In [12]:
ds.ds = ds.ds.take(2000)

In [13]:
def resize_image(examples):
    # Check if examples['image'] is a list of images or a single image
    if isinstance(examples['image'], list):
        # If it's a list, process each image
        examples['image'] = [resize_image_pil(image) for image in examples['image']]
    else:
        # If it's a single image, process the image directly
        examples['image'] = resize_image_pil(examples['image'])
    return examples

ds.ds = ds.ds.map(resize_image, batched=True)

Map:   0%|          | 0/2000 [00:00<?, ? examples/s]

In [None]:
a

In [14]:
# Dataloader
def collate_fn(ex):
    dino_images = torch.stack([e["dino_image"] for e in ex])
    sam_images = torch.stack([e["sam_image"] for e in ex])
    image_ids = [e["image_id"] for e in ex]
    pil_images = [e["image_pil"] for e in ex]
    return dict(image_ids=image_ids, dino_images=dino_images, sam_images=sam_images, pil_images=pil_images)


batch_size = 8
num_workers = os.cpu_count()

dataloader = DataLoader(
    ds,
    collate_fn=collate_fn,
    batch_size=batch_size,
    num_workers=num_workers,
    pin_memory=True,
    shuffle=False,
)

### Run Inference

In [18]:
os.environ["TOKENIZERS_PARALLELISM"] = "true"
mask_metadata = {}
masks_md = []
with tqdm(total=len(dataloader)) as pbar:
    for batch in dataloader:
        image_ids = batch.get("image_ids")
        images = batch.get("pil_images")

        with torch.no_grad():
            dino_images = batch.get("dino_images").to(device)
            sam_images = batch.get("sam_images").to(device)
            raw_results = run_grounded_sam_batch(dino_images, sam_images, text_prompt)

        for image_id, image, raw_result in zip(image_ids, images, raw_results):
            if raw_result.get('masks') is None or ('person' not in raw_result.get("phrases")):
                # mask_md_row = dict(zip(str(image_id), [None]))
                mask_md = None
            else:
                result = SAMResults(image, labels_dict, **raw_result)
                mask_md = get_masks_md(result)

            # mask_metadata.update(mask_md_row)
            masks_md.append(mask_md)
        pbar.update(1)

  0%|          | 0/250 [00:00<?, ?it/s]

In [19]:
import gc
gc.collect() 
torch.cuda.empty_cache()

In [20]:
import json
from datasets import Value
from huggingface_hub import create_repo
updated_ds = ds.ds
updated_ds = updated_ds.add_column("mask_metadata", masks_md)
updated_ds = updated_ds.cast_column("width", Value("int16"))
updated_ds = updated_ds.cast_column("height", Value("int16"))

Casting the dataset:   0%|          | 0/2000 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/2000 [00:00<?, ? examples/s]

In [21]:
# Filter out rows with no mask_metadata
a = updated_ds['mask_metadata']
b = [i for i, x in enumerate(a) if x is not None]
updated_ds= updated_ds.select(b)
len(updated_ds)

456873

In [23]:
new_repo_id = "jordandavis/fashion_people_detections"
# create_repo(
#     repo_id=new_repo_id,
#     repo_type="dataset",
#     exist_ok=True,
# )

In [24]:
updated_ds.push_to_hub(new_repo_id, commit_message="md")

Uploading the dataset shards:   0%|          | 0/4 [00:00<?, ?it/s]

Map:   0%|          | 0/467 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/5 [00:00<?, ?ba/s]

Map:   0%|          | 0/466 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/5 [00:00<?, ?ba/s]

Map:   0%|          | 0/466 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/5 [00:00<?, ?ba/s]

Map:   0%|          | 0/466 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/5 [00:00<?, ?ba/s]

CommitInfo(commit_url='https://huggingface.co/datasets/jordandavis/fashion_people_detections/commit/f6e7005ab0520b514c7304f165b7b6ec8290c072', commit_message='md', commit_description='', oid='f6e7005ab0520b514c7304f165b7b6ec8290c072', pr_url=None, pr_revision=None, pr_num=None)