In [1]:
#| default_exp preprocessing.pt_patching

# Patch Whole imageto number of patches
> Patch whole image into number of patches

In [1]:
#| export
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

In [2]:
from fastcore.all import *

In [3]:
#| export
class ImageToPatchLayer(nn.Module):
    """Custom layer to convert images into patches.
    Maintains ONNX compatibility by using standard torch operations."""
    
    def __init__(self, patch_size=256, stride=None, padding_mode='reflect'):
        super().__init__()
        self.patch_size = patch_size
        self.stride = stride if stride is not None else patch_size
        self.padding_mode = padding_mode
        
    def forward(self, x):
        """
        Args:
            x: Input tensor of shape (B, C, H, W)
        Returns:
            Patches tensor of shape (B, N, C, patch_size, patch_size)
            where N is the number of patches
        """
        B, C, H, W = x.shape
        
        # Calculate padding
        pad_h = (self.patch_size - H % self.patch_size) % self.patch_size
        pad_w = (self.patch_size - W % self.patch_size) % self.patch_size
        
        # Apply padding if needed
        if pad_h > 0 or pad_w > 0:
            x = F.pad(x, (0, pad_w, 0, pad_h), mode=self.padding_mode)
        
        # Unfold into patches using standard torch operations
        patches = F.unfold(x, 
                          kernel_size=(self.patch_size, self.patch_size),
                          stride=self.stride)
        
        # Reshape to (B, N, C, patch_size, patch_size)
        patches = patches.view(B, C, self.patch_size, self.patch_size, -1)
        patches = patches.permute(0, 4, 1, 2, 3)
        
        return patches

In [4]:
#| export
class PatchToImageLayer(nn.Module):
    """Custom layer to reconstruct image from patches.
    Maintains ONNX compatibility by using standard torch operations."""
    
    def __init__(self, output_size=None, patch_size=256, stride=None):
        super().__init__()
        self.output_size = output_size
        self.patch_size = patch_size
        self.stride = stride if stride is not None else patch_size
        
    def forward(self, x, original_size=None):
        """
        Args:
            x: Input tensor of patches (B, N, C, patch_size, patch_size)
            original_size: Optional tuple of (H, W) for output size
        Returns:
            Reconstructed image tensor of shape (B, C, H, W)
        """
        B, N, C, H, W = x.shape
        
        # Reshape patches for fold operation
        x = x.permute(0, 2, 3, 4, 1)
        x = x.reshape(B, C * H * W, N)
        
        # Calculate output size
        if original_size is not None:
            output_h, output_w = original_size
        elif self.output_size is not None:
            output_h, output_w = self.output_size
        else:
            # Calculate based on number of patches and stride
            output_h = int(math.sqrt(N)) * self.stride
            output_w = output_h
        
        # Use fold operation to reconstruct image
        output = F.fold(x,
                       output_size=(output_h, output_w),
                       kernel_size=(self.patch_size, self.patch_size),
                       stride=self.stride)
        
        return output

In [12]:
input_image = torch.randn(1, 1, 1152, 1632)  
patch_layer = ImageToPatchLayer(patch_size=256, stride=None, padding_mode='reflect')
patches = patch_layer(input_image)
print(patches.shape)
patch_to_image_layer = PatchToImageLayer(output_size=(1152, 1632), patch_size=256, stride=None)
reconstructed_image = patch_to_image_layer(patches)
print(reconstructed_image.shape)



torch.Size([1, 35, 1, 256, 256])


RuntimeError: Given output_size=(1152, 1632), kernel_size=(256, 256), dilation=(1, 1), padding=(0, 0), stride=(256, 256), expected size of input's dimension 2 to match the calculated number of sliding blocks 4 * 6 = 24, but got input.size(2)=35.

