# How we create patches from tiles

In [1]:
import lovely_tensors as lt
import matplotlib.patches as mpl_patches
import matplotlib.pyplot as plt
import torch
from darts_segmentation.utils import patch_coords, predict_in_patches

In [2]:
# Example parameters
h, w = 30, 30
patch_size = 8
overlap = 3

# Create an example tile (already as torch tensor)
tensor_tiles = torch.rand((3, 1, h, w)) * 0.2

## Patching

In [None]:
# Visualize the patching
fig, ax = plt.subplots(1, 1, figsize=(20, 20))
ax.imshow(tensor_tiles[0, 0], vmin=0, vmax=1, cmap="gray")
colors = ["red", "orange", "grey", "brown", "yellow", "purple", "teal"]
for i, (y, x, patch_idx_y, patch_idx_x) in enumerate(patch_coords(h, w, patch_size, overlap)):
    c = colors[i % len(colors)]
    rect = mpl_patches.Rectangle(
        (x - 0.5, y - 0.5), width=patch_size, height=patch_size, linewidth=3, edgecolor=c, facecolor=c, alpha=0.5
    )
    ax.add_patch(rect)
    ax.text(x, y, f"{i}: {patch_idx_x}-{patch_idx_y} ({x}-{y})", bbox={"facecolor": "white"})

## Weights of overlap

In [None]:
# Example parameters
h, w = 8000, 8000
patch_size = 1024
overlap = 128

# Create an example tile (already as torch tensor)
tensor_tiles = torch.rand((3, 1, h, w)) * 0.2


def mock_model(x: torch.Tensor) -> torch.Tensor:  # noqa: D103
    return x * 3


res, weights = predict_in_patches(
    mock_model, tensor_tiles, patch_size, overlap, batch_size=1, device="cpu", return_weights=True
)
expected = torch.sigmoid(tensor_tiles * 3).squeeze(1)

diff = torch.abs(res - expected)

print(f"{'expected': <20}{lt.lovely(expected)}")
print(f"{'res': <20}{lt.lovely(res)}")
print(f"{'diff': <20}{lt.lovely(diff)}")
print(f"{'weights': <20}{lt.lovely(weights)}")


In [None]:
fig, axs = plt.subplots(1, 3, figsize=(20, 10))
axs[0].imshow(res[0], vmin=0, vmax=1, cmap="gray")
axs[0].set_title("Result")
axs[1].imshow(expected[0], vmin=0, vmax=1, cmap="gray")
axs[1].set_title("Input")
im = axs[2].imshow(diff[0], cmap="gray")
axs[2].set_title("Difference")
plt.colorbar(im)

In [None]:
plt.imshow(diff[0], cmap="viridis", vmin=0, vmax=1e-8)
plt.colorbar()

In [None]:
plt.imshow(diff[0], cmap="viridis", vmin=0, vmax=1e-8)
plt.colorbar()

In [None]:
# Create a soft margin for the patches
margin_ramp = torch.cat(
    [
        torch.linspace(0, 1, overlap),
        torch.ones(patch_size - 2 * overlap),
        torch.linspace(1, 0, overlap),
    ]
)
soft_margin = margin_ramp.reshape(1, 1, patch_size) * margin_ramp.reshape(1, patch_size, 1)
plt.imshow(soft_margin[0], cmap="gray")
plt.title("Soft margin")
plt.colorbar()

In [None]:
plt.imshow(weights[0], cmap="hot")
# add colorbar
plt.colorbar()