# Image Analysis with the Discrete Fourier Transform

In [26]:
import numpy as np
import matplotlib.pyplot as plt
from typing import Tuple
import os
from PIL import Image

### Helper Functions

In [27]:
def save_as_png(array, output_path):
    if array.dtype != np.uint8:
        array = (array * 255).astype(np.uint8)
    img = Image.fromarray(array)
    img.save(output_path)


matrix = np.zeros((8, 8))
matrix[1,0] = 1
# save_as_png(matrix, 'img.png')

In [40]:
def pretty_print_matrix(matrix):
    N, M = matrix.shape
    for i in range(N):
        for j in range(M):
            val = matrix[i,j]
            if isinstance(val, complex):
                print(f"{val.real:.4f}+{val.imag:.4f}j", end=" ")
            else:
                print(f"{val:.4f}", end=" ")
        print()

def display(a: np.ndarray, title="") -> None:
    plt.title(title)
    plt.imshow(a, cmap='gray')
    plt.axis('off')
    plt.show()

def display_in_grid(matrix_of_images_2d, titles=None):
    rows = len(matrix_of_images_2d)
    cols = len(matrix_of_images_2d[0])
    
    fig, axes = plt.subplots(rows, cols, figsize=(cols*3, rows*3))
    
    for i in range(rows):
        for j in range(cols):
            if rows == 1:
                ax = axes[j]
            elif cols == 1:
                ax = axes[i]
            else:
                ax = axes[i,j]
            
            ax.imshow(matrix_of_images_2d[i][j], cmap='gray')

            ax.axis('off')
            if titles:
                ax.set_title(titles[i][j])
    
    plt.tight_layout()
    plt.show()

## Discretized Sine and Cosine in 2D

In 2D the sine and cosine waves can be represented by the functions:

$s_{u,v}[n,m] = A \sin(2\pi(\frac{un}{N} + \frac{vm}{M}))$ 

$c_{u,v}[n,m] = A \cos(2\pi(\frac{un}{N} + \frac{vm}{M}))$
 
- $u$, $v$ modulate the frequency in the rows and columns
- $A$ is the amplitude of the wave
- $N$, $M$ represent the number of rows, columns in the image

In [None]:
def sin_2d(N,M, u=1, v=1):
    n = np.arange(N)[:, np.newaxis]  
    m = np.arange(M)[np.newaxis, :] 
    s = np.sin(2 * np.pi * (u*n/N + v*m/M))
    return s

def cos_2d(N,M, u=1, v=1):
    n = np.arange(N)[:, np.newaxis]  
    m = np.arange(M)[np.newaxis, :] 
    c = np.cos(2*np.pi*(u*n/N + v*m/M))
    return c

display_in_grid([[sin_2d(1000,1000), cos_2d(1000,1000)]], [["Sin (1000x1000)", "Cos (1000x1000)"]])

Here we visualize the sine wave in 3D, with the Z axis representing the amplitude (pixel intensity). The blue and red lines show cross-sections of the wave at x=0 and y=0 respectively, demonstrating how 2D sine waves combine to create the full surface.

In [None]:
N = 50
x = np.linspace(0, 1, N)
y = np.linspace(0, 1, N)
X, Y = np.meshgrid(x, y)

Z = sin_2d(N, N)

fig = plt.figure(figsize=(6,6))
ax = fig.add_subplot(111, projection='3d')

surf = ax.plot_surface(X, Y, Z, cmap='gray')
fig.colorbar(surf)

x_line = np.linspace(0, 1, N)
y_line = np.linspace(0, 1, N)
xz_line = np.sin(2*np.pi*x_line)
yz_line = np.sin(2*np.pi*y_line)
ax.plot(x_line, np.zeros_like(x_line), xz_line, color='blue', linewidth=3)
ax.plot(np.zeros_like(y_line), y_line, yz_line, color='red', linewidth=3)

ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_zlabel('z')
ax.set_title('2D Sine Wave')

plt.tight_layout()
plt.show()

We can also vary the size of the image to see how visually the wave stretches/contracts

