In [None]:
import os
import random

from grounded_sam import get_results
from PIL import Image
import numpy as np
from diffusers.utils import make_image_grid
from utils import (
    display_image_with_masks,
    unload_mask,
    unload_box,
    format_results,
    convert_mask_to_coco_polygons,
    resize_preserve_aspect_ratio,
    get_coco_style_polygons,
    convert_coco_to_yolo_polygons,
)
from sam_results import SAMResults
from tqdm.notebook import tqdm
from datasets import load_dataset

## Final

In [None]:
# detection_labels = [
#     "hair",
#     "face",
#     "neck",
#     "arm",
#     "hand",
#     "back",
#     "leg",
#     "foot",
#     "outfit",
#     "person",
#     "phone",
# ]

# # make an enumerated dictionary
# labels_dict = {label: i for i, label in enumerate(detection_labels)}
# labels_dict

In [None]:
import yaml 
config_path = "configs/fashion_people.yml"
with open(config_path, 'r') as f:
    data = yaml.load(file, Loader=yaml.FullLoader)
    
labels_dict = data.get('names')
detection_labels = [v for v in labels_dict.values()]

In [None]:
ds_id = "jordandavis/fashion"
ds = load_dataset(ds_id, streaming=False, split="train", trust_remote_code=True)
ds = ds.filter(lambda example: int(example["width"]) <= 1024)

In [None]:
def get_attributes(result):
    label = result.get('label')
    mask = result.get('mask')
    label_id = result.get('label_id')
    coco_polygons = get_coco_style_polygons(mask)
    image_width, image_height = mask.image.size
    yolo_polygons = convert_coco_to_yolo_polygons(coco_polygons, image_width, image_height)

In [None]:
image = row.get('image')
results = get_results(image, detection_labels, iou_threshold=0.9)



In [None]:
row = next(iterable)
image = row.get("image")
image.resize((256, 256))

In [None]:
image = resize_preserve_aspect_ratio(image, 1024)
results = get_results(image, detection_labels, iou_threshold=0.9)

In [None]:
r = SAMResults(image.convert("RGB"), labels_dict, **results)

r.display_results()

In [None]:
max_num = 100

for dir in tqdm([train_dir, val_dir]):
    processed = 0  # Initialize a counter for successful processes
    while processed < max_num:  # Continue looping until 4 successful processes
        try:
            row = next(iterable)
            image = row.get("image")
            image = resize_preserve_aspect_ratio(image, 1024)
            results = get_results(image, detection_labels, iou_threshold=0.9)

            r = SAMResults(image.convert("RGB"), labels_dict, **results)

            image_name = row.get("image_id")
            lines = get_lines(r)
            write_image_and_text_file(image_name, lines, dir)

            processed += 1  # Increment only if the block was successful
        except Exception as e:
            print(e)  # Log the exception and continue with the next iteration
            continue

In [None]:
dir = "/home/jordan/jd_segment_anything/datasets/person_seg"
train_dir = os.path.join(dir, 'images', 'train')
train_labels = os.path.join(dir, 'labels', 'train')

val_dir = os.path.join(dir, 'images', 'val')
val_labels = os.path.join(dir, 'labels', 'val')

os.makedirs(train_dir, exist_ok=True)
os.makedirs(val_dir, exist_ok=True)
os.makedirs(train_labels, exist_ok=True)
os.makedirs(val_labels, exist_ok=True)

In [None]:
get_coco_style_polygons

In [None]:
def get_lines(r):
    lines = []

    for result in r.formatted_results:
        label = result.get('label')
        mask = result.get('mask')
        label_id = result.get('label_id')

        coco_polygons = get_coco_style_polygons(mask)

        image_width, image_height = mask.image.size

        yolo_polygons = convert_coco_to_yolo_polygons(coco_polygons, image_width, image_height)

        polygons = convert_mask_to_yolo_polygons(mask, image_width, image_height)

        polygon_string = " ".join([str(p) for p in polygons])
        yolo_line = f"{label_id} {polygon_string}"
        lines.append(yolo_line)
    return lines

In [None]:
def write_image_and_text_file(image_name, lines, output_dir):
    image_path = os.path.join(output_dir, image_name)

    image_uuid = image_name.split('.')[0]
    text_name = f"{image_uuid}.txt"

    text_output_dir = output_dir.replace('images', 'labels')
    text_path = os.path.join(text_output_dir, text_name)

    image.convert('RGB').save(image_path)

    text_file = "\n".join(lines)
    with open(text_path, "w") as f:
        f.write(text_file)