In [25]:
class RobustImageToPatchLayer(nn.Module):
    """
    Robust implementation of image to patch conversion that handles arbitrary image sizes.
    Ensures ONNX compatibility and precise patch creation.
    """
    def __init__(self, patch_size=256, stride=None, padding_mode='reflect'):
        super().__init__()
        self.patch_size = patch_size
        self.stride = stride if stride is not None else patch_size
        self.padding_mode = padding_mode

    def calculate_padding(self, height, width):
        """Calculate required padding for arbitrary image sizes."""
        if self.stride is None:
            self.stride = self.patch_size
            
        # Calculate how many patches we'll need
        n_patches_h = math.ceil((height - self.patch_size) / self.stride) + 1
        n_patches_w = math.ceil((width - self.patch_size) / self.stride) + 1
        
        # Calculate required image size to fit these patches
        required_h = (n_patches_h - 1) * self.stride + self.patch_size
        required_w = (n_patches_w - 1) * self.stride + self.patch_size
        
        # Calculate padding on each side
        pad_h = max(required_h - height, 0)
        pad_w = max(required_w - width, 0)
        
        # Split padding for both sides
        pad_h_before = pad_h // 2
        pad_h_after = pad_h - pad_h_before
        pad_w_before = pad_w // 2
        pad_w_after = pad_w - pad_w_before
        
        return pad_h_before, pad_h_after, pad_w_before, pad_w_after, n_patches_h, n_patches_w

    def forward(self, x):
        """
        Args:
            x: Input tensor of shape (B, C, H, W)
        Returns:
            Tuple of (patches, padding_info)
            - patches: tensor of shape (B, N, C, patch_size, patch_size)
            - padding_info: tuple containing padding information for reconstruction
        """
        B, C, H, W = x.shape
        
        # Calculate padding
        pad_h_before, pad_h_after, pad_w_before, pad_w_after, n_patches_h, n_patches_w = \
        	self.calculate_padding(H, W)
            
        # Store original size and padding for reconstruction
        padding_info = (H, W, pad_h_before, pad_h_after, pad_w_before, pad_w_after)
        
        # Apply padding
        x = F.pad(x, (pad_w_before, pad_w_after, pad_h_before, pad_h_after),
                	mode=self.padding_mode)
        
    	# Extract patches using unfold
        patches = x.unfold(2, self.patch_size, self.stride)\
                    .unfold(3, self.patch_size, self.stride)
        
    	# Reshape to (B, N, C, patch_size, patch_size)
        patches = patches.permute(0, 2, 3, 1, 4, 5)\
                        .reshape(B, -1, C, self.patch_size, self.patch_size)
        
        return patches, padding_info

In [26]:
robust_patch_layer = RobustImageToPatchLayer(patch_size=256, stride=None, padding_mode='reflect')
patches, padding_info = robust_patch_layer(input_image)
print(patches.shape)
print(padding_info)


torch.Size([1, 35, 1, 256, 256])
(1152, 1632, 64, 64, 80, 80)


In [27]:
class RobustPatchToImageLayer(nn.Module):
    """
    Robust implementation of patch to image reconstruction that handles arbitrary image sizes.
    Ensures ONNX compatibility and precise reconstruction.
    """
    def __init__(self, patch_size=256, stride=None):
        super().__init__()
        self.patch_size = patch_size
        self.stride = stride if stride is not None else patch_size
        
    def forward(self, patches, padding_info):
        """
        Args:
            patches: Tensor of shape (B, N, C, patch_size, patch_size)
            padding_info: Tuple containing (orig_h, orig_w, pad_h_before, pad_h_after, pad_w_before, pad_w_after)
        Returns:
            Reconstructed image tensor of shape (B, C, H, W)
        """
        orig_h, orig_w, pad_h_before, pad_h_after, pad_w_before, pad_w_after = padding_info
        B, N, C, H_patch, W_patch = patches.shape
        
        # Calculate padded image size
        H = orig_h + pad_h_before + pad_h_after
        W = orig_w + pad_w_before + pad_w_after
        
        # Calculate number of patches in each dimension
        n_patches_h = (H - self.patch_size) // self.stride + 1
        n_patches_w = (W - self.patch_size) // self.stride + 1
        
        # Reshape patches for reconstruction
        patches = patches.reshape(B, n_patches_h, n_patches_w, C, H_patch, W_patch)
        patches = patches.permute(0, 3, 1, 4, 2, 5)
        
        # Initialize output tensor
        output = torch.zeros(B, C, H, W, device=patches.device)
        count = torch.zeros(B, C, H, W, device=patches.device)
        
        # Reconstruct image
        for i in range(n_patches_h):
            for j in range(n_patches_w):
                h_start = i * self.stride
                w_start = j * self.stride
                output[:, :, h_start:h_start + H_patch, 
                      w_start:w_start + W_patch] += patches[:, :, i, :, j, :]
                count[:, :, h_start:h_start + H_patch,
                      w_start:w_start + W_patch] += 1
        
        # Average overlapping regions
        output = output / (count + 1e-8)
        
        # Remove padding to get back to original size
        output = output[:, :, 
                       pad_h_before:H - pad_h_after,
                       pad_w_before:W - pad_w_after]
        
        return output


