# Human Parsing Model Training Pipeline

This notebook implements a complete human parsing system using DeepLab-style architecture with self-correction for virtual fashion try-on applications. The model segments human body parts into 18 different classes including clothing, body parts, and accessories.

## 1. Imports and Dependencies

Import all necessary libraries for deep learning, computer vision, data processing, and visualization.

In [None]:
import os
import json
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import GradScaler, autocast
from PIL import Image
import cv2
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
from datasets import load_dataset
import matplotlib.pyplot as plt
from datetime import datetime
from typing import Dict, List, Tuple, Optional, Any
import warnings
warnings.filterwarnings('ignore')

## 2. Configuration Settings

Central configuration class containing all hyperparameters, paths, and model settings for training the human parsing model.

In [None]:
class Config:
    """Central configuration for the Human Parsing model training"""
    
    # Dataset Configuration
    DATASET_NAME: str = "mattmdjaga/human_parsing_dataset"
    SPLIT: str = "train"
    NUM_CLASSES: int = 18
    IGNORE_INDEX: int = 255
    
    # Model Configuration
    INPUT_SIZE: Tuple[int, int] = (512, 512)
    BACKBONE: str = "resnet101"
    
    # Training Configuration
    BATCH_SIZE: int = 10
    EPOCHS: int = 5
    LEARNING_RATE_BACKBONE: float = 1e-4
    LEARNING_RATE_HEAD: float = 5e-4
    WEIGHT_DECAY: float = 5e-4
    GRADIENT_CLIP: float = 1.0
    
    # Loss Configuration
    EDGE_WEIGHT: float = 0.4
    
    # Paths
    OUTPUT_DIR: str = "content"
    MODEL_PATH: str = "content/best_model.pth"
    CONFIG_PATH: str = "content/model_config.json"
    CHECKPOINT_DIR: str = "content/checkpoints"
    
    # System Configuration
    SEED: int = 42
    NUM_WORKERS: int = 10
    DEVICE: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Resume Training
    RESUME_FROM: Optional[str] = "/content/content/checkpoints/checkpoint_epoch_1.pth"
    
    # Early Stopping
    EARLY_STOPPING_PATIENCE: int = 5
    
    @classmethod
    def create_directories(cls):
        """Create necessary directories for outputs"""
        os.makedirs(cls.OUTPUT_DIR, exist_ok=True)
        os.makedirs(cls.CHECKPOINT_DIR, exist_ok=True)

## 3. Class Definitions and Labels

Define the 18 human parsing classes and their corresponding colors for visualization. These include body parts, clothing items, and accessories.

In [None]:
CLASS_NAMES = [
    "Background", "Hat", "Hair", "Sunglasses", "Upper-clothes", "Skirt",
    "Pants", "Dress", "Belt", "Left-shoe", "Right-shoe", "Face",
    "Left-leg", "Right-leg", "Left-arm", "Right-arm", "Bag", "Scarf"
]

CLASS_COLORS = np.array([
    [0, 0, 0], [128, 0, 0], [255, 0, 0], [0, 85, 0], [170, 0, 51], [255, 85, 0],
    [0, 0, 85], [0, 119, 221], [85, 85, 0], [0, 85, 85], [85, 51, 0], [52, 86, 128],
    [0, 128, 0], [0, 0, 255], [51, 170, 221], [0, 255, 255], [85, 255, 170], [170, 255, 85]
], dtype=np.uint8)

## 4. Data Processing and Augmentation

Data transformation classes for training and validation, including augmentation techniques like horizontal flip, brightness/contrast changes, and normalization.

In [None]:
class DataTransforms:
    """Data augmentation and preprocessing transforms"""
    
    @staticmethod
    def get_train_transforms() -> A.Compose:
        """Training data augmentation pipeline"""
        return A.Compose([
            A.Resize(Config.INPUT_SIZE[0], Config.INPUT_SIZE[1]),
            A.HorizontalFlip(p=0.5),
            A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.3),
            A.HueSaturationValue(hue_shift_limit=10, sat_shift_limit=15, val_shift_limit=10, p=0.3),
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2()
        ])
    
    @staticmethod
    def get_val_transforms() -> A.Compose:
        """Validation data preprocessing pipeline"""
        return A.Compose([
            A.Resize(Config.INPUT_SIZE[0], Config.INPUT_SIZE[1]),
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2()
        ])

## 5. Dataset Class

Custom PyTorch dataset class for loading and preprocessing human parsing data, handling both images and segmentation masks.

