In [16]:
#| default_exp preprocessing.pt_patching

# Patch Whole imageto number of patches
> Patch whole image into number of patches

In [17]:
#| export
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt 
import math
import numpy as np

In [18]:
#| export
from fastcore.all import *

In [19]:
#| export
class SizePreservingPatchLayer(nn.Module):
    """
    Patch conversion layer that guarantees exact input size preservation
    """
    def __init__(self, patch_size=256, min_overlap=32):
        super().__init__()
        self.patch_size = patch_size
        self.min_overlap = min_overlap
        self.register_buffer('weight_mask', self._create_weight_mask())
        
    def _create_weight_mask(self):
        """Creates gaussian weight mask for edge effect reduction"""
        x = torch.linspace(-1, 1, self.patch_size)
        y = torch.linspace(-1, 1, self.patch_size)
        xx, yy = torch.meshgrid(x, y, indexing='ij')
        gaussian = torch.exp(-(xx**2 + yy**2) / 1.5)
        return gaussian

    def _calculate_grid(self, H, W):
        """
        Calculate grid configuration ensuring full coverage of input size
        """
        # Calculate number of patches needed
        n_patches_h = math.ceil(H / (self.patch_size - self.min_overlap))
        n_patches_w = math.ceil(W / (self.patch_size - self.min_overlap))
        
        # Calculate actual strides to exactly cover the image
        stride_h = (H - self.patch_size) / (n_patches_h - 1) if n_patches_h > 1 else 0
        stride_w = (W - self.patch_size) / (n_patches_w - 1) if n_patches_w > 1 else 0
        
        return {
            'n_patches_h': n_patches_h,
            'n_patches_w': n_patches_w,
            'stride_h': stride_h,
            'stride_w': stride_w
        }
    def visualize_patch_grid(self, image_tensor):
        """Visualizes patch grid overlay on image"""
        image = image_tensor[0].cpu().numpy().transpose(1, 2, 0)
        if image.shape[2] == 1:
            image = image.squeeze(-1)
            
        H, W = image.shape[:2]
        grid = self._calculate_grid(H, W)
        
        plt.figure(figsize=(15, 10))
        plt.imshow(image, cmap='gray' if len(image.shape) == 2 else None)
        
        # Draw actual patch locations
        for i in range(grid['n_patches_h']):
            y = i * grid['stride_h']
            plt.axhline(y=y, color='r', linestyle='--', alpha=0.5)
            if i * grid['stride_h'] + self.patch_size <= H:
                plt.axhline(y=y + self.patch_size, color='g', linestyle=':', alpha=0.3)
                
        for j in range(grid['n_patches_w']):
            x = j * grid['stride_w']
            plt.axvline(x=x, color='r', linestyle='--', alpha=0.5)
            if j * grid['stride_w'] + self.patch_size <= W:
                plt.axvline(x=x + self.patch_size, color='g', linestyle=':', alpha=0.3)
                
        plt.title(f'Patch Grid ({grid["n_patches_h"]}x{grid["n_patches_w"]} patches)\n'
                 f'Image Size: {H}x{W}, Patch Size: {self.patch_size}')
        plt.show()
	
    def forward(self, x):
        B, C, H, W = x.shape
        grid = self._calculate_grid(H, W)
        
        patches = []
        locations = []
        
        # Extract patches with exact positioning
        for i in range(grid['n_patches_h']):
            for j in range(grid['n_patches_w']):
                # Calculate exact patch location
                h_start = int(i * grid['stride_h'])
                w_start = int(j * grid['stride_w'])
                
                # Handle edge cases for last patches
                h_start = min(h_start, H - self.patch_size)
                w_start = min(w_start, W - self.patch_size)
                
                patch = x[:, :,
                         h_start:h_start + self.patch_size,
                         w_start:w_start + self.patch_size]
                
                patches.append(patch)
                locations.append((h_start, w_start))
                
        patches = torch.stack(patches, dim=1)
        
        # Apply weight mask for edge effect reduction
        patches = patches * self.weight_mask.view(1, 1, 1, self.patch_size, self.patch_size)
        
        return patches, (locations, (H, W))

