# 🎭 Virtual Fashion Try-On System

This notebook demonstrates a complete virtual try-on system that combines human parsing with AI-based garment generation. The system segments human body parts and replaces clothing regions with AI-generated alternatives based on text descriptions.

## 1. Imports and Dependencies

Loading all necessary libraries for deep learning, image processing, web requests, and visualization components needed for the virtual try-on pipeline.

In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from PIL import Image
import cv2
import matplotlib.pyplot as plt
from datetime import datetime
import requests
from typing import Tuple, Any

## 2. Configuration Settings

Central configuration class containing all system parameters, paths, model settings, and test prompts for the virtual try-on system.

In [None]:
class Config:
    """Configuration class to store all settings"""
    
    # Device and paths
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    MODEL_PATH = "/kaggle/input/human/pytorch/default/1/best_model.pth"
    OUTPUT_DIR = "virtual_tryon_results"
    
    # Image settings
    IMAGE_SIZE = (512, 512)
    TEST_IMAGE_URL = "https://images.unsplash.com/photo-1506794778202-cad84cf45f1d?q=80&w=687&auto=format&fit=crop&ixlib=rb-4.1.0&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D"
    
    # Model settings
    NUM_CLASSES = 18
    CLOTHING_CLASSES = [4, 5, 7]  # Upper-clothes, Skirt, Dress
    
    # Diffusion settings
    DIFFUSION_MODEL = "stabilityai/stable-diffusion-2-inpainting"
    NUM_INFERENCE_STEPS = 25
    GUIDANCE_SCALE = 8.0
    
    # Test prompts
    TEST_PROMPTS = [
        "Men's Nick Standard Fit T-Shirt",
        "casual blue denim jacket with silver buttons"
    ]
    
    @classmethod
    def setup(cls):
        """Setup configuration and create directories"""
        os.makedirs(cls.OUTPUT_DIR, exist_ok=True)
        return cls

## 3. ASPP (Atrous Spatial Pyramid Pooling) Module

Implementation of the ASPP module that captures multi-scale contextual information through parallel atrous convolutions with different dilation rates.