In [None]:
class HumanParsingDataset(Dataset):
    """Dataset class for Human Parsing task"""
    
    def __init__(self, dataset, transform=None):
        self.dataset = dataset
        self.transform = transform
    
    def __len__(self) -> int:
        return len(self.dataset)
    
    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        sample = self.dataset[idx]
        
        # Process image
        image = sample["image"]
        if not isinstance(image, np.ndarray):
            image = np.array(image.convert("RGB"))
        
        # Process mask
        mask = sample["mask"]
        if not isinstance(mask, np.ndarray):
            mask = np.array(mask)
        
        # Handle invalid mask values
        mask = mask.astype(np.int32)
        mask[mask < 0] = Config.IGNORE_INDEX
        mask[mask >= Config.NUM_CLASSES] = Config.IGNORE_INDEX
        mask = mask.astype(np.uint8)
        
        # Apply transforms
        if self.transform:
            transformed = self.transform(image=image, mask=mask)
            image = transformed["image"].float()
            mask = transformed["mask"].long()
        else:
            image = torch.from_numpy(image.transpose(2, 0, 1)).float() / 255.0
            mask = torch.from_numpy(mask).long()
        
        return image, mask

## 6. ASPP Module (Atrous Spatial Pyramid Pooling)

Key component of DeepLab architecture that captures multi-scale contextual information using dilated convolutions with different rates.

In [None]:
class ASPP(nn.Module):
    """Atrous Spatial Pyramid Pooling module"""
    
    def __init__(self, in_channels: int, out_channels: int, rates: Tuple[int, ...] = (6, 12, 18)):
        super().__init__()
        
        # 1x1 convolution
        self.conv1x1 = self._make_branch(in_channels, out_channels, kernel_size=1)
        
        # Atrous convolutions
        self.atrous_branches = nn.ModuleList([
            self._make_branch(in_channels, out_channels, kernel_size=3, dilation=rate)
            for rate in rates
        ])
        
        # Global average pooling branch
        self.global_pool = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(in_channels, out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
        
        # Projection layer
        num_branches = len(rates) + 2  # Atrous + 1x1 + global
        self.projection = nn.Sequential(
            nn.Conv2d(out_channels * num_branches, out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5)
        )
    
    def _make_branch(self, in_channels: int, out_channels: int, 
                     kernel_size: int, dilation: int = 1) -> nn.Sequential:
        """Create a convolutional branch"""
        padding = 0 if kernel_size == 1 else dilation
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, 
                     padding=padding, dilation=dilation, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Collect features from all branches
        features = [self.conv1x1(x)]
        features.extend([branch(x) for branch in self.atrous_branches])
        
        # Global pooling branch
        global_feat = self.global_pool(x)
        global_feat = F.interpolate(global_feat, size=x.shape[-2:], 
                                   mode="bilinear", align_corners=False)
        features.append(global_feat)
        
        # Concatenate and project
        concatenated = torch.cat(features, dim=1)
        return self.projection(concatenated)

## 7. Self-Correction Module

Novel component that improves segmentation accuracy by detecting edges and using them to refine the segmentation predictions.

In [None]:
class SelfCorrectionModule(nn.Module):
    """Self-correction module with edge awareness"""
    
    def __init__(self, in_channels: int, num_classes: int):
        super().__init__()
        
        # Edge detection branch
        self.edge_branch = nn.Sequential(
            nn.Conv2d(in_channels, 256, 3, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 1, 1)  # Output edge logits
        )
        
        # Refinement branch
        self.refinement_branch = nn.Sequential(
            nn.Conv2d(in_channels + 1, 256, 3, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, num_classes, 1)
        )
    
    def forward(self, features: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        # Predict edges
        edge_logits = self.edge_branch(features)
        
        # Concatenate features with edge information
        enhanced_features = torch.cat([features, edge_logits], dim=1)
        
        # Refine predictions
        refined_logits = self.refinement_branch(enhanced_features)
        
        return refined_logits, edge_logits

## 8. Main Human Parsing Network

Complete model architecture combining ResNet101 backbone, ASPP module, decoder, and self-correction mechanism for accurate human parsing.

In [None]:
class HumanParsingNet(nn.Module):
    """Main Human Parsing Network with Self-Correction"""
    
    def __init__(self, num_classes: int = 18):
        super().__init__()
        
        # Load pretrained ResNet101 backbone
        import torchvision.models as models
        backbone = models.resnet101(pretrained=True)
        
        # Extract backbone layers
        self.initial_layers = nn.Sequential(*list(backbone.children())[:5])  # Conv1 -> Layer1
        self.layer2 = nn.Sequential(*list(backbone.children())[5])
        self.layer3 = nn.Sequential(*list(backbone.children())[6])
        self.layer4 = nn.Sequential(*list(backbone.children())[7])
        
        # Low-level feature processing
        self.low_level_conv = nn.Sequential(
            nn.Conv2d(256, 48, 1, bias=False),
            nn.BatchNorm2d(48),
            nn.ReLU(inplace=True)
        )
        
        # ASPP module
        self.aspp = ASPP(2048, 256)
        
        # Decoder
        self.decoder = nn.Sequential(
            nn.Conv2d(256 + 48, 256, 3, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Conv2d(256, 256, 3, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True)
        )
        
        # Output heads
        self.coarse_head = nn.Conv2d(256, num_classes, 1)
        self.self_correction = SelfCorrectionModule(256, num_classes)
    
    def forward(self, x: torch.Tensor) -> Any:
        input_shape = x.shape[-2:]
        
        # Backbone forward pass
        low_level = self.initial_layers(x)  # 256 channels
        x = self.layer2(low_level)          # 512 channels
        x = self.layer3(x)                  # 1024 channels
        x = self.layer4(x)                  # 2048 channels
        
        # Process low-level features
        low_level_features = self.low_level_conv(low_level)
        
        # ASPP
        aspp_features = self.aspp(x)
        
        # Upsample and concatenate
        aspp_features = F.interpolate(aspp_features, size=low_level_features.shape[-2:],
                                      mode="bilinear", align_corners=False)
        decoder_input = torch.cat([aspp_features, low_level_features], dim=1)
        
        # Decode
        decoder_features = self.decoder(decoder_input)
        
        # Generate outputs
        coarse_logits = self.coarse_head(decoder_features)
        refined_logits, edge_logits = self.self_correction(decoder_features)
        
        # Upsample to input resolution
        coarse_logits = F.interpolate(coarse_logits, size=input_shape,
                                      mode="bilinear", align_corners=False)
        refined_logits = F.interpolate(refined_logits, size=input_shape,
                                       mode="bilinear", align_corners=False)
        edge_logits = F.interpolate(edge_logits, size=input_shape,
                                    mode="bilinear", align_corners=False)
        
        if self.training:
            return coarse_logits, refined_logits, edge_logits
        else:
            return refined_logits

## 9. Loss Functions

Edge-aware loss function that combines segmentation loss with edge detection loss to improve boundary accuracy.

In [None]:
class EdgeAwareLoss(nn.Module):
    """Combined loss with edge awareness"""
    
    def __init__(self, edge_weight: float = 0.4, ignore_index: int = 255):
        super().__init__()
        self.ce_loss = nn.CrossEntropyLoss(ignore_index=ignore_index)
        self.bce_logits_loss = nn.BCEWithLogitsLoss()
        self.edge_weight = edge_weight
        self.ignore_index = ignore_index
    
    @torch.no_grad()
    def compute_edge_targets(self, masks: torch.Tensor) -> torch.Tensor:
        """Compute edge targets from segmentation masks"""
        batch_size = masks.shape[0]
        edges = []
        
        for i in range(batch_size):
            mask = masks[i].cpu().numpy().astype(np.int32)
            mask[mask == self.ignore_index] = -1
            
            # Compute gradients
            grad_y, grad_x = np.gradient(mask)
            edge = ((np.abs(grad_x) > 0) | (np.abs(grad_y) > 0)).astype(np.float32)
            edges.append(edge)
        
        return torch.from_numpy(np.stack(edges, axis=0)).to(masks.device)
    
    def forward(self, coarse_logits: torch.Tensor, refined_logits: torch.Tensor,
                edge_logits: torch.Tensor, targets: torch.Tensor) -> Tuple[torch.Tensor, Dict]:
        # Segmentation losses
        coarse_loss = self.ce_loss(coarse_logits, targets)
        refined_loss = self.ce_loss(refined_logits, targets)
        
        # Edge loss
        edge_targets = self.compute_edge_targets(targets)
        valid_mask = (targets != self.ignore_index).float()
        edge_loss = self.bce_logits_loss(
            edge_logits.squeeze(1) * valid_mask,
            edge_targets * valid_mask
        )
        
        # Combined loss
        total_loss = coarse_loss + refined_loss + self.edge_weight * edge_loss
        
        # Loss components for logging
        loss_dict = {
            "coarse": coarse_loss.item(),
            "refined": refined_loss.item(),
            "edge": edge_loss.item(),
            "total": total_loss.item()
        }
        
        return total_loss, loss_dict

## 10. Evaluation Metrics

Implementation of mean Intersection over Union (mIoU) metric for evaluating segmentation performance.

In [None]:
class Metrics:
    """Metrics computation for segmentation tasks"""
    
    @staticmethod
    def compute_miou(predictions: np.ndarray, targets: np.ndarray,
                     num_classes: int, ignore_index: int = 255) -> Tuple[float, List[float]]:
        """Compute mean Intersection over Union"""
        predictions = predictions.flatten()
        targets = targets.flatten()
        
        # Filter out ignored pixels
        valid_mask = targets != ignore_index
        predictions = predictions[valid_mask]
        targets = targets[valid_mask]
        
        ious = []
        for class_id in range(num_classes):
            pred_mask = predictions == class_id
            target_mask = targets == class_id
            
            intersection = np.logical_and(pred_mask, target_mask).sum()
            union = np.logical_or(pred_mask, target_mask).sum()
            
            if target_mask.sum() == 0:  # Class not present
                ious.append(1.0)
            else:
                ious.append(intersection / (union + 1e-6))
        
        return float(np.mean(ious)), ious

## 11. Training Utilities

Utilities for checkpoint management and early stopping to optimize training process.

In [None]:
class CheckpointManager:
    """Manage model checkpoints and resume training"""
    
    @staticmethod
    def save_checkpoint(model: nn.Module, optimizer: torch.optim.Optimizer,
                       scaler: GradScaler, epoch: int, best_miou: float,
                       loss: float, filepath: str):
        """Save training checkpoint"""
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scaler_state_dict': scaler.state_dict(),
            'best_miou': best_miou,
            'loss': loss,
            'config': {
                'num_classes': Config.NUM_CLASSES,
                'input_size': Config.INPUT_SIZE,
                'learning_rates': {
                    'backbone': Config.LEARNING_RATE_BACKBONE,
                    'head': Config.LEARNING_RATE_HEAD
                }
            }
        }
        torch.save(checkpoint, filepath)
    
    @staticmethod
    def load_checkpoint(filepath: str, model: nn.Module,
                       optimizer: Optional[torch.optim.Optimizer] = None,
                       scaler: Optional[GradScaler] = None) -> Tuple[int, float]:
        """Load training checkpoint"""
        checkpoint = torch.load(filepath, map_location=Config.DEVICE)
        
        model.load_state_dict(checkpoint['model_state_dict'])
        
        if optimizer is not None and 'optimizer_state_dict' in checkpoint:
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        
        if scaler is not None and 'scaler_state_dict' in checkpoint:
            scaler.load_state_dict(checkpoint['scaler_state_dict'])
        
        start_epoch = checkpoint['epoch'] + 1
        best_miou = checkpoint['best_miou']
        
        print(f"Resumed from epoch {start_epoch}, best mIoU: {best_miou:.4f}")
        return start_epoch, best_miou


class EarlyStopping:
    """Early stopping to prevent overfitting"""
    
    def __init__(self, patience: int = 5, min_delta: float = 0.001):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_score = None
        self.early_stop = False
    
    def __call__(self, val_score: float) -> bool:
        if self.best_score is None:
            self.best_score = val_score
        elif val_score < self.best_score + self.min_delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = val_score
            self.counter = 0
        
        return self.early_stop

## 12. Visualization Tools

Tools for visualizing model predictions, converting masks to colored images, and creating comparison plots.

In [None]:
class Visualizer:
    """Visualization utilities for model predictions"""
    
    @staticmethod
    def mask_to_color(mask: np.ndarray) -> np.ndarray:
        """Convert segmentation mask to colored image"""
        h, w = mask.shape
        colored = np.zeros((h, w, 3), dtype=np.uint8)
        
        for class_id, color in enumerate(CLASS_COLORS):
            colored[mask == class_id] = color
        
        return colored
    
    @staticmethod
    def denormalize_image(tensor: torch.Tensor) -> torch.Tensor:
        """Denormalize image tensor for visualization"""
        mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
        std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
        
        tensor = tensor.clone()
        tensor = tensor * std + mean
        return torch.clamp(tensor, 0, 1)
    
    @classmethod
    def visualize_predictions(cls, model: nn.Module, dataloader: DataLoader,
                            device: torch.device, num_samples: int = 4,
                            save_path: str = "outputs/predictions.png"):
        """Visualize model predictions"""
        model.eval()
        
        fig, axes = plt.subplots(num_samples, 4, figsize=(16, 4 * num_samples))
        if num_samples == 1:
            axes = axes.reshape(1, -1)
        
        samples_shown = 0
        
        with torch.no_grad():
            for images, masks in dataloader:
                if samples_shown >= num_samples:
                    break
                
                images = images.to(device)
                masks = masks.to(device)
                
                predictions = model(images)
                pred_masks = predictions.argmax(dim=1)
                
                batch_size = min(images.size(0), num_samples - samples_shown)
                
                for i in range(batch_size):
                    # Process image
                    img_np = cls.denormalize_image(images[i].cpu())
                    img_np = img_np.permute(1, 2, 0).numpy()
                    
                    # Process masks
                    mask_np = masks[i].cpu().numpy()
                    pred_np = pred_masks[i].cpu().numpy()
                    
                    # Convert to colors
                    gt_colored = cls.mask_to_color(mask_np)
                    pred_colored = cls.mask_to_color(pred_np)
                    
                    # Calculate IoU
                    miou, _ = Metrics.compute_miou(pred_np, mask_np, 
                                                   Config.NUM_CLASSES, Config.IGNORE_INDEX)
                    
                    # Create overlay
                    overlay = cv2.addWeighted(
                        (img_np * 255).astype(np.uint8), 0.6,
                        pred_colored, 0.4, 0
                    )
                    
                    # Plot
                    row = samples_shown
                    axes[row, 0].imshow(img_np)
                    axes[row, 0].set_title('Input Image')
                    axes[row, 0].axis('off')
                    
                    axes[row, 1].imshow(gt_colored)
                    axes[row, 1].set_title('Ground Truth')
                    axes[row, 1].axis('off')
                    
                    axes[row, 2].imshow(pred_colored)
                    axes[row, 2].set_title(f'Prediction (mIoU: {miou:.3f})')
                    axes[row, 2].axis('off')
                    
                    axes[row, 3].imshow(overlay)
                    axes[row, 3].set_title('Overlay')
                    axes[row, 3].axis('off')
                    
                    samples_shown += 1
        
        plt.tight_layout()
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        plt.close()
        print(f"Saved visualization to {save_path}")

## 13. Training Pipeline

Main trainer class that orchestrates the entire training process including data loading, loss computation, optimization, and validation.

In [None]:
class Trainer:
    """Main trainer class for Human Parsing model"""
    
    def __init__(self, model: nn.Module, train_loader: DataLoader,
                 val_loader: DataLoader, config: Config):
        self.model = model
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.config = config
        
        # Loss function
        self.criterion = EdgeAwareLoss(
            edge_weight=config.EDGE_WEIGHT,
            ignore_index=config.IGNORE_INDEX
        )
        
        # Optimizer with different learning rates
        self.optimizer = self._create_optimizer()
        
        # Mixed precision training
        self.scaler = GradScaler(enabled=config.DEVICE.type == "cuda")
        
        # Training state
        self.start_epoch = 0
        self.best_miou = 0.0
        
        # Early stopping
        self.early_stopping = EarlyStopping(patience=config.EARLY_STOPPING_PATIENCE)
        
        # Resume from checkpoint if specified
        if config.RESUME_FROM and os.path.exists(config.RESUME_FROM):
            self.start_epoch, self.best_miou = CheckpointManager.load_checkpoint(
                config.RESUME_FROM, self.model, self.optimizer, self.scaler
            )
    
    def _create_optimizer(self) -> torch.optim.Optimizer:
        """Create optimizer with different learning rates for backbone and head"""
        backbone_params = []
        head_params = []
        
        for name, param in self.model.named_parameters():
            if "layer" in name or "initial" in name:
                backbone_params.append(param)
            else:
                head_params.append(param)
        
        optimizer = torch.optim.AdamW([
            {"params": backbone_params, "lr": self.config.LEARNING_RATE_BACKBONE},
            {"params": head_params, "lr": self.config.LEARNING_RATE_HEAD}
        ], weight_decay=self.config.WEIGHT_DECAY)
        
        return optimizer
    
    def train_epoch(self, epoch: int) -> float:
        """Train for one epoch"""
        self.model.train()
        total_loss = 0.0
        num_batches = 0
        
        pbar = tqdm(self.train_loader, desc=f"Epoch {epoch+1}/{self.config.EPOCHS} [Train]")
        
        for images, masks in pbar:
            images = images.to(self.config.DEVICE)
            masks = masks.to(self.config.DEVICE)
            
            self.optimizer.zero_grad()
            
            with autocast(enabled=self.config.DEVICE.type == "cuda"):
                coarse, refined, edges = self.model(images)
                loss, loss_dict = self.criterion(coarse, refined, edges, masks)
            
            self.scaler.scale(loss).backward()
            self.scaler.unscale_(self.optimizer)
            
            # Gradient clipping
            nn.utils.clip_grad_norm_(self.model.parameters(), self.config.GRADIENT_CLIP)
            
            self.scaler.step(self.optimizer)
            self.scaler.update()
            
            total_loss += loss.item()
            num_batches += 1
            
            pbar.set_postfix({
                'loss': f"{loss.item():.4f}",
                'coarse': f"{loss_dict['coarse']:.3f}",
                'refined': f"{loss_dict['refined']:.3f}",
                'edge': f"{loss_dict['edge']:.3f}"
            })
        
        return total_loss / max(1, num_batches)
    
    def validate(self, epoch: int) -> Tuple[float, float]:
        """Validate the model"""
        self.model.eval()
        total_loss = 0.0
        all_mious = []
        
        with torch.no_grad():
            for images, masks in tqdm(self.val_loader, desc=f"Epoch {epoch+1}/{self.config.EPOCHS} [Val]"):
                images = images.to(self.config.DEVICE)
                masks = masks.to(self.config.DEVICE)
                
                with autocast(enabled=self.config.DEVICE.type == "cuda"):
                    predictions = self.model(images)
                    loss = F.cross_entropy(predictions, masks, ignore_index=self.config.IGNORE_INDEX)
                
                total_loss += loss.item()
                
                # Compute mIoU
                pred_masks = predictions.argmax(dim=1)
                for i in range(images.size(0)):
                    miou, _ = Metrics.compute_miou(
                        pred_masks[i].cpu().numpy(),
                        masks[i].cpu().numpy(),
                        self.config.NUM_CLASSES,
                        self.config.IGNORE_INDEX
                    )
                    all_mious.append(miou)
        
        avg_loss = total_loss / max(1, len(self.val_loader))
        avg_miou = float(np.mean(all_mious))
        
        return avg_loss, avg_miou
    
    def train(self):
        """Main training loop"""
        print(f"Starting training on {self.config.DEVICE}")
        print(f"Training from epoch {self.start_epoch} to {self.config.EPOCHS}")
        
        for epoch in range(self.start_epoch, self.config.EPOCHS):
            # Training phase
            train_loss = self.train_epoch(epoch)
            
            # Validation phase
            val_loss, val_miou = self.validate(epoch)
            
            print(f"Epoch {epoch+1}: train_loss={train_loss:.4f}, "
                  f"val_loss={val_loss:.4f}, val_mIoU={val_miou:.4f}")
            
            # Save checkpoint
            checkpoint_path = os.path.join(
                self.config.CHECKPOINT_DIR,
                f"checkpoint_epoch_{epoch+1}.pth"
            )
            CheckpointManager.save_checkpoint(
                self.model, self.optimizer, self.scaler,
                epoch, self.best_miou, train_loss, checkpoint_path
            )
            
            # Save best model
            if val_miou > self.best_miou:
                self.best_miou = val_miou
                torch.save({
                    "model": self.model.state_dict(),
                    "best_miou": self.best_miou,
                    "epoch": epoch
                }, self.config.MODEL_PATH)
                print(f"Saved best model with mIoU: {self.best_miou:.4f}")
            
            # Early stopping check
            if self.early_stopping(val_miou):
                print(f"Early stopping triggered at epoch {epoch+1}")
                break
        
        print(f"Training completed. Best mIoU: {self.best_miou:.4f}")
        return self.best_miou

## 14. Inference Module

Predictor class for running inference on new images using trained model weights.

In [None]:
class Predictor:
    """Inference class for Human Parsing model"""
    
    def __init__(self, model_path: str, config_path: Optional[str] = None):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        # Load configuration
        if config_path and os.path.exists(config_path):
            with open(config_path, 'r') as f:
                self.config = json.load(f)
        else:
            self.config = self._default_config()
        
        # Initialize model
        self.model = HumanParsingNet(num_classes=self.config['model']['num_classes'])
        
        # Load weights
        checkpoint = torch.load(model_path, map_location=self.device)
        if 'model' in checkpoint:
            self.model.load_state_dict(checkpoint['model'])
        else:
            self.model.load_state_dict(checkpoint)
        
        self.model.to(self.device)
        self.model.eval()
        
        # Preprocessing
        self.transform = A.Compose([
            A.Resize(self.config['preprocessing']['size'][0],
                    self.config['preprocessing']['size'][1]),
            A.Normalize(mean=self.config['preprocessing']['mean'],
                       std=self.config['preprocessing']['std']),
            ToTensorV2()
        ])
    
    def _default_config(self) -> Dict:
        """Default configuration for inference"""
        return {
            'model': {
                'num_classes': 18,
                'input_size': [512, 512]
            },
            'preprocessing': {
                'mean': [0.485, 0.456, 0.406],
                'std': [0.229, 0.224, 0.225],
                'size': [512, 512]
            }
        }
    
    def predict(self, image: np.ndarray) -> np.ndarray:
        """Predict segmentation mask for an image"""
        # Preprocess
        transformed = self.transform(image=image)
        input_tensor = transformed['image'].unsqueeze(0).to(self.device)
        
        # Predict
        with torch.no_grad():
            output = self.model(input_tensor)
            prediction = output.argmax(dim=1)[0].cpu().numpy()
        
        return prediction
    
    def predict_from_path(self, image_path: str) -> np.ndarray:
        """Load image from path and predict"""
        image = np.array(Image.open(image_path).convert('RGB'))
        return self.predict(image)

## 15. Utility Functions

Helper functions for setting up reproducible training, creating data loaders, and saving model configurations.

In [None]:
def setup_seed(seed: int):
    """Set random seeds for reproducibility"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False


def create_data_loaders() -> Tuple[DataLoader, DataLoader]:
    """Create training and validation data loaders"""
    # Load dataset
    dataset = load_dataset(Config.DATASET_NAME, split=Config.SPLIT)
    
    # Split into train and validation
    indices = list(range(len(dataset)))
    random.shuffle(indices)
    
    split_point = int(0.8 * len(indices))
    train_indices = indices[:split_point]
    val_indices = indices[split_point:]
    
    # Create datasets
    train_dataset = HumanParsingDataset(
        dataset.select(train_indices),
        transform=DataTransforms.get_train_transforms()
    )
    
    val_dataset = HumanParsingDataset(
        dataset.select(val_indices),
        transform=DataTransforms.get_val_transforms()
    )
    
    # Create data loaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=Config.BATCH_SIZE,
        shuffle=True,
        num_workers=Config.NUM_WORKERS,
        pin_memory=True,
        drop_last=True
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=Config.BATCH_SIZE,
        shuffle=False,
        num_workers=Config.NUM_WORKERS,
        pin_memory=True
    )
    
    print(f"Dataset loaded: {len(train_dataset)} training, {len(val_dataset)} validation samples")
    
    return train_loader, val_loader


def save_model_config(model: nn.Module, best_miou: float):
    """Save model configuration for inference"""
    config = {
        "model": {
            "architecture": "HumanParsingNet",
            "num_classes": Config.NUM_CLASSES,
            "input_size": list(Config.INPUT_SIZE),
            "model_path": Config.MODEL_PATH
        },
        "classes": {
            "names": CLASS_NAMES,
            "colors": CLASS_COLORS.tolist()
        },
        "preprocessing": {
            "mean": [0.485, 0.456, 0.406],
            "std": [0.229, 0.224, 0.225],
            "size": list(Config.INPUT_SIZE)
        },
        "training_info": {
            "dataset": Config.DATASET_NAME,
            "batch_size": Config.BATCH_SIZE,
            "epochs": Config.EPOCHS,
            "best_miou": float(best_miou),
            "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        }
    }
    
    with open(Config.CONFIG_PATH, 'w') as f:
        json.dump(config, f, indent=2)
    
    print(f"Model configuration saved to {Config.CONFIG_PATH}")

## 16. Model Evaluation

Comprehensive evaluation function that computes detailed metrics and identifies best/worst performing classes.

In [None]:
def evaluate_model(model: nn.Module, val_loader: DataLoader) -> Dict[str, Any]:
    """Comprehensive model evaluation"""
    model.eval()
    
    all_ious = []
    class_pixel_counts = np.zeros(Config.NUM_CLASSES)
    
    print("Performing comprehensive evaluation...")
    
    with torch.no_grad():
        for images, masks in tqdm(val_loader, desc="Evaluating"):
            images = images.to(Config.DEVICE)
            masks = masks.to(Config.DEVICE)
            
            predictions = model(images)
            pred_masks = predictions.argmax(dim=1)
            
            for i in range(images.size(0)):
                pred_np = pred_masks[i].cpu().numpy()
                mask_np = masks[i].cpu().numpy()
                
                _, ious = Metrics.compute_miou(
                    pred_np, mask_np,
                    Config.NUM_CLASSES,
                    Config.IGNORE_INDEX
                )
                all_ious.append(ious)
                
                # Count pixels per class
                for c in range(Config.NUM_CLASSES):
                    class_pixel_counts[c] += (mask_np == c).sum()
    
    # Calculate statistics
    mean_ious = np.mean(all_ious, axis=0)
    overall_miou = np.mean(mean_ious)
    
    # Find best and worst performing classes
    best_classes = np.argsort(mean_ious)[-3:][::-1]
    worst_classes = np.argsort(mean_ious)[:3]
    
    results = {
        "overall_miou": float(overall_miou),
        "per_class_ious": mean_ious.tolist(),
        "best_classes": [(CLASS_NAMES[i], float(mean_ious[i])) for i in best_classes],
        "worst_classes": [(CLASS_NAMES[i], float(mean_ious[i])) for i in worst_classes],
        "class_pixel_counts": class_pixel_counts.tolist()
    }
    
    # Print results
    print("\n" + "="*60)
    print("EVALUATION RESULTS")
    print("="*60)
    print(f"Overall mIoU: {overall_miou:.4f}")
    
    print("\nPer-Class IoU:")
    print("-"*40)
    for i, (name, iou) in enumerate(zip(CLASS_NAMES, mean_ious)):
        print(f"{i:2d} | {name:15s} | {iou:.3f}")
    
    print(f"\nBest performing classes:")
    for name, iou in results["best_classes"]:
        print(f"  - {name}: {iou:.3f}")
    
    print(f"\nClasses needing improvement:")
    for name, iou in results["worst_classes"]:
        print(f"  - {name}: {iou:.3f}")
    
    return results

## 17. External Image Testing

Function to test the trained model on external images downloaded from the internet to demonstrate real-world performance.

In [None]:
def test_external_images(model: nn.Module, device: torch.device):
    """Test model on external images from the internet"""
    print("\n" + "="*60)
    print("TESTING ON EXTERNAL IMAGES")
    print("="*60)
    
    # Import requests for downloading images
    import requests
    from io import BytesIO
    
    # Define test images
    test_images = [
        {
            'url': 'https://images.unsplash.com/photo-1529626455594-4ff0802cfb7e?w=600',
            'name': 'fashion_model.jpg'
        },
        {
            'url': 'https://images.unsplash.com/photo-1506794778202-cad84cf45f1d?w=600',
            'name': 'man_portrait.jpg'
        }
    ]
    
    # Create preprocessing transform for external images
    transform = A.Compose([
        A.Resize(Config.INPUT_SIZE[0], Config.INPUT_SIZE[1]),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2()
    ])
    
    # Process each test image
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    
    for idx, img_info in enumerate(test_images):
        print(f"\nProcessing: {img_info['name']}")
        
        try:
            # Download image
            response = requests.get(img_info['url'], timeout=10)
            response.raise_for_status()
            image = Image.open(BytesIO(response.content)).convert('RGB')
            image_np = np.array(image)
            
            # Preprocess image
            transformed = transform(image=image_np)
            input_tensor = transformed['image'].unsqueeze(0).to(device)
            
            # Run inference
            model.eval()
            with torch.no_grad():
                output = model(input_tensor)
                pred_mask = output.argmax(dim=1)[0]
                
                # Resize prediction to original size
                pred_mask = F.interpolate(
                    pred_mask.unsqueeze(0).unsqueeze(0).float(),
                    size=image_np.shape[:2],
                    mode='nearest'
                )[0, 0].long()
            
            # Convert to numpy
            pred_mask_np = pred_mask.cpu().numpy()
            
            # Create colored mask
            colored_mask = Visualizer.mask_to_color(pred_mask_np)
            
            # Create overlay
            overlay = cv2.addWeighted(image_np, 0.6, colored_mask, 0.4, 0)
            
            # Plot results
            axes[idx, 0].imshow(image_np)
            axes[idx, 0].set_title(f'Original: {img_info["name"]}')
            axes[idx, 0].axis('off')
            
            axes[idx, 1].imshow(colored_mask)
            axes[idx, 1].set_title('Segmentation Mask')
            axes[idx, 1].axis('off')
            
            axes[idx, 2].imshow(overlay)
            axes[idx, 2].set_title('Overlay')
            axes[idx, 2].axis('off')
            
            # Print detected classes
            unique_classes = np.unique(pred_mask_np)
            print(f"  Detected classes:")
            for class_id in unique_classes:
                if class_id < len(CLASS_NAMES):
                    pixel_percentage = (pred_mask_np == class_id).sum() / pred_mask_np.size * 100
                    if pixel_percentage > 1.0:  # Only show classes with >1% pixels
                        print(f"    - {CLASS_NAMES[class_id]}: {pixel_percentage:.1f}%")
            
        except Exception as e:
            print(f"  Failed to process {img_info['name']}: {e}")
            # Show error message in plot
            for j in range(3):
                axes[idx, j].text(0.5, 0.5, f'Failed to load\n{img_info["name"]}',
                                ha='center', va='center', fontsize=12)
                axes[idx, j].axis('off')
    
    plt.tight_layout()
    plt.savefig(os.path.join(Config.OUTPUT_DIR, 'external_test_results.png'), dpi=150, bbox_inches='tight')
    plt.show()
    
    print(f"\nExternal test results saved to: {os.path.join(Config.OUTPUT_DIR, 'external_test_results.png')}")

## 18. Main Execution Pipeline

Complete training and evaluation pipeline that orchestrates all components together.

In [None]:
def main():
    """Main training pipeline"""
    # Setup
    setup_seed(Config.SEED)
    Config.create_directories()
    
    print("="*60)
    print("HUMAN PARSING MODEL TRAINING")
    print("="*60)
    print(f"Device: {Config.DEVICE}")
    print(f"Dataset: {Config.DATASET_NAME}")
    print(f"Batch Size: {Config.BATCH_SIZE}")
    print(f"Epochs: {Config.EPOCHS}")
    print(f"Input Size: {Config.INPUT_SIZE}")
    print("="*60)
    
    # Create data loaders
    train_loader, val_loader = create_data_loaders()
    
    # Initialize model
    model = HumanParsingNet(num_classes=Config.NUM_CLASSES).to(Config.DEVICE)
    
    # Count parameters
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Model Parameters: {total_params:,} total, {trainable_params:,} trainable")
    print(f"Model Size: ~{total_params * 4 / 1024 / 1024:.1f} MB")
    
    # Initialize trainer
    trainer = Trainer(model, train_loader, val_loader, Config)
    
    # Train model
    best_miou = trainer.train()
    
    # Load best model for final evaluation
    checkpoint = torch.load(Config.MODEL_PATH, map_location=Config.DEVICE)
    model.load_state_dict(checkpoint['model'])
    
    # Comprehensive evaluation
    evaluation_results = evaluate_model(model, val_loader)
    
    # Visualize predictions on validation data
    Visualizer.visualize_predictions(
        model, val_loader, Config.DEVICE,
        num_samples=6,
        save_path=os.path.join(Config.OUTPUT_DIR, "predictions.png")
    )
    
    # Test on external images
    test_external_images(model, Config.DEVICE)
    
    # Save configuration
    save_model_config(model, best_miou)
    
    # Update config with evaluation results
    with open(Config.CONFIG_PATH, 'r') as f:
        config = json.load(f)
    
    config['evaluation_results'] = evaluation_results
    
    with open(Config.CONFIG_PATH, 'w') as f:
        json.dump(config, f, indent=2)
    
    print("\n" + "="*60)
    print("TRAINING COMPLETED SUCCESSFULLY")
    print("="*60)
    print(f"Best Validation mIoU: {best_miou:.4f}")
    print(f"Final Test mIoU: {evaluation_results['overall_miou']:.4f}")

if __name__ == "__main__":
    main()