In [28]:
robust_patch_to_image_layer = RobustPatchToImageLayer(patch_size=256, stride=None)
reconstructed_image = robust_patch_to_image_layer(patches, padding_info)
print(reconstructed_image.shape)


torch.Size([1, 1, 1152, 1632])


In [29]:
#| export
class RobustPatchProcessingNetwork(nn.Module):
    """Complete network with robust patch processing."""
    def __init__(self, base_model, patch_size=256, stride=None):
        super().__init__()
        self.patch_maker = RobustImageToPatchLayer(patch_size, stride)
        self.base_model = base_model
        self.patch_merger = RobustPatchToImageLayer(patch_size, stride)
    
    def forward(self, x):
        # Convert to patches
        patches, padding_info = self.patch_maker(x)
        
        # Process patches
        B, N = patches.shape[:2]
        patches = patches.reshape(B * N, *patches.shape[2:])
        processed_patches = self.base_model(patches)
        processed_patches = processed_patches.reshape(B, N, *processed_patches.shape[1:])
        
        # Reconstruct image
        output = self.patch_merger(processed_patches, padding_info)
        
        return output

In [30]:
#| export
class RobustPatchProcessingNetwork(nn.Module):
    """Complete network with robust patch processing."""
    def __init__(self, base_model, patch_size=256, stride=None):
        super().__init__()
        self.patch_maker = RobustImageToPatchLayer(patch_size, stride)
        self.base_model = base_model
        self.patch_merger = RobustPatchToImageLayer(patch_size, stride)
    
    def forward(self, x):
        # Convert to patches
        patches, padding_info = self.patch_maker(x)
        
        # Process patches
        B, N = patches.shape[:2]
        patches = patches.reshape(B * N, *patches.shape[2:])
        processed_patches = self.base_model(patches)
        processed_patches = processed_patches.reshape(B, N, *processed_patches.shape[1:])
        
        # Reconstruct image
        output = self.patch_merger(processed_patches, padding_info)
        
        return output

In [None]:
#| export
class OptimizedImageToPatchLayer(nn.Module):
    """
    Optimized patch conversion for 1152x1632 images with edge effect handling
    """
    def __init__(self, patch_size=256, overlap=32, input_size=(1152, 1632)):
        super().__init__()
        self.patch_size = patch_size
        self.overlap = overlap
        self.input_size = input_size
        
        # Pre-compute optimal grid for given input size
        self.stride = patch_size - overlap
        self.grid_h = (input_size[0] - overlap) // (patch_size - overlap)
        self.grid_w = (input_size[1] - overlap) // (patch_size - overlap)
        
        # Calculate exact padding needed
        self.total_h = (self.grid_h - 1) * (patch_size - overlap) + patch_size
        self.total_w = (self.grid_w - 1) * (patch_size - overlap) + patch_size
        
        self.pad_h = max(0, self.total_h - input_size[0])
        self.pad_w = max(0, self.total_w - input_size[1])
        
        # Create gaussian weight mask for edge effect reduction
        self.register_buffer('weight_mask', self._create_weight_mask())
        