In [None]:
interval = range(25,100,25)
image_grid = [[sin_2d(x,y) for y in interval] for x in interval]
plot_titles = [[f"sin(N={x},M={y})" for y in interval] for x in interval]
display_in_grid(image_grid, plot_titles)

Next we can modulate the frequencies of u, v to see how the wave rotates. We do this with linear spacing and logarithmic spacing so you can observe how these values impact the function locally and towards infinity.

In [None]:
interval = np.linspace(-5, 5, num=5)
image_grid = [[sin_2d(100,100, u,v) for u in interval] for v in interval]
plot_titles = [[f"u,v = ({u:.3f},{v:.3f})" for u in interval] for v in interval]
display_in_grid(image_grid, plot_titles)

In [None]:
interval = np.logspace(-2, 3, base=10, num=10)
image_grid = [[sin_2d(100,100, u,v) for u in interval] for v in interval]
plot_titles = [[f"u,v = ({u:.3f},{v:.3f})" for u in interval] for v in interval]
display_in_grid(image_grid, plot_titles)

## Discrete Fourier Transform in 2D

Let’s begin with Euler’s identity, which connects trigonometric functions with complex exponentials:
$$ e^{j\theta} = \cos(\theta) + j\sin(\theta) $$

Given an image $l[n,m]$, of size N×M, we can define the basis function:
 
$$ 
e_{u,v}[n,m] = e^{2\pi j(\frac{un}{N}+\frac{vm}{M})}
$$
This represents a 2D complex sinusoid at spatial frequency $(u,v)$. Any discrete image can be decomposed as a linear combination of these. In this formula, the $2\pi$ scales each of our axis to the unit circle and the $(\frac{un}{N}+\frac{vm}{M})$ tells us how far into each dimension of the discrete image we are. 

We will not prove that the basis functions are orthogonal, but it is possible via the following:

$$\langle e_{u,v}, e_{u',v'} \rangle = 0 \quad \forall (u\neq u'),(v\neq v') $$


### Discrete Fourier Transform: 

$$L[u,v] = \sum_{n=0}^{N-1} \sum_{m=0}^{M-1} l[n,m]e^{-2\pi j (\frac{un}{N}+\frac{vm}{M})}$$

Each coefficient in $L[u,v]$ represents the contribution of a 2D sinusoidal pattern at spatial frequency $(u,v)$. Each term is computed by the inner product between the image and the basis function $e_{u,v}[n,m]$, which intuitively tells us how much of each particular frequency exists in the image.

The DFT decomposes the image into a weighted sum of these frequency components, separating structure (phase) from texture/detail (magnitude).

- $∣L[u,v]∣$ is the magnitude spectrum: how much of the frequency exists
- $\arg{(L[u,v])}$ is the phase spectrum: this is the spatial offset of the sinusoid 

### Inverse:

$$l[n,m] = \frac{1}{NM} \sum_{u=0}^{N-1} \sum_{v=0}^{M-1} L[u,v]e^{+2\pi j (\frac{un}{N}+\frac{vm}{M})}$$

The inverse DFT reconstructs the image as a weighted sum of complex sinusoids.
Each term contributes a wave pattern whose amplitude is $∣L[u,v]∣$ and shift is determined by $\arg{(L[u,v])}$.

### Computing the DFT

Below I will lay out the discrete fourier transform (and it's inverse) manually. Computing the DFT in this way is extremely computationally inefficient, and there are much faster ways to compute the discrete fourier transform. In future examples I will use numpy's fft2() functions to perform this operation in $O(N^2log(N))$ time complexity instead of $O(N^4)$. 

In [None]:
def compute_fourier_tfm_slow(image: np.ndarray) -> np.ndarray:
    N,M = image.shape
    output = np.zeros((N,M), dtype=complex)
    for u in range(N):
        for v in range(M):
            for n in range(N):
                for m in range(M):
                    output[u,v] += image[n,m] * np.exp(-2j*np.pi*((u*n/N)+(v*m/M)))
    return output

def compute_inverse_fourier_tfm_slow(fourier_image:np.ndarray) -> np.ndarray:
    N,M = fourier_image.shape
    output = np.zeros((N,M), dtype=complex)
    for n in range(N):
        for m in range(M):
            for u in range(N):
                for v in range(M):
                    output[n,m] += fourier_image[u,v]*np.exp(2j*np.pi*((u*n/N)+(v*m/M)))  
    output *= 1/(N*M)
    return output

# The FFT shift reorders the 0 frequency components to the center of the matrix. 
# This just improves our ability to interpret the fourier matrix when displayed as an image
# low frequency data is closer to the center, and high frequency data is on the edges
# In 2D this amounts to swapping the first and third quadrants, and the second and fourth quadrants
def fft_shift(mtx: np.ndarray) -> np.ndarray:
    N, M = mtx.shape

    # This is a more intuitive way to understand the quadrant swapping (equivalent to np.rolls below)
    # n_mid = (N+1)//2
    # m_mid = (M+1)//2
    # upper_left = mtx[0:n_mid, 0:m_mid]
    # bottom_left = mtx[n_mid:, 0:m_mid]
    # upper_right = mtx[0:n_mid, m_mid:]
    # bottom_right = mtx[n_mid:, m_mid:]
    # return np.concat((np.concat((bottom_right,bottom_left),axis=1), np.concat((upper_right, upper_left),axis=1)), axis=0)
    return np.roll(np.roll(mtx,N//2, axis=0), M//2, axis=1)

# Test case for the fft shift
matrix = np.random.randn(10,9)
assert(np.allclose(fft_shift(matrix), np.fft.fftshift(matrix))), "FT Shift Failed"

def inv_fft_shift(mtx: np.ndarray) -> np.ndarray:
    N, M = mtx.shape
    return np.roll(np.roll(mtx,-N//2, axis=0), -M//2, axis=1)

# Test case for the fft shift
matrix = np.random.randn(10,9)
assert(np.allclose(fft_shift(matrix), np.fft.fftshift(matrix))), "FT Shift Failed"


def compute_magnitude_phase(fourier_image: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
    shifted_fourier = fft_shift(fourier_image)
    magnitude = np.log1p(np.abs(shifted_fourier))
    magnitude = magnitude / magnitude.max()
    phase = np.angle(shifted_fourier)
    phase = (phase + np.pi) / (2*np.pi)
    return magnitude, phase


matrix = np.zeros((9,9))
center_x = matrix.shape[0] // 2
matrix[center_x-2:center_x+2,center_x] = 1

fourier_matrix = compute_fourier_tfm_slow(matrix)
mag,phase = compute_magnitude_phase(fourier_matrix)
inverted_matrix = compute_inverse_fourier_tfm_slow(fourier_matrix)

# Test case for the slow fourier computations
assert(np.allclose(fourier_matrix, np.fft.fft2(matrix))), "Computing the FT failed"
assert(np.allclose(matrix,inverted_matrix)), "FT inversion failed"

np_mag, np_phase = compute_magnitude_phase(np.fft.fft2(matrix))

display_in_grid(
    [[matrix, mag, phase]],
    [["original image", "Magnitude: |F|", "Phase: ∠F"]]
)

display_in_grid(
    [[mag, phase, np.real(inverted_matrix) ]],
    [["Magnitude: |F|", "Phase: ∠F", "Reconstructed Image"]]
)

pretty_print_matrix(fft_shift(fourier_matrix))


Below I will display the phase, magnitude of the fourier matrix (decomposed from the complex exponential) across a set of images in the test set. I would encourage you to take a closer look at the phase/magnitude matrices and try to build an intuition for how they relate to the input image.

Please note for images where the magnitude/phase values are all uniform the constructed images will be black due to how the images are normalized. Also, please take a look at fft_shift() to better understand how high/low frequency data is visualized in the fourier images.

In [None]:
outputs = []
titles = []
fft_images_dir = 'fft_images'

for image_file in sorted(os.listdir(fft_images_dir)):
    if image_file.endswith(('.png', '.jpg', '.jpeg')):
        img = Image.open(os.path.join(fft_images_dir, image_file)).convert('L')
        img_arr = np.array(img) / 255.0
        
        fft = np.fft.fft2(img_arr)
        fft_mag, fft_phase = compute_magnitude_phase(fft)
        fft_inv = np.real(np.fft.ifft2(fft))
        
        outputs.append([img_arr, fft_mag, fft_phase, fft_inv])
        titles.append([image_file, 'magnitude', 'phase', 'reconstructed'])


display_in_grid(outputs, titles)


## Uses for the DFT

In the next section we'll go over uses for the discrete fourier transform in image analysis. The primary ones I know about are compression and image registration/alignment.

### Compression: Removing high frequency terms

This next cell demonstrates how removing the high frequency components from the fourier matrix can be used to compress an image. As more and more of these terms are zeroed out, you will notice that the features in the image will blur, because the finer grained edges can't be represented by the lower frequency complex exponentials.

One interesting thing to note is that the sin/cos images are unaffected by this form of compression. This is because they can be represented by only the lowest frequency component of the fourier matrix. 

In [None]:
outputs = []
titles = []

for image_file in sorted(os.listdir(fft_images_dir)):
    if image_file.endswith(('.png', '.jpg', '.jpeg')): 
        img = Image.open(os.path.join(fft_images_dir, image_file)).convert('L')
        img_arr = np.array(img) / 255.0
        
        fft = np.fft.fft2(img_arr)
        
        row_titles = [image_file]
        row_imgs = [img_arr]
        N,M = fft.shape
        num_steps = 7
        for i in range(num_steps):
            reduced_fft = fft_shift(fft.copy())
            cutoff_n = max(1,int(N * (i+1) / num_steps / 2))
            cutoff_m = max(1,int(M * (i+1) / num_steps / 2))
            reduced_fft[:cutoff_n, :] = 0
            reduced_fft[-cutoff_n:, :] = 0  
            reduced_fft[:, -cutoff_m:] = 0  
            reduced_fft[:, :cutoff_m] = 0
            reduced_fft = inv_fft_shift(reduced_fft)
            row_titles.append(f"removed {(i+1)/num_steps:.0%}")
            row_imgs.append(np.real(np.fft.ifft2(reduced_fft)))

        outputs.append(row_imgs)
        titles.append(row_titles)

display_in_grid(outputs, titles)

### Compression: Zeroing out low magnitude factors

In [None]:
outputs = []
titles = []

for image_file in sorted(os.listdir(fft_images_dir)):
    if image_file.endswith(('.png', '.jpg', '.jpeg')): 
        img = Image.open(os.path.join(fft_images_dir, image_file)).convert('L')
        img_arr = np.array(img) / 255.0

        fft = np.fft.fft2(img_arr)
        flat_fft = fft.flatten()
        mag = np.abs(flat_fft)

        sorted_indices = np.argsort(mag)

        percentages = [0, 10, 25, 50, 75, 90,99,99.9,99.99]

        row_imgs = []
        row_titles = []

        for p in percentages:
            mask = np.ones_like(flat_fft, dtype=bool)
            num_zero = int(len(flat_fft) * (p / 100))
            mask[sorted_indices[:num_zero]] = False
            masked_fft = np.where(mask, flat_fft, 0)
            masked_img = np.fft.ifft2(masked_fft.reshape(fft.shape)).real

            row_titles.append(f"removed {p}%")
            row_imgs.append(masked_img)

        outputs.append(row_imgs)
        titles.append(row_titles)

display_in_grid(outputs, titles)


### Compression: Towards JPEG

While implementing the above compression method, I learned that JPEG compression is based on similar frequency-domain principles. At a high level, JPEG divides the image into 8×8 blocks, applies a frequency transform, and then quantizes the resulting coefficients to achieve compression.

JPEG specifically uses the Discrete Cosine Transform (DCT) rather than the DFT. The DCT uses only real-valued cosine basis functions, which makes it more efficient for compression and better suited to human visual perception.


#### Simplified Method
I will briefly describe the method I used to perform the fourier based compression below.

1. **Block-wise Decomposition**: Chunk the image into 8x8 sections
2. **Fourier Transform**: Apply the fourier transform to this chunk and perform an fft_shift() to center the high frequency domains
3. **Quantization**: Element-wise apply the quantization matrix (see in code)
    - low-frequency components (closer to the center of the block [4,4]) should be preserved more than high frequency components 
    - tunable scale parameter $Q[i,j] = Q[i,j]^{scale}$ controls the compression aggressiveness
4. **Sparsification**: Set all coefficients with magnitude < 1.0 to 0 to sparsify the fourier matrix
5. **Storage**: Convert the result to Compressed Sparse Row (CSR) format to reduce memory usage, exploiting the sparsity introduced by quantization.



In [None]:
def get_quantization_mtx(scale: int = 1):
    q_mtx = np.reciprocal(np.power(np.asarray(
        [
            [1.06, 1.06, 1.06, 1.06, 1.05, 1.06, 1.06, 1.06],
            [1.06, 1.06, 1.06, 1.05, 1.04, 1.05, 1.06, 1.06],
            [1.06, 1.06, 1.05, 1.04, 1.03, 1.04, 1.05, 1.06],
            [1.06, 1.05, 1.04, 1.03, 1.02, 1.03, 1.04, 1.05],
            [1.05, 1.04, 1.03, 1.02, 1.01, 1.02, 1.03, 1.04],
            [1.06, 1.05, 1.04, 1.03, 1.02, 1.03, 1.04, 1.05],
            [1.06, 1.05, 1.05, 1.04, 1.03, 1.04, 1.05, 1.06],
            [1.06, 1.06, 1.06, 1.05, 1.04, 1.05, 1.06, 1.06],
        ],),scale
    ))
    return q_mtx

def matrix_to_csr(matrix: np.ndarray):
    rows, cols = matrix.shape
    
    data = []  
    indices = []  # Column indices
    indptr = [0]  # Row pointers
    
    for i in range(rows):
        for j in range(cols):
            if matrix[i,j] != 0:
                data.append(complex(matrix[i,j]))  
                indices.append(j)
        indptr.append(len(data))
    
    data = np.array(data, dtype=complex)
    indices = np.array(indices)
    indptr = np.array(indptr)
    
    return (data, indices, indptr, (rows, cols))

def csr_to_matrix(data: np.ndarray, indices: np.ndarray, indptr: np.ndarray, shape: tuple):
    rows, cols = shape
    matrix = np.zeros((rows, cols), dtype=complex)
    
    for i in range(rows):
        row_start = indptr[i]
        row_end = indptr[i + 1]
        
        row_data = data[row_start:row_end]
        col_indices = indices[row_start:row_end]
        
        if len(row_data) == 0 or len(col_indices) == 0:
            continue
        
        matrix[i, col_indices] = row_data        

    return matrix

def simplified_jpeg_compression(img_arr: np.ndarray, scale: int = 1):
    N, M = img_arr.shape
    q_mtx = get_quantization_mtx(scale)

    compressed_img = []
    for i in range(0,N,8):
        for j in range(0,M,8):
            fft = np.fft.fft2(img_arr[i:min(i+8,N), j:min(j+8,M)])
            out = fft_shift(fft)
            out[np.abs(out*q_mtx[0:min(i+8,N)-i, 0:min(j+8,M)-j])<1] = 0
            out = inv_fft_shift(out)
            compressed_img.append((i,j,matrix_to_csr(out)))
    
    return compressed_img

def reconstruct_img_from_jpeg(compressed_img):
    max_i = max(i for i,_,_ in compressed_img) + 8
    max_j = max(j for _,j,_ in compressed_img) + 8
    
    reconstructed = np.zeros((max_i, max_j))
    
    for i, j, (data, indices, indptr, shape) in compressed_img:
        block = csr_to_matrix(data, indices, indptr, shape)
        block = np.fft.ifft2(block).real
        reconstructed[i:i+block.shape[0], j:j+block.shape[1]] = block
    
    return reconstructed

def calculate_compressed_img_size(compressed_img):
    compressed_img_bytes = 0
    for _, _, (data, indices, indptr, _) in compressed_img:
        compressed_img_bytes += data.nbytes + indices.nbytes + indptr.nbytes
    return compressed_img_bytes

img = Image.open(os.path.join(fft_images_dir, '10_bird.jpeg')).convert('L')
img_arr = np.array(img) / 255.0
original_bytes = img_arr.nbytes

outputs = [img_arr]
titles = [f"original {original_bytes / 1024:.2f}KB"]

for scale in np.linspace(0,300,5):
    compressed_img = simplified_jpeg_compression(img_arr,scale)
    compressed_img_bytes = calculate_compressed_img_size(compressed_img)

    original_kb = original_bytes / 1024
    compressed_kb = compressed_img_bytes / 1024
    compression_ratio = original_bytes / compressed_img_bytes
    
    outputs.append(reconstruct_img_from_jpeg(compressed_img))
    titles.append(f"{compressed_img_bytes / 1024:.2f} KB\nscale={int(scale)}\nCompression ratio: {compression_ratio:.2f}x")

display_in_grid([outputs],[titles])

This compression method has limitations - especially on smaller images. I intentionally skipped certain boundary cases for simplicity, so some outputs may contain black padding at the edges. Additionally, the quantization matrix I used is a rough heuristic and not tuned for general-purpose compression. As a result, the algorithm performs poorly on some simple images. (In practice, effective quantization matrices are often highly optimized—and sometimes even patented.)

In [None]:
image_files = [f for f in os.listdir(fft_images_dir) if f.endswith(('.jpeg', '.png'))]

all_outputs = []
all_titles = []

for image_file in sorted(image_files):
    img = Image.open(os.path.join(fft_images_dir, image_file)).convert('L')
    img_arr = np.array(img) / 255.0
    original_bytes = img_arr.nbytes

    outputs = [img_arr]
    titles = [f"{image_file}\noriginal {original_bytes / 1024:.2f}KB"]

    for scale in np.linspace(0,300,5):
        compressed_img = simplified_jpeg_compression(img_arr,scale)
        # Calculate and display memory usage of compressed_img
        compressed_img_bytes = calculate_compressed_img_size(compressed_img)
        original_kb = original_bytes / 1024
        compressed_kb = compressed_img_bytes / 1024
        compression_ratio = original_bytes / compressed_img_bytes

        outputs.append(reconstruct_img_from_jpeg(compressed_img))
        titles.append(f"{compressed_img_bytes / 1024:.2f} KB\nscale={int(scale)}\nCompression ratio: {compression_ratio:.2f}x")
    
    all_outputs.append(outputs)
    all_titles.append(titles)

display_in_grid(all_outputs, all_titles)

## Template Matching

Template matching is often computationally expensive because we need to compute the cross-correlation between a template $T$ and an image $I$ at every possible spatial offset - essentially sliding the template over the image and measuring the similarity at each location. This produces an output map $C$ where high values indicate strong matches.

Fortunately, due to the Convolution Theorem, we can compute this more efficiently in the frequency domain. Cross-correlation can be computed by applying the FFT to both $I$ and $T$, multiplying $F(I)$ with the complex conjugate of $F(T)$, and applying the inverse FFT to get back to the spatial domain. The result is a correlation map where high values correspond to locations where the template matches the image.

**Convolution Theorem**: $f(t)\circledast g(t) = F^{-1}(F(f(t))\cdot\overline{F(g(t))})$


**Cross Correlation**: $C = F^{-1}(F(I) \cdot \overline{F(T)})$


Note: for visualization purposes, I perform non maximal suppression on the outputs.

In [None]:
import time

def match_via_convolution(image: np.ndarray, template: np.ndarray) -> np.ndarray:
    N,M = image.shape
    Nt, Mt = template.shape
    template = template - np.mean(template)
    template = template / np.max(np.abs(template))
    response = np.zeros_like(image)
    for i in range(0, N-Nt+1):
        for j in range(0,M-Mt+1):
            patch = image[i:i+Nt, j:j+Mt]
            patch = patch - np.mean(patch)
            patch = patch / np.max(np.abs(patch))
            response[i,j] = np.sum(patch*template)
    
    response = response - np.mean(response)
    response = response / np.max(np.abs(response))
    return response

def match_via_fft(image: np.ndarray, template: np.ndarray) -> np.ndarray:
    N, M = image.shape
    Nt, Mt = template.shape
    assert (N>=Nt and M>=Mt), "template should be smaller than image"
    template = template - np.mean(template)
    template = template / np.max(np.abs(template))
    padded_template = np.zeros_like(image)
    padded_template[:Nt, :Mt] = template

    Ft = np.fft.fft2(padded_template)
    Fi = np.fft.fft2(image)
    corr = np.real(np.fft.ifft2(Fi * np.conj(Ft)))
    corr = corr - np.mean(corr)
    corr = corr / np.max(np.abs(corr))
    return corr

def compute_iou(loc1, loc2, Nt, Mt):
        x1, y1 = loc1
        x2, y2 = loc2
        x_left = max(x1, x2)
        y_top = max(y1, y2)
        x_right = min(x1 + Nt, x2 + Nt)
        y_bottom = min(y1 + Mt, y2 + Mt)
        if x_right <= x_left or y_bottom <= y_top:
            return 0.0
        intersection = (x_right - x_left) * (y_bottom - y_top)
        union = 2*Nt*Mt - intersection
        return intersection / union

def perform_non_maximal_suppression(match: np.ndarray, loc: np.ndarray, Nt: int, Mt: int, iou_threshold: float):
    n = loc.shape[0]
    if n <= 1:
        return list(loc)
    
    scores = match[loc[:, 0], loc[:, 1]]
    sorted_indices = np.argsort(-scores)
    loc = loc[sorted_indices]

    kept_locations = []
    for i in range(len(loc)):
        should_keep = True
        current_loc = (loc[i,0], loc[i,1])
        for kept_loc in kept_locations:
            iou = compute_iou(current_loc, kept_loc, Nt, Mt)
            if iou > iou_threshold:
                should_keep = False
                break
        if should_keep:
            kept_locations.append(current_loc)
    
    return np.asarray(kept_locations)

def threshold_and_display_boxes(image: np.ndarray, match: np.ndarray, Nt: int, Mt: int, threshold: float) -> np.ndarray:
    N, M = image.shape
    out = image.copy()
    locations = np.argwhere(match > threshold)
    locations = perform_non_maximal_suppression(match, locations, Nt, Mt, .3)
    
    def draw_rectangle(img: np.ndarray, x: int, y:int, Nt: int, Mt: int):
        img[x:min(x+Nt, N-1),y] = 0.0
        img[x:min(x+Nt, N-1),min(M-1,y+Mt)] = 0.0
        img[x,y:min(M-1,y+Mt)] = 0.0
        img[min(x+Nt, N-1),y:min(M-1,y+Mt)] = 0.0
        return img
    
    for x,y in locations:
        out = draw_rectangle(out, x,y,Nt,Mt)

    return out



img = Image.open(os.path.join('template_matching', 'cards.jpg')).convert('L')
template = Image.open(os.path.join('template_matching', 'template.jpg')).convert('L')
img_arr = np.array(img) / 255.0
template_arr = np.array(template) / 255.0
start_time = time.time()
conv_match = match_via_convolution(img_arr, template_arr)
conv_time = time.time() - start_time

start_time = time.time()
fft_match = match_via_fft(img_arr, template_arr)
fft_time = time.time() - start_time

display_in_grid(
    [[img_arr, template_arr],[ conv_match, fft_match]],
    [["img", "template"],[f"conv_match ({conv_time:.2f}s)", f"fft_match ({fft_time:.2f}s)"]]
)

In [None]:
outputs = []
titles = []
for matches, name in zip([conv_match, fft_match],["conv", "fft"]):
    row_outputs = []
    row_titles = []
    for threshold in [.99,.9,.7,.4,.2]:
        row_outputs.append(threshold_and_display_boxes(img_arr, matches, template_arr.shape[0], template_arr.shape[1], threshold))
        row_titles.append(f"{name} threshold={threshold}")
    outputs.append(row_outputs)
    titles.append(row_titles)

display_in_grid(outputs,titles)
