# `Compose`
----

## `Compose Multiple`

In [4]:
from typing import Sequence
import numpy as np
from numpy.typing import NDArray
from PIL import Image

F = TypeVar("F", float, np.floating)

def _compose_all(frame: Image.Image, images: Sequence[Image.Image], transformations: Sequence[NDArray[F]]) -> Image.Image:
    """
    Composites multiple images onto a base frame using affine transformations.

    The function iteratively applies 3x3 affine transformations to overlay images onto 
    the background frame.

    Parameters:
    -----------
    frame : Image.Image
        The base image onto which the other images will be composited.
    images : Sequence[Image.Image]
        A sequence of PIL images to be transformed and composited onto the frame.
    transformations : Sequence[NDArray[F]]
        A sequence of (3,3) NumPy arrays representing affine transformations for each image.
        Each transformation matrix should be invertible.

    Returns:
    --------
    Image.Image
        The final composed image with all transformed images applied.

    Notes:
    ------
    - The transformation matrices must be invertible.
    - Uses `Image.AFFINE` transformation mode with `Image.BICUBIC` interpolation.
    - The compositing order follows the sequence order.

    Complexity:
    -----------
    - Matrix inversion: O(1) per image.
    - Image transformation: O(W × H) where W and H are image dimensions.
    - Composition: O(N) for N images.
    """

    width, height = frame.size

    for image, transformation in zip(images, transformations):
        # Ensure transformation is a NumPy array
        transformation = np.asarray(transformation, dtype = np.float32)

        # Validate shape
        if transformation.shape != (3, 3):
            raise ValueError("Each transformation matrix must have shape (3, 3).")

        # Compute inverse transformation matrix (needed for PIL)
        try:
            t = np.linalg.inv(transformation)
        except np.linalg.LinAlgError:
            raise ValueError("Transformation matrix must be invertible.")

        # Extract affine parameters
        a, b, c = t[0, 0], t[0, 1], t[0, 2]
        d, e, f = t[1, 0], t[1, 1], t[1, 2]

        # Apply affine transformation to the image
        transformed_image = image.transform(
            (width, height), Image.AFFINE, (a, b, c, d, e, f), Image.BICUBIC
        )

        # Create binary mask based on non-zero pixels
        im_array = np.asarray(transformed_image)
        mask_array = 255 * (np.sum(im_array, axis=-1) > 0)  # Non-zero pixels as mask
        mask = Image.fromarray(mask_array.astype(np.uint8)).convert("1")

        # Composite the transformed image onto the frame
        frame = Image.composite(transformed_image, frame, mask)

    return frame


## `Compose [unstable]`

In [None]:
def compose(source: np.ndarray, target: np.ndarray, transform: np.ndarray) -> np.ndarray:
    
    # dimension grabbing
    h, w, _ = source.shape
    target_h, target_w, _ = target.shape

    result = np.copy(target)
    
    # invert transformation matrix for backward mapping
    transform_inv = np.linalg.inv(transform)

    for y in range(target_h):
        for x in range(target_w):

            # homogenize
            # aka add one dimension
            xy_target_homogenized = np.array([x, y, 1]) # (x, y) to (x, y, 1)

            # backward mapping
            xy_backward_map = transform_inv @ xy_target_homogenized.T # convert to column vector by using "T"
            xy_backward_map = np.array(xy_backward_map).flatten() # convert to 1d array

            # homogenous divide to get source coordinates
            x_source = xy_backward_map[0] / xy_backward_map[2]
            y_source  = xy_backward_map[1] / xy_backward_map[2]

            # just another classic boundary check
            if 0 <= x_source < w and 0 <= y_source < h: # we can siimply discard anything not in bounds

                # NOTE FOR FUTURE LOGAN
                ''' 
                    We multiplied the inverse of the transformation matrix with a homogenized x and y from the target image.
                    
                    That process is a "map" to find the exact x and y coords from a source.
                    
                        Now that we've found the position of WHERE the source image's x and y are...

                        ...we interpolate them to grab an estimated pixel color from that location.
                    
                    Once we have the pixel, we set the current x and y coordinate of the target to that pixel.
                '''
                result[y, x] = _interpolate(source, x_source, y_source)
    
    return result