In [3]:
class DynamicPatchOptimizer:
    """
    Utility class to calculate optimal patch configuration for any input size
    """
    @staticmethod
    def calculate_optimal_grid(image_size, patch_size, min_overlap=32):
        """
        Calculate optimal grid configuration for any image size
        Returns optimal stride and grid dimensions
        """
        H, W = image_size
        
        # Calculate number of patches needed with minimum overlap
        n_patches_h = math.ceil((H - min_overlap) / (patch_size - min_overlap))
        n_patches_w = math.ceil((W - min_overlap) / (patch_size - min_overlap))
        
        # Calculate optimal stride to evenly distribute patches
        stride_h = math.floor((H - patch_size) / (n_patches_h - 1)) if n_patches_h > 1 else H - patch_size
        stride_w = math.floor((W - patch_size) / (n_patches_w - 1)) if n_patches_w > 1 else W - patch_size
        
        # Calculate actual overlap
        overlap_h = patch_size - stride_h
        overlap_w = patch_size - stride_w
        
        return {
            'grid_h': n_patches_h,
            'grid_w': n_patches_w,
            'stride_h': stride_h,
            'stride_w': stride_w,
            'overlap_h': overlap_h,
            'overlap_w': overlap_w
        }

In [4]:
class GeneralizedPatchLayer(nn.Module):
    """
    Generalized patch conversion layer that works with any image size
    """
    def __init__(self, patch_size=256, min_overlap=32):
        super().__init__()
        self.patch_size = patch_size
        self.min_overlap = min_overlap
        
        # Create gaussian weight mask
        self.register_buffer('weight_mask', self._create_weight_mask())
        
    def _create_weight_mask(self):
        """Creates gaussian weight mask for edge effect reduction"""
        x = torch.linspace(-1, 1, self.patch_size)
        y = torch.linspace(-1, 1, self.patch_size)
        xx, yy = torch.meshgrid(x, y, indexing='ij')
        # Use smoother gaussian for better blending
        gaussian = torch.exp(-(xx**2 + yy**2) / 1.5)
        return gaussian
    
    def _calculate_padding(self, H, W, grid_config):
        """Calculate required padding for the input size"""
        total_h = (grid_config['grid_h'] - 1) * grid_config['stride_h'] + self.patch_size
        total_w = (grid_config['grid_w'] - 1) * grid_config['stride_w'] + self.patch_size
        
        pad_h = max(0, total_h - H)
        pad_w = max(0, total_w - W)
        
        return pad_h, pad_w

    def visualize_patch_grid(self, image_tensor, grid_config):
        """Visualizes patch grid overlay on image"""
        image = image_tensor[0].cpu().numpy().transpose(1, 2, 0)
        if image.shape[2] == 1:
            image = image.squeeze(-1)
            
        plt.figure(figsize=(15, 10))
        plt.imshow(image, cmap='gray' if len(image.shape) == 2 else None)
        
        # Draw patch grid with actual strides
        for i in range(grid_config['grid_h']):
            y = i * grid_config['stride_h']
            plt.axhline(y=y, color='r', linestyle='--', alpha=0.5)
            plt.axhline(y=y + self.patch_size, color='g', linestyle=':', alpha=0.3)
            
        for j in range(grid_config['grid_w']):
            x = j * grid_config['stride_w']
            plt.axvline(x=x, color='r', linestyle='--', alpha=0.5)
            plt.axvline(x=x + self.patch_size, color='g', linestyle=':', alpha=0.3)
            
        plt.title(f'Patch Grid ({grid_config["grid_h"]}x{grid_config["grid_w"]} patches)\n'
                 f'Patch Size: {self.patch_size}, '
                 f'Overlap H: {grid_config["overlap_h"]:.1f}, '
                 f'Overlap W: {grid_config["overlap_w"]:.1f}')
        plt.show()
        
    def forward(self, x):
        B, C, H, W = x.shape
        
        # Calculate optimal grid configuration
        grid_config = DynamicPatchOptimizer.calculate_optimal_grid(
            (H, W), self.patch_size, self.min_overlap)
        
        # Calculate padding
        pad_h, pad_w = self._calculate_padding(H, W, grid_config)
        pad_h_before, pad_h_after = pad_h // 2, pad_h - pad_h // 2
        pad_w_before, pad_w_after = pad_w // 2, pad_w - pad_w // 2
        
        x_padded = F.pad(x, (pad_w_before, pad_w_after, pad_h_before, pad_h_after),
                        mode='reflect')
        
        # Extract patches
        patches = []
        for i in range(grid_config['grid_h']):
            for j in range(grid_config['grid_w']):
                h_start = i * grid_config['stride_h']
                w_start = j * grid_config['stride_w']
                patch = x_padded[:, :,
                               h_start:h_start + self.patch_size,
                               w_start:w_start + self.patch_size]
                patches.append(patch)
                
        patches = torch.stack(patches, dim=1)
        
        # Apply weight mask for edge effect reduction
        patches = patches * self.weight_mask.view(1, 1, 1, self.patch_size, self.patch_size)
        
        return patches, (grid_config, (pad_h_before, pad_h_after, pad_w_before, pad_w_after))

