# Face Reconstruction with EdgeConnect

This notebook demonstrates how to use EdgeConnect for face reconstruction, specifically for inpainting masked regions (like the white nose mask in the provided image).

In [None]:
import os
import sys
import numpy as np
import torch
import cv2
from PIL import Image
import matplotlib.pyplot as plt
import yaml
from skimage import feature
import torchvision.transforms as transforms

# Add the src directory to Python path
sys.path.append('./src')

# Import EdgeConnect modules
from edge_connect import EdgeConnect
from utils import create_mask

## 1. Load and Prepare the Configuration

In [None]:
# Create a configuration for face reconstruction
config = {
    'MODE': 1,  # 1: train, 2: test, 3: eval
    'MODEL': 3,  # 1: edge model, 2: inpaint model, 3: edge-inpaint model, 4: joint model
    'MASK': 3,  # 1: random block, 2: half, 3: external, 4: external + random block, 5: external + random block + half
    'EDGE': 1,  # 1: canny, 2: external
    'NMS': 1,   # 0: no non-max-suppression, 1: non-max-suppression on the external edges
    'SEED': 10,
    'GPU': [0],
    'DEBUG': 0,
    'VERBOSE': 1,
    
    # Training configurations
    'LR': 0.0001,
    'D2G_LR': 0.1,
    'BETA1': 0.0,
    'BETA2': 0.9,
    'BATCH_SIZE': 8,
    'INPUT_SIZE': 256,
    'SIGMA': 2,
    'MAX_ITERS': 2000000,
    'EDGE_THRESHOLD': 0.5,
    'L1_LOSS_WEIGHT': 1,
    'FM_LOSS_WEIGHT': 10,
    'STYLE_LOSS_WEIGHT': 1,
    'CONTENT_LOSS_WEIGHT': 1,
    'INPAINT_ADV_LOSS_WEIGHT': 0.01,
    'GAN_LOSS': 'nsgan',
    'GAN_POOL_SIZE': 0,
    'SAVE_INTERVAL': 1000,
    'EVAL_INTERVAL': 0,
    'LOG_INTERVAL': 10,
    'SAMPLE_INTERVAL': 1000,
    'SAMPLE_SIZE': 12,
    
    # Paths (will be set dynamically)
    'PATH': './checkpoints',
    'TRAIN_FLIST': '',
    'VAL_FLIST': '',
    'TEST_FLIST': '',
    'TRAIN_EDGE_FLIST': '',
    'VAL_EDGE_FLIST': '',
    'TEST_EDGE_FLIST': '',
    'TRAIN_MASK_FLIST': '',
    'VAL_MASK_FLIST': '',
    'TEST_MASK_FLIST': ''
}

print("Configuration loaded successfully!")

## 2. Utility Functions

In [None]:
def load_image(image_path, size=None):
    """Load and preprocess image"""
    image = Image.open(image_path).convert('RGB')
    if size:
        image = image.resize(size, Image.LANCZOS)
    return np.array(image)

def create_mask_from_white_regions(image, threshold=240):
    """Create mask from white regions in the image (like the nose mask)"""
    # Convert to grayscale
    gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
    
    # Create mask where white regions are
    mask = (gray > threshold).astype(np.uint8) * 255
    
    # Apply morphological operations to clean up the mask
    kernel = np.ones((3,3), np.uint8)
    mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
    mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
    
    return mask

def to_tensor(image):
    """Convert numpy array to tensor"""
    if len(image.shape) == 3:
        image = image.transpose(2, 0, 1)
    return torch.from_numpy(image.astype(np.float32) / 255.0).unsqueeze(0)

def from_tensor(tensor):
    """Convert tensor to numpy array"""
    if tensor.dim() == 4:
        tensor = tensor.squeeze(0)
    image = tensor.cpu().detach().numpy()
    if len(image.shape) == 3:
        image = image.transpose(1, 2, 0)
    return (image * 255).astype(np.uint8)

def canny_edge_detection(image, sigma=2, low_threshold=0.1, high_threshold=0.2):
    """Apply Canny edge detection"""
    # Convert to grayscale
    if len(image.shape) == 3:
        gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
    else:
        gray = image
    
    # Apply Gaussian blur
    blurred = cv2.GaussianBlur(gray, (0, 0), sigma)
    
    # Apply Canny edge detection
    edges = cv2.Canny(blurred, int(low_threshold * 255), int(high_threshold * 255))
    
    return edges

print("Utility functions defined successfully!")

## 3. Simple EdgeConnect Implementation

Since we may not have pre-trained models, let's create a simple version that uses traditional inpainting techniques with edge guidance.