# `Stitching`
----

## `1. SIFT`

### `Key Points`

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import cv2

left = cv2.imread("im1_left.jpg")
center = cv2.imread("im1_center.jpg")
right = cv2.imread("im1_right.jpg")

fig, axes = plt.subplots(1, 3, figsize=(10, 10))
axes[0].imshow(left[:,:,::-1])
axes[1].imshow(center[:,:,::-1])
axes[2].imshow(right[:,:,::-1])
plt.show()

In [None]:
def getKeypoints(image):

    gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    sift = cv2.SIFT_create()

    # NOTE: keypoints have an (x, y), scale, orientation, response (strength of feature), and octave (layer of the scale space where it was found)
    # NOTE: descriptors are 128 element vector describing the neighborhood of the corresponding keypoint
    
    keypoints, descriptors = sift.detectAndCompute(image = gray, mask = None)

    return keypoints, descriptors, gray



def draw_kp(images, titles):

    plt.figure(figsize = (10, 10))

    for i, (image, title) in enumerate(zip(images, titles)):
        kp, _, gray = getKeypoints(image)

        canvas = cv2.drawKeypoints(
            image = gray, 
            keypoints = kp, 
            outImage = None, 
            flags = cv2.DRAW_MATCHES_FLAGS_DRAW_RICH_KEYPOINTS
        )

        plt.subplot(1, len(images), i + 1) # 1 row, num of images as columns, current position
        plt.imshow(canvas, cmap = 'gray')
        plt.title(f'{title}')
    
    plt.show()




images = [left, center, right]
titles = ['left', 'center', 'right']

draw_kp(images, titles)

### `Matches`

In [None]:
def getMatches(query_des, train_des, ratio = .75):

    # init brute force matcher
    bf = cv2.BFMatcher(normType = cv2.NORM_L2, crossCheck = False)
    # find matches
    matches = bf.knnMatch(queryDescriptors = query_des, trainDescriptors = train_des, k = 2)
    return [ m for m, n in matches if m.distance < ratio * n.distance ]

def draw_matches(query_image, kpq, train_image, kpt, matches, title):

    result = cv2.drawMatches(
        img1 = query_image, 
        keypoints1 = kpq, 
        img2 = train_image, 
        keypoints2 = kpt, 
        matches1to2 = matches, 
        outImg = None,
        flags = cv2.DrawMatchesFlags_NOT_DRAW_SINGLE_POINTS
    )

    plt.figure(figsize = (5, 5))
    plt.imshow(result[...,:: -1])
    plt.title(title)
    plt.show()


# keypoints and descriptors
kp_l, des_l, _ = getKeypoints(left)
kp_c, des_c, _ = getKeypoints(center)
kp_r, des_r, _ = getKeypoints(right)

# Matches
matches_lc = getMatches(des_l, des_c) # left and center matches
matches_rc = getMatches(des_r, des_c) # center and right

# show left-center
draw_matches(left, kp_l, center, kp_c, matches_lc, 'Left-Center')
# show center-right
draw_matches(right, kp_r, center, kp_c, matches_rc, 'Right-Center')

### `Point Class`

In [None]:
class Point():
    def __init__(self,x,y):
        self.x = x
        self.y = y

### `Homography`