In [5]:
class GeneralizedPatchMerger(nn.Module):
    """
    Generalized patch merging layer that works with any image size
    """
    def __init__(self, patch_size=256):
        super().__init__()
        self.patch_size = patch_size

    def forward(self, patches, info):
        grid_config, padding_info = info
        B, N, C, H_patch, W_patch = patches.shape
        pad_h_before, pad_h_after, pad_w_before, pad_w_after = padding_info
        
        # Calculate padded size
        H = (grid_config['grid_h'] - 1) * grid_config['stride_h'] + self.patch_size
        W = (grid_config['grid_w'] - 1) * grid_config['stride_w'] + self.patch_size
        
        # Initialize output and weight accumulator
        output = torch.zeros((B, C, H, W), device=patches.device)
        weights = torch.zeros((B, 1, H, W), device=patches.device)
        
        # Reconstruct image with weighted averaging
        patch_idx = 0
        for i in range(grid_config['grid_h']):
            for j in range(grid_config['grid_w']):
                h_start = i * grid_config['stride_h']
                w_start = j * grid_config['stride_w']
                
                patch = patches[:, patch_idx]
                output[:, :, h_start:h_start + self.patch_size,
                      w_start:w_start + self.patch_size] += patch
                weights[:, :, h_start:h_start + self.patch_size,
                       w_start:w_start + self.patch_size] += 1
                patch_idx += 1
        
        # Average overlapping regions
        output = output / (weights + 1e-8)
        
        # Remove padding to get back original size
        output = output[:, :,
                       pad_h_before:H - pad_h_after,
                       pad_w_before:W - pad_w_after]
        
        return output

In [6]:
class GeneralizedPatchNetwork(nn.Module):
    """
    Complete network with generalized patch processing
    """
    def __init__(self, base_model, patch_size=256, min_overlap=32):
        super().__init__()
        self.patch_maker = GeneralizedPatchLayer(patch_size, min_overlap)
        self.base_model = base_model
        self.patch_merger = GeneralizedPatchMerger(patch_size)
        
    def visualize_patches(self, x):
        with torch.no_grad():
            grid_config = DynamicPatchOptimizer.calculate_optimal_grid(
                (x.shape[2], x.shape[3]), self.patch_maker.patch_size, self.patch_maker.min_overlap)
            self.patch_maker.visualize_patch_grid(x, grid_config)
        
    def forward(self, x):
        # Convert to patches
        patches, info = self.patch_maker(x)
        
        # Process patches
        B, N = patches.shape[:2]
        patches = patches.reshape(B * N, *patches.shape[2:])
        processed_patches = self.base_model(patches)
        processed_patches = processed_patches.reshape(B, N, *processed_patches.shape[1:])
        
        # Reconstruct image
        output = self.patch_merger(processed_patches, info)
        
        return output

In [39]:
test_sizes = [
    (1152, 1632),  # Original size
    (800, 1200),   # Different aspect ratio
    (2048, 2048),  # Square image
    (720, 1280),   # HD size
    (3840, 2160)   # 4K size
]

In [40]:
patches_maker = GeneralizedPatchLayer(patch_size=256, min_overlap=32)
for size in test_sizes:
    patches, info = patches_maker(torch.randn(1, 3, *size))
    print(f"Size: {size}, Patches: {patches.shape}")

Size: (1152, 1632), Patches: torch.Size([1, 40, 3, 256, 256])
Size: (800, 1200), Patches: torch.Size([1, 24, 3, 256, 256])
Size: (2048, 2048), Patches: torch.Size([1, 81, 3, 256, 256])
Size: (720, 1280), Patches: torch.Size([1, 24, 3, 256, 256])
Size: (3840, 2160), Patches: torch.Size([1, 170, 3, 256, 256])


