Using YOLO Clothing Classification Model

In [1]:
!pip install gradio
!pip install ultralytics
!pip install segment-anything

Collecting gradio
  Downloading gradio-5.23.1-py3-none-any.whl.metadata (16 kB)
Collecting aiofiles<24.0,>=22.0 (from gradio)
  Downloading aiofiles-23.2.1-py3-none-any.whl.metadata (9.7 kB)
Collecting fastapi<1.0,>=0.115.2 (from gradio)
  Downloading fastapi-0.115.12-py3-none-any.whl.metadata (27 kB)
Collecting ffmpy (from gradio)
  Downloading ffmpy-0.5.0-py3-none-any.whl.metadata (3.0 kB)
Collecting gradio-client==1.8.0 (from gradio)
  Downloading gradio_client-1.8.0-py3-none-any.whl.metadata (7.1 kB)
Collecting groovy~=0.1 (from gradio)
  Downloading groovy-0.1.2-py3-none-any.whl.metadata (6.1 kB)
Collecting pydub (from gradio)
  Downloading pydub-0.25.1-py2.py3-none-any.whl.metadata (1.4 kB)
Collecting python-multipart>=0.0.18 (from gradio)
  Downloading python_multipart-0.0.20-py3-none-any.whl.metadata (1.8 kB)
Collecting ruff>=0.9.3 (from gradio)
  Downloading ruff-0.11.2-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (25 kB)
Collecting safehttpx<0.2.0,>=0.1.6 

In [2]:
!wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth

--2025-03-31 01:43:06--  https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
Resolving dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)... 3.171.22.33, 3.171.22.68, 3.171.22.118, ...
Connecting to dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)|3.171.22.33|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 2564550879 (2.4G) [binary/octet-stream]
Saving to: ‘sam_vit_h_4b8939.pth’


2025-03-31 01:43:27 (117 MB/s) - ‘sam_vit_h_4b8939.pth’ saved [2564550879/2564550879]



In [3]:
import torch
import numpy as np
import cv2
from PIL import Image
from torchvision import transforms
from ultralytics import YOLO
from segment_anything import SamPredictor, sam_model_registry
from transformers import YolosForObjectDetection, YolosImageProcessor
import gradio as gr
import os
import urllib.request