In [None]:
class ASPP(nn.Module):
    """Atrous Spatial Pyramid Pooling module"""
    
    def __init__(self, in_channels: int, out_channels: int, rates: Tuple[int, ...] = (6, 12, 18)):
        super().__init__()
        
        # 1x1 convolution
        self.conv1x1 = self._make_branch(in_channels, out_channels, kernel_size=1)
        
        # Atrous convolutions
        self.atrous_branches = nn.ModuleList([
            self._make_branch(in_channels, out_channels, kernel_size=3, dilation=rate)
            for rate in rates
        ])
        
        # Global average pooling branch
        self.global_pool = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(in_channels, out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
        
        # Projection layer
        num_branches = len(rates) + 2  # Atrous + 1x1 + global
        self.projection = nn.Sequential(
            nn.Conv2d(out_channels * num_branches, out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5)
        )
    
    def _make_branch(self, in_channels: int, out_channels: int, 
                     kernel_size: int, dilation: int = 1) -> nn.Sequential:
        """Create a convolutional branch"""
        padding = 0 if kernel_size == 1 else dilation
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, 
                     padding=padding, dilation=dilation, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Collect features from all branches
        features = [self.conv1x1(x)]
        features.extend([branch(x) for branch in self.atrous_branches])
        
        # Global pooling branch
        global_feat = self.global_pool(x)
        global_feat = F.interpolate(global_feat, size=x.shape[-2:], 
                                   mode="bilinear", align_corners=False)
        features.append(global_feat)
        
        # Concatenate and project
        concatenated = torch.cat(features, dim=1)
        return self.projection(concatenated)

## 4. Self-Correction Module

Edge-aware self-correction module that refines segmentation predictions by detecting edges and using this information to improve boundary accuracy.

In [None]:
class SelfCorrectionModule(nn.Module):
    """Self-correction module with edge awareness"""
    
    def __init__(self, in_channels: int, num_classes: int):
        super().__init__()
        
        # Edge detection branch
        self.edge_branch = nn.Sequential(
            nn.Conv2d(in_channels, 256, 3, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 1, 1)  # Output edge logits
        )
        
        # Refinement branch
        self.refinement_branch = nn.Sequential(
            nn.Conv2d(in_channels + 1, 256, 3, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, num_classes, 1)
        )
    
    def forward(self, features: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        # Predict edges
        edge_logits = self.edge_branch(features)
        
        # Concatenate features with edge information
        enhanced_features = torch.cat([features, edge_logits], dim=1)
        
        # Refine predictions
        refined_logits = self.refinement_branch(enhanced_features)
        
        return refined_logits, edge_logits

## 5. Human Parser Network

Complete human parsing network combining ResNet101 backbone, ASPP module, and self-correction for accurate human body part segmentation.

In [None]:
class HumanParser(nn.Module):
    """Main Human Parsing Network with Self-Correction"""
    
    def __init__(self, num_classes: int = None):
        super().__init__()
        
        if num_classes is None:
            num_classes = Config.NUM_CLASSES
        
        # Load pretrained ResNet101 backbone
        import torchvision.models as models
        backbone = models.resnet101(pretrained=True)
        
        # Extract backbone layers
        self.initial_layers = nn.Sequential(*list(backbone.children())[:5])  # Conv1 -> Layer1
        self.layer2 = nn.Sequential(*list(backbone.children())[5])
        self.layer3 = nn.Sequential(*list(backbone.children())[6])
        self.layer4 = nn.Sequential(*list(backbone.children())[7])
        
        # Low-level feature processing
        self.low_level_conv = nn.Sequential(
            nn.Conv2d(256, 48, 1, bias=False),
            nn.BatchNorm2d(48),
            nn.ReLU(inplace=True)
        )
        
        # ASPP module
        self.aspp = ASPP(2048, 256)
        
        # Decoder
        self.decoder = nn.Sequential(
            nn.Conv2d(256 + 48, 256, 3, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Conv2d(256, 256, 3, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True)
        )
        
        # Output heads
        self.coarse_head = nn.Conv2d(256, num_classes, 1)
        self.self_correction = SelfCorrectionModule(256, num_classes)
    
    def forward(self, x: torch.Tensor) -> Any:
        input_shape = x.shape[-2:]
        
        # Backbone forward pass
        low_level = self.initial_layers(x)  # 256 channels
        x = self.layer2(low_level)          # 512 channels
        x = self.layer3(x)                  # 1024 channels
        x = self.layer4(x)                  # 2048 channels
        
        # Process low-level features
        low_level_features = self.low_level_conv(low_level)
        
        # ASPP
        aspp_features = self.aspp(x)
        
        # Upsample and concatenate
        aspp_features = F.interpolate(aspp_features, size=low_level_features.shape[-2:],
                                      mode="bilinear", align_corners=False)
        decoder_input = torch.cat([aspp_features, low_level_features], dim=1)
        
        # Decode
        decoder_features = self.decoder(decoder_input)
        
        # Generate outputs
        coarse_logits = self.coarse_head(decoder_features)
        refined_logits, edge_logits = self.self_correction(decoder_features)
        
        # Upsample to input resolution
        coarse_logits = F.interpolate(coarse_logits, size=input_shape,
                                      mode="bilinear", align_corners=False)
        refined_logits = F.interpolate(refined_logits, size=input_shape,
                                       mode="bilinear", align_corners=False)
        edge_logits = F.interpolate(edge_logits, size=input_shape,
                                    mode="bilinear", align_corners=False)
        
        if self.training:
            return coarse_logits, refined_logits, edge_logits
        else:
            return refined_logits

## 6. Simple Virtual Try-On System Class

Main system class that orchestrates the complete virtual try-on pipeline by integrating human parsing and diffusion models.

In [None]:
class SimpleVirtualTryOn:
    def __init__(self, config=None):
        self.config = config or Config.setup()
        self.device = torch.device(self.config.DEVICE)
        
        print("SIMPLE VIRTUAL TRY-ON SYSTEM")
        print("=" * 50)
        print(f"Device: {self.device}")
        print(f"Output Directory: {self.config.OUTPUT_DIR}")
        print(f"Model Path: {self.config.MODEL_PATH}")
        
        # Step 1: Load Human Parsing Model
        self.load_human_parser()
        
        # Step 2: Load Diffusion Model
        self.load_diffusion_model()
        
        print("=" * 50)
        print("SYSTEM READY!")
        print("=" * 50)

## 7. Model Loading Methods

Methods for loading and initializing the human parsing model and the diffusion model with proper error handling and fallback options.

In [None]:
    def load_human_parser(self):
        """Load trained Human Parsing Model with confirmation"""
        print("\nStep 1: Loading Trained Human Parsing Model...")
        try:
            # Create model
            self.parser = HumanParser(num_classes=self.config.NUM_CLASSES)
            
            # Load trained weights
            if os.path.exists(self.config.MODEL_PATH):
                print(f"Loading weights from: {self.config.MODEL_PATH}")
                checkpoint = torch.load(self.config.MODEL_PATH, map_location=self.device)
                
                # Handle different checkpoint formats
                if 'model' in checkpoint:
                    state_dict = checkpoint['model']
                elif 'state_dict' in checkpoint:
                    state_dict = checkpoint['state_dict']
                else:
                    state_dict = checkpoint
                
                # Load state dict
                self.parser.load_state_dict(state_dict, strict=False)
                print("Successfully loaded trained weights!")
            else:
                print(f"Warning: Model file not found at {self.config.MODEL_PATH}")
                print("Using initialized model without trained weights")
            
            self.parser.to(self.device)
            self.parser.eval()
            print("✓ Human Parsing Model loaded successfully!")
            self.parser_loaded = True
            
        except Exception as e:
            print(f"✗ Failed to load Human Parsing Model: {e}")
            self.parser_loaded = False
    
    def load_diffusion_model(self):
        """Load Diffusion Model with confirmation"""
        print("\nStep 2: Loading Diffusion Model...")
        try:
            from diffusers import StableDiffusionInpaintPipeline
            
            print(f"Loading {self.config.DIFFUSION_MODEL}...")
            self.pipe = StableDiffusionInpaintPipeline.from_pretrained(
                self.config.DIFFUSION_MODEL,
                torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
                safety_checker=None,
                requires_safety_checker=False
            ).to(self.device)
            print("✓ Diffusion Model loaded successfully!")
            self.diffusion_loaded = True
            
        except Exception as e:
            print(f"✗ Failed to load Diffusion Model: {e}")
            print("Using fallback generation method...")
            self.diffusion_loaded = False

## 8. Image Processing Methods

Core methods for downloading test images, segmenting human body parts, and creating masks for clothing regions.

In [None]:
    def download_test_image(self):
        """Download test image with confirmation"""
        print("\nStep 3: Downloading Test Image...")
        image_path = os.path.join(self.config.OUTPUT_DIR, "test_image.jpg")
        
        try:
            print(f"Downloading from: {self.config.TEST_IMAGE_URL}")
            response = requests.get(self.config.TEST_IMAGE_URL)
            response.raise_for_status()
            with open(image_path, 'wb') as f:
                f.write(response.content)
            print("✓ Test image downloaded successfully!")
            return image_path
        except Exception as e:
            print(f"✗ Failed to download test image: {e}")
            return None
    
    def segment_human(self, image):
        """Segment human body parts using trained model"""
        if not self.parser_loaded:
            return self.create_simple_mask(image)
        
        # Preprocess image
        if isinstance(image, Image.Image):
            image_np = np.array(image)
        else:
            image_np = image
        
        image_resized = cv2.resize(image_np, self.config.IMAGE_SIZE)
        image_tensor = torch.from_numpy(image_resized).permute(2, 0, 1).float() / 255.0
        image_tensor = image_tensor.unsqueeze(0).to(self.device)
        
        try:
            with torch.no_grad():
                output = self.parser(image_tensor)
                segmentation = output.argmax(dim=1)[0].cpu().numpy()
            
            # Resize back to original size
            if image_np.shape[:2] != self.config.IMAGE_SIZE:
                segmentation = cv2.resize(segmentation, (image_np.shape[1], image_np.shape[0]), 
                                        interpolation=cv2.INTER_NEAREST)
            
            return segmentation.astype(np.uint8)
        except Exception as e:
            print(f"Warning: Model inference failed: {e}")
            return self.create_simple_mask(image_np)
    
    def create_simple_mask(self, image):
        """Create simple body mask when parsing fails"""
        if isinstance(image, Image.Image):
            h, w = image.size[1], image.size[0]
        else:
            h, w = image.shape[:2]
        
        mask = np.zeros((h, w), dtype=np.uint8)
        # Create upper body region (class 4 = upper clothes)
        y_start, y_end = int(h * 0.2), int(h * 0.6)
        x_center = w // 2
        x_width = int(w * 0.3)
        mask[y_start:y_end, x_center-x_width:x_center+x_width] = 4
        return mask
    
    def create_clothing_mask(self, segmentation):
        """Create mask for clothing regions"""
        mask = np.zeros_like(segmentation, dtype=np.uint8)
        for cls in self.config.CLOTHING_CLASSES:
            mask[segmentation == cls] = 255
        
        # Smooth the mask
        mask = cv2.GaussianBlur(mask, (21, 21), 0)
        return mask

## 9. Clothing Generation Methods

Methods for generating new clothing using diffusion models with fallback color-based generation when models are unavailable.

In [None]:
    def generate_clothing(self, image, mask, prompt):
        """Generate new clothing"""
        if self.diffusion_loaded:
            try:
                enhanced_prompt = f"{prompt}, high quality, detailed clothing, fashion photography"
                
                result = self.pipe(
                    prompt=enhanced_prompt,
                    image=image,
                    mask_image=Image.fromarray(mask),
                    num_inference_steps=self.config.NUM_INFERENCE_STEPS,
                    guidance_scale=self.config.GUIDANCE_SCALE,
                ).images[0]
                
                return result
            except Exception as e:
                print(f"Warning: Diffusion generation failed: {e}")
                return self.simple_color_generation(image, mask, prompt)
        else:
            return self.simple_color_generation(image, mask, prompt)
    
    def simple_color_generation(self, image, mask, prompt):
        """Fallback color-based generation"""
        result = image.copy()
        result_np = np.array(result)
        
        # Extract color from prompt
        color_map = {
            'red': (200, 50, 50), 'blue': (50, 50, 200), 'green': (50, 200, 50),
            'black': (50, 50, 50), 'white': (230, 230, 230), 'yellow': (200, 200, 50),
            'purple': (150, 50, 150), 'pink': (255, 150, 150), 'brown': (139, 69, 19)
        }
        
        color = (100, 100, 150)  # Default
        for color_name, color_value in color_map.items():
            if color_name.lower() in prompt.lower():
                color = color_value
                break
        
        # Apply color to masked area
        mask_bool = mask > 128
        result_np[mask_bool] = color
        
        return Image.fromarray(result_np.astype(np.uint8))
    
    def blend_images(self, original, generated, mask):
        """Blend original and generated images"""
        original_np = np.array(original)
        generated_np = np.array(generated)
        
        # Normalize mask
        mask_norm = mask.astype(np.float32) / 255.0
        mask_3d = np.stack([mask_norm] * 3, axis=2)
        
        # Blend with smooth transition
        result = generated_np * mask_3d + original_np * (1 - mask_3d)
        
        return Image.fromarray(result.astype(np.uint8))

## 10. Pipeline Processing Methods

Complete pipeline methods that process individual prompts and orchestrate the entire try-on workflow from segmentation to final result.

In [None]:
    def process_single_prompt(self, original_image, prompt):
        """Process a single prompt and return results"""
        print(f"\n→ Processing: '{prompt}'")
        
        # Segment human body
        print("  - Segmenting human body...")
        segmentation = self.segment_human(original_image)
        
        # Create clothing mask
        print("  - Creating target area mask...")
        clothing_mask = self.create_clothing_mask(segmentation)
        
        # Generate new clothing
        print("  - Generating new clothing...")
        generated_clothing = self.generate_clothing(original_image, clothing_mask, prompt)
        
        # Create final result
        print("  - Blending final result...")
        final_result = self.blend_images(original_image, generated_clothing, clothing_mask)
        
        return {
            'original': original_image,
            'segmentation': segmentation,
            'mask': clothing_mask,
            'generated': generated_clothing,
            'result': final_result,
            'prompt': prompt
        }
    
    def run_demo(self):
        """Run complete demo with both test prompts"""
        print(f"\nRUNNING DEMO WITH {len(self.config.TEST_PROMPTS)} PROMPTS")
        print("=" * 50)
        
        # Download test image
        image_path = self.download_test_image()
        if not image_path:
            print("Cannot continue without test image")
            return
        
        # Load and resize image
        original_image = Image.open(image_path).convert('RGB')
        original_image = original_image.resize(self.config.IMAGE_SIZE)
        
        print("\nStep 4: Processing Images...")
        
        all_results = []
        
        # Process each test prompt
        for i, prompt in enumerate(self.config.TEST_PROMPTS, 1):
            print(f"\nExample {i}/{len(self.config.TEST_PROMPTS)}:")
            result = self.process_single_prompt(original_image, prompt)
            all_results.append(result)
            
            # Create individual visualization
            self.visualize_single_result(result)
        
        print("\n✓ All demos completed successfully!")
        return all_results

## 11. Visualization Method

Method for creating comprehensive visualizations showing the complete try-on pipeline from original image to final result.

In [None]:
def visualize_single_result(self, result):
    """Create visualization for a single result"""
    fig, axes = plt.subplots(1, 4, figsize=(16, 4))
    
    # Original Image
    axes[0].imshow(result['original'])
    axes[0].set_title('Original Image', fontsize=12, fontweight='bold')
    axes[0].axis('off')
    
    # Body Segmentation
    axes[1].imshow(result['segmentation'], cmap='tab20')
    axes[1].set_title('Body Segmentation', fontsize=12, fontweight='bold')
    axes[1].axis('off')
    
    # Target Area
    axes[2].imshow(result['original'])
    axes[2].imshow(result['mask'], alpha=0.5, cmap='Reds')
    axes[2].set_title('Target Area', fontsize=12, fontweight='bold')
    axes[2].axis('off')
    
    # Final Result
    axes[3].imshow(result['result'])
    axes[3].set_title('Virtual Try-On Result', fontsize=12, fontweight='bold')
    axes[3].axis('off')
    
    # Set main title as the prompt
    plt.suptitle(f'"{result["prompt"]}"', fontsize=16, fontweight='bold')
    plt.tight_layout()
    
    # Save visualization
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    safe_prompt = "".join(c for c in result['prompt'][:30] if c.isalnum() or c in (' ', '-', '_')).rstrip()
    save_path = os.path.join(self.config.OUTPUT_DIR, f"result_{safe_prompt}_{timestamp}.png")
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    print(f"  ✓ Saved: {save_path}")
    
    plt.show()

## 12. Main Execution Pipeline

System initialization and execution of the complete virtual try-on demo with both test prompts.

In [None]:
if __name__ == "__main__":
    # Initialize configuration
    config = Config.setup()
    
    # Initialize system
    tryon_system = SimpleVirtualTryOn(config)
    
    # Run demo with both test prompts
    results = tryon_system.run_demo()