In [42]:
patches_merger = GeneralizedPatchMerger(patch_size=256)
input_size = (1152, 1632)
patches, info = patches_maker(torch.randn(1, 3, *input_size))
output = patches_merger(patches, info)
print(f"Size: {input_size}, Output: {output.shape}")

Size: (1152, 1632), Output: torch.Size([1, 3, 1152, 1628])


In [24]:
#| export
class SizePreservingPatchLayer(nn.Module):
    """
    Patch conversion layer that guarantees exact input size preservation
    """
    def __init__(self, patch_size=256, min_overlap=32):
        super().__init__()
        self.patch_size = patch_size
        self.min_overlap = min_overlap
        self.register_buffer('weight_mask', self._create_weight_mask())
        
    def _create_weight_mask(self):
        """Creates gaussian weight mask for edge effect reduction"""
        x = torch.linspace(-1, 1, self.patch_size)
        y = torch.linspace(-1, 1, self.patch_size)
        xx, yy = torch.meshgrid(x, y, indexing='ij')
        gaussian = torch.exp(-(xx**2 + yy**2) / 1.5)
        return gaussian

    def _calculate_grid(self, H, W):
        """
        Calculate grid configuration ensuring full coverage of input size
        """
        # Calculate number of patches needed
        n_patches_h = math.ceil(H / (self.patch_size - self.min_overlap))
        n_patches_w = math.ceil(W / (self.patch_size - self.min_overlap))
        
        # Calculate actual strides to exactly cover the image
        stride_h = (H - self.patch_size) / (n_patches_h - 1) if n_patches_h > 1 else 0
        stride_w = (W - self.patch_size) / (n_patches_w - 1) if n_patches_w > 1 else 0
        
        return {
            'n_patches_h': n_patches_h,
            'n_patches_w': n_patches_w,
            'stride_h': stride_h,
            'stride_w': stride_w
        }
    def visualize_patch_grid(self, image_tensor):
        """Visualizes patch grid overlay on image"""
        image = image_tensor[0].cpu().numpy().transpose(1, 2, 0)
        if image.shape[2] == 1:
            image = image.squeeze(-1)
            
        H, W = image.shape[:2]
        grid = self._calculate_grid(H, W)
        
        plt.figure(figsize=(15, 10))
        plt.imshow(image, cmap='gray' if len(image.shape) == 2 else None)
        
        # Draw actual patch locations
        for i in range(grid['n_patches_h']):
            y = i * grid['stride_h']
            plt.axhline(y=y, color='r', linestyle='--', alpha=0.5)
            if i * grid['stride_h'] + self.patch_size <= H:
                plt.axhline(y=y + self.patch_size, color='g', linestyle=':', alpha=0.3)
                
        for j in range(grid['n_patches_w']):
            x = j * grid['stride_w']
            plt.axvline(x=x, color='r', linestyle='--', alpha=0.5)
            if j * grid['stride_w'] + self.patch_size <= W:
                plt.axvline(x=x + self.patch_size, color='g', linestyle=':', alpha=0.3)
                
        plt.title(f'Patch Grid ({grid["n_patches_h"]}x{grid["n_patches_w"]} patches)\n'
                 f'Image Size: {H}x{W}, Patch Size: {self.patch_size}')
        plt.show()
	
    def forward(self, x):
        B, C, H, W = x.shape
        grid = self._calculate_grid(H, W)
        
        patches = []
        locations = []
        
        # Extract patches with exact positioning
        for i in range(grid['n_patches_h']):
            for j in range(grid['n_patches_w']):
                # Calculate exact patch location
                h_start = int(i * grid['stride_h'])
                w_start = int(j * grid['stride_w'])
                
                # Handle edge cases for last patches
                h_start = min(h_start, H - self.patch_size)
                w_start = min(w_start, W - self.patch_size)
                
                patch = x[:, :,
                         h_start:h_start + self.patch_size,
                         w_start:w_start + self.patch_size]
                
                patches.append(patch)
                locations.append((h_start, w_start))
                
        patches = torch.stack(patches, dim=1)
        
        # Apply weight mask for edge effect reduction
        patches = patches * self.weight_mask.view(1, 1, 1, self.patch_size, self.patch_size)
        
        return patches, (locations, (H, W))