class GarmentMaskingPipeline:
    def __init__(self):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        print(f"Using device: {self.device}")
        self.yolo_model, self.sam_predictor, self.classification_model = self.load_models()

        self.clothing_to_body_parts = {
            'shirt': ['torso', 'arms'],
            't-shirt': ['torso', 'upper_arms'],
            'blouse': ['torso', 'arms'],
            'dress': ['torso', 'legs'],
            'skirt': ['lower_torso', 'legs'],
            'pants': ['legs'],
            'shorts': ['upper_legs'],
            'jacket': ['torso', 'arms'],
            'coat': ['torso', 'arms']
        }

        self.body_parts_positions = {
            'face': (0.0, 0.2),
            'torso': (0.2, 0.5),
            'arms': (0.2, 0.5),
            'upper_arms': (0.2, 0.35),
            'lower_torso': (0.4, 0.6),
            'legs': (0.5, 0.9),
            'upper_legs': (0.5, 0.7),
            'feet': (0.9, 1.0)
        }

    def load_models(self):
        print("Loading models...")
        # Download models if they don't exist
        self.download_models()

        # Load YOLO model
        yolo_model = YOLO('yolov8n.pt')

        # Load SAM model
        sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth")
        sam.to(self.device)
        predictor = SamPredictor(sam)

        # Load YOLOS-Fashionpedia model for clothing classification
        print("Loading YOLOS-Fashionpedia model...")
        model_name = "valentinafeve/yolos-fashionpedia"
        processor = YolosImageProcessor.from_pretrained(model_name)
        classification_model = YolosForObjectDetection.from_pretrained(model_name)
        classification_model.to(self.device)
        classification_model.eval()

        print("Models loaded successfully!")
        return yolo_model, predictor, classification_model

    def download_models(self):
        """Download required model files if they don't exist"""
        models = {
            "yolov8n.pt": "https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8n.pt",
            "sam_vit_h_4b8939.pth": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
        }

        for filename, url in models.items():
            if not os.path.exists(filename):
                print(f"Downloading {filename}...")
                urllib.request.urlretrieve(url, filename)
                print(f"Downloaded {filename}")
            else:
                print(f"{filename} already exists")

        # The YOLOS-Fashionpedia model will be downloaded automatically by transformers

    def classify_clothing(self, clothing_image):
        if not isinstance(clothing_image, Image.Image):
            clothing_image = Image.fromarray(clothing_image)

        # Process image with YOLOS processor
        processor = YolosImageProcessor.from_pretrained("valentinafeve/yolos-fashionpedia")
        inputs = processor(images=clothing_image, return_tensors="pt").to(self.device)

        # Run inference
        with torch.no_grad():
            outputs = self.classification_model(**inputs)

        # Process results
        target_sizes = torch.tensor([clothing_image.size[::-1]]).to(self.device)
        results = processor.post_process_object_detection(
            outputs, target_sizes=target_sizes, threshold=0.1
        )[0]

        # Extract detected labels and confidence scores
        labels = results["labels"]
        scores = results["scores"]

        # Get class names from model config
        id2label = self.classification_model.config.id2label

        # Define Fashionpedia to our category mapping
        fashionpedia_to_clothing = {
            'shirt': 'shirt',
            'blouse': 'shirt',
            'top': 't-shirt',
            't-shirt': 't-shirt',
            'sweater': 'shirt',
            'jacket': 'jacket',
            'cardigan': 'jacket',
            'coat': 'coat',
            'jumper': 'shirt',
            'dress': 'dress',
            'skirt': 'skirt',
            'shorts': 'shorts',
            'pants': 'pants',
            'jeans': 'pants',
            'leggings': 'pants',
            'jumpsuit': 'dress'
        }

        # Find the garment with highest confidence
        if len(labels) > 0:
            detections = [(id2label[label.item()].lower(), score.item())
                         for label, score in zip(labels, scores)]
            detections.sort(key=lambda x: x[1], reverse=True)

            for label, score in detections:
                # Look for clothing keywords in the label
                for keyword, category in fashionpedia_to_clothing.items():
                    if keyword in label:
                        return category

            # If no mapping found, use the first detection as is
            return 't-shirt'

        # Default to t-shirt if nothing detected
        return 't-shirt'

    def create_garment_mask(self, person_image, garment_image):
        clothing_type = self.classify_clothing(garment_image)
        parts_to_mask = self.clothing_to_body_parts.get(clothing_type, [])

        results = self.yolo_model(person_image, classes=[0])
        mask = np.zeros(person_image.shape[:2], dtype=np.uint8)

        if results and len(results[0].boxes.data) > 0:
            person_boxes = results[0].boxes.data
            person_areas = [(box[2] - box[0]) * (box[3] - box[1]) for box in person_boxes]
            largest_person_index = np.argmax(person_areas)
            person_box = person_boxes[largest_person_index][:4].cpu().numpy().astype(int)

            self.sam_predictor.set_image(person_image)
            masks, _, _ = self.sam_predictor.predict(box=person_box, multimask_output=False)
            person_mask = masks[0].astype(np.uint8)

            h, w = person_mask.shape
            for part in parts_to_mask:
                if part in self.body_parts_positions:
                    top_ratio, bottom_ratio = self.body_parts_positions[part]
                    top_px, bottom_px = int(h * top_ratio), int(h * bottom_ratio)

                    part_mask = np.zeros_like(person_mask)
                    part_mask[top_px:bottom_px, :] = 1
                    part_mask = np.logical_and(part_mask, person_mask).astype(np.uint8)

                    mask = np.logical_or(mask, part_mask).astype(np.uint8)

            # Remove face from the mask
            face_top_px, face_bottom_px = int(h * 0.0), int(h * 0.2)
            face_mask = np.zeros_like(person_mask)
            face_mask[face_top_px:face_bottom_px, :] = 1
            face_mask = np.logical_and(face_mask, person_mask).astype(np.uint8)
            mask = np.logical_and(mask, np.logical_not(face_mask)).astype(np.uint8)

            # Remove feet from the mask
            feet_top_px, feet_bottom_px = int(h * 0.9), int(h * 1.0)
            feet_mask = np.zeros_like(person_mask)
            feet_mask[feet_top_px:feet_bottom_px, :] = 1
            feet_mask = np.logical_and(feet_mask, person_mask).astype(np.uint8)
            mask = np.logical_and(mask, np.logical_not(feet_mask)).astype(np.uint8)

        return mask * 255

    def process(self, person_image_pil, garment_image_pil, mask_color_hex="#00FF00", opacity=0.5):
        """Process the input images and return the masked result"""
        # Convert PIL to numpy array
        person_image = np.array(person_image_pil)
        garment_image = np.array(garment_image_pil)

        # Convert to RGB if needed
        if person_image.shape[2] == 4:  # RGBA
            person_image = person_image[:, :, :3]
        if garment_image.shape[2] == 4:  # RGBA
            garment_image = garment_image[:, :, :3]

        # Create garment mask
        garment_mask = self.create_garment_mask(person_image, garment_image)

        # Convert hex color to RGB
        r = int(mask_color_hex[1:3], 16)
        g = int(mask_color_hex[3:5], 16)
        b = int(mask_color_hex[5:7], 16)
        color = (r, g, b)

        # Create a colored mask
        colored_mask = np.zeros_like(person_image)
        for i in range(3):
            colored_mask[:, :, i] = garment_mask * (color[i] / 255.0)

        # Create binary mask for visualization
        binary_mask = np.stack([garment_mask, garment_mask, garment_mask], axis=2)

        # Overlay mask on original image
        mask_3d = garment_mask[:, :, np.newaxis] / 255.0
        overlay = person_image * (1 - opacity * mask_3d) + colored_mask * opacity
        overlay = overlay.astype(np.uint8)

        # Get classification result
        clothing_type = self.classify_clothing(garment_image)
        parts_to_mask = self.clothing_to_body_parts.get(clothing_type, [])

        return overlay, binary_mask, f"Detected garment: {clothing_type}\nBody parts to mask: {', '.join(parts_to_mask)}"