In [2]:
def _homography(s0, s1, s2, s3, t0, t1, t2, t3):

    x0s = s0.x
    y0s = s0.y
    x0t = t0.x
    y0t = t0.y

    x1s = s1.x
    y1s = s1.y
    x1t = t1.x
    y1t = t1.y

    x2s = s2.x
    y2s = s2.y
    x2t = t2.x
    y2t = t2.y

    x3s = s3.x
    y3s = s3.y
    x3t = t3.x
    y3t = t3.y

    # linear constraints betwixt source and target
    A = np.matrix([
            [x0s, y0s, 1, 0, 0, 0, -x0t * x0s, -x0t * y0s],
            [0, 0, 0, x0s, y0s, 1, -y0t * x0s, -y0t * y0s],
            [x1s, y1s, 1, 0, 0, 0, -x1t * x1s, -x1t * y1s],
            [0, 0, 0, x1s, y1s, 1, -y1t * x1s, -y1t * y1s],
            [x2s, y2s, 1, 0, 0, 0, -x2t * x2s, -x2t * y2s],
            [0, 0, 0, x2s, y2s, 1, -y2t * x2s, -y2t * y2s],
            [x3s, y3s, 1, 0, 0, 0, -x3t * x3s, -x3t * y3s],
            [0, 0, 0, x3s, y3s, 1, -y3t * x3s, -y3t * y3s]
        ])

    # column vector of target coords
    b = np.array([
            [x0t],
            [y0t],
            [x1t],
            [y1t],
            [x2t],
            [y2t],
            [x3t],
            [y3t]
        ])
    
    try:
        # A^-1 @ b
        solutions = np.linalg.solve(A, b)
        
    except np.linalg.LinAlgError:
        # the matrix isn't invertible
        A += np.eye(A.shape[0]) * 1e-10 # add SMALL value to matrix to make it invertible
        solutions = np.linalg.solve(A, b)

    # homogenize the matrix
    solutions = np.append(solutions, [[1.0]], axis = 0)

    # reshape homography to 3x3 matrix
    homography = np.reshape(solutions, (3,3))
    
    return homography

## `2. Ransac`

In [4]:
def _ransac(kp_q, kp_t, matches, max = 1000, threshold = 5.0):
    import random

    best_H = None
    max_inliers = 0
    best_inliers = []

    for _ in range(max):
        samples = random.sample(matches, 4)

        query_points = [Point(*kp_q[m.queryIdx].pt) for m in samples]
        train_points = [Point(*kp_t[m.trainIdx].pt) for m in samples]

        H = _homography(*query_points, *train_points)

        inliers = []

        for m in matches:

            pt1 = np.array([*kp_q[m.queryIdx].pt, 1.0]) # homogenized coords
            pt2 = np.array(kp_t[m.trainIdx].pt)

            projected_pt = H @ pt1
            projected_pt /= projected_pt[2] # normalize to grab (x, y)

            # euclidean distance as the reprojection error
            error = np.linalg.norm(pt2 - projected_pt[:2])

            # check if distance is w/in threshold
            if error < threshold:
                inliers.append(m)

            # update best homography
            if len(inliers) > max_inliers:
                max_inliers = len(inliers)
                best_H = H
                best_inliers = inliers
            
    return best_H, best_inliers

### `Interpolate`

In [8]:
def interpolate(image: np.ndarray, x: float, y: float) -> float | int:
    
    h, w, _ = image.shape

    # get neighboring points
    # clip to w/in bounds if necessary
    x0 = max(int(np.floor(x)), 0)
    y0 = max(int(np.floor(y)), 0)
    x1 = min(x0 + 1, w - 1) 
    y1 = min(y0 + 1, h - 1)

    # distances/weights between TL pixel and orginal pixel
    xw = x - x0
    yw = y - y0

    # neighboring pixel values
    p00 = image[y0, x0].astype(np.float32) # TL
    p01 = image[y0, x1].astype(np.float32) # TR
    p10 = image[y1, x0].astype(np.float32) # BL
    p11 = image[y1, x1].astype(np.float32) # BR

    return ( 
        p00 * (1 - xw) * (1 - yw) + 
        p01 * xw * (1 - yw) + 
        p10 * (1 - xw) * yw + 
        p11 * xw * yw
    )

### `Warp`

In [None]:
import numpy as np