In [25]:
#| export
class ExactSizePatchNetwork(nn.Module):
    """
    Network that guarantees exact size preservation
    """
    def __init__(self, base_model, patch_size=256, min_overlap=32):
        super().__init__()
        self.patch_maker = SizePreservingPatchLayer(patch_size, min_overlap)
        self.base_model = base_model
        self.patch_merger = SizePreservingPatchMerger(patch_size)
        
    def visualize_patches(self, x):
        self.patch_maker.visualize_patch_grid(x)
        
    def forward(self, x):
        # Store original size for verification
        original_size = x.shape
        
        # Convert to patches
        patches, info = self.patch_maker(x)
        
        # Process patches
        B, N = patches.shape[:2]
        patches = patches.reshape(B * N, *patches.shape[2:])
        processed_patches = self.base_model(patches)
        processed_patches = processed_patches.reshape(B, N, *processed_patches.shape[1:])
        
        # Reconstruct image
        output = self.patch_merger(processed_patches, info)
        
        # Verify size match
        assert output.shape == original_size, \
            f"Size mismatch: Input {original_size}, Output {output.shape}"
            
        return output

In [26]:
#| export
class SizePreservingPatchMerger(nn.Module):
    """
    Patch merging layer that guarantees exact size preservation
    """
    def __init__(self, patch_size=256):
        super().__init__()
        self.patch_size = patch_size
        
    def forward(self, patches, info):
        locations, (H, W) = info
        B, N, C, H_patch, W_patch = patches.shape
        
        # Initialize output and weight accumulator with exact input size
        output = torch.zeros((B, C, H, W), device=patches.device)
        weights = torch.zeros((B, 1, H, W), device=patches.device)
        
        # Reconstruct image using exact patch locations
        for idx, (h_start, w_start) in enumerate(locations):
            patch = patches[:, idx]
            h_end = min(h_start + self.patch_size, H)
            w_end = min(w_start + self.patch_size, W)
            
            output[:, :, h_start:h_end, w_start:w_end] += patch[:, :, :(h_end-h_start), :(w_end-w_start)]
            weights[:, :, h_start:h_end, w_start:w_end] += 1
            
        # Average overlapping regions
        output = output / (weights + 1e-8)
        return output

In [8]:
from segmentation_test.pytorch_model_development import UNet

  check_for_updates()


In [9]:
#output = base_model(input_image)
#print(output.shape)


torch.Size([1, 48, 1, 256, 256])
([(0, 0), (0, 196), (0, 393), (0, 589), (0, 786), (0, 982), (0, 1179), (0, 1376), (179, 0), (179, 196), (179, 393), (179, 589), (179, 786), (179, 982), (179, 1179), (179, 1376), (358, 0), (358, 196), (358, 393), (358, 589), (358, 786), (358, 982), (358, 1179), (358, 1376), (537, 0), (537, 196), (537, 393), (537, 589), (537, 786), (537, 982), (537, 1179), (537, 1376), (716, 0), (716, 196), (716, 393), (716, 589), (716, 786), (716, 982), (716, 1179), (716, 1376), (896, 0), (896, 196), (896, 393), (896, 589), (896, 786), (896, 982), (896, 1179), (896, 1376)], (1152, 1632))


In [11]:
input_image = torch.randn(1, 1, 300, 300)
size_preserving_patch_layer = SizePreservingPatchLayer(patch_size=256, min_overlap=32)
patches, info = size_preserving_patch_layer(input_image)

print(patches.shape)
print(info)
base_model = UNet(in_channels=1, out_channels=1, features=[64, 128, 256], near_size=256)
B, N = patches.shape[:2]
patches = patches.reshape(B * N, *patches.shape[2:])
print(patches.shape)
base_model.eval()
with torch.no_grad():
    output = base_model(patches)
print(output.shape)

torch.Size([1, 4, 1, 256, 256])
([(0, 0), (0, 44), (44, 0), (44, 44)], (300, 300))
torch.Size([4, 1, 256, 256])
torch.Size([4, 1, 256, 256])


