In [None]:
#| default_exp patching.first_patching

# Patchify Image
> Create first patching of images

In [None]:
#| hide
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
#| export
from cv_tools.core import *
from cv_tools.imports import *
from cv_tools.data_processing.smb_tools import *


In [None]:
#| export
import torch
from torch import nn


In [None]:
img = torch.randn(1, 1, 1152, 1632)
H, W = img.shape[2:]
tile_size = 256
overlap = 32
h_steps = max(1, (H - overlap) // (tile_size - overlap))

In [None]:
#| export
class TileProcessor(nn.Module):
    def __init__(self, tile_size=256, overlap=32):
        super().__init__()
        self.tile_size = tile_size
        self.overlap = overlap
    
    def forward(self, x):
        """Split image into overlapping tiles and return tiles with their positions"""
        B, C, H, W = x.shape
        tile_size = self.tile_size
        overlap = self.overlap
        
        tiles = []
        positions = []
        
        # Calculate steps with overlap
        h_steps = max(1, (H - overlap) // (tile_size - overlap))
        w_steps = max(1, (W - overlap) // (tile_size - overlap))
        
        for i in range(h_steps):
            h_start = min(i * (tile_size - overlap), H - tile_size)
            for j in range(w_steps):
                w_start = min(j * (tile_size - overlap), W - tile_size)
                
                # Extract tile
                tile = x[:, :, h_start:h_start+tile_size, w_start:w_start+tile_size]
                tiles.append(tile)
                
                # Store position
                positions.append([h_start, w_start, h_start+tile_size, w_start+tile_size])
        
        # If we have any space left at the bottom or right, add extra tiles
        if h_steps * (tile_size - overlap) + overlap < H:
            h_start = H - tile_size
            for j in range(w_steps):
                w_start = min(j * (tile_size - overlap), W - tile_size)
                tile = x[:, :, h_start:h_start+tile_size, w_start:w_start+tile_size]
                tiles.append(tile)
                positions.append([h_start, w_start, h_start+tile_size, w_start+tile_size])
        
        if w_steps * (tile_size - overlap) + overlap < W:
            w_start = W - tile_size
            for i in range(h_steps):
                h_start = min(i * (tile_size - overlap), H - tile_size)
                tile = x[:, :, h_start:h_start+tile_size, w_start:w_start+tile_size]
                tiles.append(tile)
                positions.append([h_start, w_start, h_start+tile_size, w_start+tile_size])
        
        # Also add the corner if needed
        if (h_steps * (tile_size - overlap) + overlap < H and 
            w_steps * (tile_size - overlap) + overlap < W):
            h_start = H - tile_size
            w_start = W - tile_size
            tile = x[:, :, h_start:h_start+tile_size, w_start:w_start+tile_size]
            tiles.append(tile)
            positions.append([h_start, w_start, h_start+tile_size, w_start+tile_size])
        
        # Stack tiles into a batch
        tiles = torch.cat(tiles, dim=0)
        positions_tensor = torch.tensor(positions, device=x.device)
        
        return tiles, positions_tensor

In [None]:
import torch
from torch import nn

img = torch.randn(1, 1, 1152, 1632)
H, W = img.shape[2:]
tile_size = 256
overlap_percentage = 0.3
overlap = int(tile_size * overlap_percentage)
print(f'overlap: {overlap}')
print(f'tile_size: {tile_size}')
print(f'H: {H}')
print(f'W: {W}')


overlap: 76
tile_size: 256
H: 1152
W: 1632


In [None]:
#| export
class CLAHELayer(nn.Module):
    """CLAHE implementation using PyTorch operations (differentiable approximation)"""
    def __init__(self, clip_limit=4.0, grid_size=(8, 8)):
        super().__init__()
        self.clip_limit = clip_limit
        self.grid_size = grid_size
    
    def forward(self, x):
        # This is a simplified differentiable approximation of CLAHE
        # Actual implementation would be more complex but using standard ops
        
        # Local contrast normalization (approximation of CLAHE)
        B, C, H, W = x.shape
        kh, kw = H//self.grid_size[0], W//self.grid_size[1]
        unfold = nn.Unfold(kernel_size=(kh, kw), stride=(kh, kw))
        fold = nn.Fold(output_size=(H, W), kernel_size=(kh, kw), stride=(kh, kw))
        
        # Use local normalization as an approximation
        unfold = nn.Unfold(
            kernel_size=(H // self.grid_size[0], W // self.grid_size[1]), 
            stride=(H // self.grid_size[0], W // self.grid_size[1])
        )
        
        # Reshape and normalize locally
        patches = unfold(x)
        patches = patches.view(B, C, kh*kw, -1)
        mean = patches.mean(dim=2, keepdim=True)
        std = patches.std(dim=2, keepdim=True, unbiased=False) + 1e-6

        global_std = std.mean()
        std = torch.clamp(std, max=self.clip_limit * global_std)

        normed = (patches - mean) / std
        normed = normed.view(B, C*kh*kw, -1)
        out = fold(normed)

        ones = torch.ones_like(x)
        norm_map = fold(unfold(ones))
        out = out / (norm_map + 1e-6)
        
        return torch.sigmoid(out)

In [None]:
#| export
import torch
import torch.nn as nn
import torch.nn.functional as F

def get_gaussian_kernel(kernel_size=5, sigma=1.0, channels=1):
    """Returns a 2D Gaussian kernel for depthwise convolution."""
    # Create 1D kernel
    x = torch.arange(kernel_size) - kernel_size // 2
    gauss = torch.exp(-x**2 / (2 * sigma**2))
    gauss = gauss / gauss.sum()
    # Outer product to get 2D kernel
    kernel2d = gauss[:, None] * gauss[None, :]
    kernel2d = kernel2d / kernel2d.sum()
    # Expand to (channels, 1, k, k) for depthwise conv
    kernel = kernel2d.expand(channels, 1, kernel_size, kernel_size).contiguous()
    return kernel

class GaussianDenoiseLayer(nn.Module):
    """
    ONNX-compatible Gaussian denoising layer.
    Args:
        kernel_size: Size of the Gaussian kernel (odd integer).
        sigma: Standard deviation of the Gaussian.
        channels: Number of input channels.
    """
    def __init__(self, kernel_size=5, sigma=1.0, channels=1):
        super().__init__()
        kernel = get_gaussian_kernel(kernel_size, sigma, channels)
        self.register_buffer('weight', kernel)
        self.groups = channels
        self.padding = kernel_size // 2

    def forward(self, x):
        # x: (B, C, H, W)
        return F.conv2d(x, self.weight, padding=self.padding, groups=self.groups)

In [None]:
#| export
class BiasFieldCorrectionLayer(nn.Module):
    """
    ONNX-compatible bias field correction layer.
    Learns a smooth multiplicative mask to correct intensity inhomogeneity.
    Args:
        in_channels: Number of input channels.
        reduction: Downsampling factor for the bias field (higher = smoother).
    """
    def __init__(self, in_channels=1, reduction=16):
        super().__init__()
        self.in_channels = in_channels
        self.reduction = reduction
        # Learnable low-res bias field
        self.bias_field = nn.Parameter(torch.ones(1, in_channels, reduction, reduction))

    def forward(self, x):
        B, C, H, W = x.shape
        # Upsample bias field to input size
        bias = F.interpolate(self.bias_field, size=(H, W), mode='bilinear', align_corners=False)
        # Multiplicative correction, clamp to avoid division by zero
        corrected = x / (bias + 1e-6)
        return corrected

In [None]:
# Image processing pipeline
# 1. Denoising - Removes random noise from images
#    Example: Noisy image [⚫⚪⚫⚪] → Denoised [⚫⚫⚫⚫]
#
# 2. Bias field correction - Fixes uneven illumination across the image
#    Example: Darker on left [⚫⚫⚪⚪] → Balanced [⚫⚫⚫⚫]
# 3. Normalization - Scales pixel values to a range of [0, 1]
#    Example: Pixel values [0, 255] → Normalized [0, 1]
# 4. CLAHE (Contrast Limited Adaptive Histogram Equalization) - Enhances local contrast
#    Example: Low contrast [⚫⚫⚪⚪] → Better contrast [⚫⚪⚫⚪]
# 5. Edge enhancement - Highlights boundaries between structures
#    Example: Blurry edge [⚫⚫⚪⚪] → Sharp edge [⚫⚪⚪⚪]
# 6. Tile - Divides large image into smaller overlapping patches
#    Example: Large image [⚫⚫⚫⚫|⚪⚪⚪⚪] → Tiles [⚫⚫⚫⚫], [⚪⚪⚪⚪]
# 7. Segmentation - Identifies and labels regions of interest
#    Example: Raw image [⚫⚫⚪⚪] → Segmented [1️⃣1️⃣2️⃣2️⃣]
# 8. Merge tiles - Combines segmented tiles back into full image
#    Example: Tiles [1️⃣1️⃣], [2️⃣2️⃣] → Merged [1️⃣1️⃣2️⃣2️⃣]
#
# 9. Post-processing - Refines segmentation with morphological operations
#    Example: Rough mask [1️⃣⚪1️⃣1️⃣] → Cleaned [1️⃣1️⃣1️⃣1️⃣]

In [None]:
#| export
class MorphologyLayer(nn.Module):
    """Differentiable approximation of morphological operations"""
    def __init__(self, kernel_size=5):
        super().__init__()
        self.kernel_size = kernel_size
        
        # Create dilation and erosion kernels
        self.dilation_kernel = nn.Parameter(
            torch.ones(1, 1, kernel_size, kernel_size), 
            requires_grad=False
        )
        
        self.erosion_kernel = nn.Parameter(
            torch.ones(1, 1, kernel_size, kernel_size), 
            requires_grad=False
        )
    
    def forward(self, x):
        # Closing operation: dilation followed by erosion
        # Dilation (max pooling is an approximation of dilation)
        dilated = F.max_pool2d(
            x, 
            kernel_size=self.kernel_size, 
            stride=1, 
            padding=self.kernel_size//2
        )
        
        # Erosion (negative of max pooling on negative image is an approximation)
        eroded = -F.max_pool2d(
            -dilated, 
            kernel_size=self.kernel_size, 
            stride=1, 
            padding=self.kernel_size//2
        )
        
        return eroded


In [None]:
class ETPinSegmentationPipeline(nn.Module):
    def __init__(self, in_channels=1, tile_size=256, overlap=32):
        super().__init__()
        self.in_channels = in_channels
        
        # 1. Preprocessing
        self.normalize = nn.BatchNorm2d(in_channels, affine=False)
        self.clahe = CLAHELayer(clip_limit=4.0, grid_size=(8, 8))
        
        # Edge enhancement through a learnable filter
        self.edge_enhance = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, bias=False)
        self.edge_enhance.weight.data = torch.FloatTensor([
            [[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]]
        ]).unsqueeze(0).unsqueeze(0) * 0.1
        
        # 2. Tiling mechanism
        self.tile_processor = TileProcessor(tile_size=tile_size, overlap=overlap)
        
        # 3. Segmentation model
        self.encoder = EfficientNetEncoder(in_channels=in_channels)
        self.decoder = UNetDecoder()
        
        # 4. Post-processing
        self.tile_merger = TileMerger()
        self.boundary_refinement = nn.Conv2d(1, 16, kernel_size=3, padding=1)
        self.refine_activation = nn.ReLU(inplace=True)
        self.final_conv = nn.Conv2d(16, 1, kernel_size=3, padding=1)
        self.morphology = MorphologyLayer(kernel_size=5)
        self.final_activation = nn.Sigmoid()

In [None]:
 @patch
 def forward(self:ETPinSegmentationPipeline, x):
        B, C, H, W = x.shape
        
        # 1. Preprocessing
        x = self.normalize(x)
        x = self.clahe(x)
        edge = self.edge_enhance(x)
        x = x + edge  # Residual edge enhancement
        
        # 2. Generate tiles
        tiles, positions = self.tile_processor(x)
        
        # 3. Process each tile (need to handle batch processing)
        results = []
        # Process in smaller batches to avoid memory issues
        batch_size = 4  # Adjust based on GPU memory
        for i in range(0, tiles.shape[0], batch_size):
            batch_tiles = tiles[i:i+batch_size]
            
            # Extract features
            features = self.encoder(batch_tiles)
            
            # Decode features
            tile_masks = self.decoder(features)
            
            results.append(tile_masks)
        
        # Combine results
        all_results = torch.cat(results, dim=0)
        
        # 4. Merge tiles back to full image
        merged_mask = self.tile_merger(all_results, positions, H, W)
        
        # 5. Post-processing
        refined = self.boundary_refinement(merged_mask)
        refined = self.refine_activation(refined)
        final_mask = self.final_conv(refined)
        final_mask = self.morphology(final_mask)  # Apply morphological operations
        final_mask = self.final_activation(final_mask)
        
        return final_mask


In [None]:
in_channels=1
normalize = nn.BatchNorm2d(in_channels, affine=False)
print(f' img has a min value = {img.max()} and a max value = {img.min()}')
norm_img = normalize(img)
print(f' After normalizing img has a max of {norm_img.max()} and min value {norm_img.min()}')
print(norm_img.shape)
clahe_layer = CLAHELayer(clip_limit=4.0, grid_size=(8, 8))
morphology_layer = MorphologyLayer(kernel_size=5)

 img has a min value = 4.740786075592041 and a max value = -5.191678047180176
 After normalizing img has a max of 4.738675117492676 and min value -5.189996719360352
torch.Size([1, 1, 1152, 1632])


In [None]:
import torch
import torch.nn.functional as F

def sobel_filter():
    # 3x3 Sobel kernels for x and y
    sobel_x = torch.tensor([[1,0,-1],[2,0,-2],[1,0,-1]], dtype=torch.float32).view(1,1,3,3)
    sobel_y = torch.tensor([[1,2,1],[0,0,0],[-1,-2,-1]], dtype=torch.float32).view(1,1,3,3)
    return sobel_x, sobel_y

def edge_loss(pred, target):
    # pred, target: (B, 1, H, W), values in [0,1]
    sobel_x, sobel_y = sobel_filter()
    sobel_x, sobel_y = sobel_x.to(pred.device), sobel_y.to(pred.device)
    pred_x = F.conv2d(pred, sobel_x, padding=1)
    pred_y = F.conv2d(pred, sobel_y, padding=1)
    target_x = F.conv2d(target, sobel_x, padding=1)
    target_y = F.conv2d(target, sobel_y, padding=1)
    pred_grad = torch.sqrt(pred_x**2 + pred_y**2 + 1e-6)
    target_grad = torch.sqrt(target_x**2 + target_y**2 + 1e-6)
    return F.l1_loss(pred_grad, target_grad)

In [None]:
def total_variation_loss(pred):
    # pred: (B, 1, H, W)
    tv_h = torch.abs(pred[:, :, 1:, :] - pred[:, :, :-1, :]).mean()
    tv_w = torch.abs(pred[:, :, :, 1:] - pred[:, :, :, :-1]).mean()
    return tv_h + tv_w

In [None]:
def combined_loss(pred, target, alpha=1.0, beta=0.1, gamma=0.05):
    # pred: (B, 1, H, W), logits or probabilities
    # target: (B, 1, H, W), binary mask
    seg_loss = F.binary_cross_entropy_with_logits(pred, target)
    pred_prob = torch.sigmoid(pred)
    edge = edge_loss(pred_prob, target)
    tv = total_variation_loss(pred_prob)
	# add focal loss + dice loss or only focal loss
	#fl = FocalLoss()
    return alpha * seg_loss + beta * edge + gamma * tv

In [None]:

import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import cv2
from torch.utils.data import Dataset, DataLoader
from efficientnet_pytorch import EfficientNet
import torchvision.transforms as transforms
import albumentations as A
from albumentations.pytorch import ToTensorV2
import matplotlib.pyplot as plt
from tqdm import tqdm
import onnx
import onnxruntime as ort

# =============================================
# 1. MODEL ARCHITECTURE
# =============================================

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        return self.conv(x)

class AttentionGate(nn.Module):
    def __init__(self, F_g, F_l, F_int):
        super().__init__()
        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )
        
        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )
        
        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )
        
        self.relu = nn.ReLU(inplace=True)
        
    def forward(self, g, x):
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = self.relu(g1 + x1)
        psi = self.psi(psi)
        
        return x * psi

class EfficientNetEncoder(nn.Module):
    def __init__(self, in_channels=1):
        super().__init__()
        # Load pre-trained EfficientNet-B0
        self.efficient_net = EfficientNet.from_pretrained('efficientnet-b0')
        
        # Modify first layer to accept grayscale input
        self.efficient_net._conv_stem = nn.Conv2d(
            in_channels, 32, kernel_size=3, stride=2, bias=False
        )
        
        # Define feature extraction points
        self.feature_indices = [3, 5, 9, 16]  # These correspond to specific layers in EfficientNet
    
    def forward(self, x):
        features = []
        for idx, layer in enumerate(self.efficient_net._blocks):
            x = layer(x)
            if idx in self.feature_indices:
                features.append(x)
        
        # Add the final features
        x = self.efficient_net._conv_head(x)
        x = self.efficient_net._bn1(x)
        x = self.efficient_net._swish(x)
        features.append(x)
        
        return features

class UNetDecoder(nn.Module):
    def __init__(self, channels=[320, 112, 40, 24]):
        super().__init__()
        self.channels = channels
        
        # Upsampling blocks
        self.up_conv1 = nn.ConvTranspose2d(channels[0], channels[0]//2, kernel_size=2, stride=2)
        self.up_conv2 = nn.ConvTranspose2d(channels[1], channels[1]//2, kernel_size=2, stride=2)
        self.up_conv3 = nn.ConvTranspose2d(channels[2], channels[2]//2, kernel_size=2, stride=2)
        self.up_conv4 = nn.ConvTranspose2d(channels[3], channels[3]//2, kernel_size=2, stride=2)
        
        # Attention gates
        self.attention1 = AttentionGate(F_g=channels[0]//2, F_l=channels[1], F_int=channels[1]//2)
        self.attention2 = AttentionGate(F_g=channels[1]//2, F_l=channels[2], F_int=channels[2]//2)
        self.attention3 = AttentionGate(F_g=channels[2]//2, F_l=channels[3], F_int=channels[3]//2)
        
        # Convolutional blocks after concatenation
        self.conv1 = ConvBlock(channels[0]//2 + channels[1], channels[1])
        self.conv2 = ConvBlock(channels[1]//2 + channels[2], channels[2])
        self.conv3 = ConvBlock(channels[2]//2 + channels[3], channels[3])
        self.conv4 = ConvBlock(channels[3]//2, channels[3]//2)
        
        # Final convolution
        self.final_conv = nn.Conv2d(channels[3]//2, 1, kernel_size=1)
    
    def forward(self, features):
        # Decoder with skip connections and attention gates
        x = self.up_conv1(features[4])
        skip1_att = self.attention1(x, features[3])
        x = torch.cat([x, skip1_att], dim=1)
        x = self.conv1(x)
        
        x = self.up_conv2(x)
        skip2_att = self.attention2(x, features[2])
        x = torch.cat([x, skip2_att], dim=1)
        x = self.conv2(x)
        
        x = self.up_conv3(x)
        skip3_att = self.attention3(x, features[1])
        x = torch.cat([x, skip3_att], dim=1)
        x = self.conv3(x)
        
        x = self.up_conv4(x)
        x = self.conv4(x)
        
        # Final 1x1 convolution to get segmentation map
        x = self.final_conv(x)
        
        return x

class TileProcessor(nn.Module):
    def __init__(self, tile_size=256, overlap=32):
        super().__init__()
        self.tile_size = tile_size
        self.overlap = overlap
    
    def forward(self, x):
        """Split image into overlapping tiles and return tiles with their positions"""
        B, C, H, W = x.shape
        tile_size = self.tile_size
        overlap = self.overlap
        
        tiles = []
        positions = []
        
        # Calculate steps with overlap
        h_steps = max(1, (H - overlap) // (tile_size - overlap))
        w_steps = max(1, (W - overlap) // (tile_size - overlap))
        
        for i in range(h_steps):
            h_start = min(i * (tile_size - overlap), H - tile_size)
            for j in range(w_steps):
                w_start = min(j * (tile_size - overlap), W - tile_size)
                
                # Extract tile
                tile = x[:, :, h_start:h_start+tile_size, w_start:w_start+tile_size]
                tiles.append(tile)
                
                # Store position
                positions.append([h_start, w_start, h_start+tile_size, w_start+tile_size])
        
        # If we have any space left at the bottom or right, add extra tiles
        if h_steps * (tile_size - overlap) + overlap < H:
            h_start = H - tile_size
            for j in range(w_steps):
                w_start = min(j * (tile_size - overlap), W - tile_size)
                tile = x[:, :, h_start:h_start+tile_size, w_start:w_start+tile_size]
                tiles.append(tile)
                positions.append([h_start, w_start, h_start+tile_size, w_start+tile_size])
        
        if w_steps * (tile_size - overlap) + overlap < W:
            w_start = W - tile_size
            for i in range(h_steps):
                h_start = min(i * (tile_size - overlap), H - tile_size)
                tile = x[:, :, h_start:h_start+tile_size, w_start:w_start+tile_size]
                tiles.append(tile)
                positions.append([h_start, w_start, h_start+tile_size, w_start+tile_size])
        
        # Also add the corner if needed
        if (h_steps * (tile_size - overlap) + overlap < H and 
            w_steps * (tile_size - overlap) + overlap < W):
            h_start = H - tile_size
            w_start = W - tile_size
            tile = x[:, :, h_start:h_start+tile_size, w_start:w_start+tile_size]
            tiles.append(tile)
            positions.append([h_start, w_start, h_start+tile_size, w_start+tile_size])
        
        # Stack tiles into a batch
        tiles = torch.cat(tiles, dim=0)
        positions_tensor = torch.tensor(positions, device=x.device)
        
        return tiles, positions_tensor

class TileMerger(nn.Module):
    def __init__(self):
        super().__init__()
    
    def forward(self, tiles, positions, height, width):
        """Merge processed tiles back into a full image"""
        device = tiles.device
        B, C, tile_h, tile_w = tiles.shape
        
        # Create empty output and weight matrices
        output = torch.zeros((1, C, height, width), device=device)
        weights = torch.zeros((1, C, height, width), device=device)
        
        # Create a weight map for blending overlapping regions
        # Use cosine weighting for smooth transitions
        h_weight = torch.cosine_similarity(
            torch.arange(tile_h, device=device).float().unsqueeze(0),
            torch.tensor([tile_h//2], device=device).float().unsqueeze(0), 
            dim=0
        ).unsqueeze(0).unsqueeze(0)
        
        w_weight = torch.cosine_similarity(
            torch.arange(tile_w, device=device).float().unsqueeze(0),
            torch.tensor([tile_w//2], device=device).float().unsqueeze(0), 
            dim=0
        ).unsqueeze(0).unsqueeze(0)
        
        weight_map = h_weight.T @ w_weight
        weight_map = weight_map.expand(1, C, tile_h, tile_w)
        
        # Place tiles back into the full image with weighted blending
        for i in range(B):
            h_start, w_start, h_end, w_end = positions[i]
            output[:, :, h_start:h_end, w_start:w_end] += tiles[i:i+1] * weight_map
            weights[:, :, h_start:h_end, w_start:w_end] += weight_map
        
        # Normalize by the weights to get the final output
        output = output / (weights + 1e-8)
        
        return output

class CLAHELayer(nn.Module):
    """CLAHE implementation using PyTorch operations (differentiable approximation)"""
    def __init__(self, clip_limit=4.0, grid_size=(8, 8)):
        super().__init__()
        self.clip_limit = clip_limit
        self.grid_size = grid_size
    
    def forward(self, x):
        # This is a simplified differentiable approximation of CLAHE
        # Actual implementation would be more complex but using standard ops
        
        # Local contrast normalization (approximation of CLAHE)
        B, C, H, W = x.shape
        
        # Use local normalization as an approximation
        unfold = nn.Unfold(
            kernel_size=(H // self.grid_size[0], W // self.grid_size[1]), 
            stride=(H // self.grid_size[0], W // self.grid_size[1])
        )
        
        # Reshape and normalize locally
        patches = unfold(x)
        patches = patches.reshape(B, C, -1, self.grid_size[0] * self.grid_size[1])
        
        # Local mean and std
        mean = torch.mean(patches, dim=3, keepdim=True)
        std = torch.std(patches, dim=3, keepdim=True) + 1e-6
        
        # Clip contrast (approximation of clip limit)
        std = torch.clamp(std, max=self.clip_limit * torch.mean(std))
        
        # Normalize patches
        normalized = (patches - mean) / std
        
        # Fold back (this is not exact but approximates CLAHE effect)
        normalized = normalized.reshape(B, C, -1)
        fold = nn.Fold(
            output_size=(H, W),
            kernel_size=(H // self.grid_size[0], W // self.grid_size[1]),
            stride=(H // self.grid_size[0], W // self.grid_size[1])
        )
        
        output = fold(normalized)
        
        # Scale to proper range and add back to original for residual effect
        return torch.sigmoid(output + x)

ModuleNotFoundError: No module named 'efficientnet_pytorch'

In [None]:
#| hide
import nbdev; nbdev.nbdev_export('21_patching.first_patching.ipynb')