In [1]:
"""
Paper-Accurate Style Transfer Implementation
Uses exact parameters from the paper:
- r = 0.8 for robust fusion
- gaps = [28, 18, 8, 5] for subsampling
- IRLS for patch aggregation
"""

import numpy as np
import cv2
from scipy.ndimage import gaussian_filter
from sklearn.neighbors import NearestNeighbors
from typing import Tuple, Optional
import warnings
warnings.filterwarnings('ignore')


In [2]:
class PaperAccurateStyleTransfer:
    """
    Implements style transfer with exact paper parameters.
    
    Key parameters from paper:
    - patch_sizes = [33, 21, 13, 9]
    - gaps = [28, 18, 8, 5] (subsampling stride)
    - r = 0.8 (robust norm for IRLS)
    - IRLS iterations = 10
    - EM iterations per patch = 3
    - Pyramid levels = 3
    """
    
    def __init__(self,
                 patch_sizes=(33, 21, 13, 9),
                 gaps=(28, 18, 8, 5),
                 r_robust=0.8,
                 irls_iterations=10,
                 em_iterations_per_patch=3,
                 num_levels=3):
        """
        Initialize with exact paper parameters.
        
        Args:
            patch_sizes: Patch sizes from large to small
            gaps: Subsampling gaps (stride) for each patch size
            r_robust: Robust norm parameter (0.8 in paper)
            irls_iterations: IRLS inner iterations (10 in paper)
            em_iterations_per_patch: EM iterations per patch size (3 in paper)
            num_levels: Pyramid levels (3 in paper)
        """
        self.patch_sizes = patch_sizes
        self.gaps = gaps
        self.r = r_robust
        self.irls_iterations = irls_iterations
        self.em_iterations = em_iterations_per_patch
        self.num_levels = num_levels
        
        assert len(patch_sizes) == len(gaps), "Must have one gap per patch size"
    
    def load_image(self, path):
        """Load and normalize image"""
        img = cv2.imread(path)
        if img is None:
            raise ValueError(f"Could not load: {path}")
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        return img.astype(np.float32) / 255.0
    
    def save_image(self, path, img):
        """Save image"""
        img = np.clip(img * 255, 0, 255).astype(np.uint8)
        img_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
        cv2.imwrite(path, img_bgr)
    
    def build_pyramid(self, image, num_levels):
        """Build Gaussian pyramid"""
        pyramid = [image]
        for _ in range(num_levels - 1):
            image = cv2.pyrDown(image)
            pyramid.append(image)
        return pyramid[::-1]  # Coarse to fine
    
    def color_transfer(self, content, style):
        """Histogram matching color transfer"""
        matched = np.zeros_like(content)
        
        for c in range(3):
            src_channel = content[:, :, c].flatten()
            tmp_channel = style[:, :, c].flatten()
            
            src_values, src_counts = np.unique(src_channel, return_counts=True)
            tmp_values, tmp_counts = np.unique(tmp_channel, return_counts=True)
            
            src_quantiles = np.cumsum(src_counts).astype(np.float64)
            src_quantiles /= src_quantiles[-1]
            
            tmp_quantiles = np.cumsum(tmp_counts).astype(np.float64)
            tmp_quantiles /= tmp_quantiles[-1]
            
            interp_tmp_values = np.interp(src_quantiles, tmp_quantiles, tmp_values)
            matched_channel = np.interp(src_channel, src_values, interp_tmp_values)
            matched[:, :, c] = matched_channel.reshape(content.shape[:2])
        
        return matched
    
    def extract_patches_with_gap(self, image, patch_size, gap):
        """
        Extract patches with specified gap (subsampling).
        
        Args:
            image: Input image
            patch_size: Size of square patches
            gap: Stride between patches (subsampling)
            
        Returns:
            patches: Flattened patches
            positions: (i, j) positions of top-left corners
        """
        h, w = image.shape[:2]
        patches = []
        positions = []
        
        # Extract on grid with spacing = gap
        for i in range(0, h - patch_size + 1, gap):
            for j in range(0, w - patch_size + 1, gap):
                patch = image[i:i+patch_size, j:j+patch_size]
                patches.append(patch.flatten())
                positions.append((i, j))
        
        return np.array(patches), positions
    
    def find_nearest_neighbors(self, content_patches, style_patches):
        """Find nearest neighbor in style for each content patch"""
        nbrs = NearestNeighbors(n_neighbors=1, algorithm='ball_tree').fit(style_patches)
        distances, indices = nbrs.kneighbors(content_patches)
        return indices.flatten()
    
    def robust_aggregate_IRLS(self, matched_patches, positions, output_shape, 
                             patch_size, content=None, W=None):
        """
        IRLS patch aggregation with robust norm r=0.8.
        
        This implements Equations (8-11) from the paper.
        
        Args:
            matched_patches: Matched style patches (flattened)
            positions: Patch positions
            output_shape: (h, w, c)
            patch_size: Size of patches
            content: Content image for blending
            W: Segmentation weight map
        """
        h, w, c = output_shape
        
        # Initialize X
        if content is not None:
            X = content.copy()
        else:
            X = np.random.rand(h, w, c).astype(np.float32) * 0.1
        
        # IRLS iterations (Equation 9-11)
        for irls_iter in range(self.irls_iterations):
            # Compute weights w_ij = ||R_ij X - z_ij||^(r-2)
            weights = []
            
            for patch_flat, (i, j) in zip(matched_patches, positions):
                patch = patch_flat.reshape(patch_size, patch_size, c)
                
                h_end = min(i + patch_size, h)
                w_end = min(j + patch_size, w)
                patch_h = h_end - i
                patch_w = w_end - j
                
                # Current region in X
                current_region = X[i:h_end, j:w_end]
                
                # Compute error
                error = np.linalg.norm(current_region - patch[:patch_h, :patch_w])
                
                # Weight: w_ij = error^(r-2)
                # For r=0.8: w_ij = error^(-1.2)
                weight = np.power(error + 1e-8, self.r - 2)
                weights.append(weight)
            
            weights = np.array(weights)
            
            # Weighted aggregation (Equation 11)
            numerator = np.zeros((h, w, c), dtype=np.float32)
            denominator = np.zeros((h, w, 1), dtype=np.float32)
            
            for weight, patch_flat, (i, j) in zip(weights, matched_patches, positions):
                patch = patch_flat.reshape(patch_size, patch_size, c)
                
                h_end = min(i + patch_size, h)
                w_end = min(j + patch_size, w)
                patch_h = h_end - i
                patch_w = w_end - j
                
                numerator[i:h_end, j:w_end] += weight * patch[:patch_h, :patch_w]
                denominator[i:h_end, j:w_end] += weight
            
            denominator = np.maximum(denominator, 1e-8)
            X_tilde = numerator / denominator
            
            # Content fusion (Equation 15)
            if content is not None and W is not None:
                # Ensure W has right shape
                if W.ndim == 2:
                    W = W[:, :, np.newaxis]
                if W.shape[:2] != (h, w):
                    W = cv2.resize(W, (w, h))
                    W = W[:, :, np.newaxis] if W.ndim == 2 else W
                
                # X = (W + I)^(-1) (X_tilde + W*C)
                # Simplified: X = (X_tilde + W*content) / (1 + W)
                X = (X_tilde + W * content) / (1 + W)
            else:
                X = X_tilde
            
            X = np.clip(X, 0, 1)
        
        return X
    
    def create_edge_segmentation(self, content):
        """Create simple edge-based segmentation mask"""
        gray = cv2.cvtColor((content * 255).astype(np.uint8), cv2.COLOR_RGB2GRAY)
        edges = cv2.Canny(gray, 50, 150)
        kernel = np.ones((11, 11), np.uint8)
        mask = cv2.dilate(edges, kernel, iterations=1)
        mask = mask.astype(np.float32) / 255.0
        mask = gaussian_filter(mask, sigma=2.0)
        return mask
    
    def process_with_patch_size(self, content_level, style_level, patch_size, gap,
                               prev_result=None, W=None):
        """
        Process one pyramid level with one patch size.
        
        Args:
            content_level: Content at this pyramid level
            style_level: Style at this pyramid level  
            patch_size: Size of patches to use
            gap: Subsampling gap for this patch size
            prev_result: Previous result to build on
            W: Segmentation weight map
        """
        h, w = content_level.shape[:2]
        
        print(f"      Patch {patch_size}×{patch_size}, gap={gap}, overlap={(patch_size-gap)/patch_size*100:.0f}%")
        
        # Initialize
        if prev_result is None:
            # Paper initialization: content + strong noise (σ=50)
            X = content_level + np.random.randn(*content_level.shape) * (50.0/255.0)
            X = np.clip(X, 0, 1).astype(np.float32)
        else:
            # Upsample previous result
            if prev_result.shape[:2] != (h, w):
                X = cv2.resize(prev_result, (w, h))
            else:
                X = prev_result.copy()
        
        # Extract all style patches once
        style_patches, _ = self.extract_patches_with_gap(style_level, patch_size, gap)
        print(f"        Extracted {len(style_patches)} style patches")
        
        # EM iterations (3 per patch size in paper)
        for em_iter in range(self.em_iterations):
            print(f"        EM iteration {em_iter + 1}/{self.em_iterations}...")
            
            # E-step: Extract patches from current result and find matches
            X_patches, positions = self.extract_patches_with_gap(X, patch_size, gap)
            
            # Find nearest neighbors
            nn_indices = self.find_nearest_neighbors(X_patches, style_patches)
            matched_patches = style_patches[nn_indices]
            
            # M-step: IRLS aggregation with r=0.8
            X = self.robust_aggregate_IRLS(
                matched_patches, positions, (h, w, 3),
                patch_size, content_level, W
            )
            
            # Color transfer (after each EM iteration)
            X = self.color_transfer(X, style_level)
            
            # Denoise (Domain Transform in paper, we use Gaussian)
            X = gaussian_filter(X, sigma=0.5)
            X = np.clip(X, 0, 1)
        
        return X
    
    def transfer(self, content_path, style_path, output_path,
                use_segmentation=True, max_size=400):
        """
        Main style transfer with exact paper parameters.
        """
        print("="*70)
        print("PAPER-ACCURATE STYLE TRANSFER")
        print(f"Parameters: r={self.r}, gaps={self.gaps}")
        print("="*70)
        
        # Load images
        print("\n1. Loading images...")
        content = self.load_image(content_path)
        style = self.load_image(style_path)
        
        # Resize to max_size (paper uses 400×400)
        # h, w = content.shape[:2]
        # if max(h, w) > max_size:
        #     scale = max_size / max(h, w)
        #     new_h, new_w = int(h * scale), int(w * scale)
        #     content = cv2.resize(content, (new_w, new_h))
        #     print(f"   Content resized to: {content.shape}")
        
        # h, w = style.shape[:2]
        # if max(h, w) > max_size:
        #     scale = max_size / max(h, w)
        #     new_h, new_w = int(h * scale), int(w * scale)
        #     style = cv2.resize(style, (new_w, new_h))
        #     print(f"   Style resized to: {style.shape}")
        content = cv2.resize(content, (max_size, max_size), interpolation=cv2.INTER_AREA)
        style = cv2.resize(style, (max_size, max_size), interpolation=cv2.INTER_AREA)
        
        # Create segmentation mask
        W = None
        if use_segmentation:
            print("\n2. Creating segmentation mask...")
            W = self.create_edge_segmentation(content)
        
        # Initial color transfer
        print("\n3. Initial color transfer...")
        content_colored = self.color_transfer(content, style)
        
        # Build pyramids
        print(f"\n4. Building pyramids ({self.num_levels} levels)...")
        content_pyramid = self.build_pyramid(content_colored, self.num_levels)
        style_pyramid = self.build_pyramid(style, self.num_levels)
        
        for i, (c, s) in enumerate(zip(content_pyramid, style_pyramid)):
            print(f"   Level {i}: content {c.shape}, style {s.shape}")
        
        # Process pyramid levels
        result = None
        
        for level in range(self.num_levels):
            print(f"\n5. PYRAMID LEVEL {level + 1}/{self.num_levels}")
            content_level = content_pyramid[level]
            style_level = style_pyramid[level]
            
            # Resize segmentation for this level
            W_level = None
            if W is not None:
                h_level, w_level = content_level.shape[:2]
                W_level = cv2.resize(W, (w_level, h_level))
            
            # Process each patch size with its corresponding gap
            for patch_size, gap in zip(self.patch_sizes, self.gaps):
                # Skip if patch too large for this level
                min_dim = min(content_level.shape[:2])
                if patch_size >= min_dim:
                    print(f"    Skipping patch {patch_size} (too large for {min_dim})")
                    continue
                
                result = self.process_with_patch_size(
                    content_level, style_level,
                    patch_size, gap,
                    prev_result=result,
                    W=W_level
                )
        
        # Final color transfer
        print("\n6. Final color transfer...")
        result = self.color_transfer(result, style)
        
        # Save
        print(f"\n7. Saving to {output_path}...")
        self.save_image(output_path, result)
        
        print("\n" + "="*70)
        print("COMPLETE!")
        print("="*70)
        
        return result