In [12]:
processed_patches = output.reshape(B, N, *output.shape[1:])
print(processed_patches.shape)


torch.Size([1, 4, 1, 256, 256])


In [13]:
output = SizePreservingPatchMerger(patch_size=256)(processed_patches, info)
print(output.shape)

torch.Size([1, 1, 300, 300])


In [15]:
network = ExactSizePatchNetwork(base_model=base_model, patch_size=256, min_overlap=32)

In [16]:
output = network(input_image)

In [17]:
print(output.shape)

torch.Size([1, 1, 300, 300])


In [18]:
# Export to ONNX
torch.onnx.export(network,
                 input_image,
                 "optimized_patch_network.onnx",
                 opset_version=12,
                  input_names=['input'],
                  output_names=['output'])

  n_patches_h = math.ceil(H / (self.patch_size - self.min_overlap))
  n_patches_w = math.ceil(W / (self.patch_size - self.min_overlap))
  h_start = int(i * grid['stride_h'])
  w_start = int(j * grid['stride_w'])
  h_start = min(h_start, H - self.patch_size)
  w_start = min(w_start, W - self.patch_size)
  if h != self.size or w!=self.size:
  h_end = min(h_start + self.patch_size, H)
  w_end = min(w_start + self.patch_size, W)
  assert output.shape == original_size, \


In [22]:
import onnx
import onnxruntime
import numpy as np

# Load the ONNX model
onnx_model = onnx.load("optimized_patch_network.onnx")

# Check the model
onnx.checker.check_model(onnx_model)


In [23]:
# Print a human readable representation of the graph
onnx.helper.printable_graph(onnx_model.graph)


'graph main_graph (\n  %input[FLOAT, 1x1x300x300]\n) initializers (\n  %base_model.ups.0.up.weight[FLOAT, 512x256x2x2]\n  %base_model.ups.0.up.bias[FLOAT, 256]\n  %base_model.ups.2.up.weight[FLOAT, 256x128x2x2]\n  %base_model.ups.2.up.bias[FLOAT, 128]\n  %base_model.ups.4.up.weight[FLOAT, 128x64x2x2]\n  %base_model.ups.4.up.bias[FLOAT, 64]\n  %base_model.final_conv.weight[FLOAT, 1x64x1x1]\n  %base_model.final_conv.bias[FLOAT, 1]\n  %onnx::Conv_1324[FLOAT, 64x1x3x3]\n  %onnx::Conv_1325[FLOAT, 64]\n  %onnx::Conv_1327[FLOAT, 64x64x3x3]\n  %onnx::Conv_1328[FLOAT, 64]\n  %onnx::Conv_1330[FLOAT, 128x64x3x3]\n  %onnx::Conv_1331[FLOAT, 128]\n  %onnx::Conv_1333[FLOAT, 128x128x3x3]\n  %onnx::Conv_1334[FLOAT, 128]\n  %onnx::Conv_1336[FLOAT, 256x128x3x3]\n  %onnx::Conv_1337[FLOAT, 256]\n  %onnx::Conv_1339[FLOAT, 256x256x3x3]\n  %onnx::Conv_1340[FLOAT, 256]\n  %onnx::Conv_1342[FLOAT, 512x256x3x3]\n  %onnx::Conv_1343[FLOAT, 512]\n  %onnx::Conv_1345[FLOAT, 512x512x3x3]\n  %onnx::Conv_1346[FLOAT, 512]

In [None]:

# Load the ONNX model as a PyTorch model
ort_session = onnxruntime.InferenceSession("optimized_patch_network.onnx")

# Prepare the input data
input_data = np.random.random(size=(1, 1, 300, 300)).astype(np.float32)

# Run the model with ONNX Runtime
ort_inputs = {ort_session.get_inputs()[0].name: input_data}
ort_outs = ort_session.run(None, ort_inputs)
ort_outs[0].shape
# Compare the ONNX and PyTorch results
#np.testing.assert_allclose(ort_outs[0], output.detach().numpy(), rtol=1e-03, atol=1e-05)


In [29]:
#| hide
import nbdev; nbdev.nbdev_export('110_preprocessing.pt_patching')