In [20]:
#| export

class SizePreservingPatchLayerONNX(nn.Module):
    """
    ONNX-compatible patch conversion layer with exact numerical consistency
    """
    def __init__(self, patch_size=256, min_overlap=32):
        super().__init__()
        self.patch_size = patch_size
        self.min_overlap = min_overlap
        self.register_buffer('weight_mask', self._create_weight_mask())
        
    def _create_weight_mask(self):
        """Creates gaussian weight mask with controlled precision"""
        # Use float32 explicitly for consistent precision
        x = torch.linspace(-1, 1, self.patch_size, dtype=torch.float32)
        y = torch.linspace(-1, 1, self.patch_size, dtype=torch.float32)
        xx, yy = torch.meshgrid(x, y, indexing='ij')
        # Use a more ONNX-friendly gaussian formula
        gaussian = torch.exp(-(xx.pow(2) + yy.pow(2)) / 1.5)
        return gaussian.to(torch.float32)

    def _calculate_grid(self, H, W):
        """Calculate grid with explicit float32 calculations"""
        # Convert to float32 for consistent precision
        H, W = float(H), float(W)
        patch_size = float(self.patch_size)
        min_overlap = float(self.min_overlap)
        
        # Calculate patches needed with explicit float32
        n_patches_h = math.ceil((H - min_overlap) / (patch_size - min_overlap))
        n_patches_w = math.ceil((W - min_overlap) / (patch_size - min_overlap))
        
        # Calculate strides with controlled precision
        stride_h = torch.tensor((H - patch_size) / max(n_patches_h - 1, 1), dtype=torch.float32)
        stride_w = torch.tensor((W - patch_size) / max(n_patches_w - 1, 1), dtype=torch.float32)
        
        return {
            'n_patches_h': int(n_patches_h),
            'n_patches_w': int(n_patches_w),
            'stride_h': stride_h.item(),
            'stride_w': stride_w.item()
        }

    def forward(self, x):
        B, C, H, W = x.shape
        grid = self._calculate_grid(H, W)
        
        patches = []
        locations = []
        
        for i in range(grid['n_patches_h']):
            for j in range(grid['n_patches_w']):
                # Use explicit float32 calculations
                h_start = int(round(i * grid['stride_h']))
                w_start = int(round(j * grid['stride_w']))
                
                # Ensure exact boundary handling
                h_start = min(h_start, H - self.patch_size)
                w_start = min(w_start, W - self.patch_size)
                
                patch = x[:, :,
                         h_start:h_start + self.patch_size,
                         w_start:w_start + self.patch_size]
                
                patches.append(patch)
                locations.append((h_start, w_start))
        
        patches = torch.stack(patches, dim=1)
        # Apply weight mask with controlled precision
        patches = patches * self.weight_mask.to(x.dtype).view(1, 1, 1, self.patch_size, self.patch_size)
        
        return patches, (locations, (H, W))


In [21]:
#| export
class ExactSizePatchNetworkONNX(nn.Module):
    """
    ONNX-compatible network ensuring numerical consistency
    """
    def __init__(self, base_model, patch_size=256, min_overlap=32):
        super().__init__()
        self.patch_maker = SizePreservingPatchLayerONNX(patch_size, min_overlap)
        self.base_model = base_model
        self.patch_merger = SizePreservingPatchMergerONNX(patch_size)
        
    def forward(self, x):
        # Ensure input is float32 for consistency
        x = x.to(torch.float32)
        original_size = x.shape
        
        patches, info = self.patch_maker(x)
        
        # Process patches with controlled batch reshaping
        B, N = patches.shape[:2]
        patches = patches.reshape(B * N, *patches.shape[2:])
        processed_patches = self.base_model(patches)
        processed_patches = processed_patches.reshape(B, N, *processed_patches.shape[1:])
        
        # Merge patches with numerical stability
        output = self.patch_merger(processed_patches, info)
        
        assert output.shape == original_size, f"Size mismatch: Input {original_size}, Output {output.shape}"
        return output