In [None]:
class SimpleEdgeConnect:
    def __init__(self, config):
        self.config = config
        
    def edge_guided_inpainting(self, image, mask):
        """Perform edge-guided inpainting using traditional methods"""
        # Step 1: Extract edges from the non-masked regions
        edges = canny_edge_detection(image, sigma=self.config['SIGMA'])
        
        # Step 2: Mask the edges (remove edges in masked regions)
        edges_masked = edges.copy()
        edges_masked[mask > 0] = 0
        
        # Step 3: Extend edges into masked regions using morphological operations
        kernel = np.ones((3,3), np.uint8)
        edges_dilated = cv2.dilate(edges_masked, kernel, iterations=2)
        
        # Step 4: Create edge guidance for inpainting
        edge_guidance = edges_dilated.astype(np.float32) / 255.0
        
        # Step 5: Convert mask for OpenCV inpainting (255 for regions to inpaint)
        inpaint_mask = mask.copy()
        
        # Step 6: Perform inpainting using Telea method
        inpainted = cv2.inpaint(image, inpaint_mask, 3, cv2.INPAINT_TELEA)
        
        # Step 7: Enhance inpainting using Fast Marching method
        inpainted_fm = cv2.inpaint(image, inpaint_mask, 3, cv2.INPAINT_NS)
        
        # Step 8: Blend results based on edge guidance
        edge_weight = 0.3
        final_result = (1 - edge_weight) * inpainted + edge_weight * inpainted_fm
        
        return final_result.astype(np.uint8), edges, edges_dilated
    
    def process_image(self, image_path, output_path=None):
        """Process a single image"""
        # Load image
        image = load_image(image_path, size=(self.config['INPUT_SIZE'], self.config['INPUT_SIZE']))
        
        # Create mask from white regions
        mask = create_mask_from_white_regions(image)
        
        # Perform edge-guided inpainting
        result, edges, edges_dilated = self.edge_guided_inpainting(image, mask)
        
        # Visualize results
        fig, axes = plt.subplots(2, 3, figsize=(15, 10))
        
        axes[0, 0].imshow(image)
        axes[0, 0].set_title('Original Image with Mask')
        axes[0, 0].axis('off')
        
        axes[0, 1].imshow(mask, cmap='gray')
        axes[0, 1].set_title('Detected Mask')
        axes[0, 1].axis('off')
        
        axes[0, 2].imshow(edges, cmap='gray')
        axes[0, 2].set_title('Canny Edges')
        axes[0, 2].axis('off')
        
        axes[1, 0].imshow(edges_dilated, cmap='gray')
        axes[1, 0].set_title('Dilated Edges')
        axes[1, 0].axis('off')
        
        axes[1, 1].imshow(result)
        axes[1, 1].set_title('Inpainted Result')
        axes[1, 1].axis('off')
        
        # Show comparison
        comparison = np.hstack([image, result])
        axes[1, 2].imshow(comparison)
        axes[1, 2].set_title('Before vs After')
        axes[1, 2].axis('off')
        
        plt.tight_layout()
        plt.show()
        
        # Save result if output path is provided
        if output_path:
            Image.fromarray(result).save(output_path)
            print(f"Result saved to {output_path}")
        
        return result, mask, edges

print("SimpleEdgeConnect class defined successfully!")

## 4. Advanced EdgeConnect with Neural Networks

Let's also implement a more advanced version that tries to use the actual EdgeConnect architecture if models are available.

In [None]:
def check_pretrained_models():
    """Check if pre-trained models are available"""
    model_paths = {
        'celeba': './checkpoints/celeba',
        'places2': './checkpoints/places2',
        'paris': './checkpoints/paris'
    }
    
    available_models = []
    for name, path in model_paths.items():
        if os.path.exists(path):
            available_models.append(name)
    
    return available_models

def try_load_edge_connect():
    """Try to load the actual EdgeConnect model"""
    try:
        # Check for available models
        available_models = check_pretrained_models()
        print(f"Available pre-trained models: {available_models}")
        
        if available_models:
            # Try to initialize EdgeConnect with the first available model
            model_path = f"./checkpoints/{available_models[0]}"
            
            # Create config file for the model
            config_path = os.path.join(model_path, 'config.yml')
            if not os.path.exists(config_path):
                # Create a basic config
                with open(config_path, 'w') as f:
                    yaml.dump(config, f)
            
            # Try to load the model
            model = EdgeConnect(config)
            model.load()
            return model, True
        else:
            print("No pre-trained models found. Using simple implementation.")
            return None, False
    except Exception as e:
        print(f"Error loading EdgeConnect: {e}")
        print("Falling back to simple implementation.")
        return None, False

# Try to load the actual EdgeConnect model
edge_connect_model, model_loaded = try_load_edge_connect()

if model_loaded:
    print("EdgeConnect model loaded successfully!")
else:
    print("Using simple EdgeConnect implementation.")

## 5. Create Sample Input Image

Let's create a sample image with a white mask for testing if you don't have the original image.