def compose(source: np.ndarray, target: np.ndarray, transform: np.ndarray) -> np.ndarray:
    
    # dimension grabbing
    h, w, _ = source.shape
    target_h, target_w, _ = target.shape

    result = np.copy(target)
    
    # invert transformation matrix for backward mapping
    transform_inv = np.linalg.inv(transform)

    for y in range(target_h):
        for x in range(target_w):
            
            # homogenize
            # aka add one dimension
            xy_target_homogenized = np.array([x, y, 1]) # (x, y) to (x, y, 1)

            # backward mapping
            xy_backward_map = transform_inv @ xy_target_homogenized.T # convert to column vector by using "T"
            xy_backward_map = np.array(xy_backward_map).flatten() # convert to 1d array

            # homogenous divide to get source coordinates
            x_source = xy_backward_map[0] / xy_backward_map[2]
            y_source  = xy_backward_map[1] / xy_backward_map[2]

            # just another classic boundary check
            if 0 <= x_source < w and 0 <= y_source < h: # we can siimply discard anything not in bounds

                result[y, x] = _interpolate(source, x_source, y_source)

    return result



H_left, inliers = _ransac(kp_l, kp_c, matches_lc)
stitched_lc = compose(left, center, H_left)

H_right, inliers = _ransac(kp_r, kp_c, matches_rc)
stitched_rc = compose(right, center, H_right)

_, axes = plt.subplots(nrows = 1, ncols = 2, figsize=(10, 10))
axes[0].imshow(stitched_lc[...,:: -1]), axes[0].set_title('Left - Center')
axes[1].imshow(stitched_rc[...,:: -1]), axes[1].set_title('Center - Right')
plt.show()

## `3. Mosaic`

### `Canvas`

In [None]:
def crop(image):
    
    gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    nonzero = np.nonzero(gray)

    y_min, y_max = nonzero[0].min(), nonzero[0].max()
    x_min, x_max = nonzero[1].min(), nonzero[1].max()

    return image[y_min : y_max, x_min : x_max]


def blendImages(left, center, right, H_left, H_right):

    h, w, _ = center.shape
    
    canvas = np.zeros((3 * h, 3 * w, 3), dtype = np.uint8)
    y = (canvas.shape[0] - h) // 2
    x = (canvas.shape[1] - w) // 2

    T = np.array([
        [1, 0, w],
        [0, 1, h],
        [0, 0, 1]
    ])

    left_transform = T @ H_left
    right_transform = T @ H_right

    canvas = compose(left, np.copy(canvas), left_transform)
    canvas[ y : y + h, x : x + w] = center
    canvas = compose(right, np.copy(canvas), right_transform)

    return crop(canvas)


panorama = blendImages(left, center, right, H_left, H_right)

plt.imshow(panorama[..., ::-1])
plt.show()

## `4. Stitch`

In [None]:
def imageStitch(left, center, right):

    ''' keypoints '''
    kp_left, des_left, _ = getKeypoints(left)
    kp_center, des_center, _ = getKeypoints(center)
    kp_right, des_right, _ = getKeypoints(right)

    ''' matches '''
    left_center_matches = getMatches(des_left, des_center)
    right_center_matches = getMatches(des_right, des_center)

    ''' RANSAC '''
    H_left, _ = _ransac(kp_left, kp_center, left_center_matches)
    H_right, _ = _ransac(kp_right, kp_center, right_center_matches)

    ''' stitch '''
    panorama = blendImages(left, center, right, H_left, H_right)

    return panorama


left = cv2.imread("im1_left.jpg")
center = cv2.imread("im1_center.jpg")
right = cv2.imread("im1_right.jpg")

result = imageStitch(left,center,right)
plt.imshow(result[:,:,::-1]);plt.show()


left = cv2.imread("im2_left.jpg")
center = cv2.imread("im2_center.jpg")
right = cv2.imread("im2_right.jpg")

result = imageStitch(left,center,right)
plt.imshow(result[:,:,::-1]);plt.show()