In [1]:
#| default_exp preprocessing.pt_patching

# Patch Whole imageto number of patches
> Patch whole image into number of patches

In [1]:
#| export
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

In [17]:
#| export
from fastcore.all import *

In [18]:
#| export
class SizePreservingPatchLayer(nn.Module):
    """
    Patch conversion layer that guarantees exact input size preservation
    """
    def __init__(self, patch_size=256, min_overlap=32):
        super().__init__()
        self.patch_size = patch_size
        self.min_overlap = min_overlap
        self.register_buffer('weight_mask', self._create_weight_mask())
        
    def _create_weight_mask(self):
        """Creates gaussian weight mask for edge effect reduction"""
        x = torch.linspace(-1, 1, self.patch_size)
        y = torch.linspace(-1, 1, self.patch_size)
        xx, yy = torch.meshgrid(x, y, indexing='ij')
        gaussian = torch.exp(-(xx**2 + yy**2) / 1.5)
        return gaussian

    def _calculate_grid(self, H, W):
        """
        Calculate grid configuration ensuring full coverage of input size
        """
        # Calculate number of patches needed
        n_patches_h = math.ceil(H / (self.patch_size - self.min_overlap))
        n_patches_w = math.ceil(W / (self.patch_size - self.min_overlap))
        
        # Calculate actual strides to exactly cover the image
        stride_h = (H - self.patch_size) / (n_patches_h - 1) if n_patches_h > 1 else 0
        stride_w = (W - self.patch_size) / (n_patches_w - 1) if n_patches_w > 1 else 0
        
        return {
            'n_patches_h': n_patches_h,
            'n_patches_w': n_patches_w,
            'stride_h': stride_h,
            'stride_w': stride_w
        }
    def visualize_patch_grid(self, image_tensor):
        """Visualizes patch grid overlay on image"""
        image = image_tensor[0].cpu().numpy().transpose(1, 2, 0)
        if image.shape[2] == 1:
            image = image.squeeze(-1)
            
        H, W = image.shape[:2]
        grid = self._calculate_grid(H, W)
        
        plt.figure(figsize=(15, 10))
        plt.imshow(image, cmap='gray' if len(image.shape) == 2 else None)
        
        # Draw actual patch locations
        for i in range(grid['n_patches_h']):
            y = i * grid['stride_h']
            plt.axhline(y=y, color='r', linestyle='--', alpha=0.5)
            if i * grid['stride_h'] + self.patch_size <= H:
                plt.axhline(y=y + self.patch_size, color='g', linestyle=':', alpha=0.3)
                
        for j in range(grid['n_patches_w']):
            x = j * grid['stride_w']
            plt.axvline(x=x, color='r', linestyle='--', alpha=0.5)
            if j * grid['stride_w'] + self.patch_size <= W:
                plt.axvline(x=x + self.patch_size, color='g', linestyle=':', alpha=0.3)
                
        plt.title(f'Patch Grid ({grid["n_patches_h"]}x{grid["n_patches_w"]} patches)\n'
                 f'Image Size: {H}x{W}, Patch Size: {self.patch_size}')
        plt.show()
	
    def forward(self, x):
        B, C, H, W = x.shape
        grid = self._calculate_grid(H, W)
        
        patches = []
        locations = []
        
        # Extract patches with exact positioning
        for i in range(grid['n_patches_h']):
            for j in range(grid['n_patches_w']):
                # Calculate exact patch location
                h_start = int(i * grid['stride_h'])
                w_start = int(j * grid['stride_w'])
                
                # Handle edge cases for last patches
                h_start = min(h_start, H - self.patch_size)
                w_start = min(w_start, W - self.patch_size)
                
                patch = x[:, :,
                         h_start:h_start + self.patch_size,
                         w_start:w_start + self.patch_size]
                
                patches.append(patch)
                locations.append((h_start, w_start))
                
        patches = torch.stack(patches, dim=1)
        
        # Apply weight mask for edge effect reduction
        patches = patches * self.weight_mask.view(1, 1, 1, self.patch_size, self.patch_size)
        
        return patches, (locations, (H, W))

In [20]:
#| export
class SizePreservingPatchMerger(nn.Module):
    """
    Patch merging layer that guarantees exact size preservation
    """
    def __init__(self, patch_size=256):
        super().__init__()
        self.patch_size = patch_size
        
    def forward(self, patches, info):
        locations, (H, W) = info
        B, N, C, H_patch, W_patch = patches.shape
        
        # Initialize output and weight accumulator with exact input size
        output = torch.zeros((B, C, H, W), device=patches.device)
        weights = torch.zeros((B, 1, H, W), device=patches.device)
        
        # Reconstruct image using exact patch locations
        for idx, (h_start, w_start) in enumerate(locations):
            patch = patches[:, idx]
            h_end = min(h_start + self.patch_size, H)
            w_end = min(w_start + self.patch_size, W)
            
            output[:, :, h_start:h_end, w_start:w_end] += patch[:, :, :(h_end-h_start), :(w_end-w_start)]
            weights[:, :, h_start:h_end, w_start:w_end] += 1
            
        # Average overlapping regions
        output = output / (weights + 1e-8)
        return output