def process_images(person_img, garment_img, mask_color, opacity):
    """Gradio processing function"""
    try:
        pipeline = GarmentMaskingPipeline()
        result = pipeline.process(person_img, garment_img, mask_color, opacity)
        return result
    except Exception as e:
        import traceback
        error_msg = f"Error processing images: {str(e)}\n{traceback.format_exc()}"
        print(error_msg)
        return None, None, error_msg

def create_gradio_interface():
    """Create and launch the Gradio interface"""
    with gr.Blocks(title="VTON SAM Garment Masking Pipeline") as interface:
        gr.Markdown("""
        # Virtual Try-On Garment Masking Pipeline with SAM and YOLOS-Fashionpedia

        Upload a person image and a garment image to generate a mask for a virtual try-on application.
        The system will:
        1. Detect the person using YOLO
        2. Create a high-quality segmentation using SAM (Segment Anything Model)
        3. Classify the garment type using YOLOS-Fashionpedia
        4. Generate a mask of the area where the garment should be placed

        **Note**: This system uses state-of-the-art AI segmentation and fashion detection models for accurate results.
        """)

        with gr.Row():
            with gr.Column():
                person_input = gr.Image(label="Person Image (Image A)", type="pil")
                garment_input = gr.Image(label="Garment Image (Image B)", type="pil")

                with gr.Row():
                    mask_color = gr.ColorPicker(label="Mask Color", value="#00FF00")
                    opacity = gr.Slider(label="Mask Opacity", minimum=0.1, maximum=0.9, value=0.5, step=0.1)

                submit_btn = gr.Button("Generate Mask")

            with gr.Column():
                masked_output = gr.Image(label="Person with Masked Region")
                mask_output = gr.Image(label="Standalone Mask")
                result_text = gr.Textbox(label="Detection Results", lines=3)

        # Set up the processing flow
        submit_btn.click(
            fn=process_images,
            inputs=[person_input, garment_input, mask_color, opacity],
            outputs=[masked_output, mask_output, result_text]
        )

        gr.Markdown("""
        ## How It Works

        1. **Person Detection**: Uses YOLO to detect and locate the person in the image
        2. **Segmentation**: Uses SAM (Segment Anything Model) to create a high-quality segmentation mask
        3. **Garment Classification**: Uses YOLOS-Fashionpedia to identify the garment type with fashion-specific detection
        4. **Mask Generation**: Creates a mask based on the garment type and body part mapping

        ## Supported Garment Types

        - Shirts, Blouses, Tops, and T-shirts
        - Sweaters and Cardigans
        - Dresses and Jumpsuits
        - Skirts
        - Pants, Jeans, and Leggings
        - Shorts
        - Jackets and Coats

        """)

    return interface

if __name__ == "__main__":
    # Create and launch the Gradio interface
    interface = create_gradio_interface()
    interface.launch(debug=True)

Creating new Ultralytics Settings v0.0.6 file ✅ 
View Ultralytics Settings with 'yolo settings' or at '/root/.config/Ultralytics/settings.json'
Update Settings with 'yolo settings key=value', i.e. 'yolo settings runs_dir=path/to/dir'. For help see https://docs.ultralytics.com/quickstart/#ultralytics-settings.
Running Gradio in a Colab notebook requires sharing enabled. Automatically setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).

Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().
* Running on public URL: https://1989db4bcc0ea01106.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


Using device: cpu
Loading models...
Downloading yolov8n.pt...
Downloaded yolov8n.pt
sam_vit_h_4b8939.pth already exists
Loading YOLOS-Fashionpedia model...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


preprocessor_config.json:   0%|          | 0.00/275 [00:00<?, ?B/s]

The `max_size` parameter is deprecated and will be removed in v4.26. Please specify in `size['longest_edge'] instead`.


config.json:   0%|          | 0.00/2.69k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/123M [00:00<?, ?B/s]

Models loaded successfully!


model.safetensors:   0%|          | 0.00/123M [00:00<?, ?B/s]


0: 640x448 1 person, 997.0ms
Speed: 52.1ms preprocess, 997.0ms inference, 63.2ms postprocess per image at shape (1, 3, 640, 448)
Using device: cpu
Loading models...
yolov8n.pt already exists
sam_vit_h_4b8939.pth already exists
Loading YOLOS-Fashionpedia model...
Models loaded successfully!

0: 640x448 1 person, 271.0ms
Speed: 8.5ms preprocess, 271.0ms inference, 3.5ms postprocess per image at shape (1, 3, 640, 448)
Keyboard interruption in main thread... closing server.
Killing tunnel 127.0.0.1:7860 <> https://1989db4bcc0ea01106.gradio.live
