In [None]:
# Read in an image
import numpy as np
from PIL import Image

# Youll need to put an image here. I've got a picture of apples
img = Image.open("../../assets/test_imgs/apples.png").convert("L")
img.load()
img = np.asarray(img)

In [None]:
# First lets try a left to right join, so we'll make two arrays representing our overlap
overlap_width = img.shape[0]
overlap_height = 10

img1 = img[:overlap_width, 10 : 10 + overlap_height]
img2 = img[:overlap_width, 20 : 20 + overlap_height]

cost = np.abs(img1 - img2) ** 2

In [None]:
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 3, figsize=(2, 6))
axes[0].imshow(img1, cmap="gray")
axes[1].imshow(img2, cmap="gray")
axes[2].imshow(cost, cmap="gray")

for axis, title in zip(axes, ["img 1", "img 2", "cost"]):
    axis.axis("off")
    axis.set_title(title)

In [None]:
# Find a path through the cost array
from current_denoising.generation.quilting import seam_nodes

import numpy as np
from scipy.ndimage import binary_fill_holes
from skimage.segmentation import flood


def create_seam_mask(
    seam_coords: list[tuple[int, int]], shape: tuple[int, int]
) -> np.ndarray:
    mask = np.zeros(shape, dtype=bool)
    for x, y in seam_coords:
        mask[y, x] = True
    return mask


def stitch_using_seam(
    image1: np.ndarray, image2: np.ndarray, seam_coords: list[tuple[int, int]]
) -> np.ndarray:
    h, w = image1.shape[:2]
    seam_mask = create_seam_mask(seam_coords, (h, w))

    # Mark seam pixels as a barrier
    barrier = seam_mask

    # Flood fill from top-left for image1
    mask1 = flood(~barrier, (0, 0))  # True for pixels belonging to image1

    # Flood fill from bottom-right for image2
    mask2 = flood(~barrier, (h - 1, w - 1))  # True for pixels belonging to image2

    # Safety: remove any seam pixels from both
    mask1[seam_mask] = False
    mask2[seam_mask] = False

    # If there's overlap, prefer image2 (or handle blending separately)
    final_mask = mask1

    # Stitch
    stitched = np.zeros_like(image1)
    stitched[final_mask] = image1[final_mask]
    stitched[~final_mask] = image2[~final_mask]

    return stitched


seam = seam_nodes(cost, "bottom", "top")
stitched = stitch_using_seam(img1, img2, seam)

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(3, 9))

axes[0].imshow(img1, cmap="gray")
axes[1].imshow(img2, cmap="gray")
axes[2].imshow(stitched, cmap="gray")

y, x = zip(*seam)

for axis in axes[:2]:
    axis.axis("off")
    axis.plot(x, y, "r")

In [None]:
# We might also want to stitch images vertically

seam = seam_nodes(cost, "left", "right")
stitched = stitch_images_with_seam(img1, img2, seam)

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(3, 9))

axes[0].imshow(img1, cmap="gray")
axes[1].imshow(img2, cmap="gray")
axes[2].imshow(stitched, cmap="gray")

y, x = zip(*seam)

for axis in axes[:2]:
    axis.axis("off")
    axis.plot(x, y, "r")