In [1]:
import cv2
import numpy as np
import torch
import torchvision.transforms as transforms
from PIL import Image
import matplotlib.pyplot as plt
import json
import os
from typing import List, Dict, Tuple

# Install required packages:
# pip install torch torchvision opencv-python pillow matplotlib detectron2

class DeepFashion2Segmentor:
    def __init__(self, model_path: str = None):
        """
        Initialize DeepFashion2 segmentation model
        
        Args:
            model_path: Path to pre-trained DeepFashion2 model
        """
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model = None
        self.categories = self._get_fashion_categories()
        
        # Load model if path provided
        if model_path:
            self.load_model(model_path)
    
    def _get_fashion_categories(self) -> Dict[int, str]:
        """Define DeepFashion2 clothing categories"""
        return {
            1: 'short_sleeve_top',
            2: 'long_sleeve_top', 
            3: 'short_sleeve_outwear',
            4: 'long_sleeve_outwear',
            5: 'vest',
            6: 'sling',
            7: 'shorts',
            8: 'trousers',
            9: 'skirt',
            10: 'short_sleeve_dress',
            11: 'long_sleeve_dress',
            12: 'vest_dress',
            13: 'sling_dress'
        }
    
    def setup_detectron2_model(self):
        """
        Setup DeepFashion2 using Detectron2 framework
        Note: You'll need to download the pre-trained weights
        """
        try:
            from detectron2.engine import DefaultPredictor
            from detectron2.config import get_cfg
            from detectron2.model_zoo import model_zoo
            
            cfg = get_cfg()
            # Use a base model and adapt for fashion
            cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
            
            # Update for DeepFashion2 - you'll need to download the actual weights
            cfg.MODEL.WEIGHTS = "path_to_deepfashion2_weights.pth"  # Update this path
            cfg.MODEL.ROI_HEADS.NUM_CLASSES = 13  # DeepFashion2 has 13 clothing categories
            cfg.MODEL.DEVICE = str(self.device)
            
            self.model = DefaultPredictor(cfg)
            return True
            
        except ImportError:
            print("Detectron2 not installed. Please install it for full functionality.")
            return False
    
    def load_model(self, model_path: str):
        """Load pre-trained model"""
        if not os.path.exists(model_path):
            print(f"Model file not found: {model_path}")
            return False
            
        try:
            self.model = torch.load(model_path, map_location=self.device)
            self.model.eval()
            return True
        except Exception as e:
            print(f"Error loading model: {e}")
            return False
    
    def preprocess_image(self, image_path: str) -> np.ndarray:
        """
        Preprocess image for model input
        
        Args:
            image_path: Path to input image
            
        Returns:
            Preprocessed image array
        """
        # Load image
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        # Resize while maintaining aspect ratio
        height, width = image.shape[:2]
        max_size = 800
        
        if max(height, width) > max_size:
            scale = max_size / max(height, width)
            new_width = int(width * scale)
            new_height = int(height * scale)
            image = cv2.resize(image, (new_width, new_height))
        
        return image
    
    def segment_outfit(self, image_path: str) -> Dict:
        """
        Segment clothing items in the outfit
        
        Args:
            image_path: Path to input image
            
        Returns:
            Dictionary containing detected items and their masks
        """
        if not self.model:
            print("Model not loaded. Please load a model first.")
            return {}
        
        # Preprocess image
        image = self.preprocess_image(image_path)
        
        # For demonstration, we'll simulate the segmentation
        # In a real implementation, you'd run inference here
        results = self._mock_segmentation(image)
        
        return results
    
    def _mock_segmentation(self, image: np.ndarray) -> Dict:
        """
        Mock segmentation for demonstration purposes
        Replace this with actual model inference
        """
        height, width = image.shape[:2]
        
        # Simulate detection results
        mock_results = {
            'detections': [
                {
                    'category_id': 1,  # short_sleeve_top
                    'category_name': 'short_sleeve_top',
                    'bbox': [width//4, height//8, width//2, height//3],  # [x, y, w, h]
                    'score': 0.95,
                    'mask': np.zeros((height, width), dtype=np.uint8)  # Placeholder mask
                },
                {
                    'category_id': 8,  # trousers
                    'category_name': 'trousers', 
                    'bbox': [width//3, height//2, width//3, height//2],
                    'score': 0.92,
                    'mask': np.zeros((height, width), dtype=np.uint8)
                }
            ],
            'image_shape': (height, width, 3)
        }
        
        return mock_results
    
    def extract_clothing_items(self, image_path: str, output_dir: str = "extracted_items") -> List[str]:
        """
        Extract individual clothing items from outfit image
        
        Args:
            image_path: Path to input outfit image
            output_dir: Directory to save extracted items
            
        Returns:
            List of paths to extracted item images
        """
        # Create output directory
        os.makedirs(output_dir, exist_ok=True)
        
        # Load original image
        original_image = cv2.imread(image_path)
        original_image = cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB)
        
        # Get segmentation results
        results = self.segment_outfit(image_path)
        
        extracted_items = []
        
        if 'detections' in results:
            for i, detection in enumerate(results['detections']):
                # Extract bounding box coordinates
                x, y, w, h = detection['bbox']
                
                # Crop item from original image
                item_image = original_image[y:y+h, x:x+w]
                
                # Save extracted item
                item_name = f"{detection['category_name']}_{i+1}.jpg"
                item_path = os.path.join(output_dir, item_name)
                
                item_pil = Image.fromarray(item_image)
                item_pil.save(item_path)
                
                extracted_items.append(item_path)
                
                print(f"Extracted {detection['category_name']} with confidence {detection['score']:.2f}")
        
        return extracted_items
    
    def visualize_segmentation(self, image_path: str, save_path: str = None):
        """
        Visualize segmentation results
        
        Args:
            image_path: Path to input image
            save_path: Path to save visualization (optional)
        """
        # Load image
        image = self.preprocess_image(image_path)
        results = self.segment_outfit(image_path)
        
        # Create visualization
        plt.figure(figsize=(12, 8))
        plt.imshow(image)
        
        # Draw bounding boxes and labels
        if 'detections' in results:
            for detection in results['detections']:
                x, y, w, h = detection['bbox']
                
                # Draw bounding box
                rect = plt.Rectangle((x, y), w, h, fill=False, color='red', linewidth=2)
                plt.gca().add_patch(rect)
                
                # Add label
                label = f"{detection['category_name']}\n{detection['score']:.2f}"
                plt.text(x, y-10, label, color='red', fontsize=10, 
                        bbox=dict(boxstyle="round,pad=0.3", facecolor='white', alpha=0.8))
        
        plt.title("DeepFashion2 Outfit Segmentation")
        plt.axis('off')
        
        if save_path:
            plt.savefig(save_path, bbox_inches='tight', dpi=300)
        plt.show()


def download_deepfashion2_model():
    """
    Helper function to download DeepFashion2 model
    Note: You'll need to implement the actual download logic
    """
    print("Please download the DeepFashion2 model from:")
    print("https://github.com/switchablenorms/DeepFashion2")
    print("Or use the official model zoo links provided in their repository")


# Example usage
def main():
    # Initialize segmentor
    segmentor = DeepFashion2Segmentor()
    
    # Setup model (you'll need to download weights first)
    model_loaded = segmentor.setup_detectron2_model()
    
    if not model_loaded:
        print("Setting up mock segmentation for demonstration...")
    
    # Example image path (replace with your image)
    image_path = "path_to_your_outfit_image.jpg"
    
    # Extract clothing items
    try:
        extracted_items = segmentor.extract_clothing_items(image_path)
        print(f"Extracted {len(extracted_items)} clothing items:")
        for item in extracted_items:
            print(f"  - {item}")
        
        # Visualize results
        segmentor.visualize_segmentation(image_path, "segmentation_result.jpg")
        
    except Exception as e:
        print(f"Error processing image: {e}")
        print("Make sure the image path is correct and the image exists")


if __name__ == "__main__":
    main()

Detectron2 not installed. Please install it for full functionality.
Setting up mock segmentation for demonstration...
Error processing image: OpenCV(4.12.0) D:\a\opencv-python\opencv-python\opencv\modules\imgproc\src\color.cpp:199: error: (-215:Assertion failed) !_src.empty() in function 'cv::cvtColor'

Make sure the image path is correct and the image exists


In [None]:
# Initialize
segmentor = DeepFashion2Segmentor()
segmentor.setup_detectron2_model()

# Process outfit image
items = segmentor.extract_clothing_items("my_outfit.jpg")
segmentor.visualize_segmentation("my_outfit.jpg")