In [8]:
def main():
     
    patch_sizes=(33, 21, 13, 9)
    gaps=(28,18,8,5)
    
    content_path = "./Data/content/house3.jpg"
    style_path = "./Data/style/starry_night.jpg"
    output_path = f"V2_results/res{patch_sizes} - {gaps}.jpg"
    
    max_size = 400
    
    # Create transfer object with exact paper parameters
    st = PaperAccurateStyleTransfer(
        patch_sizes ,
        gaps,
        r_robust=0.8,
        irls_iterations=3,
        em_iterations_per_patch=3,
        num_levels=3
    )
    
    # Run transfer
    result = st.transfer(content_path, style_path, output_path, max_size=max_size)


main()

PAPER-ACCURATE STYLE TRANSFER
Parameters: r=0.8, gaps=(28, 18, 8, 5)

1. Loading images...

2. Creating segmentation mask...

3. Initial color transfer...

4. Building pyramids (3 levels)...
   Level 0: content (100, 100, 3), style (100, 100, 3)
   Level 1: content (200, 200, 3), style (200, 200, 3)
   Level 2: content (400, 400, 3), style (400, 400, 3)

5. PYRAMID LEVEL 1/3
      Patch 33×33, gap=28, overlap=15%
        Extracted 9 style patches
        EM iteration 1/3...
        EM iteration 2/3...
        EM iteration 3/3...
      Patch 21×21, gap=18, overlap=14%
        Extracted 25 style patches
        EM iteration 1/3...
        EM iteration 2/3...
        EM iteration 3/3...
      Patch 13×13, gap=8, overlap=38%
        Extracted 121 style patches
        EM iteration 1/3...
        EM iteration 2/3...
        EM iteration 3/3...
      Patch 9×9, gap=5, overlap=44%
        Extracted 361 style patches
        EM iteration 1/3...
        EM iteration 2/3...
        EM iteration 