In [21]:
#| export
class ExactSizePatchNetwork(nn.Module):
    """
    Network that guarantees exact size preservation
    """
    def __init__(self, base_model, patch_size=256, min_overlap=32):
        super().__init__()
        self.patch_maker = SizePreservingPatchLayer(patch_size, min_overlap)
        self.base_model = base_model
        self.patch_merger = SizePreservingPatchMerger(patch_size)
        
    def visualize_patches(self, x):
        self.patch_maker.visualize_patch_grid(x)
        
    def forward(self, x):
        # Store original size for verification
        original_size = x.shape
        
        # Convert to patches
        patches, info = self.patch_maker(x)
        
        # Process patches
        B, N = patches.shape[:2]
        patches = patches.reshape(B * N, *patches.shape[2:])
        processed_patches = self.base_model(patches)
        processed_patches = processed_patches.reshape(B, N, *processed_patches.shape[1:])
        
        # Reconstruct image
        output = self.patch_merger(processed_patches, info)
        
        # Verify size match
        assert output.shape == original_size, \
            f"Size mismatch: Input {original_size}, Output {output.shape}"
            
        return output

In [22]:
#| exporti
from segmentation_test.pytorch_model_development import UNet

In [9]:
#output = base_model(input_image)
#print(output.shape)


In [7]:
input_image = torch.randn(1, 1, 300, 300)
size_preserving_patch_layer = SizePreservingPatchLayer(patch_size=256, min_overlap=32)
patches, info = size_preserving_patch_layer(input_image)

print(patches.shape)
print(info)
base_model = UNet(in_channels=1, out_channels=1, features=[64, 128, 256], near_size=256)
B, N = patches.shape[:2]
patches = patches.reshape(B * N, *patches.shape[2:])
print(patches.shape)
base_model.eval()
with torch.no_grad():
    output = base_model(patches)
print(output.shape)

torch.Size([1, 4, 1, 256, 256])
([(0, 0), (0, 44), (44, 0), (44, 44)], (300, 300))
torch.Size([4, 1, 256, 256])
torch.Size([4, 1, 256, 256])


In [8]:
processed_patches = output.reshape(B, N, *output.shape[1:])
print(processed_patches.shape)


torch.Size([1, 4, 1, 256, 256])


In [9]:
output = SizePreservingPatchMerger(patch_size=256)(processed_patches, info)
print(output.shape)

torch.Size([1, 1, 300, 300])


In [10]:
network = ExactSizePatchNetwork(base_model=base_model, patch_size=256, min_overlap=32)

In [11]:
output = network(input_image)

In [12]:
print(output.shape)

torch.Size([1, 1, 300, 300])


In [13]:
# Export to ONNX
torch.onnx.export(network,
                 input_image,
                 "optimized_patch_network.onnx",
                 opset_version=12,
                  input_names=['input'],
                  output_names=['output'])

  n_patches_h = math.ceil(H / (self.patch_size - self.min_overlap))
  n_patches_w = math.ceil(W / (self.patch_size - self.min_overlap))
  h_start = int(i * grid['stride_h'])
  w_start = int(j * grid['stride_w'])
  h_start = min(h_start, H - self.patch_size)
  w_start = min(w_start, W - self.patch_size)
  if h != self.size or w!=self.size:
  h_end = min(h_start + self.patch_size, H)
  w_end = min(w_start + self.patch_size, W)
  assert output.shape == original_size, \


In [14]:
import onnx
import onnxruntime
import numpy as np

# Load the ONNX model
onnx_model = onnx.load("optimized_patch_network.onnx")

# Check the model
onnx.checker.check_model(onnx_model)


In [15]:
# Print a human readable representation of the graph
onnx.helper.printable_graph(onnx_model.graph)


'graph main_graph (\n  %input[FLOAT, 1x1x300x300]\n) initializers (\n  %base_model.ups.0.up.weight[FLOAT, 512x256x2x2]\n  %base_model.ups.0.up.bias[FLOAT, 256]\n  %base_model.ups.2.up.weight[FLOAT, 256x128x2x2]\n  %base_model.ups.2.up.bias[FLOAT, 128]\n  %base_model.ups.4.up.weight[FLOAT, 128x64x2x2]\n  %base_model.ups.4.up.bias[FLOAT, 64]\n  %base_model.final_conv.weight[FLOAT, 1x64x1x1]\n  %base_model.final_conv.bias[FLOAT, 1]\n  %onnx::Conv_1324[FLOAT, 64x1x3x3]\n  %onnx::Conv_1325[FLOAT, 64]\n  %onnx::Conv_1327[FLOAT, 64x64x3x3]\n  %onnx::Conv_1328[FLOAT, 64]\n  %onnx::Conv_1330[FLOAT, 128x64x3x3]\n  %onnx::Conv_1331[FLOAT, 128]\n  %onnx::Conv_1333[FLOAT, 128x128x3x3]\n  %onnx::Conv_1334[FLOAT, 128]\n  %onnx::Conv_1336[FLOAT, 256x128x3x3]\n  %onnx::Conv_1337[FLOAT, 256]\n  %onnx::Conv_1339[FLOAT, 256x256x3x3]\n  %onnx::Conv_1340[FLOAT, 256]\n  %onnx::Conv_1342[FLOAT, 512x256x3x3]\n  %onnx::Conv_1343[FLOAT, 512]\n  %onnx::Conv_1345[FLOAT, 512x512x3x3]\n  %onnx::Conv_1346[FLOAT, 512]

In [16]:

# Load the ONNX model as a PyTorch model
ort_session = onnxruntime.InferenceSession("optimized_patch_network.onnx")

# Prepare the input data
input_data = np.random.random(size=(1, 1, 300, 300)).astype(np.float32)

# Run the model with ONNX Runtime
ort_inputs = {ort_session.get_inputs()[0].name: input_data}
ort_outs = ort_session.run(None, ort_inputs)
ort_outs[0].shape
# Compare the ONNX and PyTorch results
#np.testing.assert_allclose(ort_outs[0], output.detach().numpy(), rtol=1e-03, atol=1e-05)


(1, 1, 300, 300)

In [29]:
#| hide
import nbdev; nbdev.nbdev_export('110_preprocessing.pt_patching')