In [8]:
from segmentation_test.pytorch_model_development import UNet

  check_for_updates()


In [9]:
input_image = torch.randn(1, 1, 300, 300)
base_model = UNet(in_channels=1, out_channels=1, features=[64, 128, 256], near_size=256)
network = ExactSizePatchNetworkONNX(
    base_model=base_model, 
    patch_size=256, 
    min_overlap=32)
output = network(input_image)
output.shape
#| export

torch.Size([1, 1, 300, 300])

In [23]:
compare_pytorch_onnx(network, input_image, "optimized_patch_network.onnx")

  H, W = float(H), float(W)
  stride_h = torch.tensor((H - patch_size) / max(n_patches_h - 1, 1), dtype=torch.float32)
  stride_w = torch.tensor((W - patch_size) / max(n_patches_w - 1, 1), dtype=torch.float32)
  'stride_h': stride_h.item(),
  'stride_w': stride_w.item()
  h_start = min(h_start, H - self.patch_size)
  w_start = min(w_start, W - self.patch_size)
  if h != self.size or w!=self.size:
  h_end = min(h_start + self.patch_size, H)
  w_end = min(w_start + self.patch_size, W)
  assert output.shape == original_size, f"Size mismatch: Input {original_size}, Output {output.shape}"


{'max_diff': 4.0978193e-08,
 'mean_diff': 5.308063e-09,
 'pytorch_output': array([[[[0.05229837, 0.05197935, 0.0531798 , ..., 0.05216144,
           0.05292663, 0.05627326],
          [0.04138024, 0.04260495, 0.04600115, ..., 0.04618616,
           0.04291147, 0.05262633],
          [0.04580892, 0.05131192, 0.0478221 , ..., 0.04899705,
           0.04629134, 0.05292745],
          ...,
          [0.04628309, 0.04958029, 0.05422394, ..., 0.04823392,
           0.04898105, 0.04917352],
          [0.05093627, 0.04550996, 0.04198207, ..., 0.05080977,
           0.0491537 , 0.05600167],
          [0.05410715, 0.05408612, 0.05009517, ..., 0.05257736,
           0.05256517, 0.05227383]]]], dtype=float32),
 'onnx_output': array([[[[0.05229836, 0.05197934, 0.05317979, ..., 0.05216144,
           0.05292663, 0.05627327],
          [0.04138024, 0.04260498, 0.04600116, ..., 0.04618617,
           0.04291149, 0.05262632],
          [0.04580891, 0.05131193, 0.0478221 , ..., 0.04899706,
           0.

In [22]:
#| export
def compare_pytorch_onnx(model, input_tensor, onnx_path="model.onnx"):
    """
    Utility function to compare PyTorch and ONNX outputs
    """
    # PyTorch forward pass
    model.eval()
    with torch.no_grad():
        pytorch_output = model(input_tensor)

    # Export to ONNX
    torch.onnx.export(model, input_tensor, onnx_path, 
                     opset_version=12,
                     input_names=['input'],
                     output_names=['output'],
                     dynamic_axes={'input': {0: 'batch_size'},
                                 'output': {0: 'batch_size'}})

    # ONNX inference
    import onnxruntime
    ort_session = onnxruntime.InferenceSession(onnx_path)
    ort_inputs = {ort_session.get_inputs()[0].name: input_tensor.cpu().numpy()}
    ort_output = ort_session.run(None, ort_inputs)[0]

    # Compare results
    pytorch_numpy = pytorch_output.cpu().numpy()
    max_diff = np.max(np.abs(pytorch_numpy - ort_output))
    mean_diff = np.mean(np.abs(pytorch_numpy - ort_output))
    np.testing.assert_allclose(ort_output, pytorch_output.detach().numpy(), rtol=1e-03, atol=1e-05)
    
    return {
        'max_diff': max_diff,
        'mean_diff': mean_diff,
        'pytorch_output': pytorch_numpy,
        'onnx_output': ort_output
    } 

In [20]:
#| export
class SizePreservingPatchMerger(nn.Module):
    """
    Patch merging layer that guarantees exact size preservation
    """
    def __init__(self, patch_size=256):
        super().__init__()
        self.patch_size = patch_size
        
    def forward(self, patches, info):
        locations, (H, W) = info
        B, N, C, H_patch, W_patch = patches.shape
        
        # Initialize output and weight accumulator with exact input size
        output = torch.zeros((B, C, H, W), device=patches.device)
        weights = torch.zeros((B, 1, H, W), device=patches.device)
        
        # Reconstruct image using exact patch locations
        for idx, (h_start, w_start) in enumerate(locations):
            patch = patches[:, idx]
            h_end = min(h_start + self.patch_size, H)
            w_end = min(w_start + self.patch_size, W)
            
            output[:, :, h_start:h_end, w_start:w_end] += patch[:, :, :(h_end-h_start), :(w_end-w_start)]
            weights[:, :, h_start:h_end, w_start:w_end] += 1
            
        # Average overlapping regions
        output = output / (weights + 1e-8)
        return output

In [21]:
#| export
class ExactSizePatchNetwork(nn.Module):
    """
    Network that guarantees exact size preservation
    """
    def __init__(self, base_model, patch_size=256, min_overlap=32):
        super().__init__()
        self.patch_maker = SizePreservingPatchLayer(patch_size, min_overlap)
        self.base_model = base_model
        self.patch_merger = SizePreservingPatchMerger(patch_size)
        
    def visualize_patches(self, x):
        self.patch_maker.visualize_patch_grid(x)
        
    def forward(self, x):
        # Store original size for verification
        original_size = x.shape
        
        # Convert to patches
        patches, info = self.patch_maker(x)
        
        # Process patches
        B, N = patches.shape[:2]
        patches = patches.reshape(B * N, *patches.shape[2:])
        processed_patches = self.base_model(patches)
        processed_patches = processed_patches.reshape(B, N, *processed_patches.shape[1:])
        
        # Reconstruct image
        output = self.patch_merger(processed_patches, info)
        
        # Verify size match
        assert output.shape == original_size, \
            f"Size mismatch: Input {original_size}, Output {output.shape}"
            
        return output

In [22]:
#| exporti
from segmentation_test.pytorch_model_development import UNet

In [9]:
#output = base_model(input_image)
#print(output.shape)


In [7]:
input_image = torch.randn(1, 1, 300, 300)
size_preserving_patch_layer = SizePreservingPatchLayer(patch_size=256, min_overlap=32)
patches, info = size_preserving_patch_layer(input_image)

print(patches.shape)
print(info)
base_model = UNet(in_channels=1, out_channels=1, features=[64, 128, 256], near_size=256)
B, N = patches.shape[:2]
patches = patches.reshape(B * N, *patches.shape[2:])
print(patches.shape)
base_model.eval()
with torch.no_grad():
    output = base_model(patches)
print(output.shape)

torch.Size([1, 4, 1, 256, 256])
([(0, 0), (0, 44), (44, 0), (44, 44)], (300, 300))
torch.Size([4, 1, 256, 256])
torch.Size([4, 1, 256, 256])


In [8]:
processed_patches = output.reshape(B, N, *output.shape[1:])
print(processed_patches.shape)


torch.Size([1, 4, 1, 256, 256])


In [9]:
output = SizePreservingPatchMerger(patch_size=256)(processed_patches, info)
print(output.shape)

torch.Size([1, 1, 300, 300])


In [10]:
network = ExactSizePatchNetwork(base_model=base_model, patch_size=256, min_overlap=32)

In [11]:
output = network(input_image)

In [12]:
print(output.shape)

torch.Size([1, 1, 300, 300])


In [13]:
# Export to ONNX
torch.onnx.export(network,
                 input_image,
                 "optimized_patch_network.onnx",
                 opset_version=12,
                  input_names=['input'],
                  output_names=['output'])

  n_patches_h = math.ceil(H / (self.patch_size - self.min_overlap))
  n_patches_w = math.ceil(W / (self.patch_size - self.min_overlap))
  h_start = int(i * grid['stride_h'])
  w_start = int(j * grid['stride_w'])
  h_start = min(h_start, H - self.patch_size)
  w_start = min(w_start, W - self.patch_size)
  if h != self.size or w!=self.size:
  h_end = min(h_start + self.patch_size, H)
  w_end = min(w_start + self.patch_size, W)
  assert output.shape == original_size, \


In [14]:
import onnx
import onnxruntime
import numpy as np

# Load the ONNX model
onnx_model = onnx.load("optimized_patch_network.onnx")

# Check the model
onnx.checker.check_model(onnx_model)


In [15]:
# Print a human readable representation of the graph
onnx.helper.printable_graph(onnx_model.graph)


'graph main_graph (\n  %input[FLOAT, 1x1x300x300]\n) initializers (\n  %base_model.ups.0.up.weight[FLOAT, 512x256x2x2]\n  %base_model.ups.0.up.bias[FLOAT, 256]\n  %base_model.ups.2.up.weight[FLOAT, 256x128x2x2]\n  %base_model.ups.2.up.bias[FLOAT, 128]\n  %base_model.ups.4.up.weight[FLOAT, 128x64x2x2]\n  %base_model.ups.4.up.bias[FLOAT, 64]\n  %base_model.final_conv.weight[FLOAT, 1x64x1x1]\n  %base_model.final_conv.bias[FLOAT, 1]\n  %onnx::Conv_1324[FLOAT, 64x1x3x3]\n  %onnx::Conv_1325[FLOAT, 64]\n  %onnx::Conv_1327[FLOAT, 64x64x3x3]\n  %onnx::Conv_1328[FLOAT, 64]\n  %onnx::Conv_1330[FLOAT, 128x64x3x3]\n  %onnx::Conv_1331[FLOAT, 128]\n  %onnx::Conv_1333[FLOAT, 128x128x3x3]\n  %onnx::Conv_1334[FLOAT, 128]\n  %onnx::Conv_1336[FLOAT, 256x128x3x3]\n  %onnx::Conv_1337[FLOAT, 256]\n  %onnx::Conv_1339[FLOAT, 256x256x3x3]\n  %onnx::Conv_1340[FLOAT, 256]\n  %onnx::Conv_1342[FLOAT, 512x256x3x3]\n  %onnx::Conv_1343[FLOAT, 512]\n  %onnx::Conv_1345[FLOAT, 512x512x3x3]\n  %onnx::Conv_1346[FLOAT, 512]

In [26]:

# Load the ONNX model as a PyTorch model
ort_session = onnxruntime.InferenceSession("optimized_patch_network.onnx")

# Prepare the input data
input_data = np.random.random(size=(1, 1, 300, 300)).astype(np.float32)

# Run the model with ONNX Runtime
ort_inputs = {ort_session.get_inputs()[0].name: input_data}
ort_outs = ort_session.run(None, ort_inputs)
ort_outs[0].shape


(1, 1, 300, 300)

In [27]:

# Compare the ONNX and PyTorch results
np.testing.assert_allclose(ort_outs[0], output.detach().numpy(), rtol=1e-03, atol=1e-05)


AssertionError: 
Not equal to tolerance rtol=0.001, atol=1e-05

Mismatched elements: 88537 / 90000 (98.4%)
Max absolute difference: 0.03038901
Max relative difference: 0.2821542
 x: array([[[[-0.099564, -0.098427, -0.097701, ..., -0.097652, -0.095919,
          -0.095454],
         [-0.094623, -0.095435, -0.094947, ..., -0.093829, -0.092482,...
 y: array([[[[-0.099845, -0.098138, -0.097527, ..., -0.098395, -0.097126,
          -0.096604],
         [-0.092851, -0.09231 , -0.092299, ..., -0.093995, -0.093602,...

In [23]:
#| hide
import nbdev; nbdev.nbdev_export('110_preprocessing.pt_patching')