In [None]:
def create_sample_face_with_mask():
    """Create a sample face image with a white nose mask for testing"""
    # Create a simple face-like image
    size = 256
    image = np.ones((size, size, 3), dtype=np.uint8) * 200  # Light background
    
    # Add some facial features
    center_x, center_y = size // 2, size // 2
    
    # Face outline (ellipse)
    cv2.ellipse(image, (center_x, center_y), (80, 100), 0, 0, 360, (180, 150, 120), -1)
    
    # Eyes
    cv2.circle(image, (center_x - 25, center_y - 20), 8, (50, 50, 50), -1)
    cv2.circle(image, (center_x + 25, center_y - 20), 8, (50, 50, 50), -1)
    
    # Mouth
    cv2.ellipse(image, (center_x, center_y + 30), (15, 8), 0, 0, 180, (100, 50, 50), 2)
    
    # Add white nose mask
    nose_points = np.array([
        [center_x - 10, center_y - 5],
        [center_x + 10, center_y - 5],
        [center_x + 8, center_y + 15],
        [center_x - 8, center_y + 15]
    ], np.int32)
    
    cv2.fillPoly(image, [nose_points], (255, 255, 255))
    
    return image

# Create and save sample image
sample_image = create_sample_face_with_mask()
sample_path = './input_images/sample_face_with_mask.jpg'
Image.fromarray(sample_image).save(sample_path)

# Display the sample image
plt.figure(figsize=(6, 6))
plt.imshow(sample_image)
plt.title('Sample Face with White Nose Mask')
plt.axis('off')
plt.show()

print(f"Sample image created and saved to {sample_path}")

## 6. Test Face Reconstruction

In [None]:
# Initialize the simple EdgeConnect model
simple_model = SimpleEdgeConnect(config)

# Test with the sample image
print("Testing face reconstruction with sample image...")
result, mask, edges = simple_model.process_image(
    sample_path, 
    output_path='./output_reconstructed.jpg'
)

print("Face reconstruction completed!")

## 7. Process Your Own Image

To process your own image with a white mask, save it to the input_images folder and update the path below.

In [None]:
# Process your own image
# Replace 'your_image.jpg' with the actual filename of your image
your_image_path = './input_images/your_image.jpg'

if os.path.exists(your_image_path):
    print("Processing your image...")
    result, mask, edges = simple_model.process_image(
        your_image_path, 
        output_path='./your_image_reconstructed.jpg'
    )
    print("Your image reconstruction completed!")
else:
    print(f"Image not found at {your_image_path}")
    print("Please save your image to the input_images folder and update the path above.")

## 8. Advanced Configuration and Tips

Here are some tips for better results:

In [None]:
def advanced_inpainting(image_path, mask_threshold=240, edge_sigma=2):
    """Advanced inpainting with customizable parameters"""
    # Load image
    image = load_image(image_path, size=(256, 256))
    
    # Create mask with custom threshold
    mask = create_mask_from_white_regions(image, threshold=mask_threshold)
    
    # Apply different inpainting methods
    methods = {
        'Telea': cv2.INPAINT_TELEA,
        'Fast Marching': cv2.INPAINT_NS
    }
    
    results = {}
    for name, method in methods.items():
        result = cv2.inpaint(image, mask, 3, method)
        results[name] = result
    
    # Display comparison
    fig, axes = plt.subplots(2, 2, figsize=(12, 10))
    
    axes[0, 0].imshow(image)
    axes[0, 0].set_title('Original with Mask')
    axes[0, 0].axis('off')
    
    axes[0, 1].imshow(mask, cmap='gray')
    axes[0, 1].set_title('Detected Mask')
    axes[0, 1].axis('off')
    
    axes[1, 0].imshow(results['Telea'])
    axes[1, 0].set_title('Telea Method')
    axes[1, 0].axis('off')
    
    axes[1, 1].imshow(results['Fast Marching'])
    axes[1, 1].set_title('Fast Marching Method')
    axes[1, 1].axis('off')
    
    plt.tight_layout()
    plt.show()
    
    return results

# Test advanced inpainting with different parameters
print("Testing advanced inpainting methods...")
advanced_results = advanced_inpainting(sample_path)
print("Advanced inpainting completed!")

## 9. Summary and Next Steps

This notebook demonstrates:

1. **Basic EdgeConnect Implementation**: A simplified version using traditional inpainting with edge guidance
2. **Mask Detection**: Automatic detection of white regions to create inpainting masks
3. **Edge-Guided Inpainting**: Using Canny edge detection to guide the inpainting process
4. **Multiple Inpainting Methods**: Comparison of different OpenCV inpainting algorithms

### For Better Results:

1. **Download Pre-trained Models**: Use the official EdgeConnect pre-trained models for state-of-the-art results
2. **Fine-tune Parameters**: Adjust edge detection and inpainting parameters based on your specific images
3. **Use GPU**: Enable GPU acceleration for faster processing with the full EdgeConnect model
4. **Custom Training**: Train the model on your specific dataset for domain-specific improvements

### Usage:

1. Save your image with white mask to `./input_images/`
2. Update the path in the "Process Your Own Image" section
3. Run the cells to see the reconstruction results

The results will show the original image, detected mask, edge information, and the